diff --git a/SpecForge-ext/benchmarks/benchmarker/__pycache__/livecodebench.cpython-311.pyc b/SpecForge-ext/benchmarks/benchmarker/__pycache__/livecodebench.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..502c23e2804cfaed07a72201d436f973f35117c1 Binary files /dev/null and b/SpecForge-ext/benchmarks/benchmarker/__pycache__/livecodebench.cpython-311.pyc differ diff --git a/SpecForge-ext/cache/compiled_kernels/34/c34af36gfqnn2ovywuaultc2pol4jyn6io3szgjeuv3uxfzcf3nv.py b/SpecForge-ext/cache/compiled_kernels/34/c34af36gfqnn2ovywuaultc2pol4jyn6io3szgjeuv3uxfzcf3nv.py new file mode 100644 index 0000000000000000000000000000000000000000..769579f1a8494ed5380da0d0ed60ea2ea176e3f5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/34/c34af36gfqnn2ovywuaultc2pol4jyn6io3szgjeuv3uxfzcf3nv.py @@ -0,0 +1,43 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/37/49508e3b35fb555ab64ad6f410ad33153cf779bf7b9d6de2ca009401cf12419e.best_config b/SpecForge-ext/cache/compiled_kernels/37/49508e3b35fb555ab64ad6f410ad33153cf779bf7b9d6de2ca009401cf12419e.best_config new file mode 100644 index 0000000000000000000000000000000000000000..758b6d2bc873793fbcdbae9195bd9dbe5fd4de6d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/37/49508e3b35fb555ab64ad6f410ad33153cf779bf7b9d6de2ca009401cf12419e.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 66, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/37/c37gymepdyiyzp5hh2xt3a5vqmje2frbmyiqgipqpazjx6xcuyyb.py b/SpecForge-ext/cache/compiled_kernels/37/c37gymepdyiyzp5hh2xt3a5vqmje2frbmyiqgipqpazjx6xcuyyb.py new file mode 100644 index 0000000000000000000000000000000000000000..17a4d557fd9f5d2a9da008fff1eebee77930c3c8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/37/c37gymepdyiyzp5hh2xt3a5vqmje2frbmyiqgipqpazjx6xcuyyb.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/4b/c4b4wkdm2d2z4hysjzfo6cyikw75man4bednwbsjwot4lkx7xfzs.py b/SpecForge-ext/cache/compiled_kernels/4b/c4b4wkdm2d2z4hysjzfo6cyikw75man4bednwbsjwot4lkx7xfzs.py new file mode 100644 index 0000000000000000000000000000000000000000..1b524a5fa2979d36e1ae9b29c52669df17b97f86 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4b/c4b4wkdm2d2z4hysjzfo6cyikw75man4bednwbsjwot4lkx7xfzs.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/4g/c4gcdzc7dkmej2ceuy3ivyfjm5wjukkm4mbbdcmc7uaq76svnppo.py b/SpecForge-ext/cache/compiled_kernels/4g/c4gcdzc7dkmej2ceuy3ivyfjm5wjukkm4mbbdcmc7uaq76svnppo.py new file mode 100644 index 0000000000000000000000000000000000000000..5a45ba27c77d20a0680745931dbb8ff9a536261e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4g/c4gcdzc7dkmej2ceuy3ivyfjm5wjukkm4mbbdcmc7uaq76svnppo.py @@ -0,0 +1,159 @@ +# AOT ID: ['1_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/e4/ce4fv76qvag73sibbo3mhwtavvyq3wneu5xe4faj6ybtsqisdlvr.py +# Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] +# Source node to ATen node mapping: +# target_head => convert_element_type +# target_p => div +# Graph fragment: +# %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:4" = PlaceHolder[target=arg0_1] +# %getitem : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:4" = PlaceHolder[target=getitem] +# %getitem_1 : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:4" = PlaceHolder[target=getitem_1] +# %convert_element_type : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {}) +# %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {}) +# %sub_tensor : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {}) +# %exp_default : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {}) +# %div : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {}) +# return %getitem,%getitem_1,%div +triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}} +) +@triton.jit +def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32) + _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + + _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine( + _tmp3_max, _tmp3_sum, tmp2, False + ) + + _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max) + _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum) + + tmp3, tmp4 = triton_helpers.online_softmax_reduce( + _tmp3_max, _tmp3_sum, 1, False) + tmp3 = tmp3[:, None] + tmp4 = tmp4[:, None] + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp6 = tmp5.to(tl.float32) + tmp7 = tmp6 - tmp3 + tmp8 = libdevice.exp(tmp7) + tmp9 = (tmp8 / tmp4) + tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf2 = empty_strided_cuda((8, 2048, 32000), (65536000, 32000, 1), torch.float32) + # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax] + stream4 = get_raw_stream(4) + triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg0_1, buf2, 16384, 32000, stream=stream4) + del arg0_1 + return (buf2, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:4', dtype=torch.bfloat16) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/4g/c4gr37y26wd4va4drshauwjr3p5l32j5cssih4o5yz3h2g6jkxrz.py b/SpecForge-ext/cache/compiled_kernels/4g/c4gr37y26wd4va4drshauwjr3p5l32j5cssih4o5yz3h2g6jkxrz.py new file mode 100644 index 0000000000000000000000000000000000000000..2639b4f7cec504f747c3bf7249e91c88fe1786c6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4g/c4gr37y26wd4va4drshauwjr3p5l32j5cssih4o5yz3h2g6jkxrz.py @@ -0,0 +1,89 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1024, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/4r/c4rogici325xsxgkeljczx3sx57vcsimyzupeu4nvgmivqoqosiz.py b/SpecForge-ext/cache/compiled_kernels/4r/c4rogici325xsxgkeljczx3sx57vcsimyzupeu4nvgmivqoqosiz.py new file mode 100644 index 0000000000000000000000000000000000000000..3d211f32e598f494ac9dbfdc1a8425b4a145e336 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4r/c4rogici325xsxgkeljczx3sx57vcsimyzupeu4nvgmivqoqosiz.py @@ -0,0 +1,1065 @@ +# AOT ID: ['9_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/jc/cjcezd4fm2g2fppy44lhtzc36sz7bi63sscwdmenwlvu3y4xt7np.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[8, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bh/cbhqle56n7we4b4miasvgh4jqrjbkehmv3legvjui32dka2bilvr.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:2" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:2" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:2" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:2" = PlaceHolder[target=getitem_5] +# %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[8, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:2" = PlaceHolder[target=primals_7] +# %primals_15 : Tensor "i32[8, 1, s56][s56, s56, 1]cuda:2" = PlaceHolder[target=primals_15] +# %primals_17 : Tensor "i32[8, 1, s84, 16][16*s84, 16*s84, 16, 1]cuda:2" = PlaceHolder[target=primals_17] +# %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[8, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:2" = PlaceHolder[target=primals_13] +# %primals_19 : Tensor "i32[8, 1, s99][s99, s99, 1]cuda:2" = PlaceHolder[target=primals_19] +# %primals_21 : Tensor "i32[8, 1, s6, 16][16*s6, 16*s6, 16, 1]cuda:2" = PlaceHolder[target=primals_21] +# %primals_10 : Tensor "i64[8][1]cuda:2" = PlaceHolder[target=primals_10] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1 = args + args.clear() + s0 = primals_8 + s72 = primals_6 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (8, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, ), (1, )) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (8, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (8, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (8, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (8, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (8, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + assert_size_stride(getitem, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (8, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (8, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream2 = get_raw_stream(2) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 524288, 128, stream=stream2) + del getitem + buf3 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream2 = get_raw_stream(2) + triton_tem_fused_zeros_1.run(primals_1, primals_3, primals_5, getitem_1, buf1, tangents_1, buf3, buf4, primals_9, primals_7, primals_15, primals_17, primals_11, primals_13, primals_19, primals_21, primals_10, buf5, s0, s72, s56, s84, 64 + ((127 + s0) // 128), 8, 8, stream=stream2) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_13 + del primals_15 + del primals_17 + del primals_19 + del primals_21 + del primals_3 + del primals_5 + del primals_7 + del primals_9 + del tangents_1 + return (buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 4096 + primals_6 = 32 + primals_12 = 32 + primals_14 = 32 + primals_16 = 32 + primals_18 = 32 + primals_20 = 32 + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:2', dtype=torch.bfloat16) + primals_5 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:2', dtype=torch.bfloat16) + primals_7 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:2', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_10 = rand_strided((8, ), (1, ), device='cuda:2', dtype=torch.int64) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_13 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:2', dtype=torch.int32) + primals_15 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:2', dtype=torch.int32) + primals_17 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:2', dtype=torch.int32) + primals_19 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:2', dtype=torch.int32) + primals_21 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:2', dtype=torch.int32) + getitem = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + getitem_1 = rand_strided((8, 32, 2048), (65536, 2048, 1), device='cuda:2', dtype=torch.float32) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:2', dtype=torch.bfloat16) + fn = lambda: call([primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/4w/669c5a8c8205272d44ea075b78e46cd1bf13f1ebe3d56d5ab422037277c923dc.best_config b/SpecForge-ext/cache/compiled_kernels/4w/669c5a8c8205272d44ea075b78e46cd1bf13f1ebe3d56d5ab422037277c923dc.best_config new file mode 100644 index 0000000000000000000000000000000000000000..42957eb64d71ed9fdefb40d3695d06d9e00d3409 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4w/669c5a8c8205272d44ea075b78e46cd1bf13f1ebe3d56d5ab422037277c923dc.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "8c03dc2e05d158372838fe4d32248dfba74b467c7576f6e1d3eb472c41b37c80", "found_by_coordesc": false, "time_taken_ms": 213, "triton_cache_hash": "VBVRCEQLKQI4X4GYXD4JC6UEYZT2F7LIKNA2UR4GNVIWAPM6GKFA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/4w/c4wdhwlu6yb3wcwazdnzmgzewiemvznxvrr3525eojupqjldo5pt.py b/SpecForge-ext/cache/compiled_kernels/4w/c4wdhwlu6yb3wcwazdnzmgzewiemvznxvrr3525eojupqjldo5pt.py new file mode 100644 index 0000000000000000000000000000000000000000..d454a9ba1e82a87c0b85cdd2c7c6344801f59d07 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4w/c4wdhwlu6yb3wcwazdnzmgzewiemvznxvrr3525eojupqjldo5pt.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 32768}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/4w/c4ww5pmlr6amerprh7v3ibioh3yvbhemdqsh7gcrlxjnhnpkktrb.py b/SpecForge-ext/cache/compiled_kernels/4w/c4ww5pmlr6amerprh7v3ibioh3yvbhemdqsh7gcrlxjnhnpkktrb.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6dadf8f157a07f11df22e310cbfb25de17a9cf --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/4w/c4ww5pmlr6amerprh7v3ibioh3yvbhemdqsh7gcrlxjnhnpkktrb.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/7t/466129ab41abc9f5794b92b332ac4be3dff826e8f59dc7fb522710de7206acdd.best_config b/SpecForge-ext/cache/compiled_kernels/7t/466129ab41abc9f5794b92b332ac4be3dff826e8f59dc7fb522710de7206acdd.best_config new file mode 100644 index 0000000000000000000000000000000000000000..39aa06f1122c6eb2904338d2578102fd0e126a89 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/7t/466129ab41abc9f5794b92b332ac4be3dff826e8f59dc7fb522710de7206acdd.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/7t/c7t3uvardqlt6x3sz37tlydghb4rt6mdilzlc7ffz3pehdn5jwdj.py b/SpecForge-ext/cache/compiled_kernels/7t/c7t3uvardqlt6x3sz37tlydghb4rt6mdilzlc7ffz3pehdn5jwdj.py new file mode 100644 index 0000000000000000000000000000000000000000..ecbb72cfd9ec219adfffc6e9798532f20b44781f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/7t/c7t3uvardqlt6x3sz37tlydghb4rt6mdilzlc7ffz3pehdn5jwdj.py @@ -0,0 +1,49 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 256, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ah/855c4fb51632a42fcf957963b85ead1d6653657da855baf9d7c221cfd3981ad0.best_config b/SpecForge-ext/cache/compiled_kernels/ah/855c4fb51632a42fcf957963b85ead1d6653657da855baf9d7c221cfd3981ad0.best_config new file mode 100644 index 0000000000000000000000000000000000000000..a570e8d663ff6e600f50df05a811c859065ec3c4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ah/855c4fb51632a42fcf957963b85ead1d6653657da855baf9d7c221cfd3981ad0.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 21, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ah/cah767udo2rzeazh6rycnirtnr5sijiv7nem2l67isu5iyh5pzyj.py b/SpecForge-ext/cache/compiled_kernels/ah/cah767udo2rzeazh6rycnirtnr5sijiv7nem2l67isu5iyh5pzyj.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4f091ccd893d159260ae8bbc4a02b2425e213f --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ah/cah767udo2rzeazh6rycnirtnr5sijiv7nem2l67isu5iyh5pzyj.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ak/cak5ufwwwsut5tju7yvwho5uqnabsn2za7nzkoy573tny5kqhtl5.py b/SpecForge-ext/cache/compiled_kernels/ak/cak5ufwwwsut5tju7yvwho5uqnabsn2za7nzkoy573tny5kqhtl5.py new file mode 100644 index 0000000000000000000000000000000000000000..3b8f765cfa4d5e631f1b2ad59febab2c37894751 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ak/cak5ufwwwsut5tju7yvwho5uqnabsn2za7nzkoy573tny5kqhtl5.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ak/cakglntm3ejviis7qbld6stbcfdrpvbryqpb63fshmmyy46mxbh3.py b/SpecForge-ext/cache/compiled_kernels/ak/cakglntm3ejviis7qbld6stbcfdrpvbryqpb63fshmmyy46mxbh3.py new file mode 100644 index 0000000000000000000000000000000000000000..d98e9df72c1d6da642f7c6b5c2245eb549bd828b --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ak/cakglntm3ejviis7qbld6stbcfdrpvbryqpb63fshmmyy46mxbh3.py @@ -0,0 +1,675 @@ +# AOT ID: ['6_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/7a/c7avdnhdkg25qkzpvb4jgb3wfrta3u7po7rrnynujrskgetlvslk.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:5" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:5" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:5" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5" = PlaceHolder[target=primals_4] +# %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "i64[8][1]cuda:5" = PlaceHolder[target=primals_6] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args + args.clear() + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (8, ), (1, )) + assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream5 = get_raw_stream(5) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 8, 32, stream=stream5) + del buf1 + return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32) + primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_6 = rand_strided((8, ), (1, ), device='cuda:5', dtype=torch.int64) + primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/as/9f962df2938e79169dbf28adc9c67d12118719f3425569a98b11309d3108a638.best_config b/SpecForge-ext/cache/compiled_kernels/as/9f962df2938e79169dbf28adc9c67d12118719f3425569a98b11309d3108a638.best_config new file mode 100644 index 0000000000000000000000000000000000000000..128251849e0d90499e31f76727557122755609e2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/as/9f962df2938e79169dbf28adc9c67d12118719f3425569a98b11309d3108a638.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/as/casevqrknafvhxbpwjozemzmdw3n2vgrctm4s4zdjzqp52cqs6kd.py b/SpecForge-ext/cache/compiled_kernels/as/casevqrknafvhxbpwjozemzmdw3n2vgrctm4s4zdjzqp52cqs6kd.py new file mode 100644 index 0000000000000000000000000000000000000000..cd58fc8d72783ccf8cdd6b52b156eabb4946db87 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/as/casevqrknafvhxbpwjozemzmdw3n2vgrctm4s4zdjzqp52cqs6kd.py @@ -0,0 +1,693 @@ +# AOT ID: ['9_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/op/cop75xk6fpjjvnvvcusccw4eu3b3i2silh5jxkjylbibzjctamxl.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:2" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:2" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=buf1] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[2, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:2" = PlaceHolder[target=primals_7] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[2, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:2" = PlaceHolder[target=primals_13] +# %primals_10 : Tensor "i64[2][1]cuda:2" = PlaceHolder[target=primals_10] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_3, %primals_5, %sdpa_score0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21 = args + args.clear() + s0 = primals_2 + s43 = primals_4 + s72 = primals_6 + s71 = primals_8 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (2, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, ), (1, )) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (2, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (2, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (2, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream2 = get_raw_stream(2) + triton_tem_fused_0.run(primals_1, primals_3, primals_5, buf0, buf1, primals_9, primals_7, primals_11, primals_13, primals_10, buf2, s0, s72, 16, 2, 32, stream=stream2) + del buf1 + return (buf2, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, buf2, buf0, s0, s72, s4, s56, s84, s99, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16) + primals_2 = 4096 + primals_3 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:2', dtype=torch.bfloat16) + primals_4 = 4096 + primals_5 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:2', dtype=torch.bfloat16) + primals_6 = 32 + primals_7 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:2', dtype=torch.int32) + primals_8 = 4096 + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_10 = rand_strided((2, ), (1, ), device='cuda:2', dtype=torch.int64) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32) + primals_12 = 32 + primals_13 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:2', dtype=torch.int32) + primals_14 = 32 + primals_15 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:2', dtype=torch.int32) + primals_16 = 32 + primals_17 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:2', dtype=torch.int32) + primals_18 = 32 + primals_19 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:2', dtype=torch.int32) + primals_20 = 32 + primals_21 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:2', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/as/casmcbz6icqn6mp2r7jahugidys5xwty64z2p3tfw4s7vlsj2oz2.py b/SpecForge-ext/cache/compiled_kernels/as/casmcbz6icqn6mp2r7jahugidys5xwty64z2p3tfw4s7vlsj2oz2.py new file mode 100644 index 0000000000000000000000000000000000000000..0826ca372abfcf8776ae9691a40102097e45116a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/as/casmcbz6icqn6mp2r7jahugidys5xwty64z2p3tfw4s7vlsj2oz2.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/c2/363ecfeae02cf0bc03b4070f8b6a6ac6bcf543c1a19d1c4a53122d5722a2b3dd.best_config b/SpecForge-ext/cache/compiled_kernels/c2/363ecfeae02cf0bc03b4070f8b6a6ac6bcf543c1a19d1c4a53122d5722a2b3dd.best_config new file mode 100644 index 0000000000000000000000000000000000000000..2a790c7f2089f09ca3d8ed3e19bda0fc38542e85 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/c2/363ecfeae02cf0bc03b4070f8b6a6ac6bcf543c1a19d1c4a53122d5722a2b3dd.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 40, "triton_cache_hash": "MMGM2ESHRXPRFAROBBDYKTZUJ2JVVKU2TB5DVA3EL4OF2SOELPMQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/c2/cc2qlkbbemfommyywsdbow3sqg7jqf5x5tfkbqjzo2qy6lt36yjr.py b/SpecForge-ext/cache/compiled_kernels/c2/cc2qlkbbemfommyywsdbow3sqg7jqf5x5tfkbqjzo2qy6lt36yjr.py new file mode 100644 index 0000000000000000000000000000000000000000..aa032c96d739ac58bd65c4274b00419156aa2bd0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/c2/cc2qlkbbemfommyywsdbow3sqg7jqf5x5tfkbqjzo2qy6lt36yjr.py @@ -0,0 +1,86 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/dm/cdma2uevipbm2dd462ztkubtq5uanau5l3oglcw7lhpt4uovlqya.py b/SpecForge-ext/cache/compiled_kernels/dm/cdma2uevipbm2dd462ztkubtq5uanau5l3oglcw7lhpt4uovlqya.py new file mode 100644 index 0000000000000000000000000000000000000000..244ec426501f8c6b551923c0c101d94ff0e721a5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/dm/cdma2uevipbm2dd462ztkubtq5uanau5l3oglcw7lhpt4uovlqya.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/dm/cdmv6ytwvbipl4lagbifdkedszdjny3opgqlnricedg4hfpkxbdo.py b/SpecForge-ext/cache/compiled_kernels/dm/cdmv6ytwvbipl4lagbifdkedszdjny3opgqlnricedg4hfpkxbdo.py new file mode 100644 index 0000000000000000000000000000000000000000..23921b6a3d8dbcf8463f419614d4eba6120d90e9 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/dm/cdmv6ytwvbipl4lagbifdkedszdjny3opgqlnricedg4hfpkxbdo.py @@ -0,0 +1,47 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/du/cduexexwzoejgfo3kafnuhcdb2jpdj5mqnwnijlnqydzf2tfuyoh.py b/SpecForge-ext/cache/compiled_kernels/du/cduexexwzoejgfo3kafnuhcdb2jpdj5mqnwnijlnqydzf2tfuyoh.py new file mode 100644 index 0000000000000000000000000000000000000000..370b596b5114bdd048f099616efbea59bda3d0d3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/du/cduexexwzoejgfo3kafnuhcdb2jpdj5mqnwnijlnqydzf2tfuyoh.py @@ -0,0 +1,682 @@ +# AOT ID: ['12_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xv/cxvkzhyhna2alntgjzwfekekacjtshos257zi4b5b75eycps5xaj.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hk/chkcqajju5cxlzspkm5ffg5s3lyxlimjbubz6ljwtrx23yf6fnkc.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_2, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_24 +# full_blocks => eq_45 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_24, view_7 +# suffix_mask => ge_3 +# Graph fragment: +# %arg2_1 : Tensor "i64[2][1]cuda:5" = PlaceHolder[target=arg2_1] +# %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:5" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %iota_2 : Tensor "i64[s12][1]cuda:5"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %view : Tensor "i64[s12, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:5"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg1_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %ge_2 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[2][1]cuda:5"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %index : Tensor "i64[2][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[2, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {}) +# %lt : Tensor "b8[2, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, %arg1_1]), kwargs = {}) +# %index_1 : Tensor "i64[2][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[2, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[2, s12][Max(1, s12), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[2, s12, 1][Max(1, s12), 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, %arg0_1, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_3 : Tensor "b8[s37][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %index_2 : Tensor "i64[2][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[2, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_3, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, %arg1_1]), kwargs = {}) +# %view_7 : Tensor "i64[s12, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %sub_24 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_24, %arg3_1), kwargs = {}) +# %eq_24 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_24), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[2, 1, s12, s37][Max(1, s12)*Max(1, s37), s12*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, %arg0_1, %arg1_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[2, 1, 128*(((s12 + 127)//128)), 128*(((s37 + 127)//128))][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_42, 0, %sub_44], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[2, 1, ((s12 + 127)//128), 128, ((s37 + 127)//128), 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [2, 1, %floordiv_3, 128, %floordiv_2, 128]), kwargs = {}) +# %permute : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128), 128, 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_45 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_45, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // ks0) % ks1) + x0 = (xindex % ks0) + x2 = xindex // ks4 + _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = ks2 + tmp2 = tmp0 < tmp1 + tmp3 = r0_3 + 128*x0 + tmp4 = ks3 + tmp5 = tmp3 < tmp4 + tmp6 = tmp2 & tmp5 + tmp7 = r0_4 + 128*x1 + tmp8 = r0_3 + 128*x0 + tmp9 = tmp7 >= tmp8 + tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0) + tmp11 = tmp8 < tmp10 + tmp12 = tmp7 < tmp10 + tmp13 = tmp11 & tmp12 + tmp14 = tmp9 & tmp13 + tmp15 = tl.full([1, 1], False, tl.int1) + tmp16 = tmp15 | tmp14 + tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK]) + tmp18 = tmp8 >= tmp17 + tmp19 = (tmp8 % tmp17) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp17 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tmp27 < tmp10 + tmp29 = tmp18 & tmp28 + tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp31 = (tmp30 % tmp17) + tmp32 = tmp31 != tmp20 + tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0 + tmp34 = tmp33 != tmp23 + tmp35 = tmp32 & tmp34 + tmp36 = tmp31 + tmp17 + tmp37 = tl.where(tmp35, tmp36, tmp31) + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp37 == tmp38 + tmp40 = tmp29 & tmp39 + tmp41 = tmp16 | tmp40 + tmp42 = tl.full(tmp41.shape, False, tmp41.dtype) + tmp43 = tl.where(tmp6, tmp41, tmp42) + tmp44 = tmp43.to(tl.int64) + tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) + tmp47 = _tmp46 + tmp45 + _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46) + tmp46 = tl.sum(_tmp46, 1)[:, None] + tmp48 = tl.full([1, 1], 0, tl.int64) + tmp49 = tmp46 > tmp48 + tmp50 = tl.full([1, 1], 16384, tl.int64) + tmp51 = tmp46 < tmp50 + tmp52 = tmp49 & tmp51 + tmp53 = tmp52.to(tl.int8) + tmp54 = tmp53.to(tl.int32) + tmp55 = tmp46 == tmp50 + tmp56 = tmp55.to(tl.int8) + tmp57 = tmp56.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp54, xmask) + tl.store(out_ptr2 + (x5), tmp57, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/34/c34af36gfqnn2ovywuaultc2pol4jyn6io3szgjeuv3uxfzcf3nv.py +# Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# child_3 => convert_element_type_3 +# num_blocks_in_row => sum_2 +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:5" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][((s12 + 127)//128), 2*(((s12 + 127)//128)), 1]cuda:5" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# return %sum_2,%convert_element_type_3 +triton_red_fused__to_copy_sum_2 = async_compile.triton('triton_red_fused__to_copy_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/7h/c7hhwzbu42q2ic55mujfpppabs5ube44ahuppbgjh35eanxqzare.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %getitem_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 2*Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:5" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:5" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %iota_7 : Tensor "i64[2][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[((s12 + 127)//128)][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_3,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:5, requires_grad: False}) +# %unsqueeze : Tensor "i32[((s12 + 127)//128), 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_2,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:5, requires_grad: False}) +# %unsqueeze_1 : Tensor "i32[2, 1, ((s12 + 127)//128), 1][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_2,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5}) +# %where : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %convert_element_type_4,%buf13 +triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3 = async_compile.triton('triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/is/cisbwn452kdvm56u75a2mwmrdzns6w4vxzuweva24qshuv4gksv2.py +# Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# Graph fragment: +# %buf13 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:5" = PlaceHolder[target=buf13] +# %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_4 +triton_poi_fused_clone_slice_4 = async_compile.triton('triton_poi_fused_clone_slice_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hr/chrxoz3s6dcccbxa4bhegahvxtofkt5hvfz7hdrybtpjo4ffso64.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# num_blocks_in_row_2 => sum_4 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 1, ((s37 + 127)//128), 1]cuda:5" = PlaceHolder[target=clone_4] +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][((s37 + 127)//128), 2*(((s37 + 127)//128)), 1]cuda:5" = PlaceHolder[target=sum_4] +# %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %sum_4,%convert_element_type_8 +triton_red_fused__to_copy_clone_slice_sum_transpose_5 = async_compile.triton('triton_red_fused__to_copy_clone_slice_sum_transpose_5', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/g2/cg2ims2kqlrojffd3to6cjqkeah4zdvhgfoxnfrkinus4ddrvhhe.py +# Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] +# Source node to ATen node mapping: +# q_indices => clone_6, convert_element_type_9 +# Graph fragment: +# %getitem_5 : Tensor "i64[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:5" = PlaceHolder[target=getitem_5] +# %convert_element_type_9 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_6 +triton_poi_fused__to_copy_6 = async_compile.triton('triton_poi_fused__to_copy_6', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s12 = arg0_1 + s37 = arg1_1 + s21 = arg3_1 + assert_size_stride(arg2_1, (2, ), (1, )) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf12 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream5) + buf21 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_poi_fused_new_zeros_0.run(buf21, triton_poi_fused_new_zeros_0_xnumel, stream=stream5) + ps0 = (127 + s37) // 128 + ps1 = (127 + s12) // 128 + ps2 = ((127 + s12) // 128)*((127 + s37) // 128) + buf1 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg2_1, buf1, buf5, ps0, ps1, s12, s37, ps2, s21, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream5) + del arg2_1 + buf10 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream5 = get_raw_stream(5) + triton_red_fused__to_copy_sum_2.run(buf1, buf10, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream5) + buf19 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row_1, child_7], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream5 = get_raw_stream(5) + triton_red_fused__to_copy_sum_2.run(buf5, buf19, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream5) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + del buf1 + buf4 = buf2[1] + assert_size_stride(buf4, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf11 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf4, buf10, buf11, buf12, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream5) + del buf4 + buf14 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_poi_fused_clone_slice_4.run(buf12, buf14, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream5) + del buf12 + buf32 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream5 = get_raw_stream(5) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf14, buf32, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream5) + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + del buf5 + buf8 = buf6[1] + assert_size_stride(buf8, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf20 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf8, buf19, buf20, buf21, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream5) + del buf8 + buf23 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_poi_fused_clone_slice_4.run(buf21, buf23, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream5) + del buf21 + buf29 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream5 = get_raw_stream(5) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf23, buf29, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream5) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf15 = torch.ops.aten.sort.stable(reinterpret_tensor(buf14, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf14 + buf17 = buf15[1] + assert_size_stride(buf17, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf17, 16, 'torch.ops.aten.sort.stable') + del buf15 + buf30 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_poi_fused__to_copy_6.run(buf17, buf30, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream5) + del buf17 + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf24 = torch.ops.aten.sort.stable(reinterpret_tensor(buf23, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf23 + buf26 = buf24[1] + assert_size_stride(buf26, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf26, 16, 'torch.ops.aten.sort.stable') + del buf24 + buf27 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [full_q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream5 = get_raw_stream(5) + triton_poi_fused__to_copy_6.run(buf26, buf27, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream5) + del buf26 + return (buf27, buf29, buf30, buf32, buf20, buf19, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1569 + arg1_1 = 1569 + arg2_1 = rand_strided((2, ), (1, ), device='cuda:5', dtype=torch.int64) + arg3_1 = 1569 + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/dz/cdz3io7w5uyfrmfqvmg2kt2ay66qv4ckwtyurhik3frq7fqnk7gm.py b/SpecForge-ext/cache/compiled_kernels/dz/cdz3io7w5uyfrmfqvmg2kt2ay66qv4ckwtyurhik3frq7fqnk7gm.py new file mode 100644 index 0000000000000000000000000000000000000000..2a6f3f4b836a071f915b60defb40fa0e10d7c7c2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/dz/cdz3io7w5uyfrmfqvmg2kt2ay66qv4ckwtyurhik3frq7fqnk7gm.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/dz/f7d5f2184a6f349e4531c61cf67ffbd51fe751bb6902c7e014986bad1a4a9b8f.best_config b/SpecForge-ext/cache/compiled_kernels/dz/f7d5f2184a6f349e4531c61cf67ffbd51fe751bb6902c7e014986bad1a4a9b8f.best_config new file mode 100644 index 0000000000000000000000000000000000000000..cbf4eb5ae8826a07243c88f3ee991df371ea45fb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/dz/f7d5f2184a6f349e4531c61cf67ffbd51fe751bb6902c7e014986bad1a4a9b8f.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 53, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/fa/cfac6ze2ka7xqvmyxx4ehmqqczd7mi63mu366jgrbaebsyxjcuna.py b/SpecForge-ext/cache/compiled_kernels/fa/cfac6ze2ka7xqvmyxx4ehmqqczd7mi63mu366jgrbaebsyxjcuna.py new file mode 100644 index 0000000000000000000000000000000000000000..e27d8fc6e2e5e7722cbca44e00a5101e60e4c03a --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fa/cfac6ze2ka7xqvmyxx4ehmqqczd7mi63mu366jgrbaebsyxjcuna.py @@ -0,0 +1,307 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/27/c274gnr6pjrqx44o2l7ymaeh7yrigwgf3ninh5xcv6vd5wswoduy.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat => cat +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# mul => mul_24 +# mul_1 => mul_45 +# neg => neg +# q_embed => add_54 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1 => slice_1 +# x2 => slice_2 +# Graph fragment: +# %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:1" = PlaceHolder[target=primals_12] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:1" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {}) +# %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {}) +# %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {}) +# %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {}) +# %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {}) +# %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {}) +# return %add_54 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/qt/cqtv2hjbuijyx7awch534sanohmqs6reawit6ksar4ud36qn7xhy.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat_1 => cat_1 +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# k_embed => add_90 +# mul_2 => mul_54 +# mul_3 => mul_75 +# neg_1 => neg_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1_1 => slice_3 +# x2_1 => slice_4 +# Graph fragment: +# %primals_13 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:1" = PlaceHolder[target=primals_13] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:1" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_54 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_13, %unsqueeze), kwargs = {}) +# %slice_3 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24, s24*s48, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, 0, %floordiv), kwargs = {}) +# %slice_4 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24, s24*s48, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg_1 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s48*Max(1, s24 - ((s24//2))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {}) +# %cat_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {}) +# %mul_75 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {}) +# %add_90 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {}) +# return %add_90 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13 = args + args.clear() + s92 = primals_1 + s24 = primals_2 + s96 = primals_3 + s79 = primals_5 + s9 = primals_7 + s38 = primals_9 + s48 = primals_10 + s34 = primals_11 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1)) + assert_size_stride(primals_13, (s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1)) + with torch.cuda._DeviceGuard(1): + torch.cuda.set_device(1) + ps0 = s24*s34 + buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9 + stream1 = get_raw_stream(1) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream1) + del primals_12 + ps1 = s24*s48 + buf1 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s9*s48*s48 + stream1 = get_raw_stream(1) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_13, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream1) + del primals_13 + return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2048 + primals_2 = 128 + primals_3 = 5245440 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_5 = 2048 + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:1', dtype=torch.bfloat16) + primals_7 = 2048 + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:1', dtype=torch.int64) + primals_9 = 1 + primals_10 = 8 + primals_11 = 32 + primals_12 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16) + primals_13 = rand_strided((8, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/fa/cfai6qfroimjkp32i57fqulbbxd7ap7nwbhmtwtra7dawieplflr.py b/SpecForge-ext/cache/compiled_kernels/fa/cfai6qfroimjkp32i57fqulbbxd7ap7nwbhmtwtra7dawieplflr.py new file mode 100644 index 0000000000000000000000000000000000000000..adc42d10f7a43f66aaae4b267ba8919d49591535 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fa/cfai6qfroimjkp32i57fqulbbxd7ap7nwbhmtwtra7dawieplflr.py @@ -0,0 +1,168 @@ +# AOT ID: ['10_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py +# Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] +# Source node to ATen node mapping: +# getitem_1 => unsqueeze +# position_mask => mul_6 +# target_mask => index +# target_mask_1 => convert_element_type +# target_max_token => argmax +# Graph fragment: +# %arg1_1 : Tensor "bf16[8, s14, 151936][151936*s14, 151936, 1]cuda:0" = PlaceHolder[target=arg1_1] +# %argmax : Tensor "i64[8, s14][s14, 1]cuda:0" = PlaceHolder[target=argmax] +# %arg2_1 : Tensor "b8[151936][1]cuda:0" = PlaceHolder[target=arg2_1] +# %arg3_1 : Tensor "i64[8, s14, 1][s14, 1, 1]cuda:0" = PlaceHolder[target=arg3_1] +# %argmax : Tensor "i64[8, s14][s14, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {}) +# %index : Tensor "b8[8, s14][s14, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%argmax]), kwargs = {}) +# %unsqueeze : Tensor "b8[8, s14, 1][s14, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {}) +# %convert_element_type : Tensor "i32[8, s14, 1][s14, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {}) +# %mul_6 : Tensor "i64[8, s14, 1][s14, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg3_1), kwargs = {}) +# return %argmax,%mul_6 +triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s24 = arg0_1 + arg1_1_size = arg1_1.size() + s14 = arg1_1_size[1] + assert_size_stride(arg1_1, (8, s14, 151936), (151936*s14, 151936, 1)) + assert_size_stride(arg2_1, (151936, ), (1, )) + assert_size_stride(arg3_1, (8, s14, 1), (s14, 1, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((8, s14), (s14, 1), torch.int64) + buf1 = reinterpret_tensor(buf0, (8, s14, 1), (s14, 1, 1), 0); del buf0 # reuse + # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel = 8*s14 + stream0 = get_raw_stream(0) + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg1_1, arg2_1, arg3_1, triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel, 151936, stream=stream0) + del arg1_1 + del arg2_1 + del arg3_1 + return (buf1, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 2009 + arg1_1 = rand_strided((8, 2009, 151936), (305239424, 151936, 1), device='cuda:0', dtype=torch.bfloat16) + arg2_1 = rand_strided((151936, ), (1, ), device='cuda:0', dtype=torch.bool) + arg3_1 = rand_strided((8, 2009, 1), (2009, 1, 1), device='cuda:0', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/fa/cfail5nyr4vuktxoags33cssvkjxk2nbmzhswhjwxszpyc4qj4wf.py b/SpecForge-ext/cache/compiled_kernels/fa/cfail5nyr4vuktxoags33cssvkjxk2nbmzhswhjwxszpyc4qj4wf.py new file mode 100644 index 0000000000000000000000000000000000000000..3aab570fff1cbbf5c315351733f335c066b9236d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fa/cfail5nyr4vuktxoags33cssvkjxk2nbmzhswhjwxszpyc4qj4wf.py @@ -0,0 +1,675 @@ +# AOT ID: ['6_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gp/cgpqg54v7ag6awmgwhlrbbyw5jxsgjo6tuzvo3rt2xzqk6f33df2.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:4" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:4" = PlaceHolder[target=buf1] +# %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "i64[2][1]cuda:4" = PlaceHolder[target=primals_6] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args + args.clear() + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (2, ), (1, )) + assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream4 = get_raw_stream(4) + triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 2, 32, stream=stream4) + del buf1 + return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:4', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:4', dtype=torch.bfloat16) + primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32) + primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32) + primals_6 = rand_strided((2, ), (1, ), device='cuda:4', dtype=torch.int64) + primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32) + primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32) + primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32) + primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/fa/cfawzdo3q32syzk5d3t3mjridjbalgrkptn5qwko7qnup25mzrum.py b/SpecForge-ext/cache/compiled_kernels/fa/cfawzdo3q32syzk5d3t3mjridjbalgrkptn5qwko7qnup25mzrum.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b09956a87371a78018d11affd863581f5c7672 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fa/cfawzdo3q32syzk5d3t3mjridjbalgrkptn5qwko7qnup25mzrum.py @@ -0,0 +1,57 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 16384, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 16384 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) diff --git a/SpecForge-ext/cache/compiled_kernels/fi/cfiplsvt2q6tbvsfjtg2dd47g7npdwtvk5m3lv4anjbxwgjigkj2.py b/SpecForge-ext/cache/compiled_kernels/fi/cfiplsvt2q6tbvsfjtg2dd47g7npdwtvk5m3lv4anjbxwgjigkj2.py new file mode 100644 index 0000000000000000000000000000000000000000..494e1b1db3f1f7648c11c174ce788decec52d2d0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/fi/cfiplsvt2q6tbvsfjtg2dd47g7npdwtvk5m3lv4anjbxwgjigkj2.py @@ -0,0 +1,72 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2048 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/h6/aa838d40f4d0e483f1277be61c094ff598dd757fa08fb0e455bf7c8a9b79036a.best_config b/SpecForge-ext/cache/compiled_kernels/h6/aa838d40f4d0e483f1277be61c094ff598dd757fa08fb0e455bf7c8a9b79036a.best_config new file mode 100644 index 0000000000000000000000000000000000000000..27b42e93657f0d84b947e2c1d4a7a83b2db38742 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/h6/aa838d40f4d0e483f1277be61c094ff598dd757fa08fb0e455bf7c8a9b79036a.best_config @@ -0,0 +1 @@ +{"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 13, "triton_cache_hash": "6FB7I6IASCIGI3DSKLBL4Q2CXFFWPYWXW7AMHNUUDLPGKUCB3PDA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/ic/cicti66tef7ykscmewrfizq5t5hma2a6k6njneyopvmhy4vmegql.py b/SpecForge-ext/cache/compiled_kernels/ic/cicti66tef7ykscmewrfizq5t5hma2a6k6njneyopvmhy4vmegql.py new file mode 100644 index 0000000000000000000000000000000000000000..2ffe0777032f41505a3b774145eb39a3a8ffa1f4 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ic/cicti66tef7ykscmewrfizq5t5hma2a6k6njneyopvmhy4vmegql.py @@ -0,0 +1,543 @@ +# AOT ID: ['5_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zc/czc4uswzazabvj7ebt72gzrcg2fgrugi6d7lol4a4jino45fz2ua.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge, view +# diagnol_mask => eq +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# m => iota_2 +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub, view_7 +# suffix_mask => ge_1 +# Graph fragment: +# %arg0_1 : Tensor "i64[2][1]cuda:2" = PlaceHolder[target=arg0_1] +# %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:2"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[2048][1]cuda:2"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False}) +# %ge : Tensor "b8[2048, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[2][1]cuda:2"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False}) +# %index : Tensor "i64[2][1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[2, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {}) +# %lt : Tensor "b8[2, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[2, 1, 2048][2048, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, 2048]), kwargs = {}) +# %index_1 : Tensor "i64[2][1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[2, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[2, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[2, 2048, 1][2048, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_1 : Tensor "b8[2048][1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[2048][1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[2][1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[2, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[2, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[2, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[2, 1, 2048][2048, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, 2048]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub : Tensor "i64[2048, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub, 2048), kwargs = {}) +# %eq : Tensor "b8[2048, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[2, 1, 2048, 2048][4194304, 4194304, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, 2048, 2048]), kwargs = {}) +# %view_10 : Tensor "b8[2, 1, 16, 128, 16, 128][4194304, 4194304, 262144, 2048, 128, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [2, 1, 16, 128, 16, 128]), kwargs = {}) +# %permute : Tensor "b8[2, 1, 16, 16, 128, 128][4194304, 4194304, 262144, 128, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# return %sum_1 +triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8192, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 512 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/cm/ccmqky4m65yifqjmfuu7vgvpuhwpa4ybaxffiy3mu2e6yzgecghe.py +# Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_4 => full_default_4 +# Graph fragment: +# %full_default_4 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# return %index_put_1 +triton_poi_fused_new_zeros_1 = async_compile.triton('triton_poi_fused_new_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 1024}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4352}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 544 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/iw/ciwoxk7cuonocxkjitlvfvf5jppmr2duv6vgwzkwaw4xszgcaf5m.py +# Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# arange_6 => iota_8 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# child_7 => convert_element_type_6 +# child_8 => convert_element_type_7 +# col_indices => sort +# col_indices_1 => sort_1 +# col_range => iota_5 +# col_range_1 => iota_9 +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# dense_mask_2 => full_default_1 +# dense_mask_4 => full_default_4 +# full_blocks => eq_1 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index_mask => lt_4 +# index_mask_1 => lt_5 +# lt_3 => lt_3 +# num_blocks_in_row => sum_2 +# num_blocks_in_row_1 => sum_3 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# row_indices => unsqueeze +# row_indices_1 => unsqueeze_7 +# setitem => full_default_3, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# setitem_1 => full_default_6, index_put_1, iota_10, iota_11, unsqueeze_10, unsqueeze_11, unsqueeze_12, unsqueeze_13, unsqueeze_9 +# unsqueeze_1 => unsqueeze_1 +# unsqueeze_3 => unsqueeze_8 +# valid_indices => full_default_2, where +# valid_indices_1 => full_default_5, where_1 +# Graph fragment: +# %sum_1 : Tensor "i64[2, 1, 16, 16][256, 512, 16, 1]cuda:2" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:2" = PlaceHolder[target=sum_2] +# %sum_3 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:2" = PlaceHolder[target=sum_3] +# %buf2 : Tensor "i16[2, 1, 16, 16][256, 512, 16, 1]cuda:2" = PlaceHolder[target=buf2] +# %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2" = PlaceHolder[target=index_put] +# %buf4 : Tensor "i16[2, 1, 16, 16][256, 512, 16, 1]cuda:2" = PlaceHolder[target=buf4] +# %convert_element_type_6 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=convert_element_type_6] +# %convert_element_type_7 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=convert_element_type_7] +# %index_put_1 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2" = PlaceHolder[target=index_put_1] +# %gt : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_2,), kwargs = {stable: True, descending: True}) +# %eq_1 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_1, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_5,), kwargs = {stable: True, descending: True}) +# %full_default_1 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %iota_7 : Tensor "i64[2][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:2, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[16][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:2, requires_grad: False}) +# %sum_2 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[2, 1, 16, 1][16, 16, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %full_default_2 : Tensor "i32[][]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %where : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %full_default_2), kwargs = {}) +# %full_default_3 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %index_put : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_3), kwargs = {}) +# %full_default_4 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %iota_11 : Tensor "i64[2][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False}) +# %unsqueeze_11 : Tensor "i64[2, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_11, -1), kwargs = {}) +# %unsqueeze_12 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_11, -1), kwargs = {}) +# %unsqueeze_13 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_12, -1), kwargs = {}) +# %iota_10 : Tensor "i64[1][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False}) +# %unsqueeze_9 : Tensor "i64[1, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_10, -1), kwargs = {}) +# %unsqueeze_10 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_9, -1), kwargs = {}) +# %iota_8 : Tensor "i32[16][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:2, requires_grad: False}) +# %unsqueeze_7 : Tensor "i32[16, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_8, -1), kwargs = {}) +# %iota_9 : Tensor "i32[16][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:2, requires_grad: False}) +# %sum_3 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_5, [-1]), kwargs = {}) +# %convert_element_type_6 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.int32), kwargs = {}) +# %unsqueeze_8 : Tensor "i32[2, 1, 16, 1][16, 16, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_6, 3), kwargs = {}) +# %lt_5 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_9, %unsqueeze_8), kwargs = {}) +# %convert_element_type_7 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_3, torch.int32), kwargs = {}) +# %full_default_5 : Tensor "i32[][]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %where_1 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_5, %convert_element_type_7, %full_default_5), kwargs = {}) +# %full_default_6 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False}) +# %index_put_1 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_4, [%unsqueeze_13, %unsqueeze_10, %unsqueeze_7, %where_1], %full_default_6), kwargs = {}) +# return %buf2,%buf4,%sum_2,%sum_3,%convert_element_type_3,%convert_element_type_6,%convert_element_type_4,%buf9,%convert_element_type_7,%buf16 +triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 = async_compile.triton('triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bp/cbp4ofim2oujxe6hm47xzugia67k4kofgbgvt7n7d5gd3iux76li.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf9 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2" = PlaceHolder[target=buf9] +# %buf11 : Tensor "i16[2, 1, 16, 16][256, 512, 16, 1]cuda:2" = PlaceHolder[target=buf11] +# %sum_4 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:2" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[2, 1, 16, 16][272, 272, 17, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, 16), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[2, 1, 16, 16][256, 256, 1, 16]cuda:2"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[2, 1, 16, 16][256, 256, 1, 16]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf11,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 256, 'r0_': 4096}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (2, ), (1, )) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((2, 1, 16, 16), (256, 512, 16, 1), torch.int64) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] + stream2 = get_raw_stream(2) + triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0.run(arg0_1, buf0, 512, 16384, stream=stream2) + del arg0_1 + buf15 = empty_strided_cuda((2, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + stream2 = get_raw_stream(2) + triton_poi_fused_new_zeros_1.run(buf15, 544, stream=stream2) + buf8 = empty_strided_cuda((2, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + stream2 = get_raw_stream(2) + triton_poi_fused_new_zeros_1.run(buf8, 544, stream=stream2) + buf6 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + buf13 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + buf7 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf14 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + stream2 = get_raw_stream(2) + triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.run(buf0, buf6, buf13, buf7, buf8, buf14, buf15, 32, 16, stream=stream2) + del buf0 + buf22 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf24 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream2 = get_raw_stream(2) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf8, buf22, buf24, 32, 16, stream=stream2) + del buf8 + buf19 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf21 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream2 = get_raw_stream(2) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf15, buf19, buf21, 32, 16, stream=stream2) + del buf15 + return (buf19, buf21, buf22, buf24, buf14, buf13, buf7, buf6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, ), (1, ), device='cuda:2', dtype=torch.int64) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/ii/ciiz7wynjvqkn6uv5csahwryt5x2d664u4o7ugmepfcsfcniut4v.py b/SpecForge-ext/cache/compiled_kernels/ii/ciiz7wynjvqkn6uv5csahwryt5x2d664u4o7ugmepfcsfcniut4v.py new file mode 100644 index 0000000000000000000000000000000000000000..26525249a17c23eb6f9a038258318a2afe3da6e0 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ii/ciiz7wynjvqkn6uv5csahwryt5x2d664u4o7ugmepfcsfcniut4v.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 524288000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/ik/ciksm4jphopwjgs55fbipcxecpw4d643lh76mj27636ryec4e3kg.py b/SpecForge-ext/cache/compiled_kernels/ik/ciksm4jphopwjgs55fbipcxecpw4d643lh76mj27636ryec4e3kg.py new file mode 100644 index 0000000000000000000000000000000000000000..f5819f2f4dc8a830ffbce3c713ff77a3af377ffb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ik/ciksm4jphopwjgs55fbipcxecpw4d643lh76mj27636ryec4e3kg.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/is/cisbwn452kdvm56u75a2mwmrdzns6w4vxzuweva24qshuv4gksv2.py b/SpecForge-ext/cache/compiled_kernels/is/cisbwn452kdvm56u75a2mwmrdzns6w4vxzuweva24qshuv4gksv2.py new file mode 100644 index 0000000000000000000000000000000000000000..434b567f2033c56fe0c03880bcb4f8287bc9ab41 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/is/cisbwn452kdvm56u75a2mwmrdzns6w4vxzuweva24qshuv4gksv2.py @@ -0,0 +1,26 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/is/d02b763bc26b4a862acff11bb1d83ee2ff669b1418d106ae0058cadf26d0f276.best_config b/SpecForge-ext/cache/compiled_kernels/is/d02b763bc26b4a862acff11bb1d83ee2ff669b1418d106ae0058cadf26d0f276.best_config new file mode 100644 index 0000000000000000000000000000000000000000..bccc339c530946640745852b622c73570887ab86 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/is/d02b763bc26b4a862acff11bb1d83ee2ff669b1418d106ae0058cadf26d0f276.best_config @@ -0,0 +1 @@ +{"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "CLTRXNE5MHPP3O5A5W4Z4EQTTZVYMOP5IPJT6N44O6FTBZFXLMNA"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/iy/ciy3jtwq2kqsaaylz6g2uxngpmmalnqcompyd7v6diseejxhwvzs.py b/SpecForge-ext/cache/compiled_kernels/iy/ciy3jtwq2kqsaaylz6g2uxngpmmalnqcompyd7v6diseejxhwvzs.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6112a56ce8e09c265eb8d208b1650ec1c28d51 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/iy/ciy3jtwq2kqsaaylz6g2uxngpmmalnqcompyd7v6diseejxhwvzs.py @@ -0,0 +1,37 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 4096, 'r0_': 32}, + reduction_hint=ReductionHint.OUTER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 32 + R0_BLOCK: tl.constexpr = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = tl.where(xmask, tmp1, 0) + tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32) + tl.store(out_ptr0 + (x0), tmp4, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/lk/clk4cgl52lrdnpqzv6ubpxawah5lw2cyfnmsbuouupfi5emjbchn.py b/SpecForge-ext/cache/compiled_kernels/lk/clk4cgl52lrdnpqzv6ubpxawah5lw2cyfnmsbuouupfi5emjbchn.py new file mode 100644 index 0000000000000000000000000000000000000000..865c44be9c11c379fff92df175991ba1c9c06dfc --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lk/clk4cgl52lrdnpqzv6ubpxawah5lw2cyfnmsbuouupfi5emjbchn.py @@ -0,0 +1,1083 @@ +# AOT ID: ['13_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ef/cefh7lkkzxkkmdldjmu75mgxgh2oczofby7slgtoagmm5sd6wlvf.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[8, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[8, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 32) + x2 = xindex // ks1 + x5 = triton_helpers.div_floor_integer(xindex, ks0) + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + x4 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hp/chpmrudjesqgjc4u7kzlnbev6u6xezu6edp3jbrg4g2q5z3yue3f.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_2 : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:4" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:4" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[8, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:4" = PlaceHolder[target=getitem_5] +# %primals_13 : Tensor "i32[8, 1, s99][s99, s99, 1]cuda:4" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[8, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:4" = PlaceHolder[target=primals_9] +# %primals_22 : Tensor "i32[8, 1, s56][s56, s56, 1]cuda:4" = PlaceHolder[target=primals_22] +# %primals_25 : Tensor "i32[8, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:4" = PlaceHolder[target=primals_25] +# %primals_17 : Tensor "i32[8, 1, s94][s94, s94, 1]cuda:4" = PlaceHolder[target=primals_17] +# %primals_20 : Tensor "i32[8, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:4" = PlaceHolder[target=primals_20] +# %primals_27 : Tensor "i32[8, 1, s100][s100, s100, 1]cuda:4" = PlaceHolder[target=primals_27] +# %primals_30 : Tensor "i32[8, 1, s6, s10][s10*s6, s10*s6, s10, 1]cuda:4" = PlaceHolder[target=primals_30] +# %primals_14 : Tensor "i64[8][1]cuda:4" = PlaceHolder[target=primals_14] +# %full_default : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1 = args + args.clear() + s37 = primals_10 + s0 = primals_11 + s75 = primals_15 + s22 = primals_7 + s72 = primals_8 + s99 = primals_12 + s94 = primals_16 + s28 = primals_18 + s4 = primals_19 + s56 = primals_21 + s53 = primals_24 + s84 = primals_23 + s100 = primals_26 + s10 = primals_29 + s6 = primals_28 + assert_size_stride(primals_2, (8, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_6, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_9, (8, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (8, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (8, ), (1, )) + assert_size_stride(primals_17, (8, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_20, (8, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_22, (8, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_25, (8, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_27, (8, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_30, (8, 1, s6, s10), (s10*s6, s10*s6, s10, 1)) + assert_size_stride(getitem, (8, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(getitem_1, (8, 32, s37), (32*max(1, s37), max(1, s37), 1)) + assert_size_stride(tangents_1, (8, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1)) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + ps0 = 32*s37 + buf1 = empty_strided_cuda((8, 32, s37), (32*s37, s37, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + triton_red_fused_zeros_0_xnumel = 256*s37 + stream4 = get_raw_stream(4) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, s37, ps0, triton_red_fused_zeros_0_xnumel, 128, stream=stream4) + del getitem + buf3 = empty_strided_cuda((8, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream4 = get_raw_stream(4) + triton_tem_fused_zeros_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_22, primals_25, primals_17, primals_20, primals_27, primals_30, primals_14, buf5, s37, s0, s99, s22, s72, s56, s53, s84, s75, 4*((127 + s37) // 128) + ((127 + s0) // 128), 8, 8, stream=stream4) + del buf1 + del getitem_1 + del primals_13 + del primals_14 + del primals_17 + del primals_2 + del primals_20 + del primals_22 + del primals_25 + del primals_27 + del primals_30 + del primals_4 + del primals_6 + del primals_9 + del tangents_1 + return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_10 = 1896 + primals_11 = 1896 + primals_15 = 1896 + primals_7 = 15 + primals_8 = 15 + primals_12 = 15 + primals_16 = 15 + primals_18 = 15 + primals_19 = 15 + primals_21 = 15 + primals_24 = 15 + primals_23 = 15 + primals_26 = 15 + primals_29 = 15 + primals_28 = 15 + primals_2 = rand_strided((8, 32, 1896, 128), (7766016, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + primals_4 = rand_strided((8, 8, 1896, 128), (1941504, 242688, 128, 1), device='cuda:4', dtype=torch.bfloat16) + primals_6 = rand_strided((8, 8, 1896, 128), (1941504, 242688, 128, 1), device='cuda:4', dtype=torch.bfloat16) + primals_9 = rand_strided((8, 1, 15, 15), (225, 225, 15, 1), device='cuda:4', dtype=torch.int32) + primals_13 = rand_strided((8, 1, 15), (15, 15, 1), device='cuda:4', dtype=torch.int32) + primals_14 = rand_strided((8, ), (1, ), device='cuda:4', dtype=torch.int64) + primals_17 = rand_strided((8, 1, 15), (15, 15, 1), device='cuda:4', dtype=torch.int32) + primals_20 = rand_strided((8, 1, 15, 15), (225, 225, 15, 1), device='cuda:4', dtype=torch.int32) + primals_22 = rand_strided((8, 1, 15), (15, 15, 1), device='cuda:4', dtype=torch.int32) + primals_25 = rand_strided((8, 1, 15, 15), (225, 225, 15, 1), device='cuda:4', dtype=torch.int32) + primals_27 = rand_strided((8, 1, 15), (15, 15, 1), device='cuda:4', dtype=torch.int32) + primals_30 = rand_strided((8, 1, 15, 15), (225, 225, 15, 1), device='cuda:4', dtype=torch.int32) + getitem = rand_strided((8, 32, 1896, 128), (7766016, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16) + getitem_1 = rand_strided((8, 32, 1896), (60672, 1896, 1), device='cuda:4', dtype=torch.float32) + tangents_1 = rand_strided((8, 32, 1896, 128), (7766016, 242688, 128, 1), device='cuda:4', dtype=torch.bfloat16) + fn = lambda: call([primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/lp/1e661150415d7fce0f5577d7db35f128089400ce692c8dfdf5e40cb9a867cea5.best_config b/SpecForge-ext/cache/compiled_kernels/lp/1e661150415d7fce0f5577d7db35f128089400ce692c8dfdf5e40cb9a867cea5.best_config new file mode 100644 index 0000000000000000000000000000000000000000..ed4bbafec32134c55e06add8fdbae259cebe3543 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lp/1e661150415d7fce0f5577d7db35f128089400ce692c8dfdf5e40cb9a867cea5.best_config @@ -0,0 +1 @@ +{"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "EB4J5U2HKNQBLXRWK6B5L6ATOH55AWD3MB7P63KH5AKRGRDZER7A"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/lp/clp43olymjc72eay3ukgvj6r4apcbbbnz3xlli3tafgvidlacqsg.py b/SpecForge-ext/cache/compiled_kernels/lp/clp43olymjc72eay3ukgvj6r4apcbbbnz3xlli3tafgvidlacqsg.py new file mode 100644 index 0000000000000000000000000000000000000000..96f12f9f6c5abf370aa8b987915bd964cf8ae076 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lp/clp43olymjc72eay3ukgvj6r4apcbbbnz3xlli3tafgvidlacqsg.py @@ -0,0 +1,322 @@ +# AOT ID: ['4_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/pz/cpzw7g6yjflpctcqkzf5osq7m5acrctaysa6th3ox3deinxluypc.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_2 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3" = PlaceHolder[target=tangents_2] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:3" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_84 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {}) +# %slice_5 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {}) +# %slice_6 : Tensor "bf16[s48, s25, s9, (s24//2)][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_2 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s25*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {}) +# %full_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_13, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %slice_scatter_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_1 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {}) +# %add_100 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_85 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {}) +# %add_101 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {}) +# return %add_101 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4194304}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/dz/cdz3io7w5uyfrmfqvmg2kt2ay66qv4ckwtyurhik3frq7fqnk7gm.py +# Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] +# Source node to ATen node mapping: +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# Graph fragment: +# %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3" = PlaceHolder[target=tangents_1] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:3" = PlaceHolder[target=primals_8] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_6] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_4] +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {}) +# %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {}) +# %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {}) +# %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {}) +# %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:3, pin_memory: False}) +# %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {}) +# %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {}) +# %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {}) +# %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {}) +# return %add_107 +triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args + args.clear() + s24 = primals_2 + s9 = primals_7 + s48 = primals_10 + s34 = primals_11 + s25 = primals_13 + s92 = primals_1 + s96 = primals_3 + s79 = primals_5 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1)) + assert_size_stride(tangents_2, (s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + buf0 = empty_strided_cuda((s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s25*s48*s9 + stream3 = get_raw_stream(3) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream3) + del tangents_2 + buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add] + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9 + stream3 = get_raw_stream(3) + triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream3) + del primals_4 + del primals_6 + del primals_8 + del tangents_1 + return (None, None, None, None, None, None, None, None, None, None, None, buf1, None, buf0, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_2 = 128 + primals_7 = 2048 + primals_10 = 2 + primals_11 = 32 + primals_13 = 8 + primals_1 = 2048 + primals_3 = 5245440 + primals_5 = 2048 + floordiv = 64 + add_96 = 64 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:3', dtype=torch.int64) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16) + tangents_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16) + fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/lp/clpt6xpoqv3wajdkyviksqw24bkxb47w4kcgihhcyrj553fxcjqs.py b/SpecForge-ext/cache/compiled_kernels/lp/clpt6xpoqv3wajdkyviksqw24bkxb47w4kcgihhcyrj553fxcjqs.py new file mode 100644 index 0000000000000000000000000000000000000000..66874ce0110634730635d4091e2aeebe2e07db87 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/lp/clpt6xpoqv3wajdkyviksqw24bkxb47w4kcgihhcyrj553fxcjqs.py @@ -0,0 +1,50 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 256, 'r0_': 4096}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 32 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ls/cls3ju4iskgwc7wepn2m46svt5vbvf47ps3tsfw7s37earyzkzz2.py b/SpecForge-ext/cache/compiled_kernels/ls/cls3ju4iskgwc7wepn2m46svt5vbvf47ps3tsfw7s37earyzkzz2.py new file mode 100644 index 0000000000000000000000000000000000000000..46c7c2714871902e1955c017160e22c717ae90c3 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ls/cls3ju4iskgwc7wepn2m46svt5vbvf47ps3tsfw7s37earyzkzz2.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': 'fp64', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tmp1 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp9 = in_ptr1 + tmp6 = ks0 + tmp7 = tmp6.to(tl.float32) + tmp8 = (tmp4 / tmp7) + tmp10 = tmp9.to(tl.float32) + tmp11 = tmp8 + tmp10 + tmp12 = libdevice.rsqrt(tmp11) + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp13 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp14 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp15 = tmp14.to(tl.float32) + tmp16 = tmp15 * tmp12 + tmp17 = tmp16.to(tl.float32) + tmp18 = tmp13 * tmp17 + tl.store(out_ptr0 + (r0_1 + ks0*x0), tmp18, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/ng/cnglvt55axgj3x37cqns4hg7zsjeu57rkczufz7vpm5o4rwbf2w7.py b/SpecForge-ext/cache/compiled_kernels/ng/cnglvt55axgj3x37cqns4hg7zsjeu57rkczufz7vpm5o4rwbf2w7.py new file mode 100644 index 0000000000000000000000000000000000000000..647ff6c102d2df49d78bf1b02ae200dc6677b7fc --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/ng/cnglvt55axgj3x37cqns4hg7zsjeu57rkczufz7vpm5o4rwbf2w7.py @@ -0,0 +1,164 @@ +# AOT ID: ['0_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mw/cmw5mlntlt7o73p24outkvtp73w3ylg6pk6fbqshalpowjpvoh47.py +# Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] +# Source node to ATen node mapping: +# getitem_1 => unsqueeze +# position_mask => mul +# target_mask => index +# target_mask_1 => convert_element_type +# target_max_token => argmax +# Graph fragment: +# %arg0_1 : Tensor "bf16[2, 2048, 151936][311164928, 151936, 1]cuda:2" = PlaceHolder[target=arg0_1] +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:2" = PlaceHolder[target=argmax] +# %arg1_1 : Tensor "b8[151936][1]cuda:2" = PlaceHolder[target=arg1_1] +# %arg2_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:2" = PlaceHolder[target=arg2_1] +# %argmax : Tensor "i64[2, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {}) +# %index : Tensor "b8[2, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%argmax]), kwargs = {}) +# %unsqueeze : Tensor "b8[2, 2048, 1][2048, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {}) +# %convert_element_type : Tensor "i32[2, 2048, 1][2048, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {}) +# %mul : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg2_1), kwargs = {}) +# return %argmax,%mul +triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 262144}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 151936 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') + tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32) + tmp4 = tmp2 + tmp3 + tmp5 = tmp2 < 0 + tmp6 = tl.where(tmp5, tmp4, tmp2) + tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936") + tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1) + tmp9 = tmp8.to(tl.int32) + tmp10 = tmp9.to(tl.int64) + tmp12 = tmp10 * tmp11 + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, None) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1 = args + args.clear() + assert_size_stride(arg0_1, (2, 2048, 151936), (311164928, 151936, 1)) + assert_size_stride(arg1_1, (151936, ), (1, )) + assert_size_stride(arg2_1, (2, 2048, 1), (2048, 1, 1)) + with torch.cuda._DeviceGuard(2): + torch.cuda.set_device(2) + buf0 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64) + buf1 = reinterpret_tensor(buf0, (2, 2048, 1), (2048, 1, 1), 0); del buf0 # reuse + # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul] + stream2 = get_raw_stream(2) + triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg0_1, arg1_1, arg2_1, 4096, 151936, stream=stream2) + del arg0_1 + del arg1_1 + del arg2_1 + return (buf1, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((2, 2048, 151936), (311164928, 151936, 1), device='cuda:2', dtype=torch.bfloat16) + arg1_1 = rand_strided((151936, ), (1, ), device='cuda:2', dtype=torch.bool) + arg2_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:2', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1, arg2_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/p6/cp66nvdwdzgxajxp2yjtqapnwidpmfnzcyyalh6z5w6f6lf3aoej.py b/SpecForge-ext/cache/compiled_kernels/p6/cp66nvdwdzgxajxp2yjtqapnwidpmfnzcyyalh6z5w6f6lf3aoej.py new file mode 100644 index 0000000000000000000000000000000000000000..93d0e99beb6c75445409d955903ea48b37d5f54e --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/p6/cp66nvdwdzgxajxp2yjtqapnwidpmfnzcyyalh6z5w6f6lf3aoej.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/qa/cqambnamuby4hynvyzhccuoc4f5nkvwpn7yeizvaaaojnmlep42d.py b/SpecForge-ext/cache/compiled_kernels/qa/cqambnamuby4hynvyzhccuoc4f5nkvwpn7yeizvaaaojnmlep42d.py new file mode 100644 index 0000000000000000000000000000000000000000..71ce05b5dcc38dd925f2debd674add3334b232eb --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qa/cqambnamuby4hynvyzhccuoc4f5nkvwpn7yeizvaaaojnmlep42d.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': 'fp64', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tmp1 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp9 = in_ptr1 + tmp6 = ks0 + tmp7 = tmp6.to(tl.float32) + tmp8 = (tmp4 / tmp7) + tmp10 = tmp9.to(tl.float32) + tmp11 = tmp8 + tmp10 + tmp12 = libdevice.rsqrt(tmp11) + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp13 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp14 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp15 = tmp14.to(tl.float32) + tmp16 = tmp15 * tmp12 + tmp17 = tmp16.to(tl.float32) + tmp18 = tmp13 * tmp17 + tl.store(out_ptr0 + (r0_1 + ks0*x0), tmp18, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/qa/cqasclcikvb2uryr7k2gtwdnliae55wql22q6kutfmldlk5e7kks.py b/SpecForge-ext/cache/compiled_kernels/qa/cqasclcikvb2uryr7k2gtwdnliae55wql22q6kutfmldlk5e7kks.py new file mode 100644 index 0000000000000000000000000000000000000000..7c31049b50c34ae362dd1e051c839f8ba4ec7cba --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qa/cqasclcikvb2uryr7k2gtwdnliae55wql22q6kutfmldlk5e7kks.py @@ -0,0 +1,41 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2, 'r0_': 8192}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 131072}} +) +@triton.jit +def triton_red_fused_sum_3(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2 + r0_numel = 8192 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp3 = _tmp2 + tmp1 + _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2) + tmp2 = tl.sum(_tmp2, 1)[:, None] + tl.store(out_ptr0 + (x0), tmp2, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/qd/cqd6lffrumnqrtflwfoqtqs6mvn23l4bxialovx3yvqgximtpflz.py b/SpecForge-ext/cache/compiled_kernels/qd/cqd6lffrumnqrtflwfoqtqs6mvn23l4bxialovx3yvqgximtpflz.py new file mode 100644 index 0000000000000000000000000000000000000000..dc961520117cd0c91b04bd167e63f28db3daf7e8 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qd/cqd6lffrumnqrtflwfoqtqs6mvn23l4bxialovx3yvqgximtpflz.py @@ -0,0 +1,66 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x3 = xindex + x1 = ((xindex // ks0) % ks1) + tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32) + tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last') + tmp0 = x0 + tmp1 = ks0 // 2 + tmp2 = tmp0 >= tmp1 + tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp5 = tl.broadcast_to(ks2, [XBLOCK]) + tmp6 = tmp4 + tmp5 + tmp7 = tmp4 < 0 + tmp8 = tl.where(tmp7, tmp6, tmp4) + tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2") + tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp11 = tmp3 * tmp10 + tmp12 = -tmp11 + tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype) + tmp14 = tl.where(tmp2, tmp12, tmp13) + tmp15 = 0.0 + tmp16 = tl.where(tmp2, tmp14, tmp15) + tmp17 = tmp0 < tmp1 + tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0) + tmp20 = tl.broadcast_to(ks2, [XBLOCK]) + tmp21 = tmp19 + tmp20 + tmp22 = tmp19 < 0 + tmp23 = tl.where(tmp22, tmp21, tmp19) + tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2") + tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp26 = tmp18 * tmp25 + tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype) + tmp28 = tl.where(tmp17, tmp26, tmp27) + tmp29 = tl.where(tmp17, tmp28, tmp15) + tmp30 = tmp16 + tmp29 + tmp33 = ks3 + tmp34 = tmp32 + tmp33 + tmp35 = tmp32 < 0 + tmp36 = tl.where(tmp35, tmp34, tmp32) + tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3") + tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32) + tmp39 = tmp31 * tmp38 + tmp40 = tmp30 + tmp39 + tl.store(out_ptr0 + (x3), tmp40, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/qd/cqd7l2ktsaxhv4w2pgoiwvrihj6ya2rmzfvnjybryke4aa6nwpjp.py b/SpecForge-ext/cache/compiled_kernels/qd/cqd7l2ktsaxhv4w2pgoiwvrihj6ya2rmzfvnjybryke4aa6nwpjp.py new file mode 100644 index 0000000000000000000000000000000000000000..2a78e1f23130a47c27a5598fd654349447291ac6 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qd/cqd7l2ktsaxhv4w2pgoiwvrihj6ya2rmzfvnjybryke4aa6nwpjp.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 1, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}} +) +@triton.jit +def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 1 + r0_numel = 4096 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp2 = tmp0 == tmp1 + tmp3 = tmp2.to(tl.int64) + tmp5 = tmp3 * tmp4 + tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK]) + tmp8 = _tmp7 + tmp6 + _tmp7 = tl.where(r0_mask, tmp8, _tmp7) + tmp7 = tl.sum(_tmp7, 1)[:, None] + _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_0 = r0_index + tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0) + tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12 = _tmp11 + tmp10 + _tmp11 = tl.where(r0_mask, tmp12, _tmp11) + tmp11 = tl.sum(_tmp11, 1)[:, None] + tmp13 = tmp7.to(tl.float32) + tmp14 = tmp11.to(tl.float32) + tmp15 = 1e-06 + tmp16 = triton_helpers.maximum(tmp14, tmp15) + tmp17 = (tmp13 / tmp16) + tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None) diff --git a/SpecForge-ext/cache/compiled_kernels/qd/cqdrwhn77m6nklfgiah2c7si5tpazqrvft6wqlk3wwicj66cog5s.py b/SpecForge-ext/cache/compiled_kernels/qd/cqdrwhn77m6nklfgiah2c7si5tpazqrvft6wqlk3wwicj66cog5s.py new file mode 100644 index 0000000000000000000000000000000000000000..702ceeba813e3a5750d3ceff33a3abb21e673750 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qd/cqdrwhn77m6nklfgiah2c7si5tpazqrvft6wqlk3wwicj66cog5s.py @@ -0,0 +1,307 @@ +# AOT ID: ['4_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2j/c2jyvidugg4t2zvjimwjrb4yacpc5zz5qifflapqv3x2b34cxuq7.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat => cat +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# mul => mul_24 +# mul_1 => mul_45 +# neg => neg +# q_embed => add_54 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1 => slice_1 +# x2 => slice_2 +# Graph fragment: +# %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:3" = PlaceHolder[target=primals_12] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:3" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {}) +# %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {}) +# %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {}) +# %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {}) +# %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {}) +# %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {}) +# return %add_54 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 67108864}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/i7/ci72wbomeqiqrinpf2bqkd3bkzlans6x5wsg36itkz6xlzcsenoc.py +# Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] +# Source node to ATen node mapping: +# cat_1 => cat_1 +# cos => squeeze_1 +# cos_1 => unsqueeze +# getitem => index +# getitem_1 => index_1 +# k_embed => add_90 +# mul_2 => mul_54 +# mul_3 => mul_75 +# neg_1 => neg_1 +# sin => squeeze_3 +# sin_1 => unsqueeze_1 +# squeeze => squeeze +# squeeze_2 => squeeze_2 +# x1_1 => slice_3 +# x2_1 => slice_4 +# Graph fragment: +# %primals_13 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:3" = PlaceHolder[target=primals_13] +# %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:3" = PlaceHolder[target=primals_8] +# %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_6] +# %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {}) +# %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {}) +# %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {}) +# %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {}) +# %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {}) +# %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {}) +# %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {}) +# %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {}) +# %mul_54 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_13, %unsqueeze), kwargs = {}) +# %slice_3 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24, s24*s48, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, 0, %floordiv), kwargs = {}) +# %slice_4 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24, s24*s48, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, %floordiv, 9223372036854775807), kwargs = {}) +# %neg_1 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s48*Max(1, s24 - ((s24//2))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {}) +# %cat_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {}) +# %mul_75 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {}) +# %add_90 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {}) +# return %add_90 +triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13 = args + args.clear() + s92 = primals_1 + s24 = primals_2 + s96 = primals_3 + s79 = primals_5 + s9 = primals_7 + s38 = primals_9 + s48 = primals_10 + s34 = primals_11 + assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1)) + assert_size_stride(primals_8, (1, s9), (s9, 1)) + assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1)) + assert_size_stride(primals_13, (s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1)) + with torch.cuda._DeviceGuard(3): + torch.cuda.set_device(3) + ps0 = s24*s34 + buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9 + stream3 = get_raw_stream(3) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream3) + del primals_12 + ps1 = s24*s48 + buf1 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add] + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s9*s48*s48 + stream3 = get_raw_stream(3) + triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_13, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream3) + del primals_13 + return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2048 + primals_2 = 128 + primals_3 = 5245440 + primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_5 = 2048 + primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:3', dtype=torch.bfloat16) + primals_7 = 2048 + primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:3', dtype=torch.int64) + primals_9 = 1 + primals_10 = 8 + primals_11 = 32 + primals_12 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:3', dtype=torch.bfloat16) + primals_13 = rand_strided((8, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:3', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/qi/cqigiagiijtl5qumiw7vx2folxswo5fdvjviiovtnhm2p3m52mfh.py b/SpecForge-ext/cache/compiled_kernels/qi/cqigiagiijtl5qumiw7vx2folxswo5fdvjviiovtnhm2p3m52mfh.py new file mode 100644 index 0000000000000000000000000000000000000000..93b7df7823c85deea3c60c1152ba2ff1957b1428 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qi/cqigiagiijtl5qumiw7vx2folxswo5fdvjviiovtnhm2p3m52mfh.py @@ -0,0 +1,682 @@ +# AOT ID: ['12_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ha/chay7klvou5rg46wg7dr724knn4epmemt53vvv5ezibh5ws7fezs.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/rg/crg6vc2xcjpmdh6c45caj2rp25etmdgbqamntmrheszjuyjm7p6x.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_2, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_24 +# full_blocks => eq_45 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_24, view_7 +# suffix_mask => ge_3 +# Graph fragment: +# %arg2_1 : Tensor "i64[2][1]cuda:4" = PlaceHolder[target=arg2_1] +# %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:4" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %iota_2 : Tensor "i64[s12][1]cuda:4"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %view : Tensor "i64[s12, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:4"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg1_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %ge_2 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[2][1]cuda:4"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %index : Tensor "i64[2][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[2, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {}) +# %lt : Tensor "b8[2, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, %arg1_1]), kwargs = {}) +# %index_1 : Tensor "i64[2][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[2, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[2, s12][Max(1, s12), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[2, s12, 1][Max(1, s12), 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, %arg0_1, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_3 : Tensor "b8[s37][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, %arg3_1), kwargs = {}) +# %index_2 : Tensor "i64[2][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[2, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_3, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, %arg1_1]), kwargs = {}) +# %view_7 : Tensor "i64[s12, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {}) +# %sub_24 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_24, %arg3_1), kwargs = {}) +# %eq_24 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_24), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[2, 1, s12, s37][Max(1, s12)*Max(1, s37), s12*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, %arg0_1, %arg1_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[2, 1, 128*(((s12 + 127)//128)), 128*(((s37 + 127)//128))][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_42, 0, %sub_44], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[2, 1, ((s12 + 127)//128), 128, ((s37 + 127)//128), 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [2, 1, %floordiv_3, 128, %floordiv_2, 128]), kwargs = {}) +# %permute : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128), 128, 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_45 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_45, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 512, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // ks0) % ks1) + x0 = (xindex % ks0) + x2 = xindex // ks4 + _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = ks2 + tmp2 = tmp0 < tmp1 + tmp3 = r0_3 + 128*x0 + tmp4 = ks3 + tmp5 = tmp3 < tmp4 + tmp6 = tmp2 & tmp5 + tmp7 = r0_4 + 128*x1 + tmp8 = r0_3 + 128*x0 + tmp9 = tmp7 >= tmp8 + tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0) + tmp11 = tmp8 < tmp10 + tmp12 = tmp7 < tmp10 + tmp13 = tmp11 & tmp12 + tmp14 = tmp9 & tmp13 + tmp15 = tl.full([1, 1], False, tl.int1) + tmp16 = tmp15 | tmp14 + tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK]) + tmp18 = tmp8 >= tmp17 + tmp19 = (tmp8 % tmp17) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp17 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tmp27 < tmp10 + tmp29 = tmp18 & tmp28 + tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp31 = (tmp30 % tmp17) + tmp32 = tmp31 != tmp20 + tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0 + tmp34 = tmp33 != tmp23 + tmp35 = tmp32 & tmp34 + tmp36 = tmp31 + tmp17 + tmp37 = tl.where(tmp35, tmp36, tmp31) + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp37 == tmp38 + tmp40 = tmp29 & tmp39 + tmp41 = tmp16 | tmp40 + tmp42 = tl.full(tmp41.shape, False, tmp41.dtype) + tmp43 = tl.where(tmp6, tmp41, tmp42) + tmp44 = tmp43.to(tl.int64) + tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) + tmp47 = _tmp46 + tmp45 + _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46) + tmp46 = tl.sum(_tmp46, 1)[:, None] + tmp48 = tl.full([1, 1], 0, tl.int64) + tmp49 = tmp46 > tmp48 + tmp50 = tl.full([1, 1], 16384, tl.int64) + tmp51 = tmp46 < tmp50 + tmp52 = tmp49 & tmp51 + tmp53 = tmp52.to(tl.int8) + tmp54 = tmp53.to(tl.int32) + tmp55 = tmp46 == tmp50 + tmp56 = tmp55.to(tl.int8) + tmp57 = tmp56.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp54, xmask) + tl.store(out_ptr2 + (x5), tmp57, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/nx/cnxyzzrtkb2x53y2inkm35xscgdtoilsp4ahxz3aomx4a2ng4rih.py +# Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# child_3 => convert_element_type_3 +# num_blocks_in_row => sum_2 +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:4" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][((s12 + 127)//128), 2*(((s12 + 127)//128)), 1]cuda:4" = PlaceHolder[target=sum_2] +# %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# return %sum_2,%convert_element_type_3 +triton_red_fused__to_copy_sum_2 = async_compile.triton('triton_red_fused__to_copy_sum_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + x2 = (xindex % ks1) + x3 = xindex // ks1 + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mm/cmm6gxzvdt4w4heysmumz7etupuzsyhhinqrbrb7e7uar4xfr27g.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %getitem_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 2*Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:4" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:4" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %iota_7 : Tensor "i64[2][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[((s12 + 127)//128)][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_3,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:4, requires_grad: False}) +# %unsqueeze : Tensor "i32[((s12 + 127)//128), 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_2,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:4, requires_grad: False}) +# %unsqueeze_1 : Tensor "i32[2, 1, ((s12 + 127)//128), 1][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_2,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4}) +# %where : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %convert_element_type_4,%buf13 +triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3 = async_compile.triton('triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + x3 = xindex // ks0 + tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tmp3 = x0 + tmp4 = tmp3 < tmp2 + tmp5 = ks0 + tmp6 = tl.where(tmp4, tmp1, tmp5) + tmp7 = 1 + ks0 + tmp8 = tmp6 + tmp7 + tmp9 = tmp6 < 0 + tmp10 = tl.where(tmp9, tmp8, tmp6) + tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))") + tmp12 = tl.full([1], 1, tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) + tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zi/cziyubt6rcqh2lzixzelkpzlqiok7tnzuj6gxfrcy6icq544ymos.py +# Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# Graph fragment: +# %buf13 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:4" = PlaceHolder[target=buf13] +# %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_4 +triton_poi_fused_clone_slice_4 = async_compile.triton('triton_poi_fused_clone_slice_4', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = xindex // ks0 + x2 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last') + tl.store(out_ptr0 + (x2), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zr/czrritdkatkiq6gisucy3tyfojxxy2il47j3sgudvr7n3phywdxk.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_4 +# num_blocks_in_row_2 => sum_4 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 1, ((s37 + 127)//128), 1]cuda:4" = PlaceHolder[target=clone_4] +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][((s37 + 127)//128), 2*(((s37 + 127)//128)), 1]cuda:4" = PlaceHolder[target=sum_4] +# %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {}) +# %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %sum_4,%convert_element_type_8 +triton_red_fused__to_copy_clone_slice_sum_transpose_5 = async_compile.triton('triton_red_fused__to_copy_clone_slice_sum_transpose_5', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = xindex // ks0 + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/le/clehbyg4yeecmekf5ytdyw6cumf5bexyvwy4tg4l6tlcizrncjzu.py +# Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] +# Source node to ATen node mapping: +# q_indices => clone_6, convert_element_type_9 +# Graph fragment: +# %getitem_5 : Tensor "i64[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:4" = PlaceHolder[target=getitem_5] +# %convert_element_type_9 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# return %clone_6 +triton_poi_fused__to_copy_6 = async_compile.triton('triton_poi_fused__to_copy_6', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 512}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % ks1) + x2 = xindex // ks2 + tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last') + tmp1 = tmp0.to(tl.int32) + tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1, arg2_1, arg3_1 = args + args.clear() + s12 = arg0_1 + s37 = arg1_1 + s21 = arg3_1 + assert_size_stride(arg2_1, (2, ), (1, )) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf12 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream4) + buf21 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused_new_zeros_0.run(buf21, triton_poi_fused_new_zeros_0_xnumel, stream=stream4) + ps0 = (127 + s37) // 128 + ps1 = (127 + s12) // 128 + ps2 = ((127 + s12) // 128)*((127 + s37) // 128) + buf1 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg2_1, buf1, buf5, ps0, ps1, s12, s37, ps2, s21, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream4) + del arg2_1 + buf10 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_sum_2.run(buf1, buf10, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream4) + buf19 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [num_blocks_in_row_1, child_7], Original ATen: [aten.sum, aten._to_copy] + triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128) + triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_sum_2.run(buf5, buf19, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream4) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + del buf1 + buf4 = buf2[1] + assert_size_stride(buf4, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf11 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf4, buf10, buf11, buf12, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream4) + del buf4 + buf14 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused_clone_slice_4.run(buf12, buf14, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream4) + del buf12 + buf32 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf14, buf32, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream4) + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + del buf5 + buf8 = buf6[1] + assert_size_stride(buf8, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf20 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf8, buf19, buf20, buf21, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream4) + del buf8 + buf23 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5], Original ATen: [aten.slice, aten.clone] + triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused_clone_slice_4.run(buf21, buf23, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream4) + del buf21 + buf29 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy] + triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128) + triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128 + stream4 = get_raw_stream(4) + triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf23, buf29, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream4) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf15 = torch.ops.aten.sort.stable(reinterpret_tensor(buf14, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf14 + buf17 = buf15[1] + assert_size_stride(buf17, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf17, 16, 'torch.ops.aten.sort.stable') + del buf15 + buf30 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused__to_copy_6.run(buf17, buf30, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream4) + del buf17 + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort] + buf24 = torch.ops.aten.sort.stable(reinterpret_tensor(buf23, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True) + del buf23 + buf26 = buf24[1] + assert_size_stride(buf26, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable') + assert_alignment(buf26, 16, 'torch.ops.aten.sort.stable') + del buf24 + buf27 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [full_q_indices], Original ATen: [aten._to_copy] + triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128) + stream4 = get_raw_stream(4) + triton_poi_fused__to_copy_6.run(buf26, buf27, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream4) + del buf26 + return (buf27, buf29, buf30, buf32, buf20, buf19, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 1543 + arg1_1 = 1543 + arg2_1 = rand_strided((2, ), (1, ), device='cuda:4', dtype=torch.int64) + arg3_1 = 1543 + fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/qi/cqijrvw4yd4k5npdnrx3v4xdz5kbluaucpendr36k5wzgkkh4oky.py b/SpecForge-ext/cache/compiled_kernels/qi/cqijrvw4yd4k5npdnrx3v4xdz5kbluaucpendr36k5wzgkkh4oky.py new file mode 100644 index 0000000000000000000000000000000000000000..1107385145f00756533a727c1c1d191a34f6c5f5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qi/cqijrvw4yd4k5npdnrx3v4xdz5kbluaucpendr36k5wzgkkh4oky.py @@ -0,0 +1,48 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 32768}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 524288000}} +) +@triton.jit +def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 4096 + r0_numel = 32000 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = xindex // 2048 + _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32) + _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32) + x3 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0) + tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index( + _tmp2, _tmp2_index, tmp1, rindex + ) + _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2) + _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index) + tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1) + tmp2 = tmp2_idx[:, None] + tl.store(out_ptr0 + (x3), tmp2, None) diff --git a/SpecForge-ext/cache/compiled_kernels/qt/cc56f06d9cec79015636d209f8138e959cda486e20060a24adb69d3b74b7bee6.best_config b/SpecForge-ext/cache/compiled_kernels/qt/cc56f06d9cec79015636d209f8138e959cda486e20060a24adb69d3b74b7bee6.best_config new file mode 100644 index 0000000000000000000000000000000000000000..b9c83cd70cc4f7d46eca037549afe001d843ad6c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qt/cc56f06d9cec79015636d209f8138e959cda486e20060a24adb69d3b74b7bee6.best_config @@ -0,0 +1 @@ +{"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 49, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"} \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/qt/cqtv2hjbuijyx7awch534sanohmqs6reawit6ksar4ud36qn7xhy.py b/SpecForge-ext/cache/compiled_kernels/qt/cqtv2hjbuijyx7awch534sanohmqs6reawit6ksar4ud36qn7xhy.py new file mode 100644 index 0000000000000000000000000000000000000000..6bb0ae57ef75b7f15e9ad6e8b927d237fda93237 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/qt/cqtv2hjbuijyx7awch534sanohmqs6reawit6ksar4ud36qn7xhy.py @@ -0,0 +1,56 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 16777216}, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x4 = xindex + x2 = ((xindex // ks0) % ks1) + x0 = (xindex % ks3) + x5 = xindex // ks3 + tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32) + tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last') + tmp2 = ks2 + tmp3 = tmp1 + tmp2 + tmp4 = tmp1 < 0 + tmp5 = tl.where(tmp4, tmp3, tmp1) + tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2") + tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32) + tmp8 = tmp0 * tmp7 + tmp9 = x0 + tmp10 = tl.full([1], 0, tl.int64) + tmp11 = tmp9 >= tmp10 + tmp12 = ks3 + (-1)*(ks3 // 2) + tmp13 = tmp9 < tmp12 + tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp15 = -tmp14 + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) + tmp17 = tl.where(tmp13, tmp15, tmp16) + tmp18 = tmp9 >= tmp12 + tmp19 = ks3 + tmp20 = tmp9 < tmp19 + tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp22 = tl.where(tmp13, tmp17, tmp21) + tmp23 = ks4 + tmp24 = tmp1 + tmp23 + tmp25 = tl.where(tmp4, tmp24, tmp1) + tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4") + tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32) + tmp28 = tmp22 * tmp27 + tmp29 = tmp8 + tmp28 + tl.store(out_ptr0 + (x4), tmp29, xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/s3/cs3yasiv6lrv3gv7zuav4rb3xghwnuf7k2xh3nhv5aukawhau2k2.py b/SpecForge-ext/cache/compiled_kernels/s3/cs3yasiv6lrv3gv7zuav4rb3xghwnuf7k2xh3nhv5aukawhau2k2.py new file mode 100644 index 0000000000000000000000000000000000000000..463643dab32f80e955899cd93a0f72f91df0605c --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/s3/cs3yasiv6lrv3gv7zuav4rb3xghwnuf7k2xh3nhv5aukawhau2k2.py @@ -0,0 +1,552 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/tn/ctn4b3cztt4sougzbjeitojp6weato3jslns5gpsn7gd7lmbw4lr.py b/SpecForge-ext/cache/compiled_kernels/tn/ctn4b3cztt4sougzbjeitojp6weato3jslns5gpsn7gd7lmbw4lr.py new file mode 100644 index 0000000000000000000000000000000000000000..067fc46c9d039229a2b19664f932f23c4c6543ff --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/tn/ctn4b3cztt4sougzbjeitojp6weato3jslns5gpsn7gd7lmbw4lr.py @@ -0,0 +1,543 @@ +# AOT ID: ['5_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mu/cmujymfqcpdztl6mvthtg7e3oyr4wtaoz6javbq3nk2aj4dhshhs.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge, view +# diagnol_mask => eq +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# m => iota_2 +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub, view_7 +# suffix_mask => ge_1 +# Graph fragment: +# %arg0_1 : Tensor "i64[8][1]cuda:4" = PlaceHolder[target=arg0_1] +# %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:4"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[2048][1]cuda:4"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %ge : Tensor "b8[2048, 2048][2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[8][1]cuda:4"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %index : Tensor "i64[8][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[8, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {}) +# %lt : Tensor "b8[8, 2048][2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, 2048]), kwargs = {}) +# %index_1 : Tensor "i64[8][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[8, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_1 : Tensor "b8[2048][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[2048][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[8][1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[8, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[8, 2048][2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[8, 2048][2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, 2048]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub : Tensor "i64[2048, 2048][2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, 2048][2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub, 2048), kwargs = {}) +# %eq : Tensor "b8[2048, 2048][2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[8, 1, 2048, 2048][4194304, 4194304, 2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, 2048]), kwargs = {}) +# %view_10 : Tensor "b8[8, 1, 16, 128, 16, 128][4194304, 4194304, 262144, 2048, 128, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [8, 1, 16, 128, 16, 128]), kwargs = {}) +# %permute : Tensor "b8[8, 1, 16, 16, 128, 128][4194304, 4194304, 262144, 128, 2048, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# return %sum_1 +triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2048 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/qr/cqrok7demcrbt3yh6rmj2bttpopxhcu4l237sc3deioytctzun6e.py +# Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_4 => full_default_4 +# Graph fragment: +# %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# return %index_put_1 +triton_poi_fused_new_zeros_1 = async_compile.triton('triton_poi_fused_new_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 2176 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/jt/cjtvcwf2usd2xtmomczzp2mogjtphmsmqtta6fceusmpjkttojhx.py +# Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# arange_6 => iota_8 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# child_7 => convert_element_type_6 +# child_8 => convert_element_type_7 +# col_indices => sort +# col_indices_1 => sort_1 +# col_range => iota_5 +# col_range_1 => iota_9 +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# dense_mask_2 => full_default_1 +# dense_mask_4 => full_default_4 +# full_blocks => eq_1 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index_mask => lt_4 +# index_mask_1 => lt_5 +# lt_3 => lt_3 +# num_blocks_in_row => sum_2 +# num_blocks_in_row_1 => sum_3 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# row_indices => unsqueeze +# row_indices_1 => unsqueeze_7 +# setitem => full_default_3, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# setitem_1 => full_default_6, index_put_1, iota_10, iota_11, unsqueeze_10, unsqueeze_11, unsqueeze_12, unsqueeze_13, unsqueeze_9 +# unsqueeze_1 => unsqueeze_1 +# unsqueeze_3 => unsqueeze_8 +# valid_indices => full_default_2, where +# valid_indices_1 => full_default_5, where_1 +# Graph fragment: +# %sum_1 : Tensor "i64[8, 1, 16, 16][256, 2048, 16, 1]cuda:4" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:4" = PlaceHolder[target=sum_2] +# %sum_3 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:4" = PlaceHolder[target=sum_3] +# %buf2 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:4" = PlaceHolder[target=buf2] +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:4" = PlaceHolder[target=index_put] +# %buf4 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:4" = PlaceHolder[target=buf4] +# %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=convert_element_type_6] +# %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=convert_element_type_7] +# %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:4" = PlaceHolder[target=index_put_1] +# %gt : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_2,), kwargs = {stable: True, descending: True}) +# %eq_1 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_1, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_5,), kwargs = {stable: True, descending: True}) +# %full_default_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %iota_7 : Tensor "i64[8][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:4, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[16][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:4, requires_grad: False}) +# %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %full_default_2 : Tensor "i32[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %where : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %full_default_2), kwargs = {}) +# %full_default_3 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_3), kwargs = {}) +# %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %iota_11 : Tensor "i64[8][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %unsqueeze_11 : Tensor "i64[8, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_11, -1), kwargs = {}) +# %unsqueeze_12 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_11, -1), kwargs = {}) +# %unsqueeze_13 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_12, -1), kwargs = {}) +# %iota_10 : Tensor "i64[1][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:4, requires_grad: False}) +# %unsqueeze_9 : Tensor "i64[1, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_10, -1), kwargs = {}) +# %unsqueeze_10 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_9, -1), kwargs = {}) +# %iota_8 : Tensor "i32[16][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:4, requires_grad: False}) +# %unsqueeze_7 : Tensor "i32[16, 1][1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_8, -1), kwargs = {}) +# %iota_9 : Tensor "i32[16][1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:4, requires_grad: False}) +# %sum_3 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_5, [-1]), kwargs = {}) +# %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.int32), kwargs = {}) +# %unsqueeze_8 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_6, 3), kwargs = {}) +# %lt_5 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_9, %unsqueeze_8), kwargs = {}) +# %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_3, torch.int32), kwargs = {}) +# %full_default_5 : Tensor "i32[][]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %where_1 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_5, %convert_element_type_7, %full_default_5), kwargs = {}) +# %full_default_6 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:4, pin_memory: False}) +# %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_4, [%unsqueeze_13, %unsqueeze_10, %unsqueeze_7, %where_1], %full_default_6), kwargs = {}) +# return %buf2,%buf4,%sum_2,%sum_3,%convert_element_type_3,%convert_element_type_6,%convert_element_type_4,%buf9,%convert_element_type_7,%buf16 +triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 = async_compile.triton('triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4f/c4ft2b47ctfnp5zp5apvq5kvdlqubdrkzxpqndsh5oasyfr4v7y7.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf9 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:4" = PlaceHolder[target=buf9] +# %buf11 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:4" = PlaceHolder[target=buf11] +# %sum_4 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:4" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[8, 1, 16, 16][272, 272, 17, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, 16), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:4"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf11,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (8, ), (1, )) + with torch.cuda._DeviceGuard(4): + torch.cuda.set_device(4) + buf0 = empty_strided_cuda((8, 1, 16, 16), (256, 2048, 16, 1), torch.int64) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] + stream4 = get_raw_stream(4) + triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0.run(arg0_1, buf0, 2048, 16384, stream=stream4) + del arg0_1 + buf15 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + stream4 = get_raw_stream(4) + triton_poi_fused_new_zeros_1.run(buf15, 2176, stream=stream4) + buf8 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + stream4 = get_raw_stream(4) + triton_poi_fused_new_zeros_1.run(buf8, 2176, stream=stream4) + buf6 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf13 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf7 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf14 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + stream4 = get_raw_stream(4) + triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.run(buf0, buf6, buf13, buf7, buf8, buf14, buf15, 128, 16, stream=stream4) + del buf0 + buf22 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf24 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream4 = get_raw_stream(4) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf8, buf22, buf24, 128, 16, stream=stream4) + del buf8 + buf19 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf21 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream4 = get_raw_stream(4) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf15, buf19, buf21, 128, 16, stream=stream4) + del buf15 + return (buf19, buf21, buf22, buf24, buf14, buf13, buf7, buf6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, ), (1, ), device='cuda:4', dtype=torch.int64) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/tv/ctvkjvnixigkqmulzswzofaswixfqysp35ikbuwb2zyuiutvwlm4.py b/SpecForge-ext/cache/compiled_kernels/tv/ctvkjvnixigkqmulzswzofaswixfqysp35ikbuwb2zyuiutvwlm4.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb4ecf26d8719d94c85c75552d089d7697434e1 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/tv/ctvkjvnixigkqmulzswzofaswixfqysp35ikbuwb2zyuiutvwlm4.py @@ -0,0 +1,1051 @@ +# AOT ID: ['6_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/dq/cdqxxevdyssoyut2euw55y27cahqqcgmvyuhdihb4tmner7cfc7f.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[8, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 524288, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 524288 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/4d/c4dw6ykxgbwk3glacutxkpzwhapvr5oszjet3n4i4q3snumjzm3x.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:6" = PlaceHolder[target=primals_1] +# %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:6" = PlaceHolder[target=primals_2] +# %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:6" = PlaceHolder[target=primals_3] +# %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:6" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:6" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:6" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:6" = PlaceHolder[target=getitem_5] +# %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_5] +# %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=primals_4] +# %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_9] +# %primals_10 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=primals_10] +# %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_7] +# %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=primals_8] +# %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=primals_11] +# %primals_12 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:6" = PlaceHolder[target=primals_12] +# %primals_6 : Tensor "i64[8][1]cuda:6" = PlaceHolder[target=primals_6] +# %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_2, %primals_3, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 8 + KV_LEN = 2048 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 256 + stride_kv_idx_m = 16 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = 16 + stride_q_idx_h = 256 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = 2048 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = True + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1 = args + args.clear() + assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1)) + assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_6, (8, ), (1, )) + assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1)) + assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1)) + assert_size_stride(getitem, (8, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (8, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (8, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream6 = get_raw_stream(6) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 524288, 128, stream=stream6) + del getitem + buf3 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((8, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((8, 8, 2048, 128), (2097152, 262144, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream6 = get_raw_stream(6) + triton_tem_fused_zeros_1.run(primals_1, primals_2, primals_3, getitem_1, buf1, tangents_1, buf3, buf4, primals_5, primals_4, primals_9, primals_10, primals_7, primals_8, primals_11, primals_12, primals_6, buf5, 80, 8, 8, stream=stream6) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_12 + del primals_2 + del primals_3 + del primals_4 + del primals_5 + del primals_6 + del primals_7 + del primals_8 + del primals_9 + del tangents_1 + return (buf3, buf5, buf4, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:6', dtype=torch.bfloat16) + primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:6', dtype=torch.int32) + primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_6 = rand_strided((8, ), (1, ), device='cuda:6', dtype=torch.int64) + primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:6', dtype=torch.int32) + primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:6', dtype=torch.int32) + primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:6', dtype=torch.int32) + primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:6', dtype=torch.int32) + getitem = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:6', dtype=torch.bfloat16) + getitem_1 = rand_strided((8, 32, 2048), (65536, 2048, 1), device='cuda:6', dtype=torch.float32) + tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:6', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/tv/ctvn756awhqumwxc7nh6h7s3zs4orgewod2ud6h3ic4vw5kk5m4e.py b/SpecForge-ext/cache/compiled_kernels/tv/ctvn756awhqumwxc7nh6h7s3zs4orgewod2ud6h3ic4vw5kk5m4e.py new file mode 100644 index 0000000000000000000000000000000000000000..3a119b3e67bb36fdb751cb24449234e3cb6ec6ae --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/tv/ctvn756awhqumwxc7nh6h7s3zs4orgewod2ud6h3ic4vw5kk5m4e.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 8 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 8 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 8 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) diff --git a/SpecForge-ext/cache/compiled_kernels/vj/cvjzly27bwmukx4ax55pzamadufonbvzwq44ofqo7zyxiclgqpht.py b/SpecForge-ext/cache/compiled_kernels/vj/cvjzly27bwmukx4ax55pzamadufonbvzwq44ofqo7zyxiclgqpht.py new file mode 100644 index 0000000000000000000000000000000000000000..a7d9c03a1f07b5240a90558653addeac5dd64ddf --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vj/cvjzly27bwmukx4ax55pzamadufonbvzwq44ofqo7zyxiclgqpht.py @@ -0,0 +1,835 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks5 + stride_q_idx_h = ks6*ks7 + stride_q_idx_n = ks6 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks8 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = ks0 + KV_LEN = ks1 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = ks8 + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) \ No newline at end of file diff --git a/SpecForge-ext/cache/compiled_kernels/vr/cvrhnrmpgyxwu34xleclee3tt4kemoldkj7iam4uciathomirvlc.py b/SpecForge-ext/cache/compiled_kernels/vr/cvrhnrmpgyxwu34xleclee3tt4kemoldkj7iam4uciathomirvlc.py new file mode 100644 index 0000000000000000000000000000000000000000..d3227baa2f9fdaa171d8e09c1c3ba5fa2d4b5aca --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/vr/cvrhnrmpgyxwu34xleclee3tt4kemoldkj7iam4uciathomirvlc.py @@ -0,0 +1,62 @@ + +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 32, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 32 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) diff --git a/SpecForge-ext/cache/compiled_kernels/x4/cx4luh7vlzc5hmgf2bp7kzhqtfs6yqj2hx4tcpnwimsa4lfyolmy.py b/SpecForge-ext/cache/compiled_kernels/x4/cx4luh7vlzc5hmgf2bp7kzhqtfs6yqj2hx4tcpnwimsa4lfyolmy.py new file mode 100644 index 0000000000000000000000000000000000000000..a288f28cdb672aa6490836f16e0f7acdcd885ad2 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/x4/cx4luh7vlzc5hmgf2bp7kzhqtfs6yqj2hx4tcpnwimsa4lfyolmy.py @@ -0,0 +1,1065 @@ +# AOT ID: ['9_backward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/5y/c5youjzyi3z3ynjz75h25htk6unkxftgtdsn4apk4b37ykfisbjl.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %getitem : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:5" = PlaceHolder[target=getitem] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %buf0 : Tensor "bf16[2, 32, 2048][65536, 2048, 1]cuda:5" = PlaceHolder[target=buf0] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %buf0,%buf1 +triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 131072, 'r0_': 128}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1048576, 'r0_': 67108864}} +) +@triton.jit +def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 131072 + r0_numel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % 2048) + x1 = ((xindex // 2048) % 32) + x2 = xindex // 65536 + x4 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp2 = tmp0 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp6 = tmp4.to(tl.float32) + tmp7 = 0.0 + tmp8 = tmp6 - tmp7 + tl.store(out_ptr1 + (x4), tmp8, None) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/si/csimgsywujh2hsabqlfrkiolvsb67iqr7fvv654a3o76p2cedho4.py +# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] +# Source node to ATen node mapping: +# Graph fragment: +# %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_1] +# %primals_3 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:5" = PlaceHolder[target=primals_3] +# %primals_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:5" = PlaceHolder[target=primals_5] +# %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:5" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:5" = PlaceHolder[target=buf1] +# %tangents_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 262144, 128, 1]cuda:5" = PlaceHolder[target=tangents_1] +# %getitem_3 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:5" = PlaceHolder[target=getitem_3] +# %getitem_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:5" = PlaceHolder[target=getitem_5] +# %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=primals_9] +# %primals_7 : Tensor "i32[2, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:5" = PlaceHolder[target=primals_7] +# %primals_15 : Tensor "i32[2, 1, s56][s56, s56, 1]cuda:5" = PlaceHolder[target=primals_15] +# %primals_17 : Tensor "i32[2, 1, s84, 16][16*s84, 16*s84, 16, 1]cuda:5" = PlaceHolder[target=primals_17] +# %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=primals_11] +# %primals_13 : Tensor "i32[2, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:5" = PlaceHolder[target=primals_13] +# %primals_19 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:5" = PlaceHolder[target=primals_19] +# %primals_21 : Tensor "i32[2, 1, s6, 16][16*s6, 16*s6, 16, 1]cuda:5" = PlaceHolder[target=primals_21] +# %primals_10 : Tensor "i64[2][1]cuda:5" = PlaceHolder[target=primals_10] +# %full_default : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {}) +# return %getitem_4 +triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + DELTA = arg_DELTA + DO = arg_DO + DQ = arg_DQ + DV = arg_DV + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + Q_NUM_BLKS = arg_Q_NUM_BLKS + Q_IDX = arg_Q_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS + FULL_Q_IDX = arg_FULL_Q_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1 + stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1 + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1 + stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1 + + ZQ = 2 + HQ = 32 + HKV = 8 + Q_LEN = 2048 + ZKV = 2 + KV_LEN = ks0 + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0).to(INDEX_DTYPE) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx + off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = 16 + stride_kv_idx_h = 16*ks1 + stride_kv_idx_m = ks1 + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = ks2 + stride_q_idx_h = 16*ks3 + stride_q_idx_n = 16 + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED]) + xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask) + +@triton.jit +def bwd_dq_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary + m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr16 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = tl.full([1], 2048, tl.int32) + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp39 = (ds) + grad_scores = tmp39 + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = 2048 + KV_LEN = ks0 + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + + # The minimum is needed to handle the case where we run with a super large + # SPARSE_BLOCK_SIZE (i.e. no block-mask!) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, +): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + BLOCK_M1 : tl.constexpr = 64 + BLOCK_N1 : tl.constexpr = 128 + BLOCK_M2 : tl.constexpr = 128 + BLOCK_N2 : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary + n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None) + + pre_mod_scores = qkT + tmp40 = (qkT) + post_mod_scores = tmp40 + + + + if not IS_DIVISIBLE: + post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp41 = tl.full([1], False, tl.int1) + tmp42 = (m) + tmp43 = (n) + tmp44 = tmp42 >= tmp43 + tmp45 = tmp43.to(tl.int64) + tmp46 = (off_z) + tmp47 = tl.load(in_ptr16 + tmp46) + tmp48 = tmp45 < tmp47 + tmp49 = tmp42.to(tl.int64) + tmp50 = tmp49 < tmp47 + tmp51 = tmp48 & tmp50 + tmp52 = tmp44 & tmp51 + tmp53 = tmp41 | tmp52 + tmp54 = tl.full([1], 2048, tl.int32) + tmp55 = tmp43 >= tmp54 + tmp56 = (tmp43 % tmp54) + tmp57 = tl.full([1], 0, tl.int32) + tmp58 = tmp56 != tmp57 + tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0 + tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0 + tmp61 = tmp59 != tmp60 + tmp62 = tmp58 & tmp61 + tmp63 = tmp56 + tmp54 + tmp64 = tl.where(tmp62, tmp63, tmp56) + tmp65 = tmp64.to(tl.int64) + tmp66 = tmp65 < tmp47 + tmp67 = tmp55 & tmp66 + tmp68 = tmp43 - tmp42 + tmp69 = (tmp68 % tmp54) + tmp70 = tmp69 != tmp57 + tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0 + tmp72 = tmp71 != tmp60 + tmp73 = tmp70 & tmp72 + tmp74 = tmp69 + tmp54 + tmp75 = tl.where(tmp73, tmp74, tmp69) + tmp76 = tmp75 == tmp57 + tmp77 = tmp67 & tmp76 + tmp78 = tmp53 | tmp77 + mask_mod_output = tmp78 + + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + tmp79 = (dsT) + grad_scores = tmp79 + + + + if not IS_DIVISIBLE: + grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dsT = grad_scores + if not IS_FULL_BLOCKS: + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1 = args + args.clear() + s0 = primals_8 + s72 = primals_6 + s4 = primals_12 + s56 = primals_14 + s84 = primals_16 + s99 = primals_18 + s6 = primals_20 + assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(primals_3, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_5, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_7, (2, 1, 16, s72), (16*s72, 16*s72, s72, 1)) + assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_10, (2, ), (1, )) + assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1)) + assert_size_stride(primals_13, (2, 1, 16, s4), (16*s4, 16*s4, s4, 1)) + assert_size_stride(primals_15, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_17, (2, 1, s84, 16), (16*s84, 16*s84, 16, 1)) + assert_size_stride(primals_19, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_21, (2, 1, s6, 16), (16*s6, 16*s6, 16, 1)) + assert_size_stride(getitem, (2, 32, 2048, 128), (8388608, 128, 4096, 1)) + assert_size_stride(getitem_1, (2, 32, 2048), (65536, 2048, 1)) + assert_size_stride(tangents_1, (2, 32, 2048, 128), (8388608, 262144, 128, 1)) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream5 = get_raw_stream(5) + triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 131072, 128, stream=stream5) + del getitem + buf3 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16) + buf4 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + buf5 = empty_strided_cuda((2, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] + stream5 = get_raw_stream(5) + triton_tem_fused_zeros_1.run(primals_1, primals_3, primals_5, getitem_1, buf1, tangents_1, buf3, buf4, primals_9, primals_7, primals_15, primals_17, primals_11, primals_13, primals_19, primals_21, primals_10, buf5, s0, s72, s56, s84, 64 + ((127 + s0) // 128), 2, 8, stream=stream5) + del buf1 + del getitem_1 + del primals_1 + del primals_10 + del primals_11 + del primals_13 + del primals_15 + del primals_17 + del primals_19 + del primals_21 + del primals_3 + del primals_5 + del primals_7 + del primals_9 + del tangents_1 + return (buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_8 = 4096 + primals_6 = 32 + primals_12 = 32 + primals_14 = 32 + primals_16 = 32 + primals_18 = 32 + primals_20 = 32 + primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_3 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_5 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:5', dtype=torch.bfloat16) + primals_7 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:5', dtype=torch.int32) + primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_10 = rand_strided((2, ), (1, ), device='cuda:5', dtype=torch.int64) + primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32) + primals_13 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:5', dtype=torch.int32) + primals_15 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:5', dtype=torch.int32) + primals_17 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:5', dtype=torch.int32) + primals_19 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:5', dtype=torch.int32) + primals_21 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:5', dtype=torch.int32) + getitem = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + getitem_1 = rand_strided((2, 32, 2048), (65536, 2048, 1), device='cuda:5', dtype=torch.float32) + tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16) + fn = lambda: call([primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/x4/cx4zxia2356l4deopgulqomafdhdef4hbb5heqlszh3lfmzh3j4i.py b/SpecForge-ext/cache/compiled_kernels/x4/cx4zxia2356l4deopgulqomafdhdef4hbb5heqlszh3lfmzh3j4i.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b092ba2eed7595e56e76f317bbfc8db9953f1d --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/x4/cx4zxia2356l4deopgulqomafdhdef4hbb5heqlszh3lfmzh3j4i.py @@ -0,0 +1,711 @@ +# AOT ID: ['13_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._C import _cuda_getCurrentRawStream as get_raw_stream +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/kx/ckxtdzhg3azhdxeooy2uushwzka4sz2hzjpq5dulk2g2jjweqr6b.py +# Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] +# Source node to ATen node mapping: +# flex_attention => flex_attention +# Graph fragment: +# %primals_2 : Tensor "bf16[2, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:0" = PlaceHolder[target=primals_2] +# %primals_4 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_4] +# %primals_6 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:0" = PlaceHolder[target=primals_6] +# %getitem_1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=getitem_1] +# %buf1 : Tensor "f32[2, 32, s37][32*s37, s37, 1]cuda:0" = PlaceHolder[target=buf1] +# %primals_13 : Tensor "i32[2, 1, s99][s99, s99, 1]cuda:0" = PlaceHolder[target=primals_13] +# %primals_9 : Tensor "i32[2, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:0" = PlaceHolder[target=primals_9] +# %primals_17 : Tensor "i32[2, 1, s94][s94, s94, 1]cuda:0" = PlaceHolder[target=primals_17] +# %primals_20 : Tensor "i32[2, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:0" = PlaceHolder[target=primals_20] +# %primals_14 : Tensor "i64[2][1]cuda:0" = PlaceHolder[target=primals_14] +# %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_2, %primals_4, %primals_6, %sdpa_score0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {}) +# return %getitem +triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + +@triton_heuristics.template( + +num_stages=3, +num_warps=8, +triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]}, +inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}}, + +) +@triton.jit +def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5): + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + Q = arg_Q + K = arg_K + V = arg_V + LSE = arg_LSE + MAX = arg_MAX + KV_NUM_BLKS = arg_KV_NUM_BLKS + KV_IDX = arg_KV_IDX + FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS + FULL_KV_IDX = arg_FULL_KV_IDX + + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1 + stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1 + stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1 + + ZQ = 2 + HQ = 32 + Q_LEN = ks0 + ZKV = 2 + KV_LEN = ks1 + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0).to(INDEX_DTYPE) + off_zq = tl.program_id(1).to(INDEX_DTYPE) + off_hq = tl.program_id(2).to(INDEX_DTYPE) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + + SPARSE_Z = 2 + SPARSE_HQ = 1 + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = ks2 + stride_kv_idx_h = ks3*ks4 + stride_kv_idx_m = ks4 + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + # K and V pointers will be passed directly to forward_inner + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + # K and V pointers will be passed directly to forward_inner + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1).to(INDEX_DTYPE) + idx_hq = tl.program_id(2).to(INDEX_DTYPE) + idx_m = offs_m[:, None].to(INDEX_DTYPE) + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE) + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED]) + xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0 + tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask) + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + + if OUTPUT_MAX: + off_hz = off_zq * HQ + off_hq + max_ptrs = MAX + off_hz * Q_LEN + offs_m + if IS_DIVISIBLE: + tl.store(max_ptrs, m_i) + else: + tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN) + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_LEN: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) + + +# Common Imports +@triton.jit +def forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + # -- load k -- + # NB reversed order to since K is transposed + kv_base_offset = kv_start + kv_offset + + # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) + k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + tmp0 = (qk) + post_mod_scores = tmp0 + + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + tmp1 = tl.full([1], False, tl.int1) + tmp2 = (m) + tmp3 = (n) + tmp4 = tmp2 >= tmp3 + tmp5 = tmp3.to(tl.int64) + tmp6 = (off_z) + tmp7 = tl.load(in_ptr9 + tmp6) + tmp8 = tmp5 < tmp7 + tmp9 = tmp2.to(tl.int64) + tmp10 = tmp9 < tmp7 + tmp11 = tmp8 & tmp10 + tmp12 = tmp4 & tmp11 + tmp13 = tmp1 | tmp12 + tmp14 = ks5 + tmp15 = tmp3 >= tmp14 + tmp16 = (tmp3 % tmp14) + tmp17 = tl.full([1], 0, tl.int32) + tmp18 = tmp16 != tmp17 + tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0 + tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp21 = tmp19 != tmp20 + tmp22 = tmp18 & tmp21 + tmp23 = tmp16 + tmp14 + tmp24 = tl.where(tmp22, tmp23, tmp16) + tmp25 = tmp24.to(tl.int64) + tmp26 = tmp25 < tmp7 + tmp27 = tmp15 & tmp26 + tmp28 = tmp3 - tmp2 + tmp29 = (tmp28 % tmp14) + tmp30 = tmp29 != tmp17 + tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0 + tmp32 = tmp31 != tmp20 + tmp33 = tmp30 & tmp32 + tmp34 = tmp29 + tmp14 + tmp35 = tl.where(tmp33, tmp34, tmp29) + tmp36 = tmp35 == tmp17 + tmp37 = tmp27 & tmp36 + tmp38 = tmp13 | tmp37 + mask_mod_output = tmp38 + + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + # Calculate offsets for V loading - reuse kv_base_offset from K loading + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + PRESCALE_QK : tl.constexpr = False + ROWS_GUARANTEED_SAFE : tl.constexpr = False + BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False + WRITE_DQ : tl.constexpr = True + OUTPUT_LOGSUMEXP : tl.constexpr = True + OUTPUT_MAX : tl.constexpr = False + FLOAT32_PRECISION : tl.constexpr = 'tf32' + IS_DIVISIBLE : tl.constexpr = False + SM_SCALE : tl.constexpr = 0.08838834764831843 + GQA_SHARED_HEADS : tl.constexpr = 4 + HAS_FULL_BLOCKS : tl.constexpr = True + QK_HEAD_DIM : tl.constexpr = 128 + QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 + V_HEAD_DIM : tl.constexpr = 128 + V_HEAD_DIM_ROUNDED : tl.constexpr = 128 + SAFE_HEAD_DIM : tl.constexpr = True + USE_TMA : tl.constexpr = False + BLOCK_M : tl.constexpr = 128 + BLOCK_N : tl.constexpr = 64 + SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 + SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 + INDEX_DTYPE : tl.constexpr = tl.int32 + + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, + q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + # Strides for K and V + stride_kk, stride_kn, stride_vn, stride_vk, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + + + return acc, l_i, m_i +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30 = args + args.clear() + s50 = primals_1 + s0 = primals_3 + s43 = primals_5 + s22 = primals_7 + s72 = primals_8 + s37 = primals_10 + s71 = primals_11 + s99 = primals_12 + s75 = primals_15 + s94 = primals_16 + s28 = primals_18 + s4 = primals_19 + s56 = primals_21 + s84 = primals_23 + s53 = primals_24 + s100 = primals_26 + s6 = primals_28 + s10 = primals_29 + assert_size_stride(primals_2, (2, 32, s37, 128), (4096*s37, 128, 4096, 1)) + assert_size_stride(primals_4, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_6, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1)) + assert_size_stride(primals_9, (2, 1, s22, s72), (s22*s72, s22*s72, s72, 1)) + assert_size_stride(primals_13, (2, 1, s99), (s99, s99, 1)) + assert_size_stride(primals_14, (2, ), (1, )) + assert_size_stride(primals_17, (2, 1, s94), (s94, s94, 1)) + assert_size_stride(primals_20, (2, 1, s28, s4), (s28*s4, s28*s4, s4, 1)) + assert_size_stride(primals_22, (2, 1, s56), (s56, s56, 1)) + assert_size_stride(primals_25, (2, 1, s84, s53), (s53*s84, s53*s84, s53, 1)) + assert_size_stride(primals_27, (2, 1, s100), (s100, s100, 1)) + assert_size_stride(primals_30, (2, 1, s6, s10), (s10*s6, s10*s6, s10, 1)) + with torch.cuda._DeviceGuard(0): + torch.cuda.set_device(0) + buf0 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32) + buf1 = empty_strided_cuda((2, 32, s37), (32*s37, s37, 1), torch.float32) + buf2 = empty_strided_cuda((2, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [flex_attention], Original ATen: [] + stream0 = get_raw_stream(0) + triton_tem_fused_0.run(primals_2, primals_4, primals_6, buf0, buf1, primals_13, primals_9, primals_17, primals_20, primals_14, buf2, s37, s0, s99, s22, s72, s75, (127 + s37) // 128, 2, 32, stream=stream0) + del buf1 + return (buf2, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, buf2, buf0, s37, s0, s75, s22, s72, s99, s94, s28, s4, s56, s53, s84, s100, s10, s6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 1130 + primals_2 = rand_strided((2, 32, 1130, 128), (4628480, 128, 4096, 1), device='cuda:0', dtype=torch.bfloat16) + primals_3 = 1130 + primals_4 = rand_strided((2, 8, 1130, 128), (1157120, 144640, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_5 = 1130 + primals_6 = rand_strided((2, 8, 1130, 128), (1157120, 144640, 128, 1), device='cuda:0', dtype=torch.bfloat16) + primals_7 = 9 + primals_8 = 9 + primals_9 = rand_strided((2, 1, 9, 9), (81, 81, 9, 1), device='cuda:0', dtype=torch.int32) + primals_10 = 1130 + primals_11 = 1130 + primals_12 = 9 + primals_13 = rand_strided((2, 1, 9), (9, 9, 1), device='cuda:0', dtype=torch.int32) + primals_14 = rand_strided((2, ), (1, ), device='cuda:0', dtype=torch.int64) + primals_15 = 1130 + primals_16 = 9 + primals_17 = rand_strided((2, 1, 9), (9, 9, 1), device='cuda:0', dtype=torch.int32) + primals_18 = 9 + primals_19 = 9 + primals_20 = rand_strided((2, 1, 9, 9), (81, 81, 9, 1), device='cuda:0', dtype=torch.int32) + primals_21 = 9 + primals_22 = rand_strided((2, 1, 9), (9, 9, 1), device='cuda:0', dtype=torch.int32) + primals_23 = 9 + primals_24 = 9 + primals_25 = rand_strided((2, 1, 9, 9), (81, 81, 9, 1), device='cuda:0', dtype=torch.int32) + primals_26 = 9 + primals_27 = rand_strided((2, 1, 9), (9, 9, 1), device='cuda:0', dtype=torch.int32) + primals_28 = 9 + primals_29 = 9 + primals_30 = rand_strided((2, 1, 9, 9), (81, 81, 9, 1), device='cuda:0', dtype=torch.int32) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28, primals_29, primals_30]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/x7/cx73ryddx34y6aoantfwkykhj52gzpwh5bdhuel3zfklqokgauet.py b/SpecForge-ext/cache/compiled_kernels/x7/cx73ryddx34y6aoantfwkykhj52gzpwh5bdhuel3zfklqokgauet.py new file mode 100644 index 0000000000000000000000000000000000000000..44d3dfb700f7c50bfb276e38f8138534485cb0dd --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/x7/cx73ryddx34y6aoantfwkykhj52gzpwh5bdhuel3zfklqokgauet.py @@ -0,0 +1,184 @@ +# AOT ID: ['2_forward'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/pq/cpquoebytalyakh5xbcgqhrouuzyvugczrc2quxh5vnumd4vms3g.py +# Topologically Sorted Source Nodes: [hidden_states, pow_1, variance, rsqrt, hidden_states_1, to_1, mul_1], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] +# Source node to ATen node mapping: +# hidden_states => convert_element_type +# hidden_states_1 => mul_16 +# mul_1 => mul_23 +# pow_1 => pow_1 +# rsqrt => rsqrt +# to_1 => convert_element_type_1 +# variance => mean +# Graph fragment: +# %primals_4 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:5" = PlaceHolder[target=primals_4] +# %buf0 : Tensor "f32[s47, s87, 1][s87, 1, s47*s87]cuda:5" = PlaceHolder[target=buf0] +# %primals_5 : Tensor "f64[][]cpu" = PlaceHolder[target=primals_5] +# %primals_7 : Tensor "bf16[s33][1]cuda:5" = PlaceHolder[target=primals_7] +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:5" = PlaceHolder[target=rsqrt] +# %convert_element_type : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_4, torch.float32), kwargs = {}) +# %pow_1 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 2), kwargs = {}) +# %mean : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {}) +# %convert_element_type_default_1 : Tensor "f32[][]cpu"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_5, torch.float32), kwargs = {}) +# %add_tensor : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, %convert_element_type_default_1), kwargs = {}) +# %rsqrt : Tensor "f32[s47, s87, 1][s87, 1, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_tensor,), kwargs = {}) +# %mul_16 : Tensor "f32[s47, s87, s33][s33*s87, s33, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %rsqrt), kwargs = {}) +# %convert_element_type_1 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_16, torch.bfloat16), kwargs = {}) +# %mul_23 : Tensor "bf16[s47, s87, s33][s33*s87, s33, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_7, %convert_element_type_1), kwargs = {}) +# return %buf0,%rsqrt,%mul_23 +triton_red_fused__to_copy_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 4096}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': 'fp64', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp1 = tmp0.to(tl.float32) + tmp2 = tmp1 * tmp1 + tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5 = _tmp4 + tmp3 + _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4) + tmp4 = tl.sum(_tmp4, 1)[:, None] + tmp9 = in_ptr1 + tmp6 = ks0 + tmp7 = tmp6.to(tl.float32) + tmp8 = (tmp4 / tmp7) + tmp10 = tmp9.to(tl.float32) + tmp11 = tmp8 + tmp10 + tmp12 = libdevice.rsqrt(tmp11) + tl.debug_barrier() + tl.store(in_out_ptr0 + (x0), tmp12, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp13 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) + tmp14 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) + tmp15 = tmp14.to(tl.float32) + tmp16 = tmp15 * tmp12 + tmp17 = tmp16.to(tl.float32) + tmp18 = tmp13 * tmp17 + tl.store(out_ptr0 + (r0_1 + ks0*x0), tmp18, r0_mask & xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7 = args + args.clear() + s47 = primals_1 + s87 = primals_2 + s33 = primals_3 + s82 = primals_6 + assert_size_stride(primals_4, (s47, s87, s33), (s33*s87, s33, 1)) + assert_size_stride(primals_5, (), ()) + assert_size_stride(primals_7, (s33, ), (1, )) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((s47, s87, 1), (s87, 1, s47*s87), torch.float32) + buf1 = reinterpret_tensor(buf0, (s47, s87, 1), (s87, 1, 1), 0); del buf0 # reuse + buf2 = empty_strided_cuda((s47, s87, s33), (s33*s87, s33, 1), torch.bfloat16) + # Topologically Sorted Source Nodes: [hidden_states, pow_1, variance, rsqrt, hidden_states_1, to_1, mul_1], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul] + triton_red_fused__to_copy_mean_mul_pow_rsqrt_0_xnumel = s47*s87 + stream5 = get_raw_stream(5) + triton_red_fused__to_copy_mean_mul_pow_rsqrt_0.run(buf1, primals_4, primals_5.item(), primals_7, buf2, s33, triton_red_fused__to_copy_mean_mul_pow_rsqrt_0_xnumel, s33, stream=stream5) + del primals_5 + return (buf2, primals_4, primals_7, buf1, s47, s87, s33, s82, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + primals_1 = 2 + primals_2 = 2048 + primals_3 = 4096 + primals_4 = rand_strided((2, 2048, 4096), (8388608, 4096, 1), device='cuda:5', dtype=torch.bfloat16) + primals_5 = rand_strided((), (), device='cpu', dtype=torch.float64) + primals_6 = 840433664 + primals_7 = rand_strided((4096, ), (1, ), device='cuda:5', dtype=torch.bfloat16) + fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/yl/cylmaptaywoo4jot2usnz3p56udutvmucopabqgi6krdvvi5ocao.py b/SpecForge-ext/cache/compiled_kernels/yl/cylmaptaywoo4jot2usnz3p56udutvmucopabqgi6krdvvi5ocao.py new file mode 100644 index 0000000000000000000000000000000000000000..03e51dc8f76421de75f729065de872828e21d199 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/yl/cylmaptaywoo4jot2usnz3p56udutvmucopabqgi6krdvvi5ocao.py @@ -0,0 +1,527 @@ +# AOT ID: ['8_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/wk/cwknascjc5r3t6bclny2fnjjakqaebtac6hrvubwnh2e5yl5qhk3.py +# Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_2 => full_default_1 +# Graph fragment: +# %full_default_1 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# return %index_put +triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 8192}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/rl/crljmltx2ijltda364wepx5plo3nzy5cath6ikqwxj7wafhhskhq.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge_1, view +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# diagnol_mask => eq_12 +# full_blocks => eq_24 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# lt_3 => lt_3 +# m => iota_2 +# mask_1 => constant_pad_nd +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub_12, view_7 +# suffix_mask => ge_2 +# Graph fragment: +# %arg1_1 : Tensor "i64[8][1]cuda:6" = PlaceHolder[target=arg1_1] +# %sum_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:6" = PlaceHolder[target=sum_1] +# %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[s37][1]cuda:6"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %ge_1 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[8][1]cuda:6"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %index : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {}) +# %lt : Tensor "b8[8, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, %arg0_1]), kwargs = {}) +# %index_1 : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_2 : Tensor "b8[s37][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[s37][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[8, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[8, 1, s37][Max(1, s37), s37, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, %arg0_1]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub_12 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_12, 2048), kwargs = {}) +# %eq_12 : Tensor "b8[2048, s37][Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_12), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[8, 2048, s37][2048*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[8, 1, 2048, s37][2048*Max(1, s37), 2048*Max(1, s37), Max(1, s37), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, %arg0_1]), kwargs = {}) +# %constant_pad_nd : Tensor "b8[8, 1, 2048, 128*(((s37 + 127)//128))][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_23, 0, 0], 0.0), kwargs = {}) +# %view_10 : Tensor "b8[8, 1, 16, 128, ((s37 + 127)//128), 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [8, 1, 16, 128, %floordiv_1, 128]), kwargs = {}) +# %permute : Tensor "b8[8, 1, 16, ((s37 + 127)//128), 128, 128][2048*Max(1, 128*(((s37 + 127)//128))), 2048*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# %gt : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %eq_24 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_24, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# return %sum_1,%convert_element_type_2,%convert_element_type_5 +triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 4096, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = (xindex % ks0) + x1 = ((xindex // ks0) % 16) + x2 = xindex // ks2 + _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x5 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_3 = (r0_index % 128) + r0_4 = r0_index // 128 + tmp0 = r0_3 + 128*x0 + tmp1 = ks1 + tmp2 = tmp0 < tmp1 + tmp3 = r0_4 + 128*x1 + tmp4 = r0_3 + 128*x0 + tmp5 = tmp3 >= tmp4 + tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0) + tmp7 = tmp4 < tmp6 + tmp8 = tmp3 < tmp6 + tmp9 = tmp7 & tmp8 + tmp10 = tmp5 & tmp9 + tmp11 = tl.full([1, 1], False, tl.int1) + tmp12 = tmp11 | tmp10 + tmp13 = tl.full([1, 1], 2048, tl.int64) + tmp14 = tmp4 >= tmp13 + tmp15 = ((r0_3 + 128*x0) % 2048) + tmp16 = tmp15 < tmp6 + tmp17 = tmp14 & tmp16 + tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp19 = (tmp18 % tmp13) + tmp20 = tl.full([1, 1], 0, tl.int32) + tmp21 = tmp19 != tmp20 + tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0 + tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0 + tmp24 = tmp22 != tmp23 + tmp25 = tmp21 & tmp24 + tmp26 = tmp19 + tmp13 + tmp27 = tl.where(tmp25, tmp26, tmp19) + tmp28 = tl.full([1, 1], 0, tl.int64) + tmp29 = tmp27 == tmp28 + tmp30 = tmp17 & tmp29 + tmp31 = tmp12 | tmp30 + tmp32 = tl.full(tmp31.shape, False, tmp31.dtype) + tmp33 = tl.where(tmp2, tmp31, tmp32) + tmp34 = tmp33.to(tl.int64) + tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK]) + tmp37 = _tmp36 + tmp35 + _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36) + tmp36 = tl.sum(_tmp36, 1)[:, None] + tmp38 = tl.full([1, 1], 0, tl.int64) + tmp39 = tmp36 > tmp38 + tmp40 = tl.full([1, 1], 16384, tl.int64) + tmp41 = tmp36 < tmp40 + tmp42 = tmp39 & tmp41 + tmp43 = tmp42.to(tl.int8) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp36 == tmp40 + tmp46 = tmp45.to(tl.int8) + tmp47 = tmp46.to(tl.int32) + tl.store(out_ptr1 + (x5), tmp44, xmask) + tl.store(out_ptr2 + (x5), tmp47, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/2x/c2xunts4zntd65pabgkkxg5ylyh7sahfyogzmgljfiljdui4o365.py +# Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# col_range => iota_5 +# dense_mask_2 => full_default_1 +# index_mask => lt_4 +# num_blocks_in_row => sum_2 +# row_indices => unsqueeze +# setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# unsqueeze_1 => unsqueeze_1 +# valid_indices => scalar_tensor, where +# Graph fragment: +# %convert_element_type_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:6" = PlaceHolder[target=convert_element_type_2] +# %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:6" = PlaceHolder[target=sum_2] +# %getitem_1 : Tensor "i64[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 128*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6" = PlaceHolder[target=getitem_1] +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:6" = PlaceHolder[target=index_put] +# %full_default_1 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, %add_166], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %iota_7 : Tensor "i64[8][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:6, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_1,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:6, requires_grad: False}) +# %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %scalar_tensor : Tensor "i32[][]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_1,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6}) +# %where : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {}) +# %full_default_2 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:6, pin_memory: False}) +# %index_put : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {}) +# return %sum_2,%convert_element_type_3,%convert_element_type_4,%buf13 +triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2 = async_compile.triton('triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 128, 'r0_': 32}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2(in_ptr0, in_ptr1, out_ptr1, out_ptr2, out_ptr3, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 128 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x0 = xindex + _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp1 = tmp0.to(tl.int64) + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) + tmp4 = _tmp3 + tmp2 + _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3) + tmp3 = tl.sum(_tmp3, 1)[:, None] + tmp5 = tmp3.to(tl.int32) + tl.store(out_ptr1 + (x0), tmp5, xmask) + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + tmp6 = tl.load(in_ptr1 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8 < tmp5 + tmp10 = ks0 + tmp11 = tl.where(tmp9, tmp7, tmp10) + tmp12 = 1 + ks0 + tmp13 = tmp11 + tmp12 + tmp14 = tmp11 < 0 + tmp15 = tl.where(tmp14, tmp13, tmp11) + tl.device_assert(((0 <= tmp15) & (tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128)))) | ~(r0_mask & xmask), "index out of bounds: 0 <= tmp15 < 1 + (triton_helpers.div_floor_integer(127 + ks1, 128))") + tmp17 = tl.full([1, 1], 1, tl.int32) + tl.store(out_ptr2 + (r0_1 + x0*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp7, r0_mask & xmask) + tl.store(out_ptr3 + (tl.broadcast_to(tmp15 + x0 + ks0*x0, [XBLOCK, R0_BLOCK])), tmp17, r0_mask & xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/pf/cpfhrnxtrt5f4ofde5sv4zdr625vgwb7w2tj7tviz3j3ruexout3.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf13 : Tensor "i32[8, 1, 16, (((s37 + 127)//128)) + 1][16*(((s37 + 127)//128)) + 16, 16*(((s37 + 127)//128)) + 16, (((s37 + 127)//128)) + 1, 1]cuda:6" = PlaceHolder[target=buf13] +# %buf15 : Tensor "i16[8, 1, ((s37 + 127)//128), 16][16*(((s37 + 127)//128)), 128*(((s37 + 127)//128)), 16, 1]cuda:6" = PlaceHolder[target=buf15] +# %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][((s37 + 127)//128), 8*(((s37 + 127)//128)), 1]cuda:6" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, (((s37 + 127)//128)) + 1), 16*Max(1, (((s37 + 127)//128)) + 1), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_1), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, 16, ((s37 + 127)//128)][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:6"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[8, 1, ((s37 + 127)//128), 16][16*Max(1, ((s37 + 127)//128)), 16*Max(1, ((s37 + 127)//128)), 16, 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[8, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:6"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf15,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 256, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr): + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % ks0) + x1 = xindex // ks0 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask) + tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, arg1_1 = args + args.clear() + s37 = arg0_1 + assert_size_stride(arg1_1, (8, ), (1, )) + with torch.cuda._DeviceGuard(6): + torch.cuda.set_device(6) + buf12 = empty_strided_cuda((8, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 128 + 128*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream6) + buf19 = empty_strided_cuda((8, 1, 16, 1 + ((127 + s37) // 128)), (16 + 16*((127 + s37) // 128), 16 + 16*((127 + s37) // 128), 1 + ((127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + triton_poi_fused_new_zeros_0_xnumel = 128 + 128*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_poi_fused_new_zeros_0.run(buf19, triton_poi_fused_new_zeros_0_xnumel, stream=stream6) + ps0 = (127 + s37) // 128 + ps1 = 16*((127 + s37) // 128) + buf1 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 128*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + buf5 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*((127 + s37) // 128), 128*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy] + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 128*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg1_1, buf1, buf5, ps0, s37, ps1, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream6) + del arg1_1 + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort] + buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True) + buf4 = buf2[1] + assert_size_stride(buf4, (8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 128*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable') + del buf2 + buf10 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf11 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf1, buf4, buf10, buf11, buf12, ps0, s37, 128, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream6) + del buf1 + del buf4 + buf26 = empty_strided_cuda((8, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf28 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 8*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf12, buf26, buf28, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream6) + del buf12 + # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort] + buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True) + buf8 = buf6[1] + assert_size_stride(buf8, (8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 128*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable') + assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable') + del buf6 + buf17 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf18 = empty_strided_cuda((8, 1, 16, (127 + s37) // 128), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten._to_copy, aten.lt, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel = (127 + s37) // 128 + stream6 = get_raw_stream(6) + triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2.run(buf5, buf8, buf17, buf18, buf19, ps0, s37, 128, triton_red_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_sum_unsqueeze_view_where_2_r0_numel, stream=stream6) + del buf5 + del buf8 + buf23 = empty_strided_cuda((8, 1, (127 + s37) // 128, 16), (16*max(1, (127 + s37) // 128), 16*max(1, (127 + s37) // 128), 16, 1), torch.int32) + buf25 = empty_strided_cuda((8, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel = 8*((127 + s37) // 128) + stream6 = get_raw_stream(6) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf19, buf23, buf25, ps0, triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3_xnumel, 16, stream=stream6) + del buf19 + return (buf23, buf25, buf26, buf28, buf18, buf17, buf11, buf10, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = 4096 + arg1_1 = rand_strided((8, ), (1, ), device='cuda:6', dtype=torch.int64) + fn = lambda: call([arg0_1, arg1_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/cache/compiled_kernels/yw/cywfovnmctecguk3uvo7jp2wmfsaqvaugskuxva4igyc5r72lgx7.py b/SpecForge-ext/cache/compiled_kernels/yw/cywfovnmctecguk3uvo7jp2wmfsaqvaugskuxva4igyc5r72lgx7.py new file mode 100644 index 0000000000000000000000000000000000000000..5afc7621e06adb94836ebce42e79ae8e4fd589a5 --- /dev/null +++ b/SpecForge-ext/cache/compiled_kernels/yw/cywfovnmctecguk3uvo7jp2wmfsaqvaugskuxva4igyc5r72lgx7.py @@ -0,0 +1,543 @@ +# AOT ID: ['5_inference'] +from ctypes import c_void_p, c_long, c_int +import torch +import math +import random +import os +import tempfile +from math import inf, nan +from cmath import nanj +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import start_graph, end_graph +from torch._C import _cuda_getCurrentRawStream as get_raw_stream + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +assert_alignment = torch._C._dynamo.guards.assert_alignment +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool +async_compile = AsyncCompile() +empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/jq/cjq5hv4rnv3k5awzzq6t2f4dupyimqnzm5i36pci6ox5vpquu66l.py +# Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] +# Source node to ATen node mapping: +# and_2 => bitwise_and_1 +# and_3 => bitwise_and_2 +# and_4 => bitwise_and_3, view_8 +# b => iota +# batched_outputs_2 => view_9 +# causal_mask => ge, view +# diagnol_mask => eq +# index => index +# index_1 => index_1 +# index_2 => index_2 +# lt => lt, view_1 +# lt_1 => lt_1, view_2 +# m => iota_2 +# mask_2 => view_10 +# mask_3 => permute +# mask_block_sum => sum_1 +# n => iota_3 +# padding_mask => bitwise_and, view_3, view_4 +# padding_mask_1 => lt_2, view_6 +# remainder => remainder +# remainder_1 => remainder_1 +# result_1 => bitwise_or, full_default +# result_2 => bitwise_or_1 +# sub => sub, view_7 +# suffix_mask => ge_1 +# Graph fragment: +# %arg0_1 : Tensor "i64[8][1]cuda:5" = PlaceHolder[target=arg0_1] +# %full_default : Tensor "b8[8, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %iota_2 : Tensor "i64[2048][1]cuda:5"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %view : Tensor "i64[2048, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %iota_3 : Tensor "i64[2048][1]cuda:5"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %ge : Tensor "b8[2048, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {}) +# %iota : Tensor "i64[8][1]cuda:5"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %index : Tensor "i64[8][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_1 : Tensor "i64[8, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [8, 1]), kwargs = {}) +# %lt : Tensor "b8[8, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {}) +# %view_4 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [8, 1, 2048]), kwargs = {}) +# %index_1 : Tensor "i64[8][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_2 : Tensor "i64[8, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [8, 1]), kwargs = {}) +# %lt_1 : Tensor "b8[8, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {}) +# %view_3 : Tensor "b8[8, 2048, 1][2048, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [8, 2048, 1]), kwargs = {}) +# %bitwise_and : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {}) +# %bitwise_and_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {}) +# %bitwise_or : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {}) +# %ge_1 : Tensor "b8[2048][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %remainder : Tensor "i64[2048][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {}) +# %index_2 : Tensor "i64[8][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {}) +# %view_6 : Tensor "i64[8, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [8, 1]), kwargs = {}) +# %lt_2 : Tensor "b8[8, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {}) +# %bitwise_and_2 : Tensor "b8[8, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_2), kwargs = {}) +# %view_8 : Tensor "b8[8, 1, 2048][2048, 2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [8, 1, 2048]), kwargs = {}) +# %view_7 : Tensor "i64[2048, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {}) +# %sub : Tensor "i64[2048, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {}) +# %remainder_1 : Tensor "i64[2048, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub, 2048), kwargs = {}) +# %eq : Tensor "b8[2048, 2048][2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {}) +# %bitwise_and_3 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq), kwargs = {}) +# %bitwise_or_1 : Tensor "b8[8, 2048, 2048][4194304, 2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {}) +# %view_9 : Tensor "b8[8, 1, 2048, 2048][4194304, 4194304, 2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [8, 1, 2048, 2048]), kwargs = {}) +# %view_10 : Tensor "b8[8, 1, 16, 128, 16, 128][4194304, 4194304, 262144, 2048, 128, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [8, 1, 16, 128, 16, 128]), kwargs = {}) +# %permute : Tensor "b8[8, 1, 16, 16, 128, 128][4194304, 4194304, 262144, 128, 2048, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {}) +# %sum_1 : Tensor "i64[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {}) +# return %sum_1 +triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.reduction( + size_hints={'x': 2048, 'r0_': 16384}, + reduction_hint=ReductionHint.INNER, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}} +) +@triton.jit +def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + xnumel = 2048 + r0_numel = 16384 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_base = tl.arange(0, R0_BLOCK)[None, :] + rbase = r0_base + x1 = ((xindex // 16) % 16) + x0 = (xindex % 16) + x2 = xindex // 256 + tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last') + _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64) + x6 = xindex + for r0_offset in range(0, r0_numel, R0_BLOCK): + r0_index = r0_offset + r0_base + r0_mask = r0_index < r0_numel + roffset = r0_offset + rindex = r0_index + r0_4 = r0_index // 128 + r0_3 = (r0_index % 128) + tmp0 = r0_4 + 128*x1 + tmp1 = r0_3 + 128*x0 + tmp2 = tmp0 >= tmp1 + tmp4 = tmp1 < tmp3 + tmp5 = tmp0 < tmp3 + tmp6 = tmp4 & tmp5 + tmp7 = tmp2 & tmp6 + tmp8 = tl.full([1, 1], False, tl.int1) + tmp9 = tmp8 | tmp7 + tmp10 = tl.full([1, 1], 2048, tl.int64) + tmp11 = tmp1 >= tmp10 + tmp12 = tmp11 & tmp4 + tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0 + tmp14 = (tmp13 % tmp10) + tmp15 = tl.full([1, 1], 0, tl.int32) + tmp16 = tmp14 != tmp15 + tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0 + tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0 + tmp19 = tmp17 != tmp18 + tmp20 = tmp16 & tmp19 + tmp21 = tmp14 + tmp10 + tmp22 = tl.where(tmp20, tmp21, tmp14) + tmp23 = tl.full([1, 1], 0, tl.int64) + tmp24 = tmp22 == tmp23 + tmp25 = tmp12 & tmp24 + tmp26 = tmp9 | tmp25 + tmp27 = tmp26.to(tl.int64) + tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK]) + tmp30 = _tmp29 + tmp28 + _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29) + tmp29 = tl.sum(_tmp29, 1)[:, None] + tl.store(out_ptr0 + (x6), tmp29, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/fr/cfrczc5cr5uotil5g5x435datuzfao56zz4vsxlh33jteluxhhme.py +# Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] +# Source node to ATen node mapping: +# dense_mask_4 => full_default_4 +# Graph fragment: +# %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# return %index_put_1 +triton_poi_fused_new_zeros_1 = async_compile.triton('triton_poi_fused_new_zeros_1', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.pointwise( + size_hints={'x': 4096}, + filename=__file__, + triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 17408}}, + min_elem_per_thread=0 +) +@triton.jit +def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 2176 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.full([1], 0, tl.int32) + tl.store(out_ptr0 + (x0), tmp0, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ji/cjij26clq6lcv6c2plwk2zxldtphmt23swyyv2i3vq3ujc4fkjp5.py +# Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] +# Source node to ATen node mapping: +# arange_4 => iota_4 +# arange_6 => iota_8 +# child_3 => convert_element_type_3 +# child_4 => convert_element_type_4 +# child_7 => convert_element_type_6 +# child_8 => convert_element_type_7 +# col_indices => sort +# col_indices_1 => sort_1 +# col_range => iota_5 +# col_range_1 => iota_9 +# dense_mask => convert_element_type_2 +# dense_mask_1 => convert_element_type_5 +# dense_mask_2 => full_default_1 +# dense_mask_4 => full_default_4 +# full_blocks => eq_1 +# full_blocks_1 => convert_element_type_1 +# gt => gt +# index_mask => lt_4 +# index_mask_1 => lt_5 +# lt_3 => lt_3 +# num_blocks_in_row => sum_2 +# num_blocks_in_row_1 => sum_3 +# partial_blocks => bitwise_and_4 +# partial_blocks_1 => convert_element_type +# row_indices => unsqueeze +# row_indices_1 => unsqueeze_7 +# setitem => full_default_3, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6 +# setitem_1 => full_default_6, index_put_1, iota_10, iota_11, unsqueeze_10, unsqueeze_11, unsqueeze_12, unsqueeze_13, unsqueeze_9 +# unsqueeze_1 => unsqueeze_1 +# unsqueeze_3 => unsqueeze_8 +# valid_indices => full_default_2, where +# valid_indices_1 => full_default_5, where_1 +# Graph fragment: +# %sum_1 : Tensor "i64[8, 1, 16, 16][256, 2048, 16, 1]cuda:5" = PlaceHolder[target=sum_1] +# %sum_2 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:5" = PlaceHolder[target=sum_2] +# %sum_3 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:5" = PlaceHolder[target=sum_3] +# %buf2 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:5" = PlaceHolder[target=buf2] +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=convert_element_type_3] +# %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5" = PlaceHolder[target=convert_element_type_4] +# %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:5" = PlaceHolder[target=index_put] +# %buf4 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:5" = PlaceHolder[target=buf4] +# %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=convert_element_type_6] +# %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5" = PlaceHolder[target=convert_element_type_7] +# %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:5" = PlaceHolder[target=index_put_1] +# %gt : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {}) +# %lt_3 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %bitwise_and_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {}) +# %convert_element_type : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {}) +# %convert_element_type_2 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {}) +# %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_2,), kwargs = {stable: True, descending: True}) +# %eq_1 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {}) +# %convert_element_type_1 : Tensor "i8[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_1, torch.int8), kwargs = {}) +# %convert_element_type_5 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {}) +# %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_5,), kwargs = {stable: True, descending: True}) +# %full_default_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %iota_7 : Tensor "i64[8][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %unsqueeze_4 : Tensor "i64[8, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {}) +# %unsqueeze_5 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {}) +# %unsqueeze_6 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {}) +# %iota_6 : Tensor "i64[1][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {}) +# %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {}) +# %iota_4 : Tensor "i32[16][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:5, requires_grad: False}) +# %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {}) +# %iota_5 : Tensor "i32[16][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:5, requires_grad: False}) +# %sum_2 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {}) +# %convert_element_type_3 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {}) +# %unsqueeze_1 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {}) +# %lt_4 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {}) +# %convert_element_type_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {}) +# %full_default_2 : Tensor "i32[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %where : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %full_default_2), kwargs = {}) +# %full_default_3 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %index_put : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_3), kwargs = {}) +# %full_default_4 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %iota_11 : Tensor "i64[8][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %unsqueeze_11 : Tensor "i64[8, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_11, -1), kwargs = {}) +# %unsqueeze_12 : Tensor "i64[8, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_11, -1), kwargs = {}) +# %unsqueeze_13 : Tensor "i64[8, 1, 1, 1][1, 1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_12, -1), kwargs = {}) +# %iota_10 : Tensor "i64[1][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False}) +# %unsqueeze_9 : Tensor "i64[1, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_10, -1), kwargs = {}) +# %unsqueeze_10 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_9, -1), kwargs = {}) +# %iota_8 : Tensor "i32[16][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:5, requires_grad: False}) +# %unsqueeze_7 : Tensor "i32[16, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_8, -1), kwargs = {}) +# %iota_9 : Tensor "i32[16][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:5, requires_grad: False}) +# %sum_3 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_5, [-1]), kwargs = {}) +# %convert_element_type_6 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.int32), kwargs = {}) +# %unsqueeze_8 : Tensor "i32[8, 1, 16, 1][16, 16, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_6, 3), kwargs = {}) +# %lt_5 : Tensor "b8[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_9, %unsqueeze_8), kwargs = {}) +# %convert_element_type_7 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_3, torch.int32), kwargs = {}) +# %full_default_5 : Tensor "i32[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %where_1 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_5, %convert_element_type_7, %full_default_5), kwargs = {}) +# %full_default_6 : Tensor "i32[8, 1, 1, 1][1, 1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False}) +# %index_put_1 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_4, [%unsqueeze_13, %unsqueeze_10, %unsqueeze_7, %where_1], %full_default_6), kwargs = {}) +# return %buf2,%buf4,%sum_2,%sum_3,%convert_element_type_3,%convert_element_type_6,%convert_element_type_4,%buf9,%convert_element_type_7,%buf16 +triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 = async_compile.triton('triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False} +) +@triton.jit +def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_1 = r0_index + x0 = xindex + tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0) + tmp1 = tl.full([1, 1], 0, tl.int64) + tmp2 = tmp0 > tmp1 + tmp3 = tl.full([1, 1], 16384, tl.int64) + tmp4 = tmp0 < tmp3 + tmp5 = tmp2 & tmp4 + tmp6 = tmp5.to(tl.int8) + tmp7 = tmp6.to(tl.int32) + tmp8 = r0_1 + tmp9 = tmp8.to(tl.int16) + tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK]) + tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True) + tmp14 = tmp0 == tmp3 + tmp15 = tmp14.to(tl.int8) + tmp16 = tmp15.to(tl.int32) + tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) + tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True) + tmp20 = tmp7.to(tl.int64) + tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK]) + tmp23 = tl.where(xmask, tmp21, 0) + tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64) + tmp25 = tmp16.to(tl.int64) + tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK]) + tmp28 = tl.where(xmask, tmp26, 0) + tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64) + tmp30 = tmp24.to(tl.int32) + tmp31 = tmp29.to(tl.int32) + tmp32 = tmp13.to(tl.int64) + tmp33 = tmp32.to(tl.int32) + tmp34 = tmp8 < tmp30 + tmp35 = tl.full([1, 1], 16, tl.int32) + tmp36 = tl.where(tmp34, tmp33, tmp35) + tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32) + tmp38 = tmp36 + tmp37 + tmp39 = tmp36 < 0 + tmp40 = tl.where(tmp39, tmp38, tmp36) + tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17") + tmp42 = tl.full([1, 1], 1, tl.int32) + tmp43 = tmp19.to(tl.int64) + tmp44 = tmp43.to(tl.int32) + tmp45 = tmp8 < tmp31 + tmp46 = tl.where(tmp45, tmp44, tmp35) + tmp47 = tmp46 + tmp37 + tmp48 = tmp46 < 0 + tmp49 = tl.where(tmp48, tmp47, tmp46) + tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17") + tl.store(out_ptr4 + (x0), tmp30, xmask) + tl.store(out_ptr5 + (x0), tmp31, xmask) + tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask) + tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) + tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask) + tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask) +''', device_str='cuda') + + +# kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/qe/cqeqbkboqdi3uxaw5a5tmxlncw3qggxtaklqqgfqju7a73hndnau.py +# Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] +# Source node to ATen node mapping: +# batched_outputs_3 => clone_4, slice_2 +# col_indices_2 => sort_2 +# num_blocks_in_row_2 => sum_4 +# q_indices => clone_6, convert_element_type_9 +# q_num_blocks => convert_element_type_8 +# transpose => permute_1 +# Graph fragment: +# %buf9 : Tensor "i32[8, 1, 16, 17][272, 272, 17, 1]cuda:5" = PlaceHolder[target=buf9] +# %buf11 : Tensor "i16[8, 1, 16, 16][256, 2048, 16, 1]cuda:5" = PlaceHolder[target=buf11] +# %sum_4 : Tensor "i64[8, 1, 16][16, 128, 1]cuda:5" = PlaceHolder[target=sum_4] +# %slice_2 : Tensor "i32[8, 1, 16, 16][272, 272, 17, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, 16), kwargs = {}) +# %clone_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format}) +# %permute_1 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {}) +# %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True}) +# %convert_element_type_9 : Tensor "i32[8, 1, 16, 16][256, 256, 1, 16]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {}) +# %clone_6 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format}) +# %sum_4 : Tensor "i64[8, 1, 16][16, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {}) +# %convert_element_type_8 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {}) +# return %buf11,%sum_4,%clone_6,%convert_element_type_8 +triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', ''' +import triton +import triton.language as tl + +from torch._inductor.runtime import triton_helpers, triton_heuristics +from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math +from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties +triton_helpers.set_driver_to_gpu() + +@triton_heuristics.persistent_reduction( + size_hints={'x': 128, 'r0_': 16}, + reduction_hint=ReductionHint.DEFAULT, + filename=__file__, + triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, + inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 1024, 'r0_': 16384}} +) +@triton.jit +def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr): + xnumel = 128 + r0_numel = 16 + R0_BLOCK: tl.constexpr = 16 + rnumel = r0_numel + RBLOCK: tl.constexpr = R0_BLOCK + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + r0_index = tl.arange(0, R0_BLOCK)[None, :] + r0_offset = 0 + r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) + roffset = r0_offset + rindex = r0_index + r0_2 = r0_index + x0 = (xindex % 16) + x1 = xindex // 16 + x3 = xindex + tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0) + tmp1 = r0_2 + tmp2 = tmp1.to(tl.int16) + tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) + tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) + tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True) + tmp7 = tmp0.to(tl.int64) + tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK]) + tmp10 = tl.where(xmask, tmp8, 0) + tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64) + tmp12 = tmp6.to(tl.int64) + tmp13 = tmp12.to(tl.int32) + tmp14 = tmp11.to(tl.int32) + tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask) + tl.store(out_ptr3 + (x3), tmp14, xmask) +''', device_str='cuda') + + +async_compile.wait(globals()) +del async_compile + +class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + arg0_1, = args + args.clear() + assert_size_stride(arg0_1, (8, ), (1, )) + with torch.cuda._DeviceGuard(5): + torch.cuda.set_device(5) + buf0 = empty_strided_cuda((8, 1, 16, 16), (256, 2048, 16, 1), torch.int64) + # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum] + stream5 = get_raw_stream(5) + triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0.run(arg0_1, buf0, 2048, 16384, stream=stream5) + del arg0_1 + buf15 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros] + stream5 = get_raw_stream(5) + triton_poi_fused_new_zeros_1.run(buf15, 2176, stream=stream5) + buf8 = empty_strided_cuda((8, 1, 16, 17), (272, 272, 17, 1), torch.int32) + # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros] + stream5 = get_raw_stream(5) + triton_poi_fused_new_zeros_1.run(buf8, 2176, stream=stream5) + buf6 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf13 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + buf7 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf14 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put] + stream5 = get_raw_stream(5) + triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.run(buf0, buf6, buf13, buf7, buf8, buf14, buf15, 128, 16, stream=stream5) + del buf0 + buf22 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf24 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream5 = get_raw_stream(5) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf8, buf22, buf24, 128, 16, stream=stream5) + del buf8 + buf19 = empty_strided_cuda((8, 1, 16, 16), (256, 256, 16, 1), torch.int32) + buf21 = empty_strided_cuda((8, 1, 16), (16, 16, 1), torch.int32) + # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum] + stream5 = get_raw_stream(5) + triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf15, buf19, buf21, 128, 16, stream=stream5) + del buf15 + return (buf19, buf21, buf22, buf24, buf14, buf13, buf7, buf6, ) + +runner = Runner(partitions=[]) +call = runner.call +recursively_apply_fns = runner.recursively_apply_fns + + +def benchmark_compiled_module(times=10, repeat=10): + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + arg0_1 = rand_strided((8, ), (1, ), device='cuda:5', dtype=torch.int64) + fn = lambda: call([arg0_1]) + return print_performance(fn, times=times, repeat=repeat) + + +if __name__ == "__main__": + from torch._inductor.wrapper_benchmark import compiled_module_main + compiled_module_main('None', benchmark_compiled_module) diff --git a/SpecForge-ext/specforge/__pycache__/utils.cpython-311.pyc b/SpecForge-ext/specforge/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fbc4d7b0e20568704ff817661a6336eaa5e6e75 Binary files /dev/null and b/SpecForge-ext/specforge/__pycache__/utils.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/core/__pycache__/__init__.cpython-311.pyc b/SpecForge-ext/specforge/core/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffb5dab181b33d147a99b81a2ef500a6a50046db Binary files /dev/null and b/SpecForge-ext/specforge/core/__pycache__/__init__.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/core/__pycache__/dflash.cpython-311.pyc b/SpecForge-ext/specforge/core/__pycache__/dflash.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abd10b32edfe236fdaa151945d4c4f05844aaf39 Binary files /dev/null and b/SpecForge-ext/specforge/core/__pycache__/dflash.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/core/__pycache__/eagle3.cpython-311.pyc b/SpecForge-ext/specforge/core/__pycache__/eagle3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4be9ee966f870779ff576ff8f691dc86c6ee9e7 Binary files /dev/null and b/SpecForge-ext/specforge/core/__pycache__/eagle3.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/core/__pycache__/loss.cpython-311.pyc b/SpecForge-ext/specforge/core/__pycache__/loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36362b76398a552abdcbc301be88598bb85f0a7e Binary files /dev/null and b/SpecForge-ext/specforge/core/__pycache__/loss.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/core/loss.py b/SpecForge-ext/specforge/core/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..30e7fba7dd49cce3706ec9c740d9e597f0eade47 --- /dev/null +++ b/SpecForge-ext/specforge/core/loss.py @@ -0,0 +1,244 @@ +""" +This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. +See the original Unsloth repository at https://github.com/unslothai/unsloth. +The idea of in-place backward pass is from Liger-Kernel. +See the original Liger-Kernel repository at https://github.com/linkedin/Liger-Kernel. +""" + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +# Reference implementation +@torch.compile(dynamic=None) +def _compute_loss(logits, target_p, position_mask): + logits = logits.float() + out_logp = nn.LogSoftmax(dim=2)(logits) + plogp = target_p * out_logp + loss = -torch.sum(position_mask * plogp, 2).mean() + return loss + + +def _calculate_settings(n): + # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + + MAX_FUSED_SIZE = 131072 + BLOCK_SIZE = triton.next_power_of_2(n) + if BLOCK_SIZE > MAX_FUSED_SIZE: + raise RuntimeError( + f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}." + ) + + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + + # AMD GPU (ROCm) + if hasattr(torch.version, "hip") and torch.version.hip is not None: + num_warps //= 2 + + return BLOCK_SIZE, num_warps + + +@triton.jit +def log_softmax_forward_kernel( + logits_ptr, + logits_stride, + target_ptr, + target_stride, + position_mask_ptr, + position_mask_stride, + loss_ptr, + loss_stride, + m_ptr, + d_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + program_id = tl.program_id(0).to(tl.int64) + logits_ptr += program_id * logits_stride + target_ptr += program_id * target_stride + position_mask_ptr += program_id * position_mask_stride + position_mask = tl.load(position_mask_ptr) + if position_mask == 0: + return + + m = float("-inf") + d = 0.0 + + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + logits_block = tl.load( + logits_ptr + offsets, mask=mask, other=float("-inf") + ).cast(tl.float32) + block_max = tl.max(tl.where(mask, logits_block, float("-inf"))) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum( + tl.where(mask, tl.exp(logits_block - m_new), 0.0) + ) + m = m_new + + loss = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + logits_block = tl.load(logits_ptr + offsets, mask=mask, other=0.0).cast( + tl.float32 + ) + target_block = tl.load(target_ptr + offsets, mask=mask, other=0.0).cast( + tl.float32 + ) + # log-softmax: log(exp(x - max) / sum) = (x - max) - log(sum) + normalized_logits = logits_block - m + log_normalizer = tl.log(d) + log_softmax_logits = normalized_logits - log_normalizer + weighted_log_prob = target_block * log_softmax_logits + loss += tl.sum(tl.where(mask, weighted_log_prob, 0.0)) + + loss_ptr += program_id * loss_stride + m_ptr += program_id + d_ptr += program_id + tl.store(loss_ptr, -loss) + tl.store(m_ptr, m.to(tl.float32)) + tl.store(d_ptr, d.to(tl.float32)) + + +@triton.jit +def log_softmax_backward_kernel( + logits_ptr, + logits_stride, + target_ptr, + target_stride, + position_mask_ptr, + grad_output_ptr, + scaling_factor, + m_ptr, + d_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + program_id = tl.program_id(0).to(tl.int64) + logits_ptr += program_id * logits_stride + target_ptr += program_id * target_stride + position_mask_ptr += program_id + + position_mask = tl.load(position_mask_ptr) + if position_mask == 0: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + tl.store(logits_ptr + offsets, 0.0, mask=mask) + return + + m_ptr += program_id + d_ptr += program_id + m = tl.load(m_ptr).to(tl.float32) + d = tl.load(d_ptr).to(tl.float32) + grad_output = tl.load(grad_output_ptr).to(tl.float32) + grad_output = grad_output * scaling_factor + + # First pass: compute sum of (target * grad_output) + target_grad_sum = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + target_block = tl.load(target_ptr + offsets, mask=mask, other=0.0).cast( + tl.float32 + ) + target_grad_sum += tl.sum(tl.where(mask, target_block * grad_output, 0.0)) + + # Second pass: compute log-softmax gradients + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + logits_block = tl.load(logits_ptr + offsets, mask=mask, other=0.0).cast( + tl.float32 + ) + target_block = tl.load(target_ptr + offsets, mask=mask, other=0.0).cast( + tl.float32 + ) + softmax_prob = tl.exp(logits_block - m) / d + normalized_grad = softmax_prob * target_grad_sum + grad_block = -(target_block * grad_output - normalized_grad) + tl.store(logits_ptr + offsets, grad_block.to(tl.float32), mask=mask) + + +class LogSoftmaxLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, target, position_mask): + B, T, V = logits.shape + loss = torch.zeros((B * T, 1), device=logits.device) + logits_flat = logits.contiguous().view(B * T, V) + target_flat = target.contiguous().view(B * T, V) + position_mask_flat = position_mask.contiguous().view(B * T, 1).bool() + grid = (B * T,) + m = torch.zeros((B * T,), device=logits.device, dtype=torch.float32) + d = torch.zeros((B * T,), device=logits.device, dtype=torch.float32) + BLOCK_SIZE, num_warps = _calculate_settings(V) + log_softmax_forward_kernel[grid]( + logits_flat, + logits_flat.stride(0), + target_flat, + target_flat.stride(0), + position_mask_flat, + position_mask_flat.stride(0), + loss, + loss.stride(0), + m, + d, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + ctx.save_for_backward(logits.detach(), target, position_mask, m, d) + return loss.squeeze(1).mean() + + @staticmethod + def backward(ctx, grad_output): + logits, target, position_mask, m, d = ctx.saved_tensors + B, T, V = logits.shape + scaling_factor = 1.0 / (B * T) + logits = logits.contiguous().view(B * T, V) + target = target.contiguous().view(B * T, V) + position_mask = position_mask.contiguous().view(B * T, 1).bool() + grid = (B * T,) + BLOCK_SIZE, num_warps = _calculate_settings(V) + log_softmax_backward_kernel[grid]( + logits, + logits.stride(0), + target, + target.stride(0), + position_mask, + grad_output, + scaling_factor, + m, + d, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + logits = logits.view(B, T, V) + return logits, None, None, None, None + + +if __name__ == "__main__": + device = "cuda" + B, T, V = 1, 1024, 16000 + logits = torch.randn(B, T, V, device=device, requires_grad=True) + logits2 = logits.clone().detach().requires_grad_(True) + target = torch.randn(B, T, V, device=device) + position_mask = torch.randint(0, 2, (B, T, 1), dtype=torch.bool, device=device) + position_mask = torch.ones((B, T, 1), dtype=torch.bool, device=device) + output1 = LogSoftmaxLoss.apply(logits, target, position_mask) + output2 = _compute_loss(logits2, target, position_mask) + torch.testing.assert_close(output1, output2, rtol=1e-4, atol=1e-4) + output1.backward() + output2.backward() + torch.testing.assert_close(logits.grad, logits2.grad, rtol=1e-4, atol=1e-4) diff --git a/SpecForge-ext/specforge/data/__pycache__/parse.cpython-311.pyc b/SpecForge-ext/specforge/data/__pycache__/parse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c85a78c775329928463a94fc0ddad3b170ecaaa8 Binary files /dev/null and b/SpecForge-ext/specforge/data/__pycache__/parse.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/data/__pycache__/utils.cpython-311.pyc b/SpecForge-ext/specforge/data/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de9ffa9b1edd0e937d6b3750d99ff7d023caee78 Binary files /dev/null and b/SpecForge-ext/specforge/data/__pycache__/utils.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/data/parse.py b/SpecForge-ext/specforge/data/parse.py new file mode 100644 index 0000000000000000000000000000000000000000..073e882a33016d3a76af9999809a19642fa0d6b0 --- /dev/null +++ b/SpecForge-ext/specforge/data/parse.py @@ -0,0 +1,341 @@ +import json +import re +import warnings +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple + +import torch +from transformers import PreTrainedTokenizer + +from .template import ChatTemplate + +__all__ = ["GeneralParser", "HarmonyParser"] + + +class Parser(ABC): + + def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): + self.tokenizer = tokenizer + self.chat_template = chat_template + + @abstractmethod + def parse( + self, conversation: "Conversation", max_length: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parse the conversation into a list of tensors. + + Args: + conversation: The conversation to parse. + + Returns: + A list of tensors: [input_ids, loss_mask] + """ + + +_harmony_encoding = None + + +class GeneralParser(Parser): + + def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): + super().__init__(tokenizer, chat_template) + self.system_prompt = chat_template.system_prompt + self.user_message_separator = f"{chat_template.end_of_turn_token}" + self.assistant_message_separator = f"{chat_template.assistant_header}" + self.set_assistant_pattern(chat_template) + + def apply_chat_template(self, messages, **kwargs) -> str: + conversation = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False, **kwargs + ) + return conversation + + def set_assistant_pattern(self, chat_template: ChatTemplate): + if chat_template.assistant_pattern_type == "longcat": + self.assistant_pattern = ( + re.escape(self.assistant_message_separator) + + r"([\s\S]*?(?:" + + re.escape("[Round ") + + r"\d+" + + re.escape("] USER:") + + "|$))" + ) + else: + self.assistant_pattern = ( + re.escape(self.assistant_message_separator) + + r"([\s\S]*?(?:" + + re.escape(self.chat_template.end_of_turn_token) + + "|$))" + ) + + def parse( + self, + conversation: "Conversation", + max_length: int, + preformatted: bool = False, + train_only_last_turn: bool = False, + **kwargs, + ) -> Dict[str, List[torch.Tensor]]: + if not preformatted: + messages = [] + + if conversation[0]["role"] == "system": + warnings.warn( + f"The first message is from system, we will use the system prompt from the data and ignore the system prompt from the template" + ) + messages.append( + {"role": "system", "content": conversation[0]["content"]} + ) + conversation = conversation[1:] + else: + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + + for j, sentence in enumerate(conversation): + role = sentence["role"] + if j == 0: + if role != "user": + warnings.warn( + f"Conversation must start with a 'user' role, but found '{role}'. Conversation truncated." + ) + break + else: + prev_role = conversation[j - 1]["role"] + if role == "tool" and prev_role not in ["assistant", "tool"]: + warnings.warn( + f"A 'tool' message must follow an 'assistant' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated." + ) + break + if role == "assistant" and prev_role not in ["user", "tool"]: + warnings.warn( + f"An 'assistant' message must follow a 'user' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated." + ) + break + tool_calls = sentence.get("tool_calls") + if isinstance(tool_calls, str): + try: + sentence["tool_calls"] = json.loads(tool_calls) + except json.JSONDecodeError: + warnings.warn(f"Failed to parse tool_calls JSON: {tool_calls}") + break + messages.append(sentence) + + try: + conversation = self.apply_chat_template(messages, **kwargs) + except (ValueError, TypeError): + # Fallback rendering for tokenizers without built-in chat_template + warnings.warn( + "Tokenizer does not have a chat_template, using fallback rendering." + ) + parts = [] + bos_token = getattr(self.tokenizer, "bos_token", None) + user_header = self.chat_template.user_header or "" + assistant_header = self.chat_template.assistant_header or "" + end_of_turn = self.chat_template.end_of_turn_token or "" + + # Add BOS token at the start + if bos_token: + parts.append(bos_token) + + for msg in messages: + if msg["role"] == "system": + parts.append(msg["content"]) + elif msg["role"] == "user": + parts.append(f"{user_header}{msg['content']}") + elif msg["role"] == "assistant": + parts.append(f"{assistant_header}{msg['content']}{end_of_turn}") + conversation = "".join(parts) + + if not self.tokenizer.pad_token_id: + self.tokenizer.pad_token_id = self.tokenizer.unk_token_id + + # get input_ids + encoding = self.tokenizer( + conversation, + max_length=max_length, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + input_ids = encoding.input_ids[0] + loss_mask = torch.zeros(len(input_ids), dtype=torch.long) + + matches = list(re.finditer(self.assistant_pattern, conversation, re.DOTALL)) + if train_only_last_turn and matches: + matches = [matches[-1]] # Only keep the last match + + for match in matches: + content_start_char = match.start(1) + content_end_char = match.end(1) + + # --- Core Alternative Operation: Calculate Token Index Based on Prefix String Length --- + # Encode the text "assistant start", the length of which is the position of the starting token. + prefix_ids = self.tokenizer.encode( + conversation[:content_start_char], + add_special_tokens=False, + truncation=True, + max_length=max_length, + ) + # Encodes the text "assistant end", the length of which is the position of the end token. + full_ids = self.tokenizer.encode( + conversation[:content_end_char], + add_special_tokens=False, + truncation=True, + max_length=max_length, + ) + + start_token_idx = len(prefix_ids) + end_token_idx = len(full_ids) + + # Handling out-of-bounds errors caused by truncation + actual_start = min(start_token_idx, len(input_ids)) + actual_end = min(end_token_idx, len(input_ids)) + + if actual_start < actual_end: + loss_mask[actual_start:actual_end] = 1 + return input_ids, loss_mask + + +class HarmonyParser(Parser): + def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): + super().__init__(tokenizer, chat_template) + self.reasoning_levels = ["low", "medium", "high"] + self.default_reasoning_level = "low" + + def build_single_turn_prompt( + self, + prompt_text: str, + role: str, + content: str, + ) -> str: + """Embed user message into the required prompt template.""" + if role == "system": + prompt_text = f"<|start|>system<|message|>{content}<|end|>" + elif role == "assistant_reasoning_effort": + prompt_text = f"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: 2025-06-28\n\nReasoning: {content.lower()}\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|>" + elif role == "user": + prompt_text += f"<|start|>user<|message|>{content}<|end|>" + elif role == "assistant_analysis": + prompt_text += ( + f"<|start|>assistant<|channel|>analysis<|message|>{content}<|end|>" + ) + elif role == "assistant_commentary": + prompt_text += ( + f"<|start|>assistant<|channel|>commentary<|message|>{content}<|end|>" + ) + elif role == "assistant_final": + prompt_text += ( + f"<|start|>assistant<|channel|>final<|message|>{content}<|end|>" + ) + else: + raise ValueError(f"Unknown role: {role}") + return prompt_text + + def parse( + self, + conversation: "Conversation", + max_length: int, + preformatted: bool = False, + train_only_last_turn: bool = False, + ) -> List[torch.Tensor]: + # conversation = process_harmony_conversations(conversation) + if not preformatted: + prompt_text = "" + for j, message in enumerate(conversation): + if j == 0 and ( + message["role"] != "system" + or message["role"] != "assistant_reasoning_effort" + ): + prompt_text = self.build_single_turn_prompt( + prompt_text, + "assistant_reasoning_effort", + self.default_reasoning_level, + ) + prompt_text = self.build_single_turn_prompt( + prompt_text, message["role"], message["content"] + ) + conversation = prompt_text + + if not self.tokenizer.pad_token_id: + self.tokenizer.pad_token_id = self.tokenizer.unk_token_id + + encoding = self.tokenizer( + conversation, + return_offsets_mapping=True, + max_length=max_length, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + input_ids = encoding.input_ids[0] + offsets = encoding.offset_mapping[0] + loss_mask = torch.zeros(len(input_ids), dtype=torch.long) + + # Find spans of assistant responses using regex + # We match `<|start|>assistant` and only extract the content following it. + # This continues until `<|start|>user<|message|>` appears, or until the end of the string. + pattern = re.compile( + r"<\|start\|>assistant([\s\S]*?)(?=<\|start\|>user<\|message\|>|$)" + ) + + # Find all matching segments + matches = list(pattern.finditer(conversation)) + if train_only_last_turn and matches: + matches = [matches[-1]] # Only keep the last match + + for match in matches: + # match.start(0) is the start index of the full match (including `<|start|>assistant`) + # match.start(1) is the start index of the first capture group (excluding `<|start|>assistant`) + # match.end(1) is the end index of the content + start_char = match.start(1) + end_char = match.end(1) + + # Map character indices to token indices + for idx, (ts, te) in enumerate(offsets): + # Set mask to 1 only if the token's character range falls entirely within the "content area" + if ts >= start_char and te <= end_char: + loss_mask[idx] = 1 + + return input_ids, loss_mask + + +class ThinkingParser(GeneralParser): + def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): + super().__init__(tokenizer, chat_template) + + def apply_chat_template(self, messages, **kwargs) -> str: + if messages[-1]["role"] == "assistant": + conversation_history = self.tokenizer.apply_chat_template( + messages[:-1], + tokenize=False, + add_generation_prompt=True, + add_special_tokens=False, + **kwargs, + ) + conversation = ( + conversation_history + + messages[-1]["content"] + + self.chat_template.end_of_turn_token + ) + return conversation + else: + raise Exception( + f"The last message is not assistant but {messages[-1]['role']}" + ) + + def parse( + self, + conversation: "Conversation", + max_length: int, + preformatted: bool = False, + train_only_last_turn: bool = False, + **kwargs, + ) -> Dict[str, List[torch.Tensor]]: + if self.chat_template.enable_thinking: + kwargs["enable_thinking"] = True + else: + pass + return super().parse( + conversation, max_length, preformatted, train_only_last_turn, **kwargs + ) diff --git a/SpecForge-ext/specforge/layers/__pycache__/embedding.cpython-311.pyc b/SpecForge-ext/specforge/layers/__pycache__/embedding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1e05c6124bd35152381bb3e9023c3062c3b5d85 Binary files /dev/null and b/SpecForge-ext/specforge/layers/__pycache__/embedding.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/layers/__pycache__/linear.cpython-311.pyc b/SpecForge-ext/specforge/layers/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..659086f455b8a03c9e728e043c6f4e66d2af1048 Binary files /dev/null and b/SpecForge-ext/specforge/layers/__pycache__/linear.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/layers/ring/__pycache__/__init__.cpython-311.pyc b/SpecForge-ext/specforge/layers/ring/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9da4f8dba5a762d3e8576bcd1e93512157be4e15 Binary files /dev/null and b/SpecForge-ext/specforge/layers/ring/__pycache__/__init__.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/layers/ring/__pycache__/ring_flash_attn.cpython-311.pyc b/SpecForge-ext/specforge/layers/ring/__pycache__/ring_flash_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0c0c25533f21da49d727a9c2e5c62436f6cd9a3 Binary files /dev/null and b/SpecForge-ext/specforge/layers/ring/__pycache__/ring_flash_attn.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/layers/ring/__pycache__/utils.cpython-311.pyc b/SpecForge-ext/specforge/layers/ring/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d55a93145ea0aa5a6464e071c3f289687ee066d4 Binary files /dev/null and b/SpecForge-ext/specforge/layers/ring/__pycache__/utils.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/layers/ring/utils.py b/SpecForge-ext/specforge/layers/ring/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..14d6a7817dc9358d64e4b93465e1b1f33d870923 --- /dev/null +++ b/SpecForge-ext/specforge/layers/ring/utils.py @@ -0,0 +1,119 @@ +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +__all__ = ["update_out_and_lse", "RingComm"] + + +@torch.jit.script +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse( + slice_out, slice_lse, block_out, block_lse + ) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +@torch.jit.script +def flatten_varlen_lse(lse, cu_seqlens): + new_lse = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse.append(lse[i, :, : end - start]) + return torch.cat(new_lse, dim=1) + + +@torch.jit.script +def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + num_seq = len(cu_seqlens) - 1 + num_head = lse.shape[-2] + new_lse = torch.empty( + (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device + ) + for i in range(num_seq): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse[i, : end - start] = lse[start:end] + return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() + + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + if process_group is not None: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv( + self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + # print(f"send_recv: empty_like {to_send.shape}") + else: + res = recv_tensor + + send_op = dist.P2POp( + dist.isend, to_send, self.send_rank, group=self._process_group + ) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] diff --git a/SpecForge-ext/specforge/modeling/__pycache__/__init__.cpython-311.pyc b/SpecForge-ext/specforge/modeling/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..785cd1ca1e181269ee566a6b6de164426ea84625 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/__pycache__/__init__.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/__pycache__/_mask_utils.cpython-311.pyc b/SpecForge-ext/specforge/modeling/__pycache__/_mask_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a545f55eff8db69545278e58d703b9b6bfe836af Binary files /dev/null and b/SpecForge-ext/specforge/modeling/__pycache__/_mask_utils.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/__pycache__/auto.cpython-311.pyc b/SpecForge-ext/specforge/modeling/__pycache__/auto.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a30ea97bd3c1f2de37ef3a4cf8efc42267446c48 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/__pycache__/auto.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/draft/__init__.py b/SpecForge-ext/specforge/modeling/draft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bdc7e2f6fa02407dc4e4bab9c4b1e252c10aa62 --- /dev/null +++ b/SpecForge-ext/specforge/modeling/draft/__init__.py @@ -0,0 +1,17 @@ +from .base import Eagle3DraftModel +from .dflash import ( + DFlashDraftModel, + build_target_layer_ids, + extract_context_feature, + sample, +) +from .llama3_eagle import LlamaForCausalLMEagle3 + +__all__ = [ + "Eagle3DraftModel", + "DFlashDraftModel", + "LlamaForCausalLMEagle3", + "build_target_layer_ids", + "extract_context_feature", + "sample", +] diff --git a/SpecForge-ext/specforge/modeling/draft/__pycache__/__init__.cpython-311.pyc b/SpecForge-ext/specforge/modeling/draft/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a4282b0c47902151f0fd7f23884abcb054df350 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/draft/__pycache__/__init__.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/draft/__pycache__/base.cpython-311.pyc b/SpecForge-ext/specforge/modeling/draft/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01c2e6ec5b19b956365f7c0022d40727d18feb35 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/draft/__pycache__/base.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/draft/__pycache__/dflash.cpython-311.pyc b/SpecForge-ext/specforge/modeling/draft/__pycache__/dflash.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5e71ff5b72e6c8f68b560500da8ab81ef052735 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/draft/__pycache__/dflash.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/draft/__pycache__/flex_attention.cpython-311.pyc b/SpecForge-ext/specforge/modeling/draft/__pycache__/flex_attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a145ef948c245040e72e75e14be65eea000bf802 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/draft/__pycache__/flex_attention.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/draft/__pycache__/llama3_eagle.cpython-311.pyc b/SpecForge-ext/specforge/modeling/draft/__pycache__/llama3_eagle.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ff51dcd7a9c25e5343da91ec10f91f849c1141b Binary files /dev/null and b/SpecForge-ext/specforge/modeling/draft/__pycache__/llama3_eagle.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/draft/base.py b/SpecForge-ext/specforge/modeling/draft/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b5584a759d78a072903e0e76999b1674a62f0a88 --- /dev/null +++ b/SpecForge-ext/specforge/modeling/draft/base.py @@ -0,0 +1,189 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in HuggingFace Transformers. +# Portions of this code are adapted from: +# - https://github.com/EleutherAI/gpt-neox (Apache License 2.0) +# - https://github.com/huggingface/transformers (Apache License 2.0) +# - https://github.com/SafeAILab/EAGLE (Apache License 2.0) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import json +import os +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from huggingface_hub import snapshot_download +from safetensors import safe_open +from transformers.cache_utils import Cache +from transformers.modeling_utils import PreTrainedModel + +from specforge.modeling._mask_utils import _expand_mask, _make_causal_mask + + +class Eagle3DraftModel(PreTrainedModel, ABC): + """ + This is the base class for the Eagle3 draft model implementation. The child class needs to implement + the abstract methods to support training with TTT. + """ + + @abstractmethod + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + """ + Embed the input ids. + """ + + @abstractmethod + def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Project the concatenated hidden states from the high, medium and low layers to the target hidden size. + """ + + @abstractmethod + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Compute the logits of the draft model. + """ + + def prepare_decoder_attention_mask( + self, + attention_mask: torch.Tensor, + hidden_states: torch.Tensor, + batch_size: int, + seq_length: int, + past_key_values_length: int, + ) -> torch.Tensor: + """ + Prepare the attention mask of the draft model. + """ + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if seq_length > 1: + combined_attention_mask = _make_causal_mask( + (batch_size, seq_length), + hidden_states.dtype, + device=hidden_states.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, hidden_states.dtype, tgt_len=seq_length + ).to(hidden_states.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + return combined_attention_mask + + @abstractmethod + def backbone( + self, + input_embeds: torch.Tensor, + hidden_states: torch.Tensor, + cache_hidden: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Optional[Cache] = None, + use_cache: bool = True, + ) -> torch.Tensor: + """ + The backbone of the draft model. + """ + + def freeze_embedding(self) -> None: + """ + Freeze the embeddings of the draft model so that they are not updated during training. + """ + self.embed_tokens.weight.requires_grad = False + + @torch.no_grad() + def load_embedding( + self, model_path: str, embedding_key: str = "model.embed_tokens.weight" + ) -> None: + """ + Load the embedding of the draft model. + + Args: + model_path (str): Path to the target model. Can be either a Hugging Face + repository ID or a local directory path containing the model files. + """ + if os.path.exists(model_path): + # model_path is a local directory + # check if there is file ending with index.json + glob_path = os.path.join(model_path, "*.index.json") + index_json_path = glob.glob(glob_path) + + if len(index_json_path) == 0: + # No index.json found, look for single model file + safetensors_path = os.path.join(model_path, "model.safetensors") + if os.path.exists(safetensors_path): + with safe_open(safetensors_path, framework="pt") as f: + self.embed_tokens.weight.copy_(f.get_tensor(embedding_key)) + return + + pytorch_model_path = os.path.join(model_path, "pytorch_model.bin") + if os.path.exists(pytorch_model_path): + state_dict = torch.load(pytorch_model_path, map_location="cpu") + self.embed_tokens.weight.copy_(state_dict[embedding_key]) + return + + raise FileNotFoundError( + f"No index.json, model.safetensors or pytorch_model.bin found in {model_path}" + ) + if len(index_json_path) > 1: + raise FileNotFoundError( + f"Multiple index.json files found in {model_path}" + ) + index_json_path = index_json_path[0] + + with open(index_json_path, "r") as f: + index_json = json.load(f) + ckpt_file = index_json["weight_map"][embedding_key] + + if ckpt_file.endswith(".safetensors"): + with safe_open( + os.path.join(model_path, ckpt_file), framework="pt" + ) as f: + emb_tokens = f.get_tensor(embedding_key) + else: + state_dict = torch.load(os.path.join(model_path, ckpt_file)) + emb_tokens = state_dict[embedding_key] + self.embed_tokens.weight.copy_(emb_tokens) + else: + # this is the case where model_path is a huggingface repository + # we first need to locate its local cache + local_cache_path = snapshot_download(repo_id=model_path) + self.load_embedding(local_cache_path, embedding_key) + + def load_vocab_mapping(self, file_path: str) -> None: + """ + Load the vocab buffers of the draft model. + + Args: + file_path (str): The path to the vocab mapping file. + """ + assert hasattr(self, "t2d") and hasattr( + self, "d2t" + ), "t2d and d2t buffersare not found in the draft model, please check your draft model implementation" + vocab_mapping = torch.load(file_path) + self.t2d.copy_(vocab_mapping["t2d"]) + self.d2t.copy_(vocab_mapping["d2t"]) + self.vocab_mapping_loaded = True diff --git a/SpecForge-ext/specforge/modeling/draft/dflash.py b/SpecForge-ext/specforge/modeling/draft/dflash.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc5f83cb3469e1db14b432ad3ba381037c90dca --- /dev/null +++ b/SpecForge-ext/specforge/modeling/draft/dflash.py @@ -0,0 +1,378 @@ +from typing import Callable, Optional + +import torch +from torch import nn +from transformers import DynamicCache +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.qwen3.modeling_qwen3 import ( + ALL_ATTENTION_FUNCTIONS, + FlashAttentionKwargs, + GradientCheckpointingLayer, + Qwen3Config, + Qwen3MLP, + Qwen3PreTrainedModel, + Qwen3RMSNorm, + Qwen3RotaryEmbedding, + eager_attention_forward, + rotate_half, +) +from typing_extensions import Tuple, Unpack + + +def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: + if temperature < 1e-5: + return torch.argmax(logits, dim=-1) + bsz, seq_len, vocab_size = logits.shape + logits = logits.view(-1, vocab_size) + logits = logits / temperature + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_len = q.size(-2) + q_embed = (q * cos[..., -q_len:, :]) + (rotate_half(q) * sin[..., -q_len:, :]) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3DFlashAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = False + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.sliding_window = ( + config.sliding_window + if config.layer_types[layer_idx] == "sliding_attention" + else None + ) + + def forward( + self, + hidden_states: torch.Tensor, + target_hidden: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + bsz, q_len = hidden_states.shape[:-1] + ctx_len = target_hidden.shape[1] + q = self.q_proj(hidden_states) + q = q.view(bsz, q_len, -1, self.head_dim) + q = self.q_norm(q).transpose(1, 2) + k_ctx = self.k_proj(target_hidden) + k_noise = self.k_proj(hidden_states) + v_ctx = self.v_proj(target_hidden) + v_noise = self.v_proj(hidden_states) + k = torch.cat([k_ctx, k_noise], dim=1).view( + bsz, ctx_len + q_len, -1, self.head_dim + ) + v = torch.cat([v_ctx, v_noise], dim=1).view( + bsz, ctx_len + q_len, -1, self.head_dim + ) + k = self.k_norm(k).transpose(1, 2) + v = v.transpose(1, 2) + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs) + attn_fn: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attn_fn( + self, + q, + k, + v, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3DFlashDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3DFlashAttention(config=config, layer_idx=layer_idx) + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + target_hidden: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + )[0] + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int): + if num_draft_layers == 1: + return [(num_target_layers // 2)] + start = 1 + end = num_target_layers - 3 + span = end - start + target_layer_ids = [ + int(round(start + (i * span) / (num_draft_layers - 1))) + for i in range(num_draft_layers) + ] + return target_layer_ids + + +def extract_context_feature( + hidden_states: list[torch.Tensor], + layer_ids: Optional[list[int]], +) -> torch.Tensor: + offset = 1 + selected_states = [] + for layer_id in layer_ids: + selected_states.append(hidden_states[layer_id + offset]) + target_hidden = torch.cat(selected_states, dim=-1) + return target_hidden + + +class DFlashDraftModel(Qwen3PreTrainedModel): + config_class = Qwen3Config + _no_split_modules = ["Qwen3DFlashDecoderLayer"] + + def __init__(self, config) -> None: + super().__init__(config) + self.config = config + self.layers = nn.ModuleList( + [ + Qwen3DFlashDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.target_layer_ids = build_target_layer_ids( + config.num_target_layers, config.num_hidden_layers + ) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config) + self.fc = nn.Linear( + len(self.target_layer_ids) * config.hidden_size, + config.hidden_size, + bias=False, + ) + self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.block_size = config.block_size + self.post_init() + + def forward( + self, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + noise_embedding: Optional[torch.Tensor] = None, + target_hidden: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + **kwargs, + ) -> CausalLMOutputWithPast: + hidden_states = noise_embedding + target_hidden = self.hidden_norm(self.fc(target_hidden)) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for layer in self.layers: + hidden_states = layer( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + return self.norm(hidden_states) + + @torch.inference_mode() + def spec_generate( + self, + target: nn.Module, + input_ids: torch.LongTensor, + mask_token_id: int, + max_new_tokens: int, + stop_token_ids: list[int], + temperature: float, + ): + self.eval() + target.eval() + num_input_tokens = input_ids.shape[1] + max_length = num_input_tokens + max_new_tokens + + block_size = self.block_size + output_ids = torch.full( + (1, max_length + block_size), + mask_token_id, + dtype=torch.long, + device=target.device, + ) + position_ids = torch.arange( + output_ids.shape[1], device=target.device + ).unsqueeze(0) + + past_key_values_target = DynamicCache() + past_key_values_draft = DynamicCache() + + # Prefill stage + output = target( + input_ids, + position_ids=position_ids[:, :num_input_tokens], + past_key_values=past_key_values_target, + use_cache=True, + logits_to_keep=1, + output_hidden_states=True, + ) + + output_ids[:, :num_input_tokens] = input_ids + output_ids[:, num_input_tokens : num_input_tokens + 1] = sample( + output.logits, temperature + ) + target_hidden = extract_context_feature( + output.hidden_states, self.target_layer_ids + ) + + # Decode stage + acceptance_lengths = [] + start = input_ids.shape[1] + while start < max_length: + block_output_ids = output_ids[:, start : start + block_size].clone() + block_position_ids = position_ids[:, start : start + block_size] + noise_embedding = target.model.embed_tokens(block_output_ids) + draft_logits = target.lm_head( + self( + target_hidden=target_hidden, + noise_embedding=noise_embedding, + position_ids=position_ids[ + :, past_key_values_draft.get_seq_length() : start + block_size + ], + past_key_values=past_key_values_draft, + use_cache=True, + is_causal=False, + )[:, -block_size + 1 :, :] + ) + past_key_values_draft.crop(start) + block_output_ids[:, 1:] = sample(draft_logits) + + output = target( + block_output_ids, + position_ids=block_position_ids, + past_key_values=past_key_values_target, + use_cache=True, + output_hidden_states=True, + ) + + posterior = sample(output.logits, temperature) + acceptance_length = ( + (block_output_ids[:, 1:] == posterior[:, :-1]) + .cumprod(dim=1) + .sum(dim=1)[0] + .item() + ) + output_ids[:, start : start + acceptance_length + 1] = block_output_ids[ + :, : acceptance_length + 1 + ] + output_ids[:, start + acceptance_length + 1] = posterior[ + :, acceptance_length + ] + start += acceptance_length + 1 + past_key_values_target.crop(start) + target_hidden = extract_context_feature( + output.hidden_states, self.target_layer_ids + )[:, : acceptance_length + 1, :] + acceptance_lengths.append(acceptance_length + 1) + if stop_token_ids is not None and any( + stop_token_id in output_ids[:, num_input_tokens:] + for stop_token_id in stop_token_ids + ): + break + output_ids = output_ids[:, :max_length] + output_ids = output_ids[:, output_ids[0] != mask_token_id] + if stop_token_ids is not None: + stop_token_ids = torch.tensor(stop_token_ids, device=output_ids.device) + stop_token_indices = torch.isin( + output_ids[0][num_input_tokens:], stop_token_ids + ).nonzero(as_tuple=True)[0] + if stop_token_indices.numel() > 0: + output_ids = output_ids[ + :, : num_input_tokens + stop_token_indices[0] + 1 + ] + + return output_ids diff --git a/SpecForge-ext/specforge/modeling/draft/flex_attention.py b/SpecForge-ext/specforge/modeling/draft/flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..50ca5f54dc658106c22d7a8a95553bf346b33525 --- /dev/null +++ b/SpecForge-ext/specforge/modeling/draft/flex_attention.py @@ -0,0 +1,127 @@ +import torch +import torch._dynamo as dynamo +from torch.nn.attention.flex_attention import ( + create_block_mask, + flex_attention, + or_masks, +) +from transformers.utils import is_torchdynamo_compiling + +dynamo.config.recompile_limit = 64 + + +# Reference Implementation https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/flex_attention.py +class WrappedFlexAttention: + """ + We are doing a singleton class so that flex attention is compiled once when it's first called. + """ + + _instance = None + _is_flex_compiled = False + _compiled_flex_attention = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + # Create a new instance if one doesn't already exist + cls._instance = super().__new__(cls) + return cls._instance + + @torch.compiler.disable(recursive=False) + def __init__(self): + """ + Initialize or update the singleton instance. + """ + if not self._is_flex_compiled: + # Enable dynamic shapes to handle different input sizes + self._compiled_flex_attention = torch.compile( + flex_attention, + # mode="max-autotune-no-cudagraphs", + ) + self._is_flex_compiled = True + + def __call__(self): + return self._compiled_flex_attention + + +def compile_friendly_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + **kwargs, +) -> torch.Tensor: + # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention + # Do not use compiled version if already compiling forward (it raises issues) + flex_attention_compiled = ( + WrappedFlexAttention()() if not is_torchdynamo_compiling() else flex_attention + ) + return flex_attention_compiled( + query, + key, + value, + **kwargs, + ) + + +class WrappedCreateBlockMask: + _instance = None + _is_create_block_mask_compiled = False + _compiled_create_block_mask = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @torch.compiler.disable(recursive=False) + def __init__(self): + if not self._is_create_block_mask_compiled: + self._compiled_create_block_mask = torch.compile(create_block_mask) + self._is_create_block_mask_compiled = True + + def __call__(self): + return self._compiled_create_block_mask + + +def compile_friendly_create_block_mask( + mask_mod, + B, + H, + Q_LEN, + KV_LEN, + device, +): + create_block_mask_compiled = ( + WrappedCreateBlockMask()() + if not is_torchdynamo_compiling() + else create_block_mask + ) + return create_block_mask_compiled( + mask_mod, + B, + H, + Q_LEN, + KV_LEN, + device, + ) + + +def generate_eagle3_mask( + seq_lengths: torch.Tensor, Q_LEN: int, KV_LEN: int, lck: int = 0 +): + + def causal_mask(b, h, q_idx, kv_idx): + # Causal will keep shrinking by 1 diagnol due to appended suffix + # Shirnk the causal by diagnol + causal_mask = q_idx >= kv_idx + padding_mask = (kv_idx < seq_lengths[b]) & (q_idx < seq_lengths[b]) + return causal_mask & padding_mask + + def suffix_mask(b, h, q_idx, kv_idx): + suffix_mask = kv_idx >= Q_LEN + padding_mask = kv_idx % Q_LEN < seq_lengths[b] + diagnol_mask = (kv_idx - q_idx) % Q_LEN == 0 + return suffix_mask & padding_mask & diagnol_mask + + mask_mod = or_masks(causal_mask, suffix_mask) + mask_mod.__name__ = f"eagle3_mask_Q_{Q_LEN}_KV_{KV_LEN}_lck_{lck}" + return mask_mod diff --git a/SpecForge-ext/specforge/modeling/draft/llama3_eagle.py b/SpecForge-ext/specforge/modeling/draft/llama3_eagle.py new file mode 100644 index 0000000000000000000000000000000000000000..552a3cf86e64916b53efffcccdb450d50257fbcc --- /dev/null +++ b/SpecForge-ext/specforge/modeling/draft/llama3_eagle.py @@ -0,0 +1,1448 @@ +import math +import warnings +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache +from transformers.models.llama.configuration_llama import LlamaConfig +from yunchang.comm import SeqAllToAll4D + +from specforge.modeling.draft.flex_attention import ( + compile_friendly_create_block_mask, + compile_friendly_flex_attention, + generate_eagle3_mask, +) +from specforge.utils import print_with_rank + +from ...distributed import get_sp_ring_group, get_sp_ulysses_group +from ...layers.ring import ring_flash_attn_func +from .base import Eagle3DraftModel + +try: + from flash_attn import flash_attn_func +except ImportError: + warnings.warn( + "flash_attn is not found, falling back to flex_attention. " + "Please install flash_attn if you want to use the flash attention backend." + ) + flash_attn_func = None + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@torch.compile(dynamic=True) +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat( + [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1 + ).unsqueeze(unsqueeze_dim) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1 + ).unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=None, + low_freq_factor=None, + high_freq_factor=None, + orig_max_position=None, + ): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + # Llama3 style rotary embedding frequency scaling + if all( + v is not None + for v in [ + scaling_factor, + low_freq_factor, + high_freq_factor, + orig_max_position, + ] + ): + print_with_rank( + f"Using Llama3 style rotary embedding with scaling_factor={scaling_factor}, low_freq_factor={low_freq_factor}, high_freq_factor={high_freq_factor}, orig_max_position={orig_max_position}" + ) + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + + low_freq_wavelen = orig_max_position / low_freq_factor + high_freq_wavelen = orig_max_position / high_freq_factor + wave_len = 2 * math.pi / inv_freq + + if low_freq_factor != high_freq_factor: + smooth = (orig_max_position / wave_len - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + else: + smooth = 0 + + new_freqs = torch.where( + wave_len < high_freq_wavelen, + inv_freq, + torch.where( + wave_len > low_freq_wavelen, + inv_freq / self.scaling_factor, + (1 - smooth) * inv_freq / self.scaling_factor + smooth * inv_freq, + ), + ) + inv_freq = new_freqs + + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings + 20, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + @torch.compile(dynamic=True) + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len and seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + +class LlamaMutiRotaryEmbedding(LlamaRotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + super().__init__(dim, max_position_embeddings, base, device) + self.scaling_factor = scaling_factor + + def forward(self, x, position_ids): + # In contrast to other models, Qwen2_5_VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = ( + self.inv_freq[None, None, :, None] + .float() + .expand(3, position_ids.shape[1], -1, 1) + ) + position_ids_expanded = position_ids[ + :, :, None, : + ].float() # shape (3, bs, 1, positions) + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.scaling_factor + sin = emb.sin() * self.scaling_factor + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min_val, max_val, dim): + if min_val == max_val: + max_val += 0.001 # Prevent singularity + linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / ( + max_val - min_val + ) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class LlamaYarnRotaryEmbedding(LlamaRotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", + (emb.cos() * _mscale)[None, None, :, :].to(dtype), + persistent=False, + ) + self.register_buffer( + "sin_cached", + (emb.sin() * _mscale)[None, None, :, :].to(dtype), + persistent=False, + ) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + if hasattr(config, "head_dim"): + self.head_dim = config.head_dim + else: + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + + self.q_proj = nn.Linear( + self.hidden_size * 2, self.num_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=getattr(self.config, "rope_theta", 10000), + ) + else: + rope_scaling = self.config.rope_scaling + + def rope_get(key, default=None): + if isinstance(rope_scaling, dict): + return rope_scaling.get(key, default) + return getattr(rope_scaling, key, default) + + scaling_type = rope_get("rope_type", rope_get("type")) + scaling_factor = rope_get("factor") + + if scaling_type == "linear": + if scaling_factor is None: + raise ValueError( + "Linear RoPE scaling requires 'factor' in rope_scaling config." + ) + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + ) + elif scaling_type == "dynamic": + if scaling_factor is None: + raise ValueError( + "Dynamic RoPE scaling requires 'factor' in rope_scaling config." + ) + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + ) + elif scaling_type == "llama3": + # for nv type + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=getattr(self.config, "rope_theta", 10000), + scaling_factor=( + scaling_factor if scaling_factor is not None else 1.0 + ), + low_freq_factor=rope_get("low_freq_factor"), + high_freq_factor=rope_get("high_freq_factor"), + orig_max_position=rope_get("original_max_position_embeddings"), + ) + elif scaling_type == "mrope": + self.rotary_emb = LlamaMutiRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings + ) + elif scaling_type == "yarn": + self.rotary_emb = LlamaYarnRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + original_max_position_embeddings=rope_get( + "original_max_position_embeddings" + ), + scaling_factor=scaling_factor, + beta_fast=rope_get("beta_fast"), + beta_slow=rope_get("beta_slow"), + mscale=rope_get("mscale"), + mscale_all_dim=rope_get("mscale_all_dim"), + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + cache_hidden: Optional[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + if cache_hidden is None: + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + cos, sin = self.rotary_emb(query_states, position_ids) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + self.config.rope_scaling["mrope_section"], + ) + else: + cos, sin = self.rotary_emb(query_states, seq_len=q_len) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None, + dropout_p=0.0, + ) + + else: + lck = len(cache_hidden[0]) + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + cos, sin = self.rotary_emb(query_states, position_ids + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + self.config.rope_scaling["mrope_section"], + ) + else: + cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + lck + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + cache_hidden[0] = cache_hidden[0] + [key_states] + cache_hidden[1] = cache_hidden[1] + [value_states] + + cache_k = cache_hidden[0] + cache_v = cache_hidden[1] + + k0 = cache_k[0] + v0 = cache_v[0] + + # causal + attn_weights = torch.matmul(query_states, k0.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + lck = len(cache_k) + + attn_weights = attn_weights + attention_mask + + for i in range(1, lck): + ki = cache_k[i] + qi = query_states + kiq = ki + + attn_weightsi = (qi * kiq).sum(-1) / math.sqrt(self.head_dim) + attn_weights = torch.cat( + (attn_weights, attn_weightsi[..., None]), dim=-1 + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights0 = attn_weights[..., :q_len] + + attn_output = torch.matmul(attn_weights0, v0) + + for i in range(1, lck): + vi = cache_v[i] + attn_weightsi = attn_weights[..., q_len + i - 1] + attn_outputi = attn_weightsi[..., None] * vi + attn_output = attn_output + attn_outputi + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +class LlamaFlexAttention(LlamaAttention): + """ + Attention layer implemented with flex attention. We keep the parameters consistent with LlamaAttention. + The used parameters are: + - hidden_states: input hidden states + - attention_mask: attention mask not expanded, straight from data loader. + - position_ids: position ids + - past_key_values: dynamic cache used for storing past key and value states. + """ + + def forward( + self, + hidden_states: torch.Tensor, + cache_hidden: Optional[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + lck = past_seen_tokens // q_len + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + cos, sin = self.rotary_emb(query_states, position_ids + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + self.config.rope_scaling["mrope_section"], + ) + else: + cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + # Keep positions ids aligned when padding so the KV cache is unaffected. + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + lck + ) + + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, past_seen_tokens + q_len, device=hidden_states.device + ) + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + + key_cache, value_cache = past_key_values.update( + key_states, + value_states, + layer_idx=0, # TODO: support multiple layers + cache_kwargs=cache_kwargs, + ) + + seq_lengths = attention_mask.sum(dim=-1) + # Shrink the attention mask to align with the padding to the right. + # This is equivalent to the shrinking logic in eagle3.py + seq_lengths -= lck + # TODO: Remove the usage of uncompiled create_block_mask after + # https://github.com/pytorch/pytorch/issues/160018 + if q_len <= 128: + create_block_mask_func = create_block_mask + flex_attention_func = flex_attention + else: + create_block_mask_func = compile_friendly_create_block_mask + flex_attention_func = compile_friendly_flex_attention + + block_mask = create_block_mask_func( + mask_mod=generate_eagle3_mask( + seq_lengths=seq_lengths, + Q_LEN=q_len, + KV_LEN=key_cache.shape[-2], + lck=lck, + ), + B=bsz, + H=1, # Rely on broadcast + Q_LEN=q_len, + KV_LEN=key_cache.shape[-2], + device=query_states.device, + ) + attn_output = flex_attention_func( + query=query_states, + key=key_cache.contiguous(), + value=value_cache.contiguous(), + block_mask=block_mask, + enable_gqa=True, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads) + attn_output = self.o_proj(attn_output) + return attn_output + + +class LlamaFlashAttention(LlamaAttention): + """ + Attention layer implemented with flash attention. We keep the parameters consistent with LlamaAttention. + The used parameters are: + - hidden_states: input hidden states + - position_ids: position ids + - cache_hidden: manual cache used for storing past key and value states + """ + + def forward( + self, + hidden_states: torch.Tensor, + cache_hidden: Optional[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ) + + lck = 0 if cache_hidden is None else len(cache_hidden[0]) + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + cos, sin = self.rotary_emb(query_states, position_ids + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + self.config.rope_scaling["mrope_section"], + unsqueeze_dim=2, + ) + else: + cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2 + ) + + if cache_hidden is not None: + cache_hidden[0] = cache_hidden[0] + [key_states] + cache_hidden[1] = cache_hidden[1] + [value_states] + + cache_k = cache_hidden[0] + cache_v = cache_hidden[1] + else: + cache_k = [key_states] + cache_v = [value_states] + + k0 = cache_k[0] + v0 = cache_v[0] + + assert ( + flash_attn_func is not None + ), "flash_attn is not installed, please install flash_attn if you want to use the flash attention backend" + attn_output, lse, _ = flash_attn_func( + query_states, + k0, + v0, + dropout_p=0.0, + softmax_scale=1.0 / math.sqrt(self.head_dim), + causal=True, + return_attn_probs=True, + ) + lse = lse.transpose(1, 2) + + lck = len(cache_k) + if lck > 1: + q_shape_expanded = ( + bsz, + q_len, + self.num_key_value_heads, + self.num_key_value_groups, + self.head_dim, + ) + attn_outputs = [attn_output.view(q_shape_expanded)] + lses = [lse.view(q_shape_expanded[:-1])] + + for i in range(1, lck): + ki = cache_k[i].unsqueeze(-2) + qi = query_states.view(q_shape_expanded) + vi = cache_v[i].unsqueeze(-2) + + attn_outputs.append(vi) + lses.append((qi * ki).sum(-1) / math.sqrt(self.head_dim)) + + lse = torch.logsumexp(torch.stack(lses, dim=-1), dim=-1) + attn_output = sum( + attn_outputi * torch.exp(lsei - lse).unsqueeze(-1) + for attn_outputi, lsei in zip(attn_outputs, lses) + ) + # lse is fp32, downcast attn_output back + attn_output = attn_output.to(self.o_proj.weight.dtype) + + attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +class LlamaUSPFlashAttention(LlamaAttention): + """ + LlamaUSPFlashAttention with Trainable Ring Attention & Correct Eagle3 Branch Merging. + """ + + def __init__(self, config): + super().__init__(config) + assert ( + dist.is_initialized() + ), f"LlamaUSPAttention requires torch.distributed; call init_distributed first." + self.ring_pg = get_sp_ring_group() + self.ulysses_pg = get_sp_ulysses_group() + self.sp_ring_degree = torch.distributed.get_world_size(self.ring_pg) + self.sp_ulysses_degree = torch.distributed.get_world_size(self.ulysses_pg) + self.ring_rank = torch.distributed.get_rank(self.ring_pg) + + self.scatter_idx = 2 + self.gather_idx = 1 + self.use_sync = False + + def forward( + self, + hidden_states: torch.Tensor, + cache_hidden: Optional[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + bsz, q_len, _ = hidden_states.size() + local_q_len = q_len + + # ============================================================= + # 1. Projections & Ulysses Scatter + # ============================================================= + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + query_states = SeqAllToAll4D.apply( + self.ulysses_pg, + query_states, + self.scatter_idx, + self.gather_idx, + self.use_sync, + ) + + key_states = self.k_proj(hidden_states) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ) + key_states = SeqAllToAll4D.apply( + self.ulysses_pg, + key_states, + self.scatter_idx, + self.gather_idx, + self.use_sync, + ) + + value_states = self.v_proj(hidden_states) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ) + value_states = SeqAllToAll4D.apply( + self.ulysses_pg, + value_states, + self.scatter_idx, + self.gather_idx, + self.use_sync, + ) + + current_q_len = query_states.shape[1] + local_num_heads = query_states.shape[2] + + # Global length calculation (for RoPE) + global_q_len = q_len * self.sp_ring_degree * self.sp_ulysses_degree + + # ============================================================= + # 2. RoPE & Cache Management + # ============================================================= + if self.sp_ring_degree > 1: + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + position_ids = position_ids.chunk(self.sp_ring_degree, dim=2)[ + self.ring_rank + ].clone() + else: + position_ids = position_ids.chunk(self.sp_ring_degree, dim=1)[ + self.ring_rank + ].clone() + + lck = 0 if cache_hidden is None else len(cache_hidden[0]) + + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + cos, sin = self.rotary_emb(query_states, position_ids + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + self.config.rope_scaling["mrope_section"], + unsqueeze_dim=2, + ) + else: + cos, sin = self.rotary_emb(query_states, seq_len=global_q_len + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2 + ) + + # Update Cache (Eagle3 Logic: Cache is a list of tensors for tree branches) + if cache_hidden is not None: + cache_hidden[0] = cache_hidden[0] + [key_states] + cache_hidden[1] = cache_hidden[1] + [value_states] + cache_k = cache_hidden[0] + cache_v = cache_hidden[1] + else: + cache_k = [key_states] + cache_v = [value_states] + + # ============================================================= + # 3. Hybrid Attention Computation + # ============================================================= + + # 3.1 Main Sequence (Ring Attention) + out_ring, lse_ring, _ = ring_flash_attn_func( + query_states, + cache_k[0], + cache_v[0], + dropout_p=0.0, + softmax_scale=1.0 / math.sqrt(self.head_dim), + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + group=self.ring_pg, + ) + + if lse_ring.dim() == 3 and lse_ring.shape[1] == local_num_heads: + acc_lse = lse_ring.transpose(1, 2).contiguous() # -> [B, S, H] + else: + acc_lse = lse_ring + + assert ( + acc_lse.shape[1] == current_q_len + ), f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}" + + acc_out = out_ring + + # 3.2 Extras Branches (Eagle3 Point-wise Update) + if len(cache_k) > 1: + num_kv_heads_local = cache_k[0].shape[2] + local_groups = local_num_heads // num_kv_heads_local + + q_shape_expanded = ( + bsz, + current_q_len, + num_kv_heads_local, + local_groups, + self.head_dim, + ) + qi_reshaped = query_states.view(q_shape_expanded) # [B, S, KV, G, D] + + for i in range(1, len(cache_k)): + ki = cache_k[i] # [B, S, KV, D] + vi = cache_v[i] # [B, S, KV, D] + + ki_expanded = ki.unsqueeze(-2) # [B, S, KV, 1, D] + + # Dot Product: [B, S, KV, G] + score_i = (qi_reshaped * ki_expanded).sum(-1) / math.sqrt(self.head_dim) + + # Flatten back to [B, S, H_local] + step_lse = score_i.view(bsz, current_q_len, -1) + + vi_expanded = vi.unsqueeze(-2) + step_out = vi_expanded.expand(q_shape_expanded).reshape(acc_out.shape) + + # Online Softmax Update + new_lse = torch.logaddexp(acc_lse, step_lse) + + acc_out = acc_out * torch.exp(acc_lse - new_lse).unsqueeze( + -1 + ) + step_out * torch.exp(step_lse - new_lse).unsqueeze(-1) + + acc_lse = new_lse + + attn_output = acc_out.to(query_states.dtype) + + # ============================================================= + # 4. Ulysses Gather & Output Projection + # ============================================================= + attn_output = SeqAllToAll4D.apply( + self.ulysses_pg, + attn_output, + self.gather_idx, # Scatter idx: 1 (Seq) + self.scatter_idx, # Gather idx: 2 (Heads) + self.use_sync, + ) + + attn_output = attn_output.reshape( + bsz, local_q_len, self.head_dim * self.num_heads + ) + attn_output = self.o_proj(attn_output) + + return attn_output + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [ + F.linear(x, gate_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ], + dim=-1, + ) + up_proj = torch.cat( + [ + F.linear(x, up_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ], + dim=-1, + ) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + @torch.compile(dynamic=True) + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config, attention_backend: str = "sdpa"): + super().__init__() + self.hidden_size = config.hidden_size + + if attention_backend == "sdpa": + self.self_attn = LlamaAttention(config=config) + elif attention_backend == "flex_attention": + print_with_rank("Using flex attention on draft model training!") + self.self_attn = LlamaFlexAttention(config=config) + elif attention_backend == "fa": + self.self_attn = LlamaFlashAttention(config=config) + elif attention_backend == "usp": + self.self_attn = LlamaUSPFlashAttention(config=config) + else: + raise ValueError(f"Unknown attention backend {attention_backend}") + + self.attention_backend = attention_backend + self.mlp = LlamaMLP(config) + # self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # if self.index!=0: + + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + input_emb: torch.Tensor, + hidden_states: torch.Tensor, + cache_hidden: List[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Cache`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.hidden_norm(hidden_states) + input_emb = self.input_layernorm(input_emb) + + hidden_states = torch.cat((input_emb, hidden_states), dim=-1) + # Self Attention + hidden_states = self.self_attn( + cache_hidden=cache_hidden, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + # outputs = (hidden_states, return_hidden) + return hidden_states + + +class LlamaForCausalLMEagle3(Eagle3DraftModel): + + config_class = LlamaConfig + + def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: + super().__init__(config) + self.config = config + self.quant_config = quant_config + + self.vocab_size = config.vocab_size + self.draft_vocab_size = config.draft_vocab_size + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, config.pad_token_id + ) + self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) + + if hasattr(config, "target_hidden_size"): + self.fc = torch.nn.Linear( + config.target_hidden_size * 3, config.hidden_size, bias=False + ) + else: + self.fc = torch.nn.Linear( + config.hidden_size * 3, config.hidden_size, bias=False + ) + + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.lm_head = nn.Linear( + config.hidden_size, config.draft_vocab_size, bias=False + ) + + # create vocab buffers + t2d = torch.ones(self.vocab_size, dtype=torch.bool) + d2t = torch.zeros(self.draft_vocab_size, dtype=torch.int64) + self.register_buffer("t2d", t2d) + self.register_buffer("d2t", d2t) + + def forward( + self, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ttt_length: int = 1, + ): + """ + Arguments: + hidden_states (`torch.FloatTensor`): input to the layer, cat low, mid high hidden_states of shape `(batch, seq_len, hidden_states * 3)` + input_ids (`torch.LongTensor`): input ids of shape `(batch, seq_len)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor`, *optional*): position ids of shape `(batch, seq_len)` + """ + if ttt_length == 1: + print_with_rank("using ttt_length 1, no need to cache hidden states") + cache_hidden = None + else: + print_with_rank(f"using ttt_length {ttt_length}, caching hidden states") + cache_hidden = [[], []] + + batch_size, seq_length, _ = hidden_states.size() + + # make position ids + device = hidden_states.device + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + + # make attention mask + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, 0 + ) + + # fc + hidden_states = self.fc(hidden_states) + hidden_states = self.midlayer( + input_emb=inputs_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + output_attentions=False, + use_cache=False, + ) + + # norm + hidden_states = self.norm(hidden_states) + + return hidden_states + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + # eagle 3 requires hidden states from 3 layers + assert hidden_states.size(-1) == self.config.hidden_size * 3 + return self.fc(hidden_states) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + norm_hidden_states = self.norm(hidden_states) + return self.lm_head(norm_hidden_states) + + def backbone( + self, + input_embeds: torch.Tensor, + hidden_states: torch.Tensor, + cache_hidden: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Optional[Cache] = None, + use_cache: bool = True, + ) -> torch.Tensor: + return self.midlayer( + input_emb=input_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=False, + use_cache=False, + ) diff --git a/SpecForge-ext/specforge/modeling/target/__pycache__/__init__.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6d3661887a1aed088e14c17ab1469b744c46eb6 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/__pycache__/__init__.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/__pycache__/eagle3_target_model.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/__pycache__/eagle3_target_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..148aa399861a01ef809c1290b7daafd25e9b6cfe Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/__pycache__/eagle3_target_model.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/__pycache__/target_head.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/__pycache__/target_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f794323c4db92ae183e99bc08a611541d4260077 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/__pycache__/target_head.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/__init__.py b/SpecForge-ext/specforge/modeling/target/custom_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5465d15a8e84c788c43e5df709c33f4efb0bd43d --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/custom_backend/__init__.py @@ -0,0 +1,17 @@ +from .gpt_oss import GptOssForCausalLM +from .llama import LlamaForCausalLM +from .llama4 import Llama4ForCausalLM +from .phi3 import Phi3ForCausalLM +from .qwen2 import Qwen2ForCausalLM +from .qwen3 import Qwen3ForCausalLM +from .qwen3_moe import Qwen3MoeForCausalLM + +__all__ = [ + "GptOssForCausalLM", + "LlamaForCausalLM", + "Llama4ForCausalLM", + "Phi3ForCausalLM", + "Qwen2ForCausalLM", + "Qwen3ForCausalLM", + "Qwen3MoeForCausalLM", +] diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/__init__.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bad5f4036c68cdc7604de67783ece3bd53db440 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/__init__.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/gpt_oss.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/gpt_oss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97916e0b4389e0dff8e7d87567703ab175a34d6f Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/gpt_oss.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/llama.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/llama.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4baf0015b622d5c21044bc6d1e7ba144b0f9f20 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/llama.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/llama4.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/llama4.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e63d819ce7879b260af329efd077ccc63c7b431 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/llama4.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/phi3.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/phi3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..357e98859ebedb41f9420606f931c46e6ff5086a Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/phi3.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/qwen2.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/qwen2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6576ff9ea8d77beb756025bfd11a4298bdeb6822 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/qwen2.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/qwen3.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/qwen3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27c6bb6a6ef32f54334923b5d3538b669b74cee1 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/qwen3.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/qwen3_moe.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/qwen3_moe.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51650aaa3800ff18e168cccfdc096b7c8845e612 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/custom_backend/__pycache__/qwen3_moe.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/gpt_oss.py b/SpecForge-ext/specforge/modeling/target/custom_backend/gpt_oss.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b4a79723f48dd2b3e7db077e030f88288a3145 --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/custom_backend/gpt_oss.py @@ -0,0 +1,879 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, List, Optional, Union + +import torch +import torch.distributed as dist +from torch import nn +from torch.nn import functional as F +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations.hub_kernels import use_kernel_forward_from_hub +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple +from transformers.utils.generic import check_model_inputs + +from specforge.distributed import get_tp_group, shard_tensor +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + + +class GptOssExperts(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + + # apply tp + self.tp_group = get_tp_group() + self.tp_size = dist.get_world_size(self.tp_group) + self.expert_dim_per_shard = self.expert_dim // self.tp_size + self.gate_up_proj = nn.Parameter( + torch.empty( + self.num_experts, self.hidden_size, 2 * self.expert_dim_per_shard + ) + ) + self.gate_up_proj_bias = nn.Parameter( + torch.empty(self.num_experts, 2 * self.expert_dim_per_shard) + ) + self.down_proj = nn.Parameter( + torch.empty((self.num_experts, self.expert_dim_per_shard, self.hidden_size)) + ) + self.down_proj_bias = nn.Parameter( + torch.empty(self.num_experts, self.hidden_size) + ) + + self.alpha = 1.702 + self.limit = 7.0 + + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + if "down_proj" in state_dict: + # columnwise splitting + value = state_dict["down_proj"] + state_dict["down_proj"] = shard_tensor(value, self.tp_group, 1) + + if "down_proj_bias" in state_dict: + value = state_dict["down_proj_bias"] + if dist.get_rank(self.tp_group) != 0: + value.zero_() + + if "gate_up_proj_bias" in state_dict: + value = state_dict["gate_up_proj_bias"] + state_dict["gate_up_proj_bias"] = shard_tensor(value, self.tp_group, 1) + + if "gate_up_proj" in state_dict: + value = state_dict["gate_up_proj"] + gate, up = value[..., ::2], value[..., 1::2] + gate = shard_tensor(gate, self.tp_group, 2) + up = shard_tensor(up, self.tp_group, 2) + new_value = torch.zeros_like(self.gate_up_proj, device=value.device) + new_value[..., ::2] = gate + new_value[..., 1::2] = up + state_dict["gate_up_proj"] = new_value + + def forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + """ + When training is is more efficient to just loop over the experts and compute the output for each expert + as otherwise the memory would explode. + + For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. + + Args: + hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, num_experts) + Returns: + torch.Tensor + """ + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape( + -1, self.hidden_size + ) # (num_tokens, hidden_size) + num_experts = routing_weights.shape[1] + if self.training: + next_states = torch.zeros_like( + hidden_states, dtype=hidden_states.dtype, device=hidden_states.device + ) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot( + router_indices, num_classes=num_experts + ) + expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence lenght to get which experts + # are hit this time around + expert_hitted = torch.greater( + expert_mask.sum(dim=(-1, -2)), 0 + ).nonzero() + for expert_idx in expert_hitted[:]: + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx[0]]) + current_state = hidden_states[token_idx] + gate_up = ( + current_state @ self.gate_up_proj[expert_idx] + + self.gate_up_proj_bias[expert_idx] + ) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + out = ( + gated_output @ self.down_proj[expert_idx] + + self.down_proj_bias[expert_idx] + ) + weighted_output = out[0] * routing_weights[token_idx, expert_idx, None] + next_states.index_add_( + 0, token_idx, weighted_output.to(hidden_states.dtype) + ) + next_states = next_states.view(batch_size, -1, self.hidden_size) + else: + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + gate_up = ( + torch.bmm(hidden_states, self.gate_up_proj) + + self.gate_up_proj_bias[..., None, :] + ) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), self.down_proj) + next_states = next_states + self.down_proj_bias[..., None, :] + next_states = next_states.view( + num_experts, batch_size, -1, self.hidden_size + ) + next_states = ( + next_states + * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[ + ..., None + ] + ) + dist.all_reduce(next_states, op=dist.ReduceOp.SUM, group=self.tp_group) + + next_states = next_states.sum(dim=0) + return next_states + + +class GptOssTopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + self.bias = nn.Parameter(torch.empty(self.num_experts)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear( + hidden_states, self.weight, self.bias + ) # (seq_len, num_experts) + router_top_value, router_indices = torch.topk( + router_logits, self.top_k, dim=-1 + ) # (seq_len, top_k) + router_top_value = torch.nn.functional.softmax( + router_top_value, dim=1, dtype=router_top_value.dtype + ) + router_scores = torch.zeros_like(router_logits).scatter_( + 1, router_indices, router_top_value + ) + return router_scores, router_indices + + +@use_kernel_forward_from_hub("MegaBlocksMoeMLP") +class GptOssMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.router = GptOssTopKRouter(config) + self.experts = GptOssExperts(config) + + def forward(self, hidden_states): + router_scores, router_indices = self.router( + hidden_states + ) # (num_experts, seq_len) + routed_out = self.experts( + hidden_states, router_indices=router_indices, routing_weights=router_scores + ) + return routed_out, router_scores + + +class GptOssRotaryEmbedding(nn.Module): + def __init__(self, config: GptOssConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = freqs + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(x.dtype), sin.to(x.dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + first_half, second_half = torch.chunk(x, 2, dim=-1) + first_ = first_half * cos - second_half * sin + second_ = second_half * cos + first_half * sin + return torch.cat((first_, second_), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = _apply_rotary_emb(q, cos, sin) + k_embed = _apply_rotary_emb(k, cos, sin) + return q_embed, k_embed + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + query.shape[0], -1, query.shape[-2], -1 + ) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class GptOssAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GptOssConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + # self.q_proj = nn.Linear( + # config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + # ) + # self.k_proj = nn.Linear( + # config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + # ) + # self.v_proj = nn.Linear( + # config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + # ) + # self.o_proj = nn.Linear( + # config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + # ) + # self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + # self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) + + self.tp_group = get_tp_group() + self.tp_size = dist.get_world_size(self.tp_group) + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.num_attention_heads_per_shard = config.num_attention_heads // self.tp_size + self.sliding_window = ( + config.sliding_window + if config.layer_types[layer_idx] == "sliding_attention" + else None + ) + self.sinks = nn.Parameter(torch.empty(self.num_attention_heads_per_shard)) + + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + if "sinks" in state_dict: + value = state_dict["sinks"] + state_dict["sinks"] = shard_tensor(value, self.tp_group, 0) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class GptOssDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: GptOssConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx) + self.mlp = GptOssMLP(config) + self.input_layernorm = GptOssRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = GptOssRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class GptOssPreTrainedModel(PreTrainedModel): + config: GptOssConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GptOssDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = False + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = {} + _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] + _supports_flash_attention = False + _supports_flex_attention = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Parameter): + module.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, GptOssRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, GptOssExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.gate_up_proj_bias.data.zero_() + module.down_proj.data.normal_(mean=0.0, std=std) + module.down_proj_bias.data.zero_() + elif isinstance(module, GptOssAttention): + module.sinks.data.normal_(mean=0.0, std=std) + elif isinstance(module, GptOssTopKRouter): + module.weight.data.normal_(mean=0.0, std=std) + module.bias.data.normal_(mean=0.0, std=std) + + +@auto_docstring +class GptOssModel(GptOssPreTrainedModel): + _no_split_modules = ["GptOssDecoderLayer"] + + def __init__(self, config: GptOssConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + GptOssDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = GptOssRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + layers_to_output_hidden_states: Optional[List[int]] = kwargs.pop( + "layers_to_output_hidden_states", None + ) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + } + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + ) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 + ) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // ( + batch_size * sequence_length + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand( + (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) + ) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum( + expert_mask.float() * expert_attention_mask, dim=0 + ) / torch.sum(expert_attention_mask, dim=0) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum( + routing_weights * router_per_expert_attention_mask, dim=0 + ) / torch.sum(router_per_expert_attention_mask, dim=0) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class GptOssForCausalLM(GptOssPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = GptOssModel(config) + self.vocab_size = config.vocab_size + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GptOssForCausalLM + + >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +__all__ = ["GptOssForCausalLM", "GptOssModel", "GptOssPreTrainedModel"] diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/llama.py b/SpecForge-ext/specforge/modeling/target/custom_backend/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..04a3f6c9bd40b684e5d287ddf4477ea50cfa68c8 --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/custom_backend/llama.py @@ -0,0 +1,460 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, List, Optional, Union + +import torch +import torch.distributed as dist +from torch import nn +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.masking_utils import create_causal_mask +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, logging +from transformers.utils.generic import check_model_inputs + +from specforge.distributed import get_tp_group +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + +logger = logging.get_logger(__name__) + + +class TensorParallelLlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + # self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + # self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + + self.tp_group = get_tp_group() + self.gate_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=config.mlp_bias + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=config.mlp_bias + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, self.hidden_size, bias=config.mlp_bias + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) + return down_proj + + +class TensorParallelLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + # self.q_proj = nn.Linear( + # config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + # ) + # self.k_proj = nn.Linear( + # config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + # ) + # self.v_proj = nn.Linear( + # config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + # ) + # self.o_proj = nn.Linear( + # config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + # ) + + # distributed linear layers + self.tp_group = get_tp_group() + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class TensorParallelLlamaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = TensorParallelLlamaAttention( + config=config, layer_idx=layer_idx + ) + + self.mlp = TensorParallelLlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class LlamaPreTrainedModel(PreTrainedModel): + config: LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["TensorParallelLlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = {} + + +class LlamaModel(LlamaPreTrainedModel): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + TensorParallelLlamaDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + layers_to_output_hidden_states: Optional[List[int]] = kwargs.pop( + "layers_to_output_hidden_states", None + ) + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + + # distributed the lm head + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + ) + + +__all__ = [ + "LlamaForCausalLM", + "LlamaModel", +] diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/llama4.py b/SpecForge-ext/specforge/modeling/target/custom_backend/llama4.py new file mode 100644 index 0000000000000000000000000000000000000000..22f807daed1f6a1b1535745afb95a4feee7e3d0b --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/custom_backend/llama4.py @@ -0,0 +1,613 @@ +# coding=utf-8 +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations.hub_kernels import use_kernel_forward_from_hub +from transformers.masking_utils import create_causal_mask, create_chunked_causal_mask +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.llama4.configuration_llama4 import ( + Llama4Config, + Llama4TextConfig, +) +from transformers.models.llama4.modeling_llama4 import ( + Llama4Router, + Llama4TextL2Norm, + Llama4TextRMSNorm, + Llama4TextRotaryEmbedding, + Llama4VisionModel, + apply_rotary_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + logging, +) +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import check_model_inputs + +# [MODIFIED] Import from transformers library +from specforge.distributed import get_tp_group, shard_tensor +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + +logger = logging.get_logger(__name__) + + +class Llama4TextExperts(nn.Module): + def __init__(self, config: Llama4TextConfig): + super().__init__() + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + + self.tp_group = get_tp_group() + self.tp_size = dist.get_world_size(self.tp_group) + self.expert_dim_per_shard = self.expert_dim // self.tp_size + self.gate_up_proj = nn.Parameter( + torch.empty( + self.num_experts, self.hidden_size, 2 * self.expert_dim_per_shard + ) + ) + self.down_proj = nn.Parameter( + torch.empty((self.num_experts, self.expert_dim_per_shard, self.hidden_size)) + ) + self.act_fn = ACT2FN[config.hidden_act] + + # deal with weight loading and sharding + self._register_load_state_dict_pre_hook(self.shard_state_dict) + + def shard_state_dict(self, state_dict, *args): + if "down_proj" in state_dict: + value = state_dict["down_proj"] + state_dict["down_proj"] = shard_tensor(value, self.tp_group, 1) + + if "gate_up_proj" in state_dict: + value = state_dict["gate_up_proj"] + gate, up = value.chunk(2, dim=-1) + gate = shard_tensor(gate, self.tp_group, -1) + up = shard_tensor(up, self.tp_group, -1) + value = torch.cat((gate, up), dim=-1) + state_dict["gate_up_proj"] = value + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + This should really not be run on a single machine, as we are reaching compute bound: + - the inputs are expected to be "sorted" per expert already. + - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape + + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, top_k) + Returns: + torch.Tensor + """ + hidden_states = hidden_states.view( + self.gate_up_proj.shape[0], -1, self.hidden_size + ) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj) + dist.all_reduce(next_states, op=dist.ReduceOp.SUM, group=self.tp_group) + next_states = next_states.view(-1, self.hidden_size) + return next_states + + +class Llama4TextMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + + if intermediate_size is None: + intermediate_size = config.intermediate_size + + self.config = config + self.tp_group = get_tp_group() + self.gate_proj = ColumnParallelLinear( + config.hidden_size, intermediate_size, bias=False + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, intermediate_size, bias=False + ) + self.down_proj = RowParallelLinear( + intermediate_size, config.hidden_size, bias=False + ) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.activation_fn(self.gate_proj(x)) * self.up_proj(x) + out = self.down_proj(down_proj) + dist.all_reduce(out, op=dist.ReduceOp.SUM, group=self.tp_group) + return out + + +class Llama4TextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Llama4TextConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_attention_heads = config.num_attention_heads + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attn_scale = config.attn_scale + self.floor_scale = config.floor_scale + self.attn_temperature_tuning = config.attn_temperature_tuning + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.use_rope = config.no_rope_layers[layer_idx] + + self.tp_group = get_tp_group() + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + if self.config.use_qk_norm and self.use_rope: + self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if self.use_rope: # the 16E model skips rope for long context on certain layers + query_states, key_states = apply_rotary_emb( + query_states, key_states, position_embeddings.to(query_states.device) + ) + + if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm + query_states = self.qk_norm(query_states) + key_states = self.qk_norm(key_states) + + # Use temperature tuning from https://huggingface.co/papers/2501.19399) to NoROPE layers + if self.attn_temperature_tuning and not self.use_rope: + attn_scales = ( + torch.log1p( + torch.floor((cache_position.float() + 1.0) / self.floor_scale) + ) + * self.attn_scale + + 1.0 + ) + attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand( + (*input_shape, 1, 1) + ) # batch size > 1 + query_states = (query_states * attn_scales).to(query_states.dtype) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("Llama4TextMoe") +class Llama4TextMoe(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.experts = Llama4TextExperts(config) + self.router = Llama4Router(config) + self.shared_expert = Llama4TextMLP(config) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_scores, router_logits = self.router(hidden_states) + routed_in = hidden_states.repeat(router_scores.shape[1], 1) + routed_in = routed_in * router_scores.transpose(0, 1).reshape(-1, 1) + routed_out = self.experts(routed_in) + out = self.shared_expert(hidden_states) + out.add_( + routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]).sum( + dim=0 + ) + ) + return out, router_logits + + +class Llama4TextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + self.self_attn = Llama4TextAttention(config, layer_idx) + self.is_moe_layer = layer_idx in config.moe_layers + if self.is_moe_layer: # the 128E model interleaves dense / sparse + self.feed_forward = Llama4TextMoe(config) + else: + self.feed_forward = Llama4TextMLP( + config, intermediate_size=config.intermediate_size_mlp + ) + + self.input_layernorm = Llama4TextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Llama4TextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attention_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + attention_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + if self.is_moe_layer: + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states.view(residual.shape) + return hidden_states + + +@auto_docstring +class Llama4PreTrainedModel(PreTrainedModel): + config: Llama4Config + supports_gradient_checkpointing = True + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = False + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, Llama4TextRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Llama4TextExperts): + module.gate_up_proj.data.normal_(mean=0.0, std=std) + module.down_proj.data.normal_(mean=0.0, std=std) + elif isinstance(module, Llama4VisionModel): + module.class_embedding.data.normal_(std=module.scale) + module.positional_embedding_vlm.data.normal_(std=module.scale) + + +@auto_docstring +class Llama4TextModel(Llama4PreTrainedModel): + _no_split_modules = ["Llama4TextDecoderLayer"] + base_model_prefix = "model" + config: Llama4TextConfig + _can_record_outputs = {} + + def __init__(self, config: Llama4TextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Llama4TextDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Llama4TextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + layers_to_output_hidden_states: Optional[List[int]] = kwargs.pop( + "layers_to_output_hidden_states", None + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens( + input_ids.to(self.embed_tokens.weight.device) + ) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "chunked_attention": create_chunked_causal_mask(**mask_kwargs), + } + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + freq_cis = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=freq_cis, + **kwargs, + ) + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin): + _no_split_modules = ["Llama4TextDecoderLayer"] + base_model_prefix = "language_model" + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + config: Llama4TextConfig + + def __init__(self, config: Llama4TextConfig): + super().__init__(config) + self.model = Llama4TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Llama4ForCausalLM + + >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/phi3.py b/SpecForge-ext/specforge/modeling/target/custom_backend/phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..2515701f90f8c58cd164fc3e345549877212f379 --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/custom_backend/phi3.py @@ -0,0 +1,495 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, List, Optional, Union + +import torch +import torch.distributed as dist +from torch import nn +from transformers import Phi3Config +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.phi3.modeling_phi3 import ( + Phi3RMSNorm, + Phi3RotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import check_model_inputs + +from specforge.distributed import get_tp_group +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + + # Add TP support + self.tp_group = get_tp_group() + + self.gate_up_proj = ColumnParallelLinear( + config.hidden_size, + 2 * config.intermediate_size, + bias=False, + layout_type="gate_up", + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, config.hidden_size, bias=False + ) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + down_proj = self.down_proj(up_states) + # Add all_reduce for TP + dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) + return down_proj + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + # Add TP support + self.tp_group = get_tp_group() + tp_size = dist.get_world_size(self.tp_group) + + # Adjust head counts for TP + self.num_attention_heads_per_rank = config.num_attention_heads // tp_size + self.num_key_value_heads_per_rank = config.num_key_value_heads // tp_size + + # ColumnParallel splits the full QKV output across ranks + op_size = config.num_attention_heads * self.head_dim + 2 * ( + config.num_key_value_heads * self.head_dim + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=False + ) + self.qkv_proj = ColumnParallelLinear( + config.hidden_size, op_size, bias=False, layout_type="merged_qkv" + ) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_attention_heads_per_rank * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[ + ..., + query_pos : query_pos + self.num_key_value_heads_per_rank * self.head_dim, + ] + value_states = qkv[ + ..., query_pos + self.num_key_value_heads_per_rank * self.head_dim : + ] + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + # Add all_reduce for TP + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class Phi3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Phi3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx) + self.mlp = Phi3MLP(config) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Phi3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.config = config + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.resid_attn_dropout( + hidden_states + ) # main diff with Llama + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout( + hidden_states + ) # main diff with Llama + return hidden_states + + +@auto_docstring +class Phi3PreTrainedModel(PreTrainedModel): + config: Phi3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Phi3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = {} + _version = "0.0.5" + + +@auto_docstring +class Phi3Model(Phi3PreTrainedModel): + def __init__(self, config: Phi3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Phi3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Phi3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + layers_to_output_hidden_states: Optional[List[int]] = kwargs.pop( + "layers_to_output_hidden_states", None + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = ( + create_causal_mask + if self.config.sliding_window is None + else create_sliding_window_causal_mask + ) + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + ) + + +@auto_docstring +class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Phi3Model(config) + self.vocab_size = config.vocab_size + + # Use ColumnParallelLinear for lm_head + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/qwen2.py b/SpecForge-ext/specforge/modeling/target/custom_backend/qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ea42f95b4ca6b28bc17584b616b909703f3293 --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/custom_backend/qwen2.py @@ -0,0 +1,829 @@ +# coding=utf-8 +# Copyright 2025 The Qwen2 and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2RMSNorm, + Qwen2RotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + logging, +) + +# [MODIFIED] Import from distributed library +from specforge.distributed import get_tp_group +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + +logger = logging.get_logger(__name__) + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # distributed linear layers + self.tp_group = get_tp_group() + self.gate_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, self.hidden_size, bias=False + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) + return down_proj + + +class Qwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + # distributed linear layers + self.tp_group = get_tp_group() + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=True, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=True, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=True, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=False, + ) + + self.sliding_window = ( + config.sliding_window + if config.layer_types[layer_idx] == "sliding_attention" + else None + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class Qwen2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class Qwen2PreTrainedModel(PreTrainedModel): + config_class = Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) + + +@auto_docstring +class Qwen2Model(Qwen2PreTrainedModel): + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Qwen2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + layers_to_output_hidden_states = flash_attn_kwargs.pop( + "layers_to_output_hidden_states", None + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError( + "The `past_key_values` should be either a `Cache` object or `None`." + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = ( + create_sliding_window_causal_mask(**mask_kwargs) + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@auto_docstring +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + + # distributed the lm head + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + layers_to_output_hidden_states = kwargs.pop( + "layers_to_output_hidden_states", None + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + layers_to_output_hidden_states=layers_to_output_hidden_states, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The Qwen2 Model transformer with a sequence classification head on top (linear layer). + + [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to( + logits.device, torch.int32 + ) + token_indices = torch.arange( + input_ids.shape[-1], device=logits.device, dtype=torch.int32 + ) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), last_non_pad_token + ] + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + pooled_logits=pooled_logits, + config=self.config, + ) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +class Qwen2ForTokenClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = Qwen2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> QuestionAnsweringModelOutput: + outputs: BaseModelOutputWithPast = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function( + start_logits, end_logits, start_positions, end_positions, **kwargs + ) + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Qwen2PreTrainedModel", + "Qwen2Model", + "Qwen2ForCausalLM", + "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", + "Qwen2ForQuestionAnswering", +] diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/qwen3.py b/SpecForge-ext/specforge/modeling/target/custom_backend/qwen3.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0df91f03a3fd74be205cc685ad864f73fd35e8 --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/custom_backend/qwen3.py @@ -0,0 +1,606 @@ +# coding=utf-8 +# Copyright 2025 Qwen Team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from transformers import Qwen3Config +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3RMSNorm, + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import auto_docstring, can_return_tuple, logging + +from specforge.distributed import get_tp_group +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + +logger = logging.get_logger(__name__) + + +class Qwen3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # Add TP support + self.tp_group = get_tp_group() + + self.gate_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, self.hidden_size, bias=False + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # Add all_reduce for TP + dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) + return down_proj + + +class Qwen3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.total_num_kv_heads = config.num_key_value_heads + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + # Add TP support + self.tp_group = get_tp_group() + + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.q_norm = Qwen3RMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # unlike olmo, only on the head dim! + self.k_norm = Qwen3RMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + # Sliding window logic is kept as is, assuming it's handled in config.layer_types + self.sliding_window = ( + config.sliding_window + if config.layer_types[layer_idx] == "sliding_attention" + else None + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm( + self.q_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + key_states = self.k_norm( + self.k_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + # Add all_reduce for TP + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class Qwen3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + return outputs + + +class Qwen3RotaryEmbedding(nn.Module): + def __init__(self, config: Qwen3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Qwen3PreTrainedModel(PreTrainedModel): + config_class = Qwen3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen3RMSNorm): + module.weight.data.fill_(1.0) + + +@auto_docstring +class Qwen3Model(Qwen3PreTrainedModel): + def __init__(self, config: Qwen3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Qwen3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + layers_to_output_hidden_states = flash_attn_kwargs.pop( + "layers_to_output_hidden_states", None + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = ( + create_sliding_window_causal_mask(**mask_kwargs) + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@auto_docstring +class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3Model(config) + self.vocab_size = config.vocab_size + + # Use ColumnParallelLinear for lm_head + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3ForCausalLM + + >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/SpecForge-ext/specforge/modeling/target/custom_backend/qwen3_moe.py b/SpecForge-ext/specforge/modeling/target/custom_backend/qwen3_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..61f1880f6d3f92ab112388bb7ec991e8535f8600 --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/custom_backend/qwen3_moe.py @@ -0,0 +1,889 @@ +# coding=utf-8 +# Copyright 2025 Qwen Team and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from transformers import Qwen3MoeConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations import use_kernel_forward_from_hub +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.qwen3_moe.modeling_qwen3_moe import ( + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import auto_docstring, can_return_tuple, logging + +from specforge.distributed import get_tp_group +from specforge.layers import ( + ColumnParallelLinear, + ParallelLMHead, + RowParallelLinear, + VocabParallelEmbedding, +) + +logger = logging.get_logger(__name__) + + +class Qwen3MoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + # Add TP support and head calculations + self.tp_group = get_tp_group() + self.tp_size = ( + dist.get_world_size(self.tp_group) if self.tp_group is not None else 1 + ) + self.tp_rank = dist.get_rank(self.tp_group) if self.tp_group is not None else 0 + + # Calculate head distribution for TP + self.total_num_heads = config.num_attention_heads + self.total_num_kv_heads = config.num_key_value_heads + self.num_heads = ( + self.total_num_heads // self.tp_size + ) # this is the number heads per rank + + # Handle KV head replication when tp_size > total_num_kv_heads + if self.tp_size > self.total_num_kv_heads: + # In replication mode, each rank gets 1 KV head (replicated across groups) + self.num_kv_heads = 1 + self.num_kv_head_replicas = self.tp_size // self.total_num_kv_heads + self.num_key_value_groups = ( + self.num_heads // self.num_kv_heads + ) # this is size for expanding kv for gqa + self.kv_head_replicas = True + else: + self.num_kv_heads = self.total_num_kv_heads + self.num_kv_head_replicas = 1 + self.num_key_value_groups = config.num_attention_heads // self.num_kv_heads + self.kv_head_replicas = False + + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + self.num_kv_heads * self.head_dim, + bias=config.attention_bias, + kv_head_replicas=self.kv_head_replicas, + kv_head_idx=self.tp_rank // self.num_kv_head_replicas, + total_num_kv_heads=config.num_key_value_heads, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + self.num_kv_heads * self.head_dim, + bias=config.attention_bias, + kv_head_replicas=self.kv_head_replicas, + kv_head_idx=self.tp_rank // self.num_kv_head_replicas, + total_num_kv_heads=config.num_key_value_heads, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.q_norm = Qwen3MoeRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # unlike olmo, only on the head dim! + self.k_norm = Qwen3MoeRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + self.sliding_window = getattr(config, "sliding_window", None) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm( + self.q_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + key_states = self.k_norm( + self.k_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + # Add all_reduce for TP + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class Qwen3MoeMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = ( + intermediate_size + if intermediate_size is not None + else config.intermediate_size + ) + + # Add TP support + self.tp_group = get_tp_group() + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, self.hidden_size, bias=False + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # Add all_reduce for TP + dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) + return down_proj + + +class Qwen3MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = nn.ModuleList( + [ + Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(self.num_experts) + ] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hitted: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + return final_hidden_states, router_logits + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen3MoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen3MoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen3MoeDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3MoeAttention(config, layer_idx) + + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3MoeSparseMoeBlock(config) + else: + self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen3MoeRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = Qwen3MoeRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states, router_logits = hidden_states + else: + router_logits = None + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +class Qwen3MoeRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen3MoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@auto_docstring +class Qwen3MoePreTrainedModel(PreTrainedModel): + config_class = Qwen3MoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3MoeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Qwen3MoeRMSNorm): + module.weight.data.fill_(1.0) + + +@auto_docstring +class Qwen3MoeModel(Qwen3MoePreTrainedModel): + def __init__(self, config: Qwen3MoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Qwen3MoeDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + layers_to_output_hidden_states = flash_attn_kwargs.pop( + "layers_to_output_hidden_states", None + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = ( + create_causal_mask + if self.config.sliding_window is None + else create_sliding_window_causal_mask + ) + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + if ( + layers_to_output_hidden_states is None + or idx in layers_to_output_hidden_states + ): + all_hidden_states += (hidden_states,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 + ) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // ( + batch_size * sequence_length + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand( + (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) + ) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum( + expert_mask.float() * expert_attention_mask, dim=0 + ) / torch.sum(expert_attention_mask, dim=0) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum( + routing_weights * router_per_expert_attention_mask, dim=0 + ) / torch.sum(router_per_expert_attention_mask, dim=0) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3MoeModel(config) + self.vocab_size = config.vocab_size + + # Use ColumnParallelLinear for lm_head + self.lm_head = ParallelLMHead(config.hidden_size, config.vocab_size, bias=False) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM + + >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :], gather_output=True) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None and aux_loss != 0: + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/SpecForge-ext/specforge/modeling/target/dflash_target_model.py b/SpecForge-ext/specforge/modeling/target/dflash_target_model.py new file mode 100644 index 0000000000000000000000000000000000000000..25ad5d79560a3a5a4616f866592ae94d2153fea4 --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/dflash_target_model.py @@ -0,0 +1,342 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.scheduler import Scheduler +from sglang.srt.mem_cache.cache_init_params import CacheInitParams +from sglang.srt.mem_cache.radix_cache import RadixCache +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather +from transformers import AutoModelForCausalLM + +from specforge.distributed import get_tp_group +from specforge.utils import padding + +from .sglang_backend import SGLangRunner + + +@dataclass +class DFlashTargetOutput: + hidden_states: torch.Tensor # [batch, seq_len, hidden_size] + input_ids: torch.Tensor # [batch, seq_len] + attention_mask: torch.Tensor # [batch, seq_len] + loss_mask: torch.Tensor # [batch, seq_len] + + +class DFlashTargetModel(ABC): + """ + Abstract base class for DFlash target model backend. + """ + + def __init__(self): + self.capture_layer_ids = None + + @classmethod + @abstractmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **kwargs, + ) -> "DFlashTargetModel": + """Initialize the target model backend.""" + + @abstractmethod + def generate_dflash_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + ) -> DFlashTargetOutput: + """Generate context hidden states for DFlash training.""" + + def set_capture_layers(self, layer_ids: List[int]) -> None: + """Set which layers' hidden states to capture.""" + self.capture_layer_ids = layer_ids + + +class SGLangDFlashTargetModel(DFlashTargetModel): + def __init__(self, model_runner: SGLangRunner): + super().__init__() + self.model_runner = model_runner + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + trust_remote_code: bool = False, + **kwargs, + ) -> "SGLangDFlashTargetModel": + tp_size = dist.get_world_size(get_tp_group()) + server_args = ServerArgs( + model_path=pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + dtype=torch_dtype, + enable_return_hidden_states=True, # Critical for DFlash + disable_cuda_graph=True, + tp_size=tp_size, + pp_size=1, + **kwargs, + ) + + tp_rank = dist.get_rank(get_tp_group()) + moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) + model_config = ModelConfig.from_server_args(server_args) + + model_runner = SGLangRunner( + model_config=model_config, + mem_fraction_static=server_args.mem_fraction_static, + gpu_id=torch.cuda.current_device(), + tp_rank=dist.get_rank(get_tp_group()), + tp_size=server_args.tp_size, + moe_ep_rank=moe_ep_rank, + moe_ep_size=server_args.ep_size, + pp_rank=0, + pp_size=1, + server_args=server_args, + nccl_port=None, + ) + return cls(model_runner) + + def set_capture_layers(self, layer_ids: List[int]) -> None: + super().set_capture_layers(layer_ids) + # Note: We need to ensure SGLang supports custom capture layers. + # Eagle3 implementation uses `set_eagle3_layers_to_capture`. + # For DFlash, we might need to rely on `output_hidden_states=True` returning all layers + # and then filtering, OR implementing `set_custom_layers_to_capture` in SGLang patch. + # Assuming we can use the same mechanism or general mechanism if available. + # If SGLang doesn't support selective capture easily, we might get all and select later. + # But for memory efficiency, selective capture is better. + + # Checking Eagle3 implementation again: it calls `model.set_eagle3_layers_to_capture`. + # This implies SGLang model wrapper has this method patched. + # We will try to use a similar approach or assume we get full hidden states. + + # For now, let's assume we capture what's needed. + if hasattr(self.model_runner.model, "set_eagle3_layers_to_capture"): + self.model_runner.model.set_eagle3_layers_to_capture(layer_ids) + + @torch.no_grad + def _extend(self, reqs): + # Similar to Eagle3 _extend but simplified for just hidden states + cache_params = CacheInitParams( + disable=False, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + page_size=self.model_runner.server_args.page_size, + ) + tree_cache = RadixCache(cache_params) + + batch = ScheduleBatch.init_new( + reqs=reqs, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + tree_cache=tree_cache, + model_config=self.model_runner.model_config, + enable_overlap=False, + spec_algorithm=SpeculativeAlgorithm.NONE, + ) + batch.prepare_for_extend() + + if require_mlp_sync(self.model_runner.server_args): + Scheduler.prepare_mlp_sync_batch_raw( + batch, + dp_size=self.model_runner.server_args.dp_size, + attn_tp_size=1, + tp_group=self.model_runner.tp_group, + get_idle_batch=None, + disable_cuda_graph=self.model_runner.server_args.disable_cuda_graph, + spec_algorithm=SpeculativeAlgorithm.NONE, + speculative_num_draft_tokens=None, + require_mlp_tp_gather=require_mlp_tp_gather( + self.model_runner.server_args + ), + disable_overlap_schedule=self.model_runner.server_args.disable_overlap_schedule, + offload_tags=set(), + ) + + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + forward_batch.capture_hidden_mode = CaptureHiddenMode.FULL + + output, _ = self.model_runner.forward(forward_batch) + + # Eagle3 output has aux_hidden_states. + # We need to check what SGLang returns. Typically it returns 'hidden_states' or 'aux_hidden_states'. + # Assuming it aligns with Eagle3 patch. + + input_lens = [len(req.origin_input_ids) for req in reqs] + + # Split per request + if ( + hasattr(output, "aux_hidden_states") + and output.aux_hidden_states is not None + ): + hidden_states_list = torch.split( + output.aux_hidden_states, input_lens, dim=0 + ) + elif hasattr(output, "hidden_states") and output.hidden_states is not None: + hidden_states_list = torch.split(output.hidden_states, input_lens, dim=0) + else: + raise ValueError("SGLang output does not contain hidden states.") + + self.model_runner.req_to_token_pool.clear() + self.model_runner.token_to_kv_pool_allocator.clear() + + return hidden_states_list + + @torch.no_grad() + def generate_dflash_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + ) -> DFlashTargetOutput: + sampling_params = SamplingParams(temperature=0, max_new_tokens=1) + reqs, data_cache = [], [] + + if isinstance(input_ids, torch.Tensor): + input_ids_list = torch.split(input_ids, 1, dim=0) + attn_mask_list = torch.split(attention_mask, 1, dim=0) + loss_mask_list = torch.split(loss_mask, 1, dim=0) + + for idx, (curr_ids, curr_attn, curr_loss) in enumerate( + zip(input_ids_list, attn_mask_list, loss_mask_list) + ): + req = Req( + rid=str(idx), + origin_input_text="", + origin_input_ids=curr_ids.view(-1).tolist(), + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + data_cache.append((curr_ids, curr_attn, curr_loss)) + reqs.append(req) + + hidden_states_list = self._extend(reqs) + + # Stack back to batch + hidden_states = torch.cat([h.unsqueeze(0) for h in hidden_states_list], dim=0) + input_ids = torch.cat([d[0] for d in data_cache], dim=0) + attention_mask = torch.cat([d[1] for d in data_cache], dim=0) + loss_mask = torch.cat([d[2] for d in data_cache], dim=0) + + # Padding might be needed if batching varied lengths (but usually fixed length training) + hidden_states = padding(hidden_states, left=False) + input_ids = padding(input_ids, left=False) + + return DFlashTargetOutput( + hidden_states=hidden_states, + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + ) + + +class HFDFlashTargetModel(DFlashTargetModel): + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **kwargs, + ) -> "HFDFlashTargetModel": + + target_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + output_hidden_states=True, + trust_remote_code=True, + **kwargs, + ).eval() + + if device: + target_model = target_model.to(device) + + return cls(target_model) + + @torch.no_grad() + def generate_dflash_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + ) -> DFlashTargetOutput: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Extract selected layers + # outputs.hidden_states is a tuple of (L+1) tensors + # Indices in self.capture_layer_ids correspond to 0-based index of transformer layers. + # outputs.hidden_states[0] is embedding output (usually). + # Typically hidden_states[i+1] is output of layer i. + + offset = 1 + selected = [] + if self.capture_layer_ids is not None: + for idx in self.capture_layer_ids: + selected.append(outputs.hidden_states[idx + offset]) + hidden_states = torch.cat(selected, dim=-1) + else: + # Fallback if no layers specified (maybe return last?) + hidden_states = outputs.hidden_states[-1] + + return DFlashTargetOutput( + hidden_states=hidden_states, + input_ids=input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + ) + + +def get_dflash_target_model( + pretrained_model_name_or_path: str, + backend: str = "sglang", + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> DFlashTargetModel: + if backend == "sglang": + return SGLangDFlashTargetModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + torch_dtype=torch_dtype, + device=device, + cache_dir=cache_dir, + **kwargs, + ) + elif backend == "hf": + return HFDFlashTargetModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + torch_dtype=torch_dtype, + device=device, + cache_dir=cache_dir, + **kwargs, + ) + else: + raise ValueError(f"Invalid backend: {backend}") diff --git a/SpecForge-ext/specforge/modeling/target/eagle3_target_model.py b/SpecForge-ext/specforge/modeling/target/eagle3_target_model.py new file mode 100644 index 0000000000000000000000000000000000000000..23554414868bb0e36cd6c1e6728bdf671813b03c --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/eagle3_target_model.py @@ -0,0 +1,857 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import sglang.srt.managers.mm_utils as mm_utils +import torch +import torch.distributed as dist +import torch.nn as nn +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.layers.rotary_embedding import MRotaryEmbedding +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + init_mm_embedding_cache, +) +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, + Req, + ScheduleBatch, +) +from sglang.srt.managers.scheduler import Scheduler +from sglang.srt.mem_cache.cache_init_params import CacheInitParams +from sglang.srt.mem_cache.radix_cache import RadixCache +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch +from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather +from transformers import AutoModelForCausalLM + +from specforge.distributed import get_tp_device_mesh, get_tp_group +from specforge.utils import padding + +from .sglang_backend import SGLangRunner, wrap_eagle3_logits_processors_in_module +from .sglang_backend.utils import LogitsProcessorForEAGLE3 + + +@dataclass +class Eagle3TargetOutput: + hidden_states: torch.Tensor + target: torch.Tensor + loss_mask: torch.Tensor + input_ids: torch.Tensor + attention_mask: torch.Tensor + last_hidden_states: Optional[torch.Tensor] = None + + +class Eagle3TargetModel(ABC): + """ + This offers a layer of abstraction for the target model backend. The user can choose different backends to suit their needs: + 1. SGLang backend: for the mainstream model support with the fastest inference speed + 2. HuggingFace backend: for models that are not supported by SGLang but can be loaded by HuggingFace. + 3. Custom backend: for models with customized architecture and inference plan. + """ + + def __init__(self): + self.aux_hidden_states_layers = None + + @classmethod + @abstractmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **kwargs, + ) -> "Eagle3TargetModel": + """ + Initialize the target model backend from a pretrained model path. + """ + + @abstractmethod + def generate_eagle3_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + ) -> Eagle3TargetOutput: + """ + Generate the eagle3 data from the target model. + """ + + def set_aux_hidden_states_layers( + self, aux_hidden_states_layers: Optional[List[int]] = None + ) -> None: + """ + Set the layers to capture the aux hidden states from the target model outputs. + """ + if aux_hidden_states_layers is None: + if hasattr(self.model.config, "num_hidden_layers"): + num_layers = self.model.config.num_hidden_layers + else: + raise ValueError( + f"Failed to set aux hidden states layers as model config {self.model.config} does not have num_hidden_layers" + ) + aux_hidden_states_layers = [ + 1, + num_layers // 2 - 1, + num_layers - 4, + ] + self.aux_hidden_states_layers = aux_hidden_states_layers + assert ( + len(self.aux_hidden_states_layers) == 3 + ), "aux_hidden_states_layers is expected to be 3 layers for EAGLE3" + + +class HFEagle3TargetModel(Eagle3TargetModel): + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **kwargs, + ) -> "HFEagle3TargetModel": + """ + Initialize the HuggingFace target model backend from a pretrained model path. + """ + tp_size = get_tp_group().size() + + if tp_size > 1: + device_kwargs = { + "tp_plan": "auto", + "tp_size": tp_size, + "device_mesh": get_tp_device_mesh(), + } + else: + device_kwargs = { + "device_map": device, + } + + target_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + **device_kwargs, + **kwargs, + ) + return cls(target_model) + + def _get_transformer_layers(self): + """ + Helper to find the module list containing the transformer layers. + Adapts to common architectures (Llama, Qwen, Mistral, OPT, etc.) + """ + if hasattr(self.model, "model") and hasattr(self.model.model, "layers"): + return self.model.model.layers + elif hasattr(self.model, "layers"): + return self.model.layers + elif hasattr(self.model, "transformer") and hasattr( + self.model.transformer, "h" + ): + return self.model.transformer.h + else: + raise ValueError( + "Could not locate transformer layers in the model architecture to register hooks." + ) + + @torch.no_grad() + def generate_eagle3_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + ) -> Eagle3TargetOutput: + """ + Optimized HF backend: + Instead of returning all hidden states (memory heavy), we use forward hooks + to capture only the specific layers required by Eagle3. + """ + captured_states = {} + handles = [] + + def get_hook(layer_idx): + def hook(module, input, output): + # HF outputs for layers are usually tuples (hidden_states, present_key_value, ...) + # We only need the hidden_states (first element) + if isinstance(output, tuple): + hidden = output[0] + else: + hidden = output + captured_states[layer_idx] = hidden + + return hook + + # Locate the transformer layers ModuleList + layers = self._get_transformer_layers() + + target_indices = self.aux_hidden_states_layers + + # Register hooks + for idx in target_indices: + # Ensure index is within bounds + if 0 <= idx < len(layers): + handles.append(layers[idx].register_forward_hook(get_hook(idx))) + else: + raise ValueError( + f"Layer index {idx} out of bounds for model with {len(layers)} layers." + ) + + try: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=False, + output_attentions=False, + output_router_logits=False, + use_cache=False, + ) + target = outputs.logits + finally: + # Always remove hooks to prevent memory leaks or side effects on subsequent calls + for handle in handles: + handle.remove() + + # Verify we captured everything + if len(captured_states) != 3: + raise RuntimeError( + f"Expected to capture 3 layers, but captured {len(captured_states)}" + ) + + # Extract in the correct order + hidden_states0 = captured_states[target_indices[0]] + hidden_states1 = captured_states[target_indices[1]] + hidden_states2 = captured_states[target_indices[2]] + + hidden_states = torch.cat( + (hidden_states0, hidden_states1, hidden_states2), dim=-1 + ) + + # apply pading + target = outputs.logits + target = padding(target, left=False) + input_ids = padding(input_ids, left=False) + loss_mask = loss_mask[..., None].to(target.device) + + return Eagle3TargetOutput( + hidden_states=hidden_states, + target=target, + loss_mask=loss_mask, + input_ids=input_ids, + attention_mask=attention_mask, + ) + + +class SGLangEagle3TargetModel(Eagle3TargetModel): + + def __init__(self, model_runner: SGLangRunner, hf_config=None): + super().__init__() + self.model_runner = model_runner + self.hf_config = hf_config + + # VLM-specific attributes (initialized from hf_config if available) + self._init_vlm_attributes() + + def _init_vlm_attributes(self): + """Initialize VLM-specific attributes from hf_config for models like Qwen2.5-VL""" + if self.hf_config is None: + self.is_vlm = False + return + + # Check if this is a VLM model by looking for vision_config + self.is_vlm = hasattr(self.hf_config, "vision_config") + + if not self.is_vlm: + return + + init_mm_embedding_cache(1024 * 1024 * 512) + # Model type (e.g., "qwen2_5_vl", "qwen2_vl") + self.model_type = getattr(self.hf_config, "model_type", None) + + # Vision config attributes + vision_config = self.hf_config.vision_config + self.spatial_merge_size = getattr(vision_config, "spatial_merge_size", 2) + self.tokens_per_second = getattr(vision_config, "tokens_per_second", None) + + # Special token IDs from hf_config + self.image_token_id = getattr(self.hf_config, "image_token_id", None) + self.video_token_id = getattr(self.hf_config, "video_token_id", None) + self.vision_start_token_id = getattr( + self.hf_config, "vision_start_token_id", None + ) + self.vision_end_token_id = getattr(self.hf_config, "vision_end_token_id", None) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + trust_remote_code: bool = False, + **kwargs, + ) -> "SGLangEagle3TargetModel": + tp_size = dist.get_world_size(get_tp_group()) + server_args = ServerArgs( + model_path=pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + dtype=torch_dtype, + enable_return_hidden_states=True, + disable_cuda_graph=True, # we use piecewise cuda graph for prefill instead + tp_size=tp_size, + pp_size=1, + **kwargs, + ) + + tp_rank = dist.get_rank(get_tp_group()) + moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) + model_config = ModelConfig.from_server_args(server_args) + model_runner = SGLangRunner( + model_config=model_config, + mem_fraction_static=server_args.mem_fraction_static, + gpu_id=torch.cuda.current_device(), + tp_rank=dist.get_rank(get_tp_group()), + tp_size=server_args.tp_size, + moe_ep_rank=moe_ep_rank, + moe_ep_size=server_args.ep_size, + pp_rank=0, + pp_size=1, + server_args=server_args, + nccl_port=None, + ) + wrap_eagle3_logits_processors_in_module( + model_runner.model, return_full_logits=False + ) + + # Get hf_config from model_config for VLM attributes + hf_config = getattr(model_config, "hf_config", None) + + return cls(model_runner, hf_config=hf_config) + + def set_aux_hidden_states_layers( + self, aux_hidden_states_layers: Optional[List[int]] = None + ) -> None: + self.model_runner.model.set_eagle3_layers_to_capture(aux_hidden_states_layers) + + @torch.no_grad + def _extend( + self, + reqs, + capture_aux_hidden_states: bool = True, + return_last_hidden_states: bool = False, + return_logits: bool = False, + ): + # set the logits processor for the model runner + for name, module in self.model_runner.model.named_modules(): + if isinstance(module, LogitsProcessorForEAGLE3): + module.return_last_hidden_states = return_last_hidden_states + module.return_logits = return_logits + + cache_params = CacheInitParams( + disable=False, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + page_size=self.model_runner.server_args.page_size, + ) + tree_cache = RadixCache(cache_params) + + batch = ScheduleBatch.init_new( + reqs=reqs, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + tree_cache=tree_cache, + model_config=self.model_runner.model_config, + enable_overlap=False, + spec_algorithm=SpeculativeAlgorithm.NONE, + ) + batch.prepare_for_extend() + self._maybe_prepare_mlp_sync_batch(batch) + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + forward_batch.capture_hidden_mode = CaptureHiddenMode.FULL + eagle3_output, _ = self.model_runner.forward(forward_batch) + + aux_hidden_states_list = None + input_lens = [len(req.origin_input_ids) for req in reqs] + + if return_logits: + logits = torch.split(eagle3_output.logits, input_lens, dim=0) + else: + logits = [None] * len(reqs) + + if capture_aux_hidden_states: + aux_hidden_states_list = torch.split( + eagle3_output.aux_hidden_states, input_lens, dim=0 + ) + else: + aux_hidden_states_list = [None] * len(reqs) + + if return_last_hidden_states: + last_hidden_states = torch.split( + eagle3_output.last_hidden_states, input_lens, dim=0 + ) + else: + last_hidden_states = [None] * len(reqs) + + # TODO: can we not clear? + self.model_runner.req_to_token_pool.clear() + self.model_runner.token_to_kv_pool_allocator.clear() + return logits, aux_hidden_states_list, last_hidden_states + + def _maybe_prepare_mlp_sync_batch(self, batch: ScheduleBatch): + if require_mlp_sync(self.model_runner.server_args): + Scheduler.prepare_mlp_sync_batch_raw( + batch, + dp_size=self.model_runner.server_args.dp_size, + attn_tp_size=1, + tp_group=self.model_runner.tp_group, + get_idle_batch=None, + disable_cuda_graph=self.model_runner.server_args.disable_cuda_graph, + spec_algorithm=SpeculativeAlgorithm.NONE, + speculative_num_draft_tokens=None, + require_mlp_tp_gather=require_mlp_tp_gather( + self.model_runner.server_args + ), + disable_overlap_schedule=self.model_runner.server_args.disable_overlap_schedule, + offload_tags=set(), + ) + + def extend( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + return_last_hidden_states: bool = False, + return_logits: bool = True, + ): + sampling_params = SamplingParams(temperature=0, max_new_tokens=1, top_k=1) + reqs, data_cache = [], [] + + if isinstance(input_ids, torch.Tensor): + input_ids = torch.split(input_ids, 1, dim=0) + attention_mask = torch.split(attention_mask, 1, dim=0) + loss_mask = torch.split(loss_mask, 1, dim=0) + + for idx, (input_id_, attention_mask_, loss_mask_) in enumerate( + zip( + input_ids, + attention_mask, + loss_mask, + ) + ): + req = Req( + rid=str(idx), + origin_input_text="", + origin_input_ids=input_id_.view(-1).tolist(), + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 + data_cache.append([input_id_, attention_mask_, loss_mask_]) + reqs.append(req) + + logits_list, aux_hidden_states_list, last_hidden_states_list = self._extend( + reqs, + capture_aux_hidden_states=True, + return_last_hidden_states=return_last_hidden_states, + return_logits=return_logits, + ) + + return data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list + + def get_rope_index( + self, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Get M-RoPE position indices for VLM models like Qwen2.5-VL. + + This is a wrapper around MRotaryEmbedding.get_rope_index that uses + the VLM-specific attributes initialized from hf_config. + + Args: + input_ids: (batch_size, seq_len) input token IDs + image_grid_thw: (num_images, 3) image grid dimensions (t, h, w) + video_grid_thw: (num_videos, 3) video grid dimensions (t, h, w) + second_per_grid_ts: Optional temporal information for videos + attention_mask: (batch_size, seq_len) attention mask + + Returns: + position_ids: (3, batch_size, seq_len) M-RoPE position IDs + rope_deltas: Optional position deltas for incremental decoding + """ + if not self.is_vlm: + raise ValueError("get_rope_index is only available for VLM models") + + from sglang.srt.layers.rotary_embedding import MRotaryEmbedding + + position_ids, rope_deltas = MRotaryEmbedding.get_rope_index( + spatial_merge_size=self.spatial_merge_size, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + model_type=self.model_type, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + attention_mask=attention_mask, + tokens_per_second=self.tokens_per_second, + ) + + return position_ids, rope_deltas + + def extend_vlm( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + return_last_hidden_states: bool = False, + return_logits: bool = True, + pixel_values: Optional[List[torch.Tensor]] = None, + image_grid_thw: Optional[List[torch.Tensor]] = None, + ): + """ + Args: + input_ids: (batch_size, seq_len) or List of (1, seq_len) tensors + attention_mask: (batch_size, seq_len) or List of (1, seq_len) tensors + loss_mask: (batch_size, seq_len) or List of (1, seq_len) tensors + pixel_values: List of pixel_values tensors, one per sample in batch + image_grid_thw: List of image_grid_thw tensors, one per sample in batch + """ + mm_utils.embedding_cache.clear() + sampling_params = SamplingParams(temperature=0, max_new_tokens=1, top_k=1) + reqs, data_cache = [], [] + + # Split tensors if needed + if isinstance(input_ids, torch.Tensor): + batch_size = input_ids.shape[0] + input_ids = torch.split(input_ids, 1, dim=0) + attention_mask = torch.split(attention_mask, 1, dim=0) + loss_mask = torch.split(loss_mask, 1, dim=0) + else: + batch_size = len(input_ids) + # Process image_grid_thw - convert to list if needed + if image_grid_thw is None: + image_grid_thw = [None] * batch_size + elif not isinstance(image_grid_thw, (list, tuple)): + image_grid_thw = [image_grid_thw] + + # pixel_values is a single 2D tensor (total_patches, patch_dim) for Qwen2.5-VL + # We need to track offset and slice it based on image_grid_thw for each sample + pixel_values_offset = 0 # Track current offset in pixel_values + + for idx, (input_id_, attention_mask_, loss_mask_, image_grid_thw_) in enumerate( + zip( + input_ids, + attention_mask, + loss_mask, + image_grid_thw, + ) + ): + # Compute num_patches for this sample from image_grid_thw_ + # image_grid_thw_: (num_images, 3) where each row is (t, h, w) + if image_grid_thw_ is not None: + # Ensure image_grid_thw_ is 2D: (num_images, 3) + if image_grid_thw_.dim() == 1: + image_grid_thw_ = image_grid_thw_.unsqueeze(0) # (3,) -> (1, 3) + elif image_grid_thw_.dim() == 0: + raise ValueError( + f"image_grid_thw_ is 0-dim tensor, expected at least 1D. Value: {image_grid_thw_}" + ) + + # Calculate num_patches for this sample: sum(t * h * w) for all images + num_patches = ( + ( + image_grid_thw_[:, 0] + * image_grid_thw_[:, 1] + * image_grid_thw_[:, 2] + ) + .sum() + .item() + ) + num_patches = int(num_patches) + + # Slice pixel_values for this sample + pixel_value_ = pixel_values[ + pixel_values_offset : pixel_values_offset + num_patches + ] + pixel_values_offset += num_patches + else: + pixel_value_ = None + num_patches = 0 + + # Compute mrope positions for VLM models (e.g., Qwen2.5-VL) + input_id_flat = input_id_.view(-1) + + # Count image tokens + num_img_tokens = (input_id_flat == self.image_token_id).sum().item() + # print(f"[extend_vlm] num_img_tokens in input_ids: {num_img_tokens}") + + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( + spatial_merge_size=self.spatial_merge_size, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + vision_start_token_id=self.vision_start_token_id, + model_type=self.model_type, + input_ids=input_id_flat.unsqueeze(0), + image_grid_thw=( + image_grid_thw_.cpu() if image_grid_thw_ is not None else None + ), + tokens_per_second=self.tokens_per_second, + ) + + offset = BaseMultimodalProcessor.get_mm_items_offset( + input_id_flat, self.image_token_id + ) + mm_item = MultimodalDataItem( + modality=Modality.IMAGE, + feature=pixel_value_, # torch.Tensor: (num_patches, patch_dim) + pad_value=self.image_token_id, # Required for placeholder tensor creation + offsets=offset, # List of (start, end) tuples + ) + mm_item.set("image_grid_thw", image_grid_thw_.cpu()) + mm_item.set_pad_value() + mm_inputs = MultimodalInputs( + mm_items=[mm_item], + im_token_id=self.image_token_id, + im_start_id=self.vision_start_token_id, + im_end_id=self.vision_end_token_id, + mrope_positions=( + mrope_positions.squeeze(1) if mrope_positions is not None else None + ), + mrope_position_delta=mrope_position_delta, + ) + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + input_id_list = pattern.pad_input_tokens( + input_id_.view(-1).tolist(), mm_inputs + ) + req = Req( + rid=str(idx), + origin_input_text="", + origin_input_ids=input_id_list, + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 + req.multimodal_inputs = mm_inputs + data_cache.append([input_id_, attention_mask_, loss_mask_]) + reqs.append(req) + + logits_list, aux_hidden_states_list, last_hidden_states_list = self._extend( + reqs, + capture_aux_hidden_states=True, + return_last_hidden_states=return_last_hidden_states, + return_logits=return_logits, + ) + + return data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list + + @torch.no_grad() + def generate_eagle3_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + is_vlm: bool = False, + ) -> Eagle3TargetOutput: + """ + return: + data_for_draft: List[Dict[str, torch.Tensor]] of draft_batch_size, draft_micro_batch_size = 1 + - input_ids: (1, seq_len) + - attention_mask: (1, seq_len) + - loss_mask: (1, seq_len) + - target: (1, seq_len, vocab_size) or (1, seq_len, hidden_size) + - hidden_states: (1, seq_len, hidden_size) + - pixel_values: (patch_len, patch_width) + - image_grid_thw (batch_size, 3) + """ + if is_vlm: + data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list = ( + self.extend_vlm( + input_ids, + attention_mask, + loss_mask, + return_last_hidden_states=False, + return_logits=True, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + ) + else: + data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list = ( + self.extend( + input_ids, + attention_mask, + loss_mask, + return_last_hidden_states=False, + return_logits=True, + ) + ) + aux_hidden_states_out = [] + target_out = [] + loss_mask_out = [] + input_ids_out = [] + last_hidden_states_out = [] + + for idx, (data, logits, aux_hidden_states, last_hidden_states) in enumerate( + zip( + data_cache, logits_list, aux_hidden_states_list, last_hidden_states_list + ) + ): + aux_hidden_states_out.append(aux_hidden_states.unsqueeze(0)) + loss_mask_out.append(data[2]) + input_ids_out.append(data[0]) + + # when generating hidden states for offline training, we don't compute logits and only keep the last_hidden_states + # when training online, we don't keep the last_hidden_states and only keep the logits + if logits is not None: + target_out.append(logits.unsqueeze(0)) + else: + target_out.append(None) + + if last_hidden_states is not None: + last_hidden_states_out.append(last_hidden_states.unsqueeze(0)) + else: + last_hidden_states_out.append(None) + + aux_hidden_states_out = torch.cat(aux_hidden_states_out, dim=0) + + loss_mask_out = torch.cat(loss_mask_out, dim=0) + input_ids_out = torch.cat(input_ids_out, dim=0) + + if target_out[0] is not None: + target_out = torch.cat(target_out, dim=0) + else: + target_out = None + + if last_hidden_states_out[0] is not None: + last_hidden_states_out = torch.cat(last_hidden_states_out, dim=0) + else: + last_hidden_states_out = None + + target_out = padding(target_out, left=False) + input_ids_out = padding(input_ids_out, left=False) + loss_mask_out = loss_mask_out[..., None] + + return Eagle3TargetOutput( + hidden_states=aux_hidden_states_out, + target=target_out, + loss_mask=loss_mask_out, + input_ids=input_ids_out, + attention_mask=attention_mask, + last_hidden_states=last_hidden_states_out, + ) + + +class CustomEagle3TargetModel(Eagle3TargetModel): + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **kwargs, + ) -> "CustomEagle3TargetModel": + from specforge.modeling.auto import AutoDistributedTargetModel + + target_model = AutoDistributedTargetModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + device=device, + **kwargs, + ) + return cls(target_model) + + @torch.no_grad() + def generate_eagle3_data( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + ) -> Eagle3TargetOutput: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + layers_to_output_hidden_states=self.aux_hidden_states_layers, + use_cache=False, + ) + + # For custom backends, the model implementation is responsible for only + # returning the requested layers in `outputs.hidden_states`. + hidden_states = torch.cat(outputs.hidden_states, dim=-1) + + target = outputs.logits + target = padding(target, left=False) + input_ids = padding(input_ids, left=False) + loss_mask = loss_mask[..., None].to(target.device) + + return Eagle3TargetOutput( + hidden_states=hidden_states, + target=target, + loss_mask=loss_mask, + input_ids=input_ids, + attention_mask=attention_mask, + ) + + +def get_eagle3_target_model( + pretrained_model_name_or_path: str, + backend: str = "sglang", + torch_dtype: torch.dtype = None, + device: str = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> Eagle3TargetModel: + if backend == "sglang": + return SGLangEagle3TargetModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + torch_dtype=torch_dtype, + device=device, + cache_dir=cache_dir, + **kwargs, + ) + elif backend == "hf": + return HFEagle3TargetModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + torch_dtype=torch_dtype, + device=device, + cache_dir=cache_dir, + **kwargs, + ) + elif backend == "custom": + return CustomEagle3TargetModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + torch_dtype=torch_dtype, + device=device, + cache_dir=cache_dir, + **kwargs, + ) + else: + raise ValueError(f"Invalid backend: {backend}") diff --git a/SpecForge-ext/specforge/modeling/target/sglang_backend/__init__.py b/SpecForge-ext/specforge/modeling/target/sglang_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0e02ab7b3950bf2405141a61a89c071f42a9a2a7 --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/sglang_backend/__init__.py @@ -0,0 +1,4 @@ +from .model_runner import SGLangRunner +from .utils import wrap_eagle3_logits_processors_in_module + +__all__ = ["SGLangRunner", "wrap_eagle3_logits_processors_in_module"] diff --git a/SpecForge-ext/specforge/modeling/target/sglang_backend/__pycache__/__init__.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/sglang_backend/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5eb77aa69d7bda51eb95f23dd7c87483655bf1b2 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/sglang_backend/__pycache__/__init__.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/sglang_backend/__pycache__/model_runner.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/sglang_backend/__pycache__/model_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fbb3a9cfaf6c2a748bee5e08c9b899c2e9b30f5 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/sglang_backend/__pycache__/model_runner.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/sglang_backend/__pycache__/patch.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/sglang_backend/__pycache__/patch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88aa70711e5e2cece2355bd16a082a5f70bbe063 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/sglang_backend/__pycache__/patch.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/sglang_backend/__pycache__/utils.cpython-311.pyc b/SpecForge-ext/specforge/modeling/target/sglang_backend/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..850c7ef4f633ad4b8888f4e8a5b642e48a76f557 Binary files /dev/null and b/SpecForge-ext/specforge/modeling/target/sglang_backend/__pycache__/utils.cpython-311.pyc differ diff --git a/SpecForge-ext/specforge/modeling/target/sglang_backend/model_runner.py b/SpecForge-ext/specforge/modeling/target/sglang_backend/model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..7b86a30747dd3c26007483dd7d436ee86d33dd1b --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/sglang_backend/model_runner.py @@ -0,0 +1,159 @@ +import logging +import os + +import torch +from sglang.srt.distributed import ( + get_pp_group, + get_tp_group, + get_world_group, + set_custom_all_reduce, + set_mscclpp_all_reduce, + set_torch_symm_mem_all_reduce, +) +from sglang.srt.layers.dp_attention import ( + get_attention_tp_group, + initialize_dp_attention, +) +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.utils import ( + cpu_has_amx_support, + get_available_gpu_memory, + get_bool_env_var, + is_hip, + is_npu, + monkey_patch_p2p_access_check, +) + +from .patch import ( + init_distributed_environment, + initialize_dp_attention, + initialize_model_parallel, +) + +_is_hip = is_hip() +_is_npu = is_npu() +_is_cpu_amx_available = cpu_has_amx_support() + +# Use a small KV cache pool size for tests in CI +SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) + +# Detect stragger ranks in model loading +UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 + +logger = logging.getLogger(__name__) + + +class SGLangRunner(ModelRunner): + + def init_torch_distributed(self): + logger.info("Init torch distributed begin.") + + try: + torch.get_device_module(self.device).set_device(self.gpu_id) + except Exception: + logger.warning( + f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}" + ) + raise + + if self.device == "cuda": + if self.server_args.elastic_ep_backend == "mooncake": + backend = "mooncake" + if self.server_args.mooncake_ib_device: + mooncake_ib_device = self.server_args.mooncake_ib_device.split(",") + try: + from mooncake import ep as mooncake_ep + + mooncake_ep.set_device_filter(mooncake_ib_device) + except: + pass # A warning will be raised in `init_distributed_environment` + else: + backend = "nccl" + elif self.device == "xpu": + backend = "xccl" + elif self.device == "hpu": + backend = "hccl" + elif self.device == "cpu": + backend = "gloo" + elif self.device == "npu": + backend = "hccl" + + before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) + if not self.server_args.enable_p2p_check: + monkey_patch_p2p_access_check() + + if self.server_args.dist_init_addr: + dist_init_method = f"tcp://{self.server_args.dist_init_addr}" + else: + dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" + set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) + set_mscclpp_all_reduce(self.server_args.enable_mscclpp) + set_torch_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem) + + if not self.is_draft_worker: + if self.device == "cpu": + if _is_cpu_amx_available: + # Bind OpenMP threads to CPU cores + torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) + + # Set local size to hint SGLang to use shared memory based AllReduce + os.environ["LOCAL_SIZE"] = str(self.tp_size) + torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank) + + @torch.library.register_fake("sgl_kernel::shm_allgather") + def _(data, dim): + return torch.cat([data] * self.tp_size, dim=dim) + + else: + logger.warning( + "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available" + ) + + # Only initialize the distributed environment on the target model worker. + init_distributed_environment( + backend=backend, + world_size=self.tp_size * self.pp_size, + rank=self.tp_size * self.pp_rank + self.tp_rank, + local_rank=self.gpu_id, + ) + initialize_model_parallel( + tensor_model_parallel_size=self.tp_size, + pipeline_model_parallel_size=self.pp_size, + expert_model_parallel_size=self.moe_ep_size, + duplicate_tp_group=self.server_args.enable_pdmux, + torch_compile=self.server_args.enable_piecewise_cuda_graph, + ) + initialize_dp_attention( + server_args=self.server_args, + model_config=self.model_config, + ) + + min_per_gpu_memory = get_available_gpu_memory( + self.device, + self.gpu_id, + distributed=get_world_group().world_size > 1, + cpu_group=get_world_group().cpu_group, + ) + self.tp_group = get_tp_group() + self.pp_group = get_pp_group() + self.attention_tp_group = get_attention_tp_group() + + # Check memory for tensor parallelism + local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) + if self.tp_size > 1 and not self.is_draft_worker: + if min_per_gpu_memory < local_gpu_memory * 0.9: + if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"): + logger.warning( + "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. " + f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}" + ) + else: + raise ValueError( + "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. " + f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}" + ) + + logger.info( + f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB" + ) + return min_per_gpu_memory diff --git a/SpecForge-ext/specforge/modeling/target/sglang_backend/patch.py b/SpecForge-ext/specforge/modeling/target/sglang_backend/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..b48ed611f1d33ee8d198ad7e3cbdc0c58ed920dc --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/sglang_backend/patch.py @@ -0,0 +1,294 @@ +import logging +from typing import Optional + +import sglang.srt.distributed.parallel_state as parallel_state +import torch +import torch.distributed as dist +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed import init_model_parallel_group +from sglang.srt.distributed.parallel_state import GroupCoordinator +from sglang.srt.layers.dp_attention import ( + _DpGatheredBufferWrapper, + compute_dp_attention_local_info, + compute_dp_attention_world_info, +) +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import get_bool_env_var + +from specforge.distributed import get_tp_group as get_specforge_tp_group + +logger = logging.getLogger(__name__) + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + local_rank: int = -1, + backend: str = "nccl", +): + logger.debug( + "world_size=%d rank=%d backend=%s", + world_size, + rank, + backend, + ) + assert ( + torch.distributed.is_initialized() + ), "distributed environment should be initialized first" + + tp_group = get_specforge_tp_group() + world_size = dist.get_world_size() + tp_size = dist.get_world_size(tp_group) + num_tp_groups = world_size // tp_size + tp_ranks = [] + for i in range(num_tp_groups): + tp_ranks.append(list(range(i * tp_size, (i + 1) * tp_size))) + + parallel_state._WORLD = GroupCoordinator( + group_ranks=tp_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=False, + use_pymscclpp=False, + use_custom_allreduce=False, + use_torch_symm_mem_all_reduce=False, + use_hpu_communicator=False, + use_xpu_communicator=False, + use_npu_communicator=False, + group_name="world", + ) + # we destroy the newly created world group and replace it + # with the existing tp group from specforge to save CUDA memory + group_to_destroy = parallel_state._WORLD.device_group + parallel_state._WORLD.device_group = tp_group + dist.destroy_process_group(group_to_destroy) + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, + duplicate_tp_group: bool = False, + torch_compile: Optional[bool] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = parallel_state._WORLD.world_size + backend = backend or dist.get_backend(parallel_state._WORLD.device_group) + + if world_size != tensor_model_parallel_size * pipeline_model_parallel_size: + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + f"pipeline_model_parallel_size ({pipeline_model_parallel_size})" + ) + + # Build the tensor model-parallel groups. + num_tensor_model_parallel_groups: int = ( + dist.get_world_size() // tensor_model_parallel_size + ) + assert ( + parallel_state._TP is None + ), "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list( + range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + ) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + parallel_state._TP = init_model_parallel_group( + group_ranks, + parallel_state._WORLD.local_rank, + backend, + use_message_queue_broadcaster=get_bool_env_var( + "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" + ), + group_name="tp", + pynccl_use_current_stream=duplicate_tp_group, + torch_compile=torch_compile, + ) + + if duplicate_tp_group: + assert ( + parallel_state._PDMUX_PREFILL_TP_GROUP is None + ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized" + assert ( + parallel_state._PDMUX_PREFILL_TP_GROUP is None + ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized" + parallel_state._PDMUX_PREFILL_TP_GROUP = init_model_parallel_group( + group_ranks, + parallel_state._WORLD.local_rank, + backend, + use_message_queue_broadcaster=get_bool_env_var( + "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" + ), + group_name="pdmux_prefill_tp", + pynccl_use_current_stream=True, + torch_compile=torch_compile, + ) + parallel_state._TP.pynccl_comm.disabled = False + parallel_state._PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False + + moe_ep_size = expert_model_parallel_size + + moe_tp_size = tensor_model_parallel_size // moe_ep_size + assert ( + parallel_state._MOE_EP is None + ), "expert model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + for j in range(moe_tp_size): + st = i * tensor_model_parallel_size + j + en = (i + 1) * tensor_model_parallel_size + j + ranks = list(range(st, en, moe_tp_size)) + group_ranks.append(ranks) + + parallel_state._MOE_EP = init_model_parallel_group( + group_ranks, + parallel_state._WORLD.local_rank, + backend, + use_custom_allreduce=False, + group_name="moe_ep", + ) + + assert ( + parallel_state._MOE_TP is None + ), "moe tensor model parallel group is already initialized" + if moe_ep_size == 1: + parallel_state._MOE_TP = parallel_state._TP + else: + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + for j in range(moe_ep_size): + st = i * tensor_model_parallel_size + j * moe_tp_size + en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size + ranks = list(range(st, en)) + group_ranks.append(ranks) + parallel_state._MOE_TP = init_model_parallel_group( + group_ranks, + parallel_state._WORLD.local_rank, + backend, + use_custom_allreduce=False, + group_name="moe_tp", + ) + + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = ( + dist.get_world_size() // pipeline_model_parallel_size + ) + assert ( + parallel_state._PP is None + ), "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list( + range(i, dist.get_world_size(), num_pipeline_model_parallel_groups) + ) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + parallel_state._PP = init_model_parallel_group( + group_ranks, + parallel_state._WORLD.local_rank, + backend, + use_custom_allreduce=False, + group_name="pp", + ) + + +def initialize_dp_attention( + server_args: ServerArgs, + model_config: ModelConfig, +): + import sglang.srt.layers.dp_attention as dp_attention + from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP + + enable_dp_attention = server_args.enable_dp_attention + tp_size = server_args.tp_size + dp_size = server_args.dp_size + moe_dense_tp_size = server_args.moe_dense_tp_size + pp_size = server_args.pp_size + + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + dp_attention._ENABLE_DP_ATTENTION_FLAG = enable_dp_attention + + ( + dp_attention._ATTN_TP_RANK, + dp_attention._ATTN_TP_SIZE, + dp_attention._ATTN_DP_RANK, + ) = compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size) + _, _, dp_attention._LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info( + enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size + ) + + if enable_dp_attention: + dp_attention._ATTN_DP_SIZE = dp_size + if moe_dense_tp_size is None: + dp_attention._LOCAL_ATTN_DP_SIZE = dp_attention._ATTN_DP_SIZE + else: + dp_attention._LOCAL_ATTN_DP_SIZE = max( + 1, dp_size // (tp_size // moe_dense_tp_size) + ) + else: + dp_attention._ATTN_DP_SIZE = 1 + dp_attention._LOCAL_ATTN_DP_SIZE = 1 + + tp_group = parallel_state.get_tp_group() + num_model_parallel_groups = dist.get_world_size() // (pp_size * tp_size) + mp_size = pp_size * tp_size + group_ranks = [] + + for i in range(num_model_parallel_groups): + ranks = [ + list(range(head, head + dp_attention._ATTN_TP_SIZE)) + for head in range( + mp_size * i, mp_size * (i + 1), dp_attention._ATTN_TP_SIZE + ) + ] + group_ranks.extend(ranks) + + dp_attention._ATTN_TP_GROUP = GroupCoordinator( + group_ranks, + tp_group.local_rank, + torch.distributed.get_backend(tp_group.device_group), + use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP, + use_pymscclpp=False, + use_custom_allreduce=False, + use_torch_symm_mem_all_reduce=False, + use_hpu_communicator=False, + use_xpu_communicator=False, + use_npu_communicator=False, + group_name="attention_tp", + ) + # print(f"{parallel_state._ATTN_TP_GROUP=}") + + _DpGatheredBufferWrapper.set_metadata( + hidden_size=model_config.hidden_size, + dtype=model_config.dtype, + device=torch.device(server_args.device), + ) diff --git a/SpecForge-ext/specforge/modeling/target/sglang_backend/utils.py b/SpecForge-ext/specforge/modeling/target/sglang_backend/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..441adbb5808c0614789ac0e3c033dd53788a0e37 --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/sglang_backend/utils.py @@ -0,0 +1,165 @@ +""" +This file contains the wrapper for the SGL model. +""" + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn as nn +from sglang.srt.layers.logits_processor import ( + LogitsMetadata, + LogitsProcessor, + LogitsProcessorOutput, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.server_args import get_global_server_args + + +@dataclass +class ReplacedLogitsProcessorEagle3Output: + """ + A dataclass to store the logits and aux hidden states needed for EAGLE3. + """ + + logits: torch.Tensor + aux_hidden_states: torch.Tensor + last_hidden_states: Optional[torch.Tensor] = None + + +def replaced_logits_processor_forward_for_eagle3( + self, + input_ids, + hidden_states, + lm_head, + logits_metadata: Union[LogitsMetadata, ForwardBatch], + aux_hidden_states: Optional[torch.Tensor] = None, + return_last_hidden_states: bool = False, + return_logits: bool = False, +) -> LogitsProcessorOutput: + """ + This is a modified forward function for the SGLang's logits processor, adapted from https://github.com/sgl-project/sglang/blob/v0.5.4/python/sglang/srt/layers/logits_processor.py. + The modification is to return the logits and aux hidden states instead of the last hidden states. + """ + + if isinstance(logits_metadata, ForwardBatch): + logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) + + # Check if multi-item scoring is enabled via server args (only for prefill-only requests) + multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter + if multi_item_delimiter is not None and logits_metadata.is_prefill_only: + return self.compute_logprobs_for_multi_item_scoring( + input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter + ) + + # Get the last hidden states and last logits for the next token prediction + if ( + logits_metadata.forward_mode.is_decode_or_idle() + or logits_metadata.forward_mode.is_target_verify() + or logits_metadata.forward_mode.is_draft_extend_v2() + ): + pruned_states = hidden_states + if aux_hidden_states is not None: + aux_pruned_states = [hidden for hidden in aux_hidden_states] + sample_indices = None + input_logprob_indices = None + else: + raise RuntimeError( + f"The modified logits processor is not supported for this forward mode: {logits_metadata.forward_mode}" + ) + + if return_last_hidden_states: + last_hidden_states = pruned_states + else: + last_hidden_states = None + + if return_logits: + # Compute logits for both input and sampled tokens. + logits = self._get_logits(pruned_states, lm_head, logits_metadata) + else: + logits = None + + # get the aux hidden states + hidden_states_to_store: Optional[torch.Tensor] = None + if logits_metadata.capture_hidden_mode.need_capture(): + if logits_metadata.capture_hidden_mode.is_full(): + if aux_hidden_states is not None: + aux_hidden_states = torch.cat(aux_hidden_states, dim=-1) + hidden_states_to_store = aux_hidden_states + else: + hidden_states_to_store = hidden_states + elif logits_metadata.capture_hidden_mode.is_last(): + # Get the last token hidden states. If sample_indices is None, + # pruned states only contain the last tokens already. + if aux_hidden_states is not None: + aux_pruned_states = torch.cat(aux_pruned_states, dim=-1) + hidden_states_to_store = ( + aux_pruned_states[sample_indices] + if sample_indices is not None + else aux_pruned_states + ) + else: + hidden_states_to_store = ( + pruned_states[sample_indices] + if sample_indices is not None + else pruned_states + ) + else: + assert False, "Should never reach" + + assert ( + not logits_metadata.extend_return_logprob + ), "extend_return_logprob is not supported" + # Decode mode or extend mode without return_logprob. + return ReplacedLogitsProcessorEagle3Output( + logits=logits, + aux_hidden_states=hidden_states_to_store, + last_hidden_states=last_hidden_states, + ) + + +class LogitsProcessorForEAGLE3(torch.nn.Module): + def __init__( + self, + logits_processor: LogitsProcessor, + return_last_hidden_states: bool = False, + return_logits: bool = False, + ): + super().__init__() + self.logits_processor = logits_processor + self.return_last_hidden_states = return_last_hidden_states + self.return_logits = return_logits + + def forward( + self, + input_ids, + hidden_states, + lm_head, + logits_metadata, + aux_hidden_states: Optional[torch.Tensor] = None, + ) -> LogitsProcessorOutput: + logits_metadata.forward_mode = ForwardMode.DECODE + ret = replaced_logits_processor_forward_for_eagle3( + self.logits_processor, + input_ids, + hidden_states, + lm_head, + logits_metadata, + aux_hidden_states, + self.return_last_hidden_states, + self.return_logits, + ) + return ret + + +def wrap_eagle3_logits_processors_in_module( + module: nn.Module, return_full_logits: bool = False +): + """ + This function will wrap the SGLang's original logits processor with the modified one for EAGLE3. + """ + for name, submodule in module.named_modules(): + if isinstance(submodule, LogitsProcessor): + wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits) + setattr(module, name, wrapped) + print(f"wrapped {name} with LogitsProcessorForEAGLE3") diff --git a/SpecForge-ext/specforge/modeling/target/target_head.py b/SpecForge-ext/specforge/modeling/target/target_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7231117cefa1903e733ead6305642318422e64b9 --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/target_head.py @@ -0,0 +1,92 @@ +import glob +import json +import os +from typing import Optional + +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download +from safetensors import safe_open +from transformers import AutoConfig + +from specforge.utils import padding + + +class TargetHead(nn.Module): + def __init__(self, model_path, trust_remote_code: bool = False): + super().__init__() + self.config = AutoConfig.from_pretrained( + model_path, trust_remote_code=trust_remote_code + ) + self.fc = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) + + @classmethod + def from_pretrained( + cls, + model_path, + lm_head_key: str = "lm_head.weight", + cache_dir: Optional[str] = None, + trust_remote_code: bool = False, + ) -> "TargetHead": + target_head = cls(model_path, trust_remote_code=trust_remote_code) + target_head.load_weights( + model_path=model_path, + lm_head_key=lm_head_key, + cache_dir=cache_dir, + ) + target_head.freeze_weights() + target_head = target_head.eval().cuda().to(torch.bfloat16) + return target_head + + @torch.no_grad() + def load_weights( + self, + model_path, + lm_head_key: str = "lm_head.weight", + cache_dir: Optional[str] = None, + ): + if os.path.exists(model_path): + self.model_path = model_path + else: + self.model_path = snapshot_download(repo_id=model_path) + + # model_path is a local directory + # check if there is file ending with index.json + glob_path = os.path.join(self.model_path, "*.index.json") + index_json_path = glob.glob(glob_path) + + if len(index_json_path) == 0: + raise FileNotFoundError(f"No index.json file found in {self.model_path}") + if len(index_json_path) > 1: + raise FileNotFoundError( + f"Multiple index.json files found in {self.model_path}" + ) + index_json_path = index_json_path[0] + + with open(index_json_path, "r") as f: + index_json = json.load(f) + ckpt_file = index_json["weight_map"][lm_head_key] + + if ckpt_file.endswith(".safetensors"): + with safe_open( + os.path.join(self.model_path, ckpt_file), framework="pt" + ) as f: + lm_head = f.get_tensor(lm_head_key) + else: + state_dict = torch.load(os.path.join(self.model_path, ckpt_file)) + lm_head = state_dict[lm_head_key] + self.fc.weight.copy_(lm_head) + + def freeze_weights(self): + for param in self.fc.parameters(): + param.requires_grad = False + + def forward(self, hidden_states): + return self.fc(hidden_states) + + def preprocess(self, input_ids, target, loss_mask): + # apply pading + target = padding(target, left=False) + input_ids = padding(input_ids, left=False) + loss_mask = loss_mask[..., None] + return input_ids, target, loss_mask diff --git a/SpecForge-ext/specforge/modeling/target/target_utils.py b/SpecForge-ext/specforge/modeling/target/target_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9377753fce2f1391fc14abf2f47a8634c57f0eee --- /dev/null +++ b/SpecForge-ext/specforge/modeling/target/target_utils.py @@ -0,0 +1,134 @@ +import glob +import json +import os +from typing import Optional + +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download +from safetensors import safe_open +from transformers import AutoConfig + + +class TargetEmbeddingsAndHead(nn.Module): + """ + Efficiently loads only the embedding layer and lm_head from a pretrained model. + Avoids loading the full model into memory. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + @classmethod + def from_pretrained( + cls, + model_path: str, + embed_key: str = "model.embed_tokens.weight", + lm_head_key: str = "lm_head.weight", + cache_dir: Optional[str] = None, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + trust_remote_code: bool = False, + ) -> "TargetEmbeddingsAndHead": + + # 1. Load Config + config = AutoConfig.from_pretrained( + model_path, cache_dir=cache_dir, trust_remote_code=trust_remote_code + ) + instance = cls(config) + + # 2. Resolve Model Path (Handle Hub) + local_model_path = model_path + if not os.path.exists(local_model_path): + try: + local_model_path = snapshot_download( + repo_id=model_path, cache_dir=cache_dir + ) + except: + pass # Maybe it's a local path that looks like a repo ID but doesn't exist? + + # 3. Load Weights Efficiently + instance._load_weights(local_model_path, embed_key, lm_head_key) + + # 4. Move to Device & Freeze + instance.to(device=device, dtype=dtype) + instance.eval() + instance.requires_grad_(False) + + return instance + + def _load_weights(self, model_path: str, embed_key: str, lm_head_key: str): + # Locate index.json + index_files = glob.glob(os.path.join(model_path, "*.index.json")) + + weight_map = {} + if index_files: + # Sharded Checkpoint + with open(index_files[0], "r") as f: + index = json.load(f) + + # Find which file contains our keys + weight_map = index.get("weight_map", {}) + files_to_load = {} + + if embed_key in weight_map: + files_to_load[embed_key] = weight_map[embed_key] + else: + # Fallback: sometimes keys are prefixed differently? + print( + f"Warning: {embed_key} not found in weight_map. Keys available: {list(weight_map.keys())[:5]}..." + ) + + if lm_head_key in weight_map: + files_to_load[lm_head_key] = weight_map[lm_head_key] + + # Load specific files + for key, filename in files_to_load.items(): + file_path = os.path.join(model_path, filename) + self._load_key_from_file(file_path, key) + + else: + # Non-sharded Checkpoint (single file) + # Try finding .safetensors or .bin + safetensors = glob.glob(os.path.join(model_path, "*.safetensors")) + bins = glob.glob(os.path.join(model_path, "*.bin")) + + target_file = None + if safetensors: + target_file = safetensors[0] + elif bins: + target_file = bins[0] + + if target_file: + self._load_key_from_file(target_file, embed_key) + self._load_key_from_file(target_file, lm_head_key) + else: + raise FileNotFoundError(f"No checkpoint file found in {model_path}") + + def _load_key_from_file(self, file_path: str, key: str): + tensor = None + if file_path.endswith(".safetensors"): + with safe_open(file_path, framework="pt") as f: + if key in f.keys(): + tensor = f.get_tensor(key) + else: + # torch.load loads full dict, less efficient but works + state_dict = torch.load(file_path, map_location="cpu") + if key in state_dict: + tensor = state_dict[key] + del state_dict # Free immediately + + if tensor is not None: + if key.endswith("embed_tokens.weight"): + self.embed_tokens.weight.data.copy_(tensor) + print(f"Loaded embedding weights from {file_path}") + elif key.endswith("lm_head.weight"): + self.lm_head.weight.data.copy_(tensor) + print(f"Loaded lm_head weights from {file_path}") + else: + print(f"Warning: Key {key} not found in {file_path}")