Lekr0 commited on
Commit
eae7bce
·
verified ·
1 Parent(s): d522318

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SpecForge-ext/benchmarks/benchmarker/__pycache__/livecodebench.cpython-311.pyc +0 -0
  2. SpecForge-ext/cache/compiled_kernels/34/c34af36gfqnn2ovywuaultc2pol4jyn6io3szgjeuv3uxfzcf3nv.py +43 -0
  3. SpecForge-ext/cache/compiled_kernels/37/49508e3b35fb555ab64ad6f410ad33153cf779bf7b9d6de2ca009401cf12419e.best_config +1 -0
  4. SpecForge-ext/cache/compiled_kernels/37/c37gymepdyiyzp5hh2xt3a5vqmje2frbmyiqgipqpazjx6xcuyyb.py +66 -0
  5. SpecForge-ext/cache/compiled_kernels/4b/c4b4wkdm2d2z4hysjzfo6cyikw75man4bednwbsjwot4lkx7xfzs.py +47 -0
  6. SpecForge-ext/cache/compiled_kernels/4g/c4gcdzc7dkmej2ceuy3ivyfjm5wjukkm4mbbdcmc7uaq76svnppo.py +159 -0
  7. SpecForge-ext/cache/compiled_kernels/4g/c4gr37y26wd4va4drshauwjr3p5l32j5cssih4o5yz3h2g6jkxrz.py +89 -0
  8. SpecForge-ext/cache/compiled_kernels/4r/c4rogici325xsxgkeljczx3sx57vcsimyzupeu4nvgmivqoqosiz.py +1065 -0
  9. SpecForge-ext/cache/compiled_kernels/4w/669c5a8c8205272d44ea075b78e46cd1bf13f1ebe3d56d5ab422037277c923dc.best_config +1 -0
  10. SpecForge-ext/cache/compiled_kernels/4w/c4wdhwlu6yb3wcwazdnzmgzewiemvznxvrr3525eojupqjldo5pt.py +47 -0
  11. SpecForge-ext/cache/compiled_kernels/4w/c4ww5pmlr6amerprh7v3ibioh3yvbhemdqsh7gcrlxjnhnpkktrb.py +835 -0
  12. SpecForge-ext/cache/compiled_kernels/7t/466129ab41abc9f5794b92b332ac4be3dff826e8f59dc7fb522710de7206acdd.best_config +1 -0
  13. SpecForge-ext/cache/compiled_kernels/7t/c7t3uvardqlt6x3sz37tlydghb4rt6mdilzlc7ffz3pehdn5jwdj.py +49 -0
  14. SpecForge-ext/cache/compiled_kernels/ah/855c4fb51632a42fcf957963b85ead1d6653657da855baf9d7c221cfd3981ad0.best_config +1 -0
  15. SpecForge-ext/cache/compiled_kernels/ah/cah767udo2rzeazh6rycnirtnr5sijiv7nem2l67isu5iyh5pzyj.py +56 -0
  16. SpecForge-ext/cache/compiled_kernels/ak/cak5ufwwwsut5tju7yvwho5uqnabsn2za7nzkoy573tny5kqhtl5.py +552 -0
  17. SpecForge-ext/cache/compiled_kernels/ak/cakglntm3ejviis7qbld6stbcfdrpvbryqpb63fshmmyy46mxbh3.py +675 -0
  18. SpecForge-ext/cache/compiled_kernels/as/9f962df2938e79169dbf28adc9c67d12118719f3425569a98b11309d3108a638.best_config +1 -0
  19. SpecForge-ext/cache/compiled_kernels/as/casevqrknafvhxbpwjozemzmdw3n2vgrctm4s4zdjzqp52cqs6kd.py +693 -0
  20. SpecForge-ext/cache/compiled_kernels/as/casmcbz6icqn6mp2r7jahugidys5xwty64z2p3tfw4s7vlsj2oz2.py +66 -0
  21. SpecForge-ext/cache/compiled_kernels/c2/363ecfeae02cf0bc03b4070f8b6a6ac6bcf543c1a19d1c4a53122d5722a2b3dd.best_config +1 -0
  22. SpecForge-ext/cache/compiled_kernels/c2/cc2qlkbbemfommyywsdbow3sqg7jqf5x5tfkbqjzo2qy6lt36yjr.py +86 -0
  23. SpecForge-ext/cache/compiled_kernels/dm/cdma2uevipbm2dd462ztkubtq5uanau5l3oglcw7lhpt4uovlqya.py +835 -0
  24. SpecForge-ext/cache/compiled_kernels/dm/cdmv6ytwvbipl4lagbifdkedszdjny3opgqlnricedg4hfpkxbdo.py +47 -0
  25. SpecForge-ext/cache/compiled_kernels/du/cduexexwzoejgfo3kafnuhcdb2jpdj5mqnwnijlnqydzf2tfuyoh.py +682 -0
  26. SpecForge-ext/cache/compiled_kernels/dz/cdz3io7w5uyfrmfqvmg2kt2ay66qv4ckwtyurhik3frq7fqnk7gm.py +66 -0
  27. SpecForge-ext/cache/compiled_kernels/dz/f7d5f2184a6f349e4531c61cf67ffbd51fe751bb6902c7e014986bad1a4a9b8f.best_config +1 -0
  28. SpecForge-ext/cache/compiled_kernels/fa/cfac6ze2ka7xqvmyxx4ehmqqczd7mi63mu366jgrbaebsyxjcuna.py +307 -0
  29. SpecForge-ext/cache/compiled_kernels/fa/cfai6qfroimjkp32i57fqulbbxd7ap7nwbhmtwtra7dawieplflr.py +168 -0
  30. SpecForge-ext/cache/compiled_kernels/fa/cfail5nyr4vuktxoags33cssvkjxk2nbmzhswhjwxszpyc4qj4wf.py +675 -0
  31. SpecForge-ext/cache/compiled_kernels/fa/cfawzdo3q32syzk5d3t3mjridjbalgrkptn5qwko7qnup25mzrum.py +57 -0
  32. SpecForge-ext/cache/compiled_kernels/fi/cfiplsvt2q6tbvsfjtg2dd47g7npdwtvk5m3lv4anjbxwgjigkj2.py +72 -0
  33. SpecForge-ext/cache/compiled_kernels/h6/aa838d40f4d0e483f1277be61c094ff598dd757fa08fb0e455bf7c8a9b79036a.best_config +1 -0
  34. SpecForge-ext/cache/compiled_kernels/ic/cicti66tef7ykscmewrfizq5t5hma2a6k6njneyopvmhy4vmegql.py +543 -0
  35. SpecForge-ext/cache/compiled_kernels/ii/ciiz7wynjvqkn6uv5csahwryt5x2d664u4o7ugmepfcsfcniut4v.py +48 -0
  36. SpecForge-ext/cache/compiled_kernels/ik/ciksm4jphopwjgs55fbipcxecpw4d643lh76mj27636ryec4e3kg.py +552 -0
  37. SpecForge-ext/cache/compiled_kernels/is/cisbwn452kdvm56u75a2mwmrdzns6w4vxzuweva24qshuv4gksv2.py +26 -0
  38. SpecForge-ext/cache/compiled_kernels/is/d02b763bc26b4a862acff11bb1d83ee2ff669b1418d106ae0058cadf26d0f276.best_config +1 -0
  39. SpecForge-ext/cache/compiled_kernels/iy/ciy3jtwq2kqsaaylz6g2uxngpmmalnqcompyd7v6diseejxhwvzs.py +37 -0
  40. SpecForge-ext/cache/compiled_kernels/lk/clk4cgl52lrdnpqzv6ubpxawah5lw2cyfnmsbuouupfi5emjbchn.py +1083 -0
  41. SpecForge-ext/cache/compiled_kernels/lp/1e661150415d7fce0f5577d7db35f128089400ce692c8dfdf5e40cb9a867cea5.best_config +1 -0
  42. SpecForge-ext/cache/compiled_kernels/lp/clp43olymjc72eay3ukgvj6r4apcbbbnz3xlli3tafgvidlacqsg.py +322 -0
  43. SpecForge-ext/cache/compiled_kernels/lp/clpt6xpoqv3wajdkyviksqw24bkxb47w4kcgihhcyrj553fxcjqs.py +50 -0
  44. SpecForge-ext/cache/compiled_kernels/ls/cls3ju4iskgwc7wepn2m46svt5vbvf47ps3tsfw7s37earyzkzz2.py +62 -0
  45. SpecForge-ext/cache/compiled_kernels/ng/cnglvt55axgj3x37cqns4hg7zsjeu57rkczufz7vpm5o4rwbf2w7.py +164 -0
  46. SpecForge-ext/cache/compiled_kernels/p6/cp66nvdwdzgxajxp2yjtqapnwidpmfnzcyyalh6z5w6f6lf3aoej.py +56 -0
  47. SpecForge-ext/cache/compiled_kernels/qa/cqambnamuby4hynvyzhccuoc4f5nkvwpn7yeizvaaaojnmlep42d.py +62 -0
  48. SpecForge-ext/cache/compiled_kernels/qa/cqasclcikvb2uryr7k2gtwdnliae55wql22q6kutfmldlk5e7kks.py +41 -0
  49. SpecForge-ext/cache/compiled_kernels/qd/cqd6lffrumnqrtflwfoqtqs6mvn23l4bxialovx3yvqgximtpflz.py +66 -0
  50. SpecForge-ext/cache/compiled_kernels/qd/cqd7l2ktsaxhv4w2pgoiwvrihj6ya2rmzfvnjybryke4aa6nwpjp.py +62 -0
SpecForge-ext/benchmarks/benchmarker/__pycache__/livecodebench.cpython-311.pyc ADDED
Binary file (2.92 kB). View file
 
SpecForge-ext/cache/compiled_kernels/34/c34af36gfqnn2ovywuaultc2pol4jyn6io3szgjeuv3uxfzcf3nv.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 32, 'r0_': 16},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x0 = xindex
27
+ _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_1 = r0_index
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
35
+ tmp1 = tmp0.to(tl.int64)
36
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
37
+ tmp4 = _tmp3 + tmp2
38
+ _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3)
39
+ tmp3 = tl.sum(_tmp3, 1)[:, None]
40
+ x2 = (xindex % ks1)
41
+ x3 = xindex // ks1
42
+ tmp5 = tmp3.to(tl.int32)
43
+ tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask)
SpecForge-ext/cache/compiled_kernels/37/49508e3b35fb555ab64ad6f410ad33153cf779bf7b9d6de2ca009401cf12419e.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 66, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"}
SpecForge-ext/cache/compiled_kernels/37/c37gymepdyiyzp5hh2xt3a5vqmje2frbmyiqgipqpazjx6xcuyyb.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 67108864},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = (xindex % ks0)
23
+ x3 = xindex
24
+ x1 = ((xindex // ks0) % ks1)
25
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
26
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
27
+ tmp0 = x0
28
+ tmp1 = ks0 // 2
29
+ tmp2 = tmp0 >= tmp1
30
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
31
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
32
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
33
+ tmp6 = tmp4 + tmp5
34
+ tmp7 = tmp4 < 0
35
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
36
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
37
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
38
+ tmp11 = tmp3 * tmp10
39
+ tmp12 = -tmp11
40
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
41
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
42
+ tmp15 = 0.0
43
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
44
+ tmp17 = tmp0 < tmp1
45
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
46
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
47
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
48
+ tmp21 = tmp19 + tmp20
49
+ tmp22 = tmp19 < 0
50
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
51
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
52
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
53
+ tmp26 = tmp18 * tmp25
54
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
55
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
56
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
57
+ tmp30 = tmp16 + tmp29
58
+ tmp33 = ks3
59
+ tmp34 = tmp32 + tmp33
60
+ tmp35 = tmp32 < 0
61
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
62
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
63
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
64
+ tmp39 = tmp31 * tmp38
65
+ tmp40 = tmp30 + tmp39
66
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
SpecForge-ext/cache/compiled_kernels/4b/c4b4wkdm2d2z4hysjzfo6cyikw75man4bednwbsjwot4lkx7xfzs.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 32000
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = (xindex % ks0)
28
+ x1 = xindex // ks0
29
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
30
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
31
+ x3 = xindex
32
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
33
+ r0_index = r0_offset + r0_base
34
+ r0_mask = r0_index < r0_numel
35
+ roffset = r0_offset
36
+ rindex = r0_index
37
+ r0_2 = r0_index
38
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
39
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
40
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
41
+ _tmp2, _tmp2_index, tmp1, rindex
42
+ )
43
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
44
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
45
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
46
+ tmp2 = tmp2_idx[:, None]
47
+ tl.store(out_ptr0 + (x3), tmp2, xmask)
SpecForge-ext/cache/compiled_kernels/4g/c4gcdzc7dkmej2ceuy3ivyfjm5wjukkm4mbbdcmc7uaq76svnppo.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['1_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/e4/ce4fv76qvag73sibbo3mhwtavvyq3wneu5xe4faj6ybtsqisdlvr.py
38
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
39
+ # Source node to ATen node mapping:
40
+ # target_head => convert_element_type
41
+ # target_p => div
42
+ # Graph fragment:
43
+ # %arg0_1 : Tensor "bf16[8, 2048, 32000][65536000, 32000, 1]cuda:4" = PlaceHolder[target=arg0_1]
44
+ # %getitem : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:4" = PlaceHolder[target=getitem]
45
+ # %getitem_1 : Tensor "f32[8, 2048, 1][2048, 1, 16384]cuda:4" = PlaceHolder[target=getitem_1]
46
+ # %convert_element_type : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:4"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg0_1, torch.float32), kwargs = {})
47
+ # %prepare_softmax_online_default : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%convert_element_type, 2), kwargs = {})
48
+ # %sub_tensor : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%convert_element_type, %getitem), kwargs = {})
49
+ # %exp_default : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor,), kwargs = {})
50
+ # %div : Tensor "f32[8, 2048, 32000][65536000, 32000, 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default, %getitem_1), kwargs = {})
51
+ # return %getitem,%getitem_1,%div
52
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0 = async_compile.triton('triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', '''
53
+ import triton
54
+ import triton.language as tl
55
+
56
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
57
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
58
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
59
+ triton_helpers.set_driver_to_gpu()
60
+
61
+ @triton_heuristics.reduction(
62
+ size_hints={'x': 16384, 'r0_': 32768},
63
+ reduction_hint=ReductionHint.INNER,
64
+ filename=__file__,
65
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
66
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 5242880000}}
67
+ )
68
+ @triton.jit
69
+ def triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0(in_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
70
+ xnumel = 16384
71
+ r0_numel = 32000
72
+ rnumel = r0_numel
73
+ RBLOCK: tl.constexpr = R0_BLOCK
74
+ xoffset = tl.program_id(0) * XBLOCK
75
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
76
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
77
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
78
+ rbase = r0_base
79
+ x0 = xindex
80
+ _tmp3_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
81
+ _tmp3_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
82
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
83
+ r0_index = r0_offset + r0_base
84
+ r0_mask = r0_index < r0_numel
85
+ roffset = r0_offset
86
+ rindex = r0_index
87
+ r0_1 = r0_index
88
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
89
+ tmp1 = tmp0.to(tl.float32)
90
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
91
+
92
+ _tmp3_max_next, _tmp3_sum_next = triton_helpers.online_softmax_combine(
93
+ _tmp3_max, _tmp3_sum, tmp2, False
94
+ )
95
+
96
+ _tmp3_max = tl.where(r0_mask, _tmp3_max_next, _tmp3_max)
97
+ _tmp3_sum = tl.where(r0_mask, _tmp3_sum_next, _tmp3_sum)
98
+
99
+ tmp3, tmp4 = triton_helpers.online_softmax_reduce(
100
+ _tmp3_max, _tmp3_sum, 1, False)
101
+ tmp3 = tmp3[:, None]
102
+ tmp4 = tmp4[:, None]
103
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
104
+ r0_index = r0_offset + r0_base
105
+ r0_mask = r0_index < r0_numel
106
+ roffset = r0_offset
107
+ rindex = r0_index
108
+ r0_1 = r0_index
109
+ tmp5 = tl.load(in_ptr0 + (r0_1 + 32000*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
110
+ tmp6 = tmp5.to(tl.float32)
111
+ tmp7 = tmp6 - tmp3
112
+ tmp8 = libdevice.exp(tmp7)
113
+ tmp9 = (tmp8 / tmp4)
114
+ tl.store(out_ptr2 + (r0_1 + 32000*x0), tmp9, r0_mask)
115
+ ''', device_str='cuda')
116
+
117
+
118
+ async_compile.wait(globals())
119
+ del async_compile
120
+
121
+ class Runner:
122
+ def __init__(self, partitions):
123
+ self.partitions = partitions
124
+
125
+ def recursively_apply_fns(self, fns):
126
+ new_callables = []
127
+ for fn, c in zip(fns, self.partitions):
128
+ new_callables.append(fn(c))
129
+ self.partitions = new_callables
130
+
131
+ def call(self, args):
132
+ arg0_1, = args
133
+ args.clear()
134
+ assert_size_stride(arg0_1, (8, 2048, 32000), (65536000, 32000, 1))
135
+ with torch.cuda._DeviceGuard(4):
136
+ torch.cuda.set_device(4)
137
+ buf2 = empty_strided_cuda((8, 2048, 32000), (65536000, 32000, 1), torch.float32)
138
+ # Topologically Sorted Source Nodes: [target_head, target_p], Original ATen: [aten._to_copy, prims.prepare_softmax_online, aten.sub, aten.exp, aten._softmax]
139
+ stream4 = get_raw_stream(4)
140
+ triton_red_fused__softmax__to_copy_exp_prepare_softmax_online_sub_0.run(arg0_1, buf2, 16384, 32000, stream=stream4)
141
+ del arg0_1
142
+ return (buf2, )
143
+
144
+ runner = Runner(partitions=[])
145
+ call = runner.call
146
+ recursively_apply_fns = runner.recursively_apply_fns
147
+
148
+
149
+ def benchmark_compiled_module(times=10, repeat=10):
150
+ from torch._dynamo.testing import rand_strided
151
+ from torch._inductor.utils import print_performance
152
+ arg0_1 = rand_strided((8, 2048, 32000), (65536000, 32000, 1), device='cuda:4', dtype=torch.bfloat16)
153
+ fn = lambda: call([arg0_1])
154
+ return print_performance(fn, times=times, repeat=repeat)
155
+
156
+
157
+ if __name__ == "__main__":
158
+ from torch._inductor.wrapper_benchmark import compiled_module_main
159
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/4g/c4gr37y26wd4va4drshauwjr3p5l32j5cssih4o5yz3h2g6jkxrz.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 1024, 'r0_': 16384},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 16384
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = (xindex % ks0)
28
+ x1 = ((xindex // ks0) % 16)
29
+ x2 = xindex // ks2
30
+ _tmp36 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
31
+ x5 = xindex
32
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
33
+ r0_index = r0_offset + r0_base
34
+ r0_mask = r0_index < r0_numel
35
+ roffset = r0_offset
36
+ rindex = r0_index
37
+ r0_3 = (r0_index % 128)
38
+ r0_4 = r0_index // 128
39
+ tmp0 = r0_3 + 128*x0
40
+ tmp1 = ks1
41
+ tmp2 = tmp0 < tmp1
42
+ tmp3 = r0_4 + 128*x1
43
+ tmp4 = r0_3 + 128*x0
44
+ tmp5 = tmp3 >= tmp4
45
+ tmp6 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp2 & xmask, eviction_policy='evict_last', other=0.0)
46
+ tmp7 = tmp4 < tmp6
47
+ tmp8 = tmp3 < tmp6
48
+ tmp9 = tmp7 & tmp8
49
+ tmp10 = tmp5 & tmp9
50
+ tmp11 = tl.full([1, 1], False, tl.int1)
51
+ tmp12 = tmp11 | tmp10
52
+ tmp13 = tl.full([1, 1], 2048, tl.int64)
53
+ tmp14 = tmp4 >= tmp13
54
+ tmp15 = ((r0_3 + 128*x0) % 2048)
55
+ tmp16 = tmp15 < tmp6
56
+ tmp17 = tmp14 & tmp16
57
+ tmp18 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
58
+ tmp19 = (tmp18 % tmp13)
59
+ tmp20 = tl.full([1, 1], 0, tl.int32)
60
+ tmp21 = tmp19 != tmp20
61
+ tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0
62
+ tmp23 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0
63
+ tmp24 = tmp22 != tmp23
64
+ tmp25 = tmp21 & tmp24
65
+ tmp26 = tmp19 + tmp13
66
+ tmp27 = tl.where(tmp25, tmp26, tmp19)
67
+ tmp28 = tl.full([1, 1], 0, tl.int64)
68
+ tmp29 = tmp27 == tmp28
69
+ tmp30 = tmp17 & tmp29
70
+ tmp31 = tmp12 | tmp30
71
+ tmp32 = tl.full(tmp31.shape, False, tmp31.dtype)
72
+ tmp33 = tl.where(tmp2, tmp31, tmp32)
73
+ tmp34 = tmp33.to(tl.int64)
74
+ tmp35 = tl.broadcast_to(tmp34, [XBLOCK, R0_BLOCK])
75
+ tmp37 = _tmp36 + tmp35
76
+ _tmp36 = tl.where(r0_mask & xmask, tmp37, _tmp36)
77
+ tmp36 = tl.sum(_tmp36, 1)[:, None]
78
+ tmp38 = tl.full([1, 1], 0, tl.int64)
79
+ tmp39 = tmp36 > tmp38
80
+ tmp40 = tl.full([1, 1], 16384, tl.int64)
81
+ tmp41 = tmp36 < tmp40
82
+ tmp42 = tmp39 & tmp41
83
+ tmp43 = tmp42.to(tl.int8)
84
+ tmp44 = tmp43.to(tl.int32)
85
+ tmp45 = tmp36 == tmp40
86
+ tmp46 = tmp45.to(tl.int8)
87
+ tmp47 = tmp46.to(tl.int32)
88
+ tl.store(out_ptr1 + (x5), tmp44, xmask)
89
+ tl.store(out_ptr2 + (x5), tmp47, xmask)
SpecForge-ext/cache/compiled_kernels/4r/c4rogici325xsxgkeljczx3sx57vcsimyzupeu4nvgmivqoqosiz.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['9_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/jc/cjcezd4fm2g2fppy44lhtzc36sz7bi63sscwdmenwlvu3y4xt7np.py
38
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
39
+ # Source node to ATen node mapping:
40
+ # Graph fragment:
41
+ # %getitem : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem]
42
+ # %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:2" = PlaceHolder[target=tangents_1]
43
+ # %buf0 : Tensor "bf16[8, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=buf0]
44
+ # %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:2, pin_memory: False})
45
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {})
46
+ # return %buf0,%buf1
47
+ triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', '''
48
+ import triton
49
+ import triton.language as tl
50
+
51
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
52
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
53
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
54
+ triton_helpers.set_driver_to_gpu()
55
+
56
+ @triton_heuristics.reduction(
57
+ size_hints={'x': 524288, 'r0_': 128},
58
+ reduction_hint=ReductionHint.DEFAULT,
59
+ filename=__file__,
60
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
61
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4194304, 'r0_': 268435456}}
62
+ )
63
+ @triton.jit
64
+ def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
65
+ xnumel = 524288
66
+ r0_numel = 128
67
+ rnumel = r0_numel
68
+ RBLOCK: tl.constexpr = R0_BLOCK
69
+ xoffset = tl.program_id(0) * XBLOCK
70
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
71
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
72
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
73
+ rbase = r0_base
74
+ x0 = (xindex % 2048)
75
+ x1 = ((xindex // 2048) % 32)
76
+ x2 = xindex // 65536
77
+ x4 = xindex
78
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
79
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
80
+ r0_index = r0_offset + r0_base
81
+ r0_mask = r0_index < r0_numel
82
+ roffset = r0_offset
83
+ rindex = r0_index
84
+ r0_3 = r0_index
85
+ tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 8388608*x2), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
86
+ tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x4), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
87
+ tmp2 = tmp0 * tmp1
88
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
89
+ tmp5 = _tmp4 + tmp3
90
+ _tmp4 = tl.where(r0_mask, tmp5, _tmp4)
91
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
92
+ tmp6 = tmp4.to(tl.float32)
93
+ tmp7 = 0.0
94
+ tmp8 = tmp6 - tmp7
95
+ tl.store(out_ptr1 + (x4), tmp8, None)
96
+ ''', device_str='cuda')
97
+
98
+
99
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bh/cbhqle56n7we4b4miasvgh4jqrjbkehmv3legvjui32dka2bilvr.py
100
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
101
+ # Source node to ATen node mapping:
102
+ # Graph fragment:
103
+ # %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_1]
104
+ # %primals_3 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:2" = PlaceHolder[target=primals_3]
105
+ # %primals_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:2" = PlaceHolder[target=primals_5]
106
+ # %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=getitem_1]
107
+ # %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=buf1]
108
+ # %tangents_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 262144, 128, 1]cuda:2" = PlaceHolder[target=tangents_1]
109
+ # %getitem_3 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=getitem_3]
110
+ # %getitem_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:2" = PlaceHolder[target=getitem_5]
111
+ # %primals_9 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_9]
112
+ # %primals_7 : Tensor "i32[8, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:2" = PlaceHolder[target=primals_7]
113
+ # %primals_15 : Tensor "i32[8, 1, s56][s56, s56, 1]cuda:2" = PlaceHolder[target=primals_15]
114
+ # %primals_17 : Tensor "i32[8, 1, s84, 16][16*s84, 16*s84, 16, 1]cuda:2" = PlaceHolder[target=primals_17]
115
+ # %primals_11 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_11]
116
+ # %primals_13 : Tensor "i32[8, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:2" = PlaceHolder[target=primals_13]
117
+ # %primals_19 : Tensor "i32[8, 1, s99][s99, s99, 1]cuda:2" = PlaceHolder[target=primals_19]
118
+ # %primals_21 : Tensor "i32[8, 1, s6, 16][16*s6, 16*s6, 16, 1]cuda:2" = PlaceHolder[target=primals_21]
119
+ # %primals_10 : Tensor "i64[8][1]cuda:2" = PlaceHolder[target=primals_10]
120
+ # %full_default : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, 2048], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:2, pin_memory: False})
121
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_1, %primals_3, %primals_5, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {})
122
+ # return %getitem_4
123
+ triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', '''
124
+ import triton
125
+ import triton.language as tl
126
+
127
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
128
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
129
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
130
+
131
+ @triton_heuristics.template(
132
+
133
+ num_stages=3,
134
+ num_warps=8,
135
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
136
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
137
+
138
+ )
139
+ @triton.jit
140
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3):
141
+ PRESCALE_QK : tl.constexpr = False
142
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
143
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
144
+ WRITE_DQ : tl.constexpr = True
145
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
146
+ OUTPUT_MAX : tl.constexpr = False
147
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
148
+ IS_DIVISIBLE : tl.constexpr = False
149
+ SM_SCALE : tl.constexpr = 0.08838834764831843
150
+ GQA_SHARED_HEADS : tl.constexpr = 4
151
+ HAS_FULL_BLOCKS : tl.constexpr = True
152
+ QK_HEAD_DIM : tl.constexpr = 128
153
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
154
+ V_HEAD_DIM : tl.constexpr = 128
155
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
156
+ SAFE_HEAD_DIM : tl.constexpr = True
157
+ BLOCK_M1 : tl.constexpr = 64
158
+ BLOCK_N1 : tl.constexpr = 128
159
+ BLOCK_M2 : tl.constexpr = 128
160
+ BLOCK_N2 : tl.constexpr = 64
161
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
162
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
163
+ INDEX_DTYPE : tl.constexpr = tl.int32
164
+ Q = arg_Q
165
+ K = arg_K
166
+ V = arg_V
167
+ LSE = arg_LSE
168
+ DELTA = arg_DELTA
169
+ DO = arg_DO
170
+ DQ = arg_DQ
171
+ DV = arg_DV
172
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
173
+ KV_IDX = arg_KV_IDX
174
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
175
+ Q_IDX = arg_Q_IDX
176
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
177
+ FULL_KV_IDX = arg_FULL_KV_IDX
178
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
179
+ FULL_Q_IDX = arg_FULL_Q_IDX
180
+
181
+ # Sub notation for this kernel:
182
+ #
183
+ # Q: Query, K: Key, V: Value
184
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
185
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
186
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
187
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
188
+ # inductor codegen
189
+ # M: Number of queries, N: Number of keys/values
190
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
191
+ # V_HEAD_DIM: The dimension of the value embeddings
192
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
193
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
194
+ # (Modifiable) Performance tuning options
195
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
196
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
197
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
198
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
199
+ #
200
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
201
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
202
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
203
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
204
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
205
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
206
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
207
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
208
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
209
+
210
+ # The below are kernel options that can be applied for certain score_mods,
211
+ # or involve a numerics vs. perf tradeoff
212
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
213
+ # about 20% more numerical error, but slightly faster.
214
+
215
+ # Define strides of inputs
216
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
217
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks0, 128*ks0, 128, 1
218
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks0, 128*ks0, 128, 1
219
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
220
+
221
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
222
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks0, 128*ks0, 128, 1
223
+
224
+ ZQ = 8
225
+ HQ = 32
226
+ HKV = 8
227
+ Q_LEN = 2048
228
+ ZKV = 8
229
+ KV_LEN = ks0
230
+
231
+ MATMUL_PRECISION = Q.dtype.element_ty
232
+
233
+ pid = tl.program_id(0).to(INDEX_DTYPE)
234
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
235
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
236
+
237
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
238
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
239
+ off_zkv = off_zq % ZKV # kv batch idx
240
+
241
+ SPARSE_Z = 8
242
+ SPARSE_HQ = 1
243
+
244
+ sparse_idx_z = off_zq % SPARSE_Z
245
+
246
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
247
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
248
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
249
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
250
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
251
+
252
+ # offset K, V, DV pointers for batch/kv-head
253
+ K += k_adj
254
+ V += v_adj
255
+ DV += dv_adj
256
+
257
+ RCP_LN2 = 1.44269504
258
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
259
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
260
+
261
+ if pid >= NUM_KV_BLOCKS:
262
+ off_pid = pid - NUM_KV_BLOCKS
263
+ # THIS BLOCK DOES DQ
264
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
265
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
266
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
267
+ start_m2_block = off_pid % NUM_Q_BLOCKS
268
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
269
+ stride_kv_num_blks_h = 16
270
+ stride_kv_idx_h = 16*ks1
271
+ stride_kv_idx_m = ks1
272
+
273
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
274
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
275
+
276
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
277
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
278
+
279
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
280
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
281
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
282
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
283
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
284
+
285
+ Q2 = Q + q_adj2
286
+ DO2 = DO + do_adj2
287
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
288
+ # if Q is broadcasted)
289
+ DQ2 = DQ + dq_adj2
290
+ LSE2 = LSE + off_chz2
291
+ DELTA2 = DELTA + off_chz2
292
+
293
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
294
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
295
+
296
+ start_m2 = start_m2_block * BLOCK_M2
297
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
298
+
299
+ # load Q and do: they stay in SRAM throughout the inner loop.
300
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
301
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
302
+
303
+ if PRESCALE_QK:
304
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
305
+
306
+ if IS_DIVISIBLE:
307
+ Di = tl.load(DELTA2 + offs_m2)
308
+ lse = tl.load(LSE2 + offs_m2)
309
+ else:
310
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
311
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
312
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
313
+ lse = lse[:, None]
314
+
315
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
316
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
317
+ kv_indices = KV_IDX + sparse_kv_idx_offset
318
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
319
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
320
+
321
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
322
+ dq = bwd_dq_inner(
323
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
324
+ K, V,
325
+ dq, q, do, Di, lse,
326
+ off_zq, off_hq2, offs_m2, offs_n2,
327
+ stride_kn, stride_kd, stride_vn, stride_vd,
328
+ kv_indices, sparse_kv_num_blocks,
329
+ MATMUL_PRECISION,
330
+ IS_FULL_BLOCKS=False,
331
+ )
332
+
333
+ if HAS_FULL_BLOCKS:
334
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
335
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
336
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
337
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
338
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
339
+
340
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
341
+ dq = bwd_dq_inner(
342
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
343
+ K, V,
344
+ dq, q, do, Di, lse,
345
+ off_zq, off_hq2, offs_m2, offs_n2,
346
+ stride_kn, stride_kd, stride_vn, stride_vd,
347
+ kv_indices, sparse_kv_num_blocks,
348
+ MATMUL_PRECISION,
349
+ IS_FULL_BLOCKS=True,
350
+ )
351
+
352
+ # Write back dQ.
353
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
354
+ dq *= SM_SCALE
355
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
356
+ tl.store(dq_ptrs, dq)
357
+ else:
358
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
359
+ else:
360
+ # THIS BLOCK DOES DK & DV
361
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
362
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
363
+
364
+ pid_mask = pid // SPARSE_KV_MULTIPLE
365
+
366
+ stride_q_num_blks_h = ks2
367
+ stride_q_idx_h = 16*ks3
368
+ stride_q_idx_n = 16
369
+
370
+
371
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
372
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
373
+
374
+ start_n1 = pid * BLOCK_N1
375
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
376
+
377
+ # load K and V: they stay in SRAM throughout the inner loop.
378
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
379
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
380
+
381
+ if PRESCALE_QK:
382
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
383
+
384
+ for off_g in range(0, GQA_SHARED_HEADS):
385
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
386
+
387
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
388
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
389
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
390
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
391
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
392
+
393
+ Q1 = Q + q_adj1
394
+ DO1 = DO + do_adj1
395
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
396
+ # if Q is broadcasted)
397
+ LSE1 = LSE + off_chz1
398
+ DELTA1 = DELTA + off_chz1
399
+
400
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
401
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
402
+
403
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
404
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
405
+
406
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
407
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
408
+ q_indices = Q_IDX + sparse_q_idx_offset
409
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
410
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
411
+
412
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
413
+ dk, dv = bwd_dkdv_inner(
414
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
415
+ Q1, DO1, DELTA1, LSE1,
416
+ dk, dv, k, v,
417
+ off_zq, off_hq1, offs_n1, offs_m1,
418
+ stride_qm, stride_qd, stride_dom, stride_dod,
419
+ q_indices, sparse_q_num_blocks,
420
+ MATMUL_PRECISION,
421
+ IS_FULL_BLOCKS=False,
422
+ )
423
+
424
+
425
+ if HAS_FULL_BLOCKS:
426
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
427
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
428
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
429
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
430
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
431
+
432
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
433
+ dk, dv = bwd_dkdv_inner(
434
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
435
+ Q1, DO1, DELTA1, LSE1,
436
+ dk, dv, k, v,
437
+ off_zq, off_hq1, offs_n1, offs_m1,
438
+ stride_qm, stride_qd, stride_dom, stride_dod,
439
+ q_indices, sparse_q_num_blocks,
440
+ MATMUL_PRECISION,
441
+ IS_FULL_BLOCKS=True,
442
+ )
443
+
444
+ # Write back dV and dK.
445
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
446
+
447
+ index_n = offs_n1[:, None]
448
+ index_k = offs_k[None, :]
449
+ index_v = offs_v[None, :]
450
+
451
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
452
+ tl.store(dv_ptrs, dv)
453
+ else:
454
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
455
+
456
+ dk *= SM_SCALE
457
+
458
+ if SAFE_HEAD_DIM:
459
+ mask = index_n < KV_LEN
460
+ else:
461
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
462
+
463
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
464
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
465
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
466
+ xindex = index_k + 128*index_n + 128*off_hkv*ks0 + 1024*off_zq*ks0
467
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
468
+
469
+ @triton.jit
470
+ def bwd_dq_inner(
471
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
472
+ K, V, # pointers
473
+ dq, q, do, Di, lse,
474
+ off_z, off_hq, offs_m2, offs_n2,
475
+ stride_kn, stride_kd, stride_vn, stride_vd,
476
+ kv_indices, sparse_kv_num_blocks,
477
+ MATMUL_PRECISION,
478
+ IS_FULL_BLOCKS,
479
+ ):
480
+ PRESCALE_QK : tl.constexpr = False
481
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
482
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
483
+ WRITE_DQ : tl.constexpr = True
484
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
485
+ OUTPUT_MAX : tl.constexpr = False
486
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
487
+ IS_DIVISIBLE : tl.constexpr = False
488
+ SM_SCALE : tl.constexpr = 0.08838834764831843
489
+ GQA_SHARED_HEADS : tl.constexpr = 4
490
+ HAS_FULL_BLOCKS : tl.constexpr = True
491
+ QK_HEAD_DIM : tl.constexpr = 128
492
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
493
+ V_HEAD_DIM : tl.constexpr = 128
494
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
495
+ SAFE_HEAD_DIM : tl.constexpr = True
496
+ BLOCK_M1 : tl.constexpr = 64
497
+ BLOCK_N1 : tl.constexpr = 128
498
+ BLOCK_M2 : tl.constexpr = 128
499
+ BLOCK_N2 : tl.constexpr = 64
500
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
501
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
502
+ INDEX_DTYPE : tl.constexpr = tl.int32
503
+
504
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
505
+ RCP_LN2: tl.constexpr = 1.44269504
506
+ Q_LEN = 2048
507
+ KV_LEN = ks0
508
+
509
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
510
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
511
+
512
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
513
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
514
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
515
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
516
+
517
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
518
+
519
+ for start_n in range(0, hi):
520
+ dq = bwd_dq_block_mn(
521
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
522
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
523
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
524
+ stride_kn, stride_kd, stride_vn, stride_vd,
525
+ kv_indices, sparse_kv_num_blocks,
526
+ MATMUL_PRECISION, RCP_LN2,
527
+ IS_FULL_BLOCKS,
528
+ )
529
+
530
+ # Increment pointers.
531
+ offset = get_offset_for_next_block(
532
+ start_n, kv_indices, sparse_kv_num_blocks,
533
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
534
+ )
535
+
536
+ kT_ptrs += offset * stride_kn
537
+ vT_ptrs += offset * stride_vn
538
+
539
+ offs_n2 += offset
540
+
541
+ return dq
542
+
543
+
544
+ @triton.jit
545
+ def bwd_dq_block_mn(
546
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
547
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
548
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
549
+ stride_kn, stride_kd, stride_vn, stride_vd,
550
+ kv_indices, sparse_kv_num_blocks,
551
+ MATMUL_PRECISION, RCP_LN2,
552
+ IS_FULL_BLOCKS,
553
+ ):
554
+ PRESCALE_QK : tl.constexpr = False
555
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
556
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
557
+ WRITE_DQ : tl.constexpr = True
558
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
559
+ OUTPUT_MAX : tl.constexpr = False
560
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
561
+ IS_DIVISIBLE : tl.constexpr = False
562
+ SM_SCALE : tl.constexpr = 0.08838834764831843
563
+ GQA_SHARED_HEADS : tl.constexpr = 4
564
+ HAS_FULL_BLOCKS : tl.constexpr = True
565
+ QK_HEAD_DIM : tl.constexpr = 128
566
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
567
+ V_HEAD_DIM : tl.constexpr = 128
568
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
569
+ SAFE_HEAD_DIM : tl.constexpr = True
570
+ BLOCK_M1 : tl.constexpr = 64
571
+ BLOCK_N1 : tl.constexpr = 128
572
+ BLOCK_M2 : tl.constexpr = 128
573
+ BLOCK_N2 : tl.constexpr = 64
574
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
575
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
576
+ INDEX_DTYPE : tl.constexpr = tl.int32
577
+
578
+
579
+ # NB reversed order to since K is transposed
580
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
581
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
582
+ if not PRESCALE_QK:
583
+ qk *= SM_SCALE
584
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
585
+ pre_mod_scores = qk
586
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
587
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
588
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
589
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
590
+
591
+ tmp0 = (qk)
592
+ post_mod_scores = tmp0
593
+
594
+
595
+
596
+
597
+ if not IS_DIVISIBLE:
598
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
599
+
600
+ if not IS_FULL_BLOCKS:
601
+ tmp1 = tl.full([1], False, tl.int1)
602
+ tmp2 = (m)
603
+ tmp3 = (n)
604
+ tmp4 = tmp2 >= tmp3
605
+ tmp5 = tmp3.to(tl.int64)
606
+ tmp6 = (off_z)
607
+ tmp7 = tl.load(in_ptr16 + tmp6)
608
+ tmp8 = tmp5 < tmp7
609
+ tmp9 = tmp2.to(tl.int64)
610
+ tmp10 = tmp9 < tmp7
611
+ tmp11 = tmp8 & tmp10
612
+ tmp12 = tmp4 & tmp11
613
+ tmp13 = tmp1 | tmp12
614
+ tmp14 = tl.full([1], 2048, tl.int32)
615
+ tmp15 = tmp3 >= tmp14
616
+ tmp16 = (tmp3 % tmp14)
617
+ tmp17 = tl.full([1], 0, tl.int32)
618
+ tmp18 = tmp16 != tmp17
619
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
620
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
621
+ tmp21 = tmp19 != tmp20
622
+ tmp22 = tmp18 & tmp21
623
+ tmp23 = tmp16 + tmp14
624
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
625
+ tmp25 = tmp24.to(tl.int64)
626
+ tmp26 = tmp25 < tmp7
627
+ tmp27 = tmp15 & tmp26
628
+ tmp28 = tmp3 - tmp2
629
+ tmp29 = (tmp28 % tmp14)
630
+ tmp30 = tmp29 != tmp17
631
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
632
+ tmp32 = tmp31 != tmp20
633
+ tmp33 = tmp30 & tmp32
634
+ tmp34 = tmp29 + tmp14
635
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
636
+ tmp36 = tmp35 == tmp17
637
+ tmp37 = tmp27 & tmp36
638
+ tmp38 = tmp13 | tmp37
639
+ mask_mod_output = tmp38
640
+
641
+
642
+ # apply mask for partial masked block
643
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
644
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
645
+ if not PRESCALE_QK:
646
+ post_mod_scores *= RCP_LN2
647
+ p = tl.math.exp2(post_mod_scores - lse)
648
+ # Compute dP and dS.
649
+ # NB reversed order to since V is transposed
650
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
651
+
652
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
653
+ ds = p * (dp - Di[:, None])
654
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
655
+ tmp39 = (ds)
656
+ grad_scores = tmp39
657
+
658
+
659
+ if not IS_DIVISIBLE:
660
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
661
+
662
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
663
+ if WRITE_DQ:
664
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
665
+
666
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
667
+ ds = grad_scores
668
+
669
+ if not IS_FULL_BLOCKS:
670
+ # (grads) apply mask for partially unmasked block
671
+ ds = tl.where(mask_mod_output, ds, 0.0)
672
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
673
+ ds = ds.to(MATMUL_PRECISION)
674
+ # Compute dQ.
675
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
676
+
677
+ return dq
678
+
679
+
680
+ @triton.jit
681
+ def bwd_dkdv_inner(
682
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
683
+ Q, DO, DELTA, LSE, # pointers
684
+ dk, dv, k, v,
685
+ off_z, off_hq, offs_n1, offs_m1,
686
+ stride_qm, stride_qd, stride_dom, stride_dod,
687
+ q_indices, sparse_q_num_blocks,
688
+ MATMUL_PRECISION,
689
+ IS_FULL_BLOCKS,
690
+ ):
691
+ PRESCALE_QK : tl.constexpr = False
692
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
693
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
694
+ WRITE_DQ : tl.constexpr = True
695
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
696
+ OUTPUT_MAX : tl.constexpr = False
697
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
698
+ IS_DIVISIBLE : tl.constexpr = False
699
+ SM_SCALE : tl.constexpr = 0.08838834764831843
700
+ GQA_SHARED_HEADS : tl.constexpr = 4
701
+ HAS_FULL_BLOCKS : tl.constexpr = True
702
+ QK_HEAD_DIM : tl.constexpr = 128
703
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
704
+ V_HEAD_DIM : tl.constexpr = 128
705
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
706
+ SAFE_HEAD_DIM : tl.constexpr = True
707
+ BLOCK_M1 : tl.constexpr = 64
708
+ BLOCK_N1 : tl.constexpr = 128
709
+ BLOCK_M2 : tl.constexpr = 128
710
+ BLOCK_N2 : tl.constexpr = 64
711
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
712
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
713
+ INDEX_DTYPE : tl.constexpr = tl.int32
714
+
715
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
716
+ RCP_LN2: tl.constexpr = 1.44269504
717
+ Q_LEN = 2048
718
+ KV_LEN = ks0
719
+
720
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
721
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
722
+
723
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
724
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
725
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
726
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
727
+
728
+ # The minimum is needed to handle the case where we run with a super large
729
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
730
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
731
+
732
+ for start_m in range(0, hi):
733
+ dk, dv = bwd_dkdv_block_mn(
734
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
735
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
736
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
737
+ stride_qm, stride_qd, stride_dom, stride_dod,
738
+ q_indices, sparse_q_num_blocks,
739
+ MATMUL_PRECISION, RCP_LN2,
740
+ IS_FULL_BLOCKS,
741
+ )
742
+ # Increment pointers.
743
+ offset = get_offset_for_next_block(
744
+ start_m, q_indices, sparse_q_num_blocks,
745
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
746
+ )
747
+
748
+ qT_ptrs += offset * stride_qm
749
+ do_ptrs += offset * stride_dom
750
+ offs_m1 += offset
751
+
752
+ return dk, dv
753
+
754
+
755
+ @triton.jit
756
+ def bwd_dkdv_block_mn(
757
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3,
758
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
759
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
760
+ stride_qm, stride_qd, stride_dom, stride_dod,
761
+ q_indices, sparse_q_num_blocks,
762
+ MATMUL_PRECISION, RCP_LN2,
763
+ IS_FULL_BLOCKS,
764
+ ):
765
+ PRESCALE_QK : tl.constexpr = False
766
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
767
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
768
+ WRITE_DQ : tl.constexpr = True
769
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
770
+ OUTPUT_MAX : tl.constexpr = False
771
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
772
+ IS_DIVISIBLE : tl.constexpr = False
773
+ SM_SCALE : tl.constexpr = 0.08838834764831843
774
+ GQA_SHARED_HEADS : tl.constexpr = 4
775
+ HAS_FULL_BLOCKS : tl.constexpr = True
776
+ QK_HEAD_DIM : tl.constexpr = 128
777
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
778
+ V_HEAD_DIM : tl.constexpr = 128
779
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
780
+ SAFE_HEAD_DIM : tl.constexpr = True
781
+ BLOCK_M1 : tl.constexpr = 64
782
+ BLOCK_N1 : tl.constexpr = 128
783
+ BLOCK_M2 : tl.constexpr = 128
784
+ BLOCK_N2 : tl.constexpr = 64
785
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
786
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
787
+ INDEX_DTYPE : tl.constexpr = tl.int32
788
+
789
+
790
+ # NB reversed order since Q is transposed
791
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
792
+ # Load LSE before computing qk to reduce pipeline stall.
793
+ if IS_DIVISIBLE:
794
+ lse = tl.load(LSE + offs_m1)
795
+ else:
796
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
797
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
798
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
799
+ if not PRESCALE_QK:
800
+ qkT *= SM_SCALE
801
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
802
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
803
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
804
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
805
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
806
+
807
+ pre_mod_scores = qkT
808
+ tmp40 = (qkT)
809
+ post_mod_scores = tmp40
810
+
811
+
812
+
813
+ if not IS_DIVISIBLE:
814
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
815
+
816
+ if not IS_FULL_BLOCKS:
817
+ tmp41 = tl.full([1], False, tl.int1)
818
+ tmp42 = (m)
819
+ tmp43 = (n)
820
+ tmp44 = tmp42 >= tmp43
821
+ tmp45 = tmp43.to(tl.int64)
822
+ tmp46 = (off_z)
823
+ tmp47 = tl.load(in_ptr16 + tmp46)
824
+ tmp48 = tmp45 < tmp47
825
+ tmp49 = tmp42.to(tl.int64)
826
+ tmp50 = tmp49 < tmp47
827
+ tmp51 = tmp48 & tmp50
828
+ tmp52 = tmp44 & tmp51
829
+ tmp53 = tmp41 | tmp52
830
+ tmp54 = tl.full([1], 2048, tl.int32)
831
+ tmp55 = tmp43 >= tmp54
832
+ tmp56 = (tmp43 % tmp54)
833
+ tmp57 = tl.full([1], 0, tl.int32)
834
+ tmp58 = tmp56 != tmp57
835
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
836
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
837
+ tmp61 = tmp59 != tmp60
838
+ tmp62 = tmp58 & tmp61
839
+ tmp63 = tmp56 + tmp54
840
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
841
+ tmp65 = tmp64.to(tl.int64)
842
+ tmp66 = tmp65 < tmp47
843
+ tmp67 = tmp55 & tmp66
844
+ tmp68 = tmp43 - tmp42
845
+ tmp69 = (tmp68 % tmp54)
846
+ tmp70 = tmp69 != tmp57
847
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
848
+ tmp72 = tmp71 != tmp60
849
+ tmp73 = tmp70 & tmp72
850
+ tmp74 = tmp69 + tmp54
851
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
852
+ tmp76 = tmp75 == tmp57
853
+ tmp77 = tmp67 & tmp76
854
+ tmp78 = tmp53 | tmp77
855
+ mask_mod_output = tmp78
856
+
857
+ # (grads) apply mask for fully masked block
858
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
859
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
860
+ if not PRESCALE_QK:
861
+ post_mod_scores *= RCP_LN2
862
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
863
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
864
+ # Compute dV.
865
+ ppT = pT
866
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
867
+ if IS_DIVISIBLE:
868
+ Di = tl.load(DELTA + offs_m1)
869
+ else:
870
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
871
+ # Compute dP and dS.
872
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
873
+ dsT = pT * (dpT - Di[None, :])
874
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
875
+ tmp79 = (dsT)
876
+ grad_scores = tmp79
877
+
878
+
879
+
880
+ if not IS_DIVISIBLE:
881
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
882
+
883
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
884
+ if not WRITE_DQ:
885
+ idx_b = off_z
886
+ idx_h = off_hq
887
+ idx_m = m
888
+ idx_n = n
889
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
890
+
891
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
892
+ dsT = grad_scores
893
+ if not IS_FULL_BLOCKS:
894
+ # (grads) apply mask for partially unmasked block
895
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
896
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
897
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
898
+
899
+ return dk, dv
900
+
901
+ # Utility triton funcs
902
+ @triton.jit
903
+ def get_offset_for_next_block(
904
+ loop_iter, col_indices, total_blocks,
905
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
906
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
907
+ ):
908
+ if BLOCKS_ARE_CONTIGUOUS:
909
+ return BLOCK
910
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
911
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
912
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
913
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
914
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
915
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
916
+ return offset
917
+
918
+ @triton.jit
919
+ def get_bounded_indices(indices, max_len=None):
920
+ return indices % max_len if max_len is not None else indices
921
+
922
+ @triton.jit
923
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
924
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
925
+ return tl.load(block_ptr)
926
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
927
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
928
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
929
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
930
+ else:
931
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
932
+
933
+ @triton.jit
934
+ def load_checked_2d(
935
+ ptr,
936
+ offs_m,
937
+ offs_n,
938
+ stride_m,
939
+ stride_n,
940
+ IS_DIVISIBLE_M: tl.constexpr,
941
+ IS_DIVISIBLE_N: tl.constexpr,
942
+ M_LEN: tl.constexpr,
943
+ N_LEN: tl.constexpr,
944
+ ):
945
+ # Calculate final pointer if strides are provided
946
+ if stride_m is not None and stride_n is not None:
947
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
948
+
949
+ # Handle all masking cases
950
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
951
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
952
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
953
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
954
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
955
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
956
+ else: # Both divisible
957
+ return tl.load(ptr)
958
+ ''', device_str='cuda')
959
+
960
+
961
+ async_compile.wait(globals())
962
+ del async_compile
963
+
964
+ class Runner:
965
+ def __init__(self, partitions):
966
+ self.partitions = partitions
967
+
968
+ def recursively_apply_fns(self, fns):
969
+ new_callables = []
970
+ for fn, c in zip(fns, self.partitions):
971
+ new_callables.append(fn(c))
972
+ self.partitions = new_callables
973
+
974
+ def call(self, args):
975
+ primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1 = args
976
+ args.clear()
977
+ s0 = primals_8
978
+ s72 = primals_6
979
+ s4 = primals_12
980
+ s56 = primals_14
981
+ s84 = primals_16
982
+ s99 = primals_18
983
+ s6 = primals_20
984
+ assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1))
985
+ assert_size_stride(primals_3, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
986
+ assert_size_stride(primals_5, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
987
+ assert_size_stride(primals_7, (8, 1, 16, s72), (16*s72, 16*s72, s72, 1))
988
+ assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1))
989
+ assert_size_stride(primals_10, (8, ), (1, ))
990
+ assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1))
991
+ assert_size_stride(primals_13, (8, 1, 16, s4), (16*s4, 16*s4, s4, 1))
992
+ assert_size_stride(primals_15, (8, 1, s56), (s56, s56, 1))
993
+ assert_size_stride(primals_17, (8, 1, s84, 16), (16*s84, 16*s84, 16, 1))
994
+ assert_size_stride(primals_19, (8, 1, s99), (s99, s99, 1))
995
+ assert_size_stride(primals_21, (8, 1, s6, 16), (16*s6, 16*s6, 16, 1))
996
+ assert_size_stride(getitem, (8, 32, 2048, 128), (8388608, 128, 4096, 1))
997
+ assert_size_stride(getitem_1, (8, 32, 2048), (65536, 2048, 1))
998
+ assert_size_stride(tangents_1, (8, 32, 2048, 128), (8388608, 262144, 128, 1))
999
+ with torch.cuda._DeviceGuard(2):
1000
+ torch.cuda.set_device(2)
1001
+ buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32)
1002
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
1003
+ stream2 = get_raw_stream(2)
1004
+ triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, 524288, 128, stream=stream2)
1005
+ del getitem
1006
+ buf3 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16)
1007
+ buf4 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16)
1008
+ buf5 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16)
1009
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
1010
+ stream2 = get_raw_stream(2)
1011
+ triton_tem_fused_zeros_1.run(primals_1, primals_3, primals_5, getitem_1, buf1, tangents_1, buf3, buf4, primals_9, primals_7, primals_15, primals_17, primals_11, primals_13, primals_19, primals_21, primals_10, buf5, s0, s72, s56, s84, 64 + ((127 + s0) // 128), 8, 8, stream=stream2)
1012
+ del buf1
1013
+ del getitem_1
1014
+ del primals_1
1015
+ del primals_10
1016
+ del primals_11
1017
+ del primals_13
1018
+ del primals_15
1019
+ del primals_17
1020
+ del primals_19
1021
+ del primals_21
1022
+ del primals_3
1023
+ del primals_5
1024
+ del primals_7
1025
+ del primals_9
1026
+ del tangents_1
1027
+ return (buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, )
1028
+
1029
+ runner = Runner(partitions=[])
1030
+ call = runner.call
1031
+ recursively_apply_fns = runner.recursively_apply_fns
1032
+
1033
+
1034
+ def benchmark_compiled_module(times=10, repeat=10):
1035
+ from torch._dynamo.testing import rand_strided
1036
+ from torch._inductor.utils import print_performance
1037
+ primals_8 = 4096
1038
+ primals_6 = 32
1039
+ primals_12 = 32
1040
+ primals_14 = 32
1041
+ primals_16 = 32
1042
+ primals_18 = 32
1043
+ primals_20 = 32
1044
+ primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16)
1045
+ primals_3 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:2', dtype=torch.bfloat16)
1046
+ primals_5 = rand_strided((8, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:2', dtype=torch.bfloat16)
1047
+ primals_7 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:2', dtype=torch.int32)
1048
+ primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32)
1049
+ primals_10 = rand_strided((8, ), (1, ), device='cuda:2', dtype=torch.int64)
1050
+ primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32)
1051
+ primals_13 = rand_strided((8, 1, 16, 32), (512, 512, 32, 1), device='cuda:2', dtype=torch.int32)
1052
+ primals_15 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:2', dtype=torch.int32)
1053
+ primals_17 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:2', dtype=torch.int32)
1054
+ primals_19 = rand_strided((8, 1, 32), (32, 32, 1), device='cuda:2', dtype=torch.int32)
1055
+ primals_21 = rand_strided((8, 1, 32, 16), (512, 512, 16, 1), device='cuda:2', dtype=torch.int32)
1056
+ getitem = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16)
1057
+ getitem_1 = rand_strided((8, 32, 2048), (65536, 2048, 1), device='cuda:2', dtype=torch.float32)
1058
+ tangents_1 = rand_strided((8, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:2', dtype=torch.bfloat16)
1059
+ fn = lambda: call([primals_8, primals_6, primals_12, primals_14, primals_16, primals_18, primals_20, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, getitem, getitem_1, tangents_1])
1060
+ return print_performance(fn, times=times, repeat=repeat)
1061
+
1062
+
1063
+ if __name__ == "__main__":
1064
+ from torch._inductor.wrapper_benchmark import compiled_module_main
1065
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/4w/669c5a8c8205272d44ea075b78e46cd1bf13f1ebe3d56d5ab422037277c923dc.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "R0_BLOCK": 2048, "num_warps": 16, "num_stages": 1, "configs_hash": "8c03dc2e05d158372838fe4d32248dfba74b467c7576f6e1d3eb472c41b37c80", "found_by_coordesc": false, "time_taken_ms": 213, "triton_cache_hash": "VBVRCEQLKQI4X4GYXD4JC6UEYZT2F7LIKNA2UR4GNVIWAPM6GKFA"}
SpecForge-ext/cache/compiled_kernels/4w/c4wdhwlu6yb3wcwazdnzmgzewiemvznxvrr3525eojupqjldo5pt.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 16384, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 32000
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = (xindex % ks0)
28
+ x1 = xindex // ks0
29
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
30
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
31
+ x3 = xindex
32
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
33
+ r0_index = r0_offset + r0_base
34
+ r0_mask = r0_index < r0_numel
35
+ roffset = r0_offset
36
+ rindex = r0_index
37
+ r0_2 = r0_index
38
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
39
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
40
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
41
+ _tmp2, _tmp2_index, tmp1, rindex
42
+ )
43
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
44
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
45
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
46
+ tmp2 = tmp2_idx[:, None]
47
+ tl.store(out_ptr0 + (x3), tmp2, xmask)
SpecForge-ext/cache/compiled_kernels/4w/c4ww5pmlr6amerprh7v3ibioh3yvbhemdqsh7gcrlxjnhnpkktrb.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = True
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1
101
+
102
+ ZQ = 2
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 2
107
+ KV_LEN = 2048
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 2
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 256
149
+ stride_kv_idx_m = 16
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 16
245
+ stride_q_idx_h = 256
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = True
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = 2048
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = True
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = True
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = 2048
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = True
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/7t/466129ab41abc9f5794b92b332ac4be3dff826e8f59dc7fb522710de7206acdd.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 26, "triton_cache_hash": "G2LU7LIHIOEHQSWVLFBJATACJ76YHM672CUBUDGJGAJUEQVWVOFQ"}
SpecForge-ext/cache/compiled_kernels/7t/c7t3uvardqlt6x3sz37tlydghb4rt6mdilzlc7ffz3pehdn5jwdj.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 256, 'r0_': 16},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 16
20
+ R0_BLOCK: tl.constexpr = 16
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_2 = r0_index
32
+ x0 = (xindex % ks0)
33
+ x1 = xindex // ks0
34
+ x3 = xindex
35
+ tmp0 = tl.load(in_ptr0 + (r0_2 + x0 + 16*x1 + ks0*r0_2 + 16*ks0*x1), xmask, eviction_policy='evict_last', other=0.0)
36
+ tmp1 = r0_2
37
+ tmp2 = tmp1.to(tl.int16)
38
+ tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
39
+ tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
40
+ tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True)
41
+ tmp7 = tmp0.to(tl.int64)
42
+ tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
43
+ tmp10 = tl.where(xmask, tmp8, 0)
44
+ tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64)
45
+ tmp12 = tmp6.to(tl.int64)
46
+ tmp13 = tmp12.to(tl.int32)
47
+ tmp14 = tmp11.to(tl.int32)
48
+ tl.store(out_ptr2 + (r0_2 + 16*x0 + 16*x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp13, xmask)
49
+ tl.store(out_ptr3 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp14, xmask)
SpecForge-ext/cache/compiled_kernels/ah/855c4fb51632a42fcf957963b85ead1d6653657da855baf9d7c221cfd3981ad0.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 21, "triton_cache_hash": "Z2RWAHMO7VUWQKIIRA5A46JYV2SEXHWLKREQM7TOP6VGUWDXAYAQ"}
SpecForge-ext/cache/compiled_kernels/ah/cah767udo2rzeazh6rycnirtnr5sijiv7nem2l67isu5iyh5pzyj.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 4194304},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x2 = ((xindex // ks0) % ks1)
24
+ x0 = (xindex % ks3)
25
+ x5 = xindex // ks3
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
28
+ tmp2 = ks2
29
+ tmp3 = tmp1 + tmp2
30
+ tmp4 = tmp1 < 0
31
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
32
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
33
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
34
+ tmp8 = tmp0 * tmp7
35
+ tmp9 = x0
36
+ tmp10 = tl.full([1], 0, tl.int64)
37
+ tmp11 = tmp9 >= tmp10
38
+ tmp12 = ks3 + (-1)*(ks3 // 2)
39
+ tmp13 = tmp9 < tmp12
40
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = -tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
44
+ tmp18 = tmp9 >= tmp12
45
+ tmp19 = ks3
46
+ tmp20 = tmp9 < tmp19
47
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
48
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
49
+ tmp23 = ks4
50
+ tmp24 = tmp1 + tmp23
51
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
52
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
53
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
54
+ tmp28 = tmp22 * tmp27
55
+ tmp29 = tmp8 + tmp28
56
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/ak/cak5ufwwwsut5tju7yvwho5uqnabsn2za7nzkoy573tny5kqhtl5.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1
88
+
89
+ ZQ = 8
90
+ HQ = 32
91
+ Q_LEN = 2048
92
+ ZKV = 8
93
+ KV_LEN = ks0
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 8
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = 16
130
+ stride_kv_idx_h = 16*ks1
131
+ stride_kv_idx_m = ks1
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831843
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = tl.full([1], False, tl.int1)
371
+ tmp2 = (m)
372
+ tmp3 = (n)
373
+ tmp4 = tmp2 >= tmp3
374
+ tmp5 = tmp3.to(tl.int64)
375
+ tmp6 = (off_z)
376
+ tmp7 = tl.load(in_ptr9 + tmp6)
377
+ tmp8 = tmp5 < tmp7
378
+ tmp9 = tmp2.to(tl.int64)
379
+ tmp10 = tmp9 < tmp7
380
+ tmp11 = tmp8 & tmp10
381
+ tmp12 = tmp4 & tmp11
382
+ tmp13 = tmp1 | tmp12
383
+ tmp14 = tl.full([1], 2048, tl.int32)
384
+ tmp15 = tmp3 >= tmp14
385
+ tmp16 = (tmp3 % tmp14)
386
+ tmp17 = tl.full([1], 0, tl.int32)
387
+ tmp18 = tmp16 != tmp17
388
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
389
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
390
+ tmp21 = tmp19 != tmp20
391
+ tmp22 = tmp18 & tmp21
392
+ tmp23 = tmp16 + tmp14
393
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
394
+ tmp25 = tmp24.to(tl.int64)
395
+ tmp26 = tmp25 < tmp7
396
+ tmp27 = tmp15 & tmp26
397
+ tmp28 = tmp3 - tmp2
398
+ tmp29 = (tmp28 % tmp14)
399
+ tmp30 = tmp29 != tmp17
400
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
401
+ tmp32 = tmp31 != tmp20
402
+ tmp33 = tmp30 & tmp32
403
+ tmp34 = tmp29 + tmp14
404
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
405
+ tmp36 = tmp35 == tmp17
406
+ tmp37 = tmp27 & tmp36
407
+ tmp38 = tmp13 | tmp37
408
+ mask_mod_output = tmp38
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
413
+ # apply mask for partially unmasked blocks
414
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
415
+
416
+ if not PRESCALE_QK:
417
+ post_mod_scores *= RCP_LN2
418
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419
+
420
+ # -- compute scaling constant ---
421
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
422
+ if not ROWS_GUARANTEED_SAFE:
423
+ masked_out_rows = (m_ij == float("-inf"))
424
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
425
+ else:
426
+ m_ij_masked = m_ij
427
+
428
+ alpha = tl.math.exp2(m_i - m_ij_masked)
429
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
430
+
431
+ # NB: l_i update is pulled up here since it's a bit faster
432
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
433
+ # m_ij
434
+ l_i = l_i * alpha + tl.sum(p, 1)
435
+ # # -- scale and update acc --
436
+ acc = acc * alpha[:, None]
437
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
438
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
439
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
440
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
441
+
442
+ # -- update m_i
443
+ m_i = m_ij
444
+
445
+ return acc, l_i, m_i
446
+
447
+ @triton.jit
448
+ def forward_inner(
449
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
450
+ q, K, V,
451
+ desc_k, desc_v, Q_LEN, KV_LEN,
452
+ # accumulated values
453
+ acc, l_i, m_i,
454
+ # Offsets used as inputs to score_mod & mask_mod
455
+ # of size [BLOCK_M, BLOCK_N] or scalar.
456
+ off_z, off_h, offs_m, offs_n,
457
+ # Offsets needed for TMA loads
458
+ kv_start,
459
+ # blocksparse data
460
+ kv_indices, kv_num_blocks,
461
+ # start kv and end kv block
462
+ block_n_start, block_n_end,
463
+ MATMUL_PRECISION,
464
+ # Strides for K and V
465
+ stride_kk, stride_kn, stride_vn, stride_vk,
466
+ IS_FULL_BLOCKS,
467
+ ):
468
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
469
+ PRESCALE_QK : tl.constexpr = False
470
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
471
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
472
+ WRITE_DQ : tl.constexpr = True
473
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
474
+ OUTPUT_MAX : tl.constexpr = False
475
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
476
+ IS_DIVISIBLE : tl.constexpr = False
477
+ SM_SCALE : tl.constexpr = 0.08838834764831843
478
+ GQA_SHARED_HEADS : tl.constexpr = 4
479
+ HAS_FULL_BLOCKS : tl.constexpr = True
480
+ QK_HEAD_DIM : tl.constexpr = 128
481
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
482
+ V_HEAD_DIM : tl.constexpr = 128
483
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
484
+ SAFE_HEAD_DIM : tl.constexpr = True
485
+ USE_TMA : tl.constexpr = False
486
+ BLOCK_M : tl.constexpr = 128
487
+ BLOCK_N : tl.constexpr = 64
488
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
489
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
490
+ INDEX_DTYPE : tl.constexpr = tl.int32
491
+
492
+
493
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
494
+ RCP_LN2: tl.constexpr = 1.44269504
495
+
496
+ if PRESCALE_QK:
497
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
498
+
499
+ kv_offset = 0
500
+
501
+ # loop over k, v and update accumulator until block_n_end
502
+ for start_n in range(block_n_start, block_n_end):
503
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
504
+ if IS_DIVISIBLE:
505
+ acc, l_i, m_i = forward_block_mn(
506
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
507
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
508
+ # accumulated values
509
+ acc, l_i, m_i,
510
+ # Offsets
511
+ off_z, off_h, offs_m, offs_n,
512
+ # Offsets needed for TMA loads
513
+ kv_start,
514
+ kv_offset,
515
+ MATMUL_PRECISION, RCP_LN2,
516
+ # Strides for K and V
517
+ stride_kk, stride_kn, stride_vn, stride_vk,
518
+ IS_FULL_BLOCKS,
519
+ )
520
+ else:
521
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
522
+ # it's on par or slightly faster than only applying to the last block in fwd.
523
+ # However, we choose different strategy for bwd, where we only apply mod & mask
524
+ # to the last block because it's faster a lot.
525
+ acc, l_i, m_i = forward_block_mn(
526
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
527
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
528
+ # accumulated values
529
+ acc, l_i, m_i,
530
+ # Offsets
531
+ off_z, off_h, offs_m, offs_n,
532
+ # Offsets needed for TMA loads
533
+ kv_start,
534
+ kv_offset,
535
+ MATMUL_PRECISION, RCP_LN2,
536
+ # Strides for K and V
537
+ stride_kk, stride_kn, stride_vn, stride_vk,
538
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
539
+ )
540
+
541
+
542
+
543
+ offset = get_offset_for_next_block(
544
+ start_n, kv_indices, kv_num_blocks,
545
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
546
+ )
547
+
548
+ offs_n = offs_n + offset
549
+ kv_offset += offset
550
+
551
+
552
+ return acc, l_i, m_i
SpecForge-ext/cache/compiled_kernels/ak/cakglntm3ejviis7qbld6stbcfdrpvbryqpb63fshmmyy46mxbh3.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['6_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
17
+ import triton
18
+ import triton.language as tl
19
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
20
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
21
+
22
+ aten = torch.ops.aten
23
+ inductor_ops = torch.ops.inductor
24
+ _quantized = torch.ops._quantized
25
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
26
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
27
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
28
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
29
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
30
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
31
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
32
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
33
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
34
+ async_compile = AsyncCompile()
35
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
36
+
37
+
38
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/7a/c7avdnhdkg25qkzpvb4jgb3wfrta3u7po7rrnynujrskgetlvslk.py
39
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
40
+ # Source node to ATen node mapping:
41
+ # flex_attention => flex_attention
42
+ # Graph fragment:
43
+ # %primals_1 : Tensor "bf16[8, 32, 2048, 128][8388608, 128, 4096, 1]cuda:5" = PlaceHolder[target=primals_1]
44
+ # %primals_2 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:5" = PlaceHolder[target=primals_2]
45
+ # %primals_3 : Tensor "bf16[8, 8, 2048, 128][2097152, 262144, 128, 1]cuda:5" = PlaceHolder[target=primals_3]
46
+ # %getitem_1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:5" = PlaceHolder[target=getitem_1]
47
+ # %buf1 : Tensor "f32[8, 32, 2048][65536, 2048, 1]cuda:5" = PlaceHolder[target=buf1]
48
+ # %primals_5 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=primals_5]
49
+ # %primals_4 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5" = PlaceHolder[target=primals_4]
50
+ # %primals_7 : Tensor "i32[8, 1, 16][16, 16, 1]cuda:5" = PlaceHolder[target=primals_7]
51
+ # %primals_8 : Tensor "i32[8, 1, 16, 16][256, 256, 16, 1]cuda:5" = PlaceHolder[target=primals_8]
52
+ # %primals_6 : Tensor "i64[8][1]cuda:5" = PlaceHolder[target=primals_6]
53
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {})
54
+ # return %getitem
55
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
56
+ import triton
57
+ import triton.language as tl
58
+
59
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
60
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
61
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
62
+
63
+ @triton_heuristics.template(
64
+
65
+ num_stages=3,
66
+ num_warps=8,
67
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
68
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
69
+
70
+ )
71
+ @triton.jit
72
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0):
73
+ PRESCALE_QK : tl.constexpr = False
74
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
75
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
76
+ WRITE_DQ : tl.constexpr = True
77
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
78
+ OUTPUT_MAX : tl.constexpr = False
79
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
80
+ IS_DIVISIBLE : tl.constexpr = True
81
+ SM_SCALE : tl.constexpr = 0.08838834764831843
82
+ GQA_SHARED_HEADS : tl.constexpr = 4
83
+ HAS_FULL_BLOCKS : tl.constexpr = True
84
+ QK_HEAD_DIM : tl.constexpr = 128
85
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ V_HEAD_DIM : tl.constexpr = 128
87
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
88
+ SAFE_HEAD_DIM : tl.constexpr = True
89
+ USE_TMA : tl.constexpr = False
90
+ BLOCK_M : tl.constexpr = 128
91
+ BLOCK_N : tl.constexpr = 64
92
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
93
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
94
+ INDEX_DTYPE : tl.constexpr = tl.int32
95
+ Q = arg_Q
96
+ K = arg_K
97
+ V = arg_V
98
+ LSE = arg_LSE
99
+ MAX = arg_MAX
100
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
101
+ KV_IDX = arg_KV_IDX
102
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
103
+ FULL_KV_IDX = arg_FULL_KV_IDX
104
+
105
+ # Sub notation for this kernel:
106
+ #
107
+ # Q: Query, K: Key, V: Value
108
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
109
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
110
+ # V_HEAD_DIM: The dimension of the value embeddings
111
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
112
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
113
+ #
114
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
115
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
116
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
117
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
118
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
119
+ #
120
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
121
+ #
122
+ # (Modifiable) Performance tuning options
123
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
124
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
125
+
126
+ # The below are kernel options that can be applied for certain score_mods,
127
+ # or involve a numerics vs. perf tradeoff
128
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
129
+ # about 20% more numerical error, but slightly faster.
130
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
131
+ # is not masked out? If so, we can skip an extra safety check
132
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
133
+ # contiguous? If so, we don't need to do an indirect jump for every block
134
+
135
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
136
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
137
+
138
+ # Define strides of inputs
139
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
140
+ stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1
141
+ stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1
142
+
143
+ ZQ = 8
144
+ HQ = 32
145
+ Q_LEN = 2048
146
+ ZKV = 8
147
+ KV_LEN = 2048
148
+
149
+ MATMUL_PRECISION = Q.dtype.element_ty
150
+
151
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
152
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
153
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
154
+
155
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
156
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
157
+ off_zkv = off_zq % ZKV
158
+ off_hkv = off_hq // GQA_SHARED_HEADS
159
+ off_g = off_hq % GQA_SHARED_HEADS
160
+
161
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
162
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
163
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
164
+
165
+ Q = Q + q_offset
166
+ K = K + k_offset
167
+ V = V + v_offset
168
+
169
+ # Setting up the TMA descriptors for Q, K, V
170
+ desc_q = None
171
+ desc_k = None
172
+ desc_v = None
173
+
174
+ SPARSE_Z = 8
175
+ SPARSE_HQ = 1
176
+
177
+ sparse_idx_z = off_zq % SPARSE_Z
178
+ sparse_idx_hq = off_hq % SPARSE_HQ
179
+
180
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
181
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
182
+
183
+ stride_kv_num_blks_h = 16
184
+ stride_kv_idx_h = 256
185
+ stride_kv_idx_m = 16
186
+
187
+ # initialize pointer to m and l
188
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
189
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
190
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
191
+
192
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
193
+
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
196
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
197
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
198
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
199
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
200
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
201
+
202
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
203
+ # We don't know anything "special" about these blocks, so we need to apply
204
+ # both score_mod and mask_mod to it
205
+ kv_indices = KV_IDX + sparse_kv_idx_offset
206
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
207
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
208
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
209
+
210
+
211
+ # K and V pointers will be passed directly to forward_inner
212
+
213
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
214
+
215
+
216
+ acc, l_i, m_i = forward_inner(
217
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
218
+ q, K, V,
219
+ desc_k, desc_v, Q_LEN, KV_LEN,
220
+ acc, l_i, m_i,
221
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
222
+ kv_start,
223
+ kv_indices, kv_num_blocks,
224
+ 0, block_n_end,
225
+ MATMUL_PRECISION,
226
+ stride_kk, stride_kn, stride_vn, stride_vk,
227
+ IS_FULL_BLOCKS=False,
228
+ )
229
+
230
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
231
+ # We know these blocks are guaranteed to be "full", so we don't need to
232
+ # apply mask_mod to them - only score_mod
233
+ if HAS_FULL_BLOCKS:
234
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
235
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
236
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
237
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
238
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
239
+ # K and V pointers will be passed directly to forward_inner
240
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
241
+
242
+ acc, l_i, m_i = forward_inner(
243
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
244
+ q, K, V,
245
+ desc_k, desc_v, Q_LEN, KV_LEN,
246
+ acc, l_i, m_i,
247
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
248
+ kv_start,
249
+ kv_indices, kv_num_blocks,
250
+ 0, block_n_end,
251
+ MATMUL_PRECISION,
252
+ stride_kk, stride_kn, stride_vn, stride_vk,
253
+ IS_FULL_BLOCKS=True,
254
+ )
255
+
256
+
257
+ # [Note] Handle fully masked out rows:
258
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
259
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
260
+ l_i = tl.where(l_i == 0.0, 1, l_i)
261
+
262
+ acc = acc / l_i[:, None]
263
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
264
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
265
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
266
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
267
+
268
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
269
+
270
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
271
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
272
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
273
+
274
+ if OUTPUT_LOGSUMEXP:
275
+ off_hz = off_zq * HQ + off_hq
276
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
277
+ lse = m_i + tl.math.log2(l_i)
278
+ if IS_DIVISIBLE:
279
+ tl.store(l_ptrs, lse)
280
+ else:
281
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
282
+
283
+ if OUTPUT_MAX:
284
+ off_hz = off_zq * HQ + off_hq
285
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
286
+ if IS_DIVISIBLE:
287
+ tl.store(max_ptrs, m_i)
288
+ else:
289
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
290
+
291
+
292
+ # Utility triton funcs
293
+ @triton.jit
294
+ def get_offset_for_next_block(
295
+ loop_iter, col_indices, total_blocks,
296
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
297
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
298
+ ):
299
+ if BLOCKS_ARE_CONTIGUOUS:
300
+ return BLOCK
301
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
302
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
303
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
304
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
305
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
306
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
307
+ return offset
308
+
309
+ @triton.jit
310
+ def get_bounded_indices(indices, max_len=None):
311
+ return indices % max_len if max_len is not None else indices
312
+
313
+ @triton.jit
314
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
315
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr)
317
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
319
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
320
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
321
+ else:
322
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
323
+
324
+ @triton.jit
325
+ def load_checked_2d(
326
+ ptr,
327
+ offs_m,
328
+ offs_n,
329
+ stride_m,
330
+ stride_n,
331
+ IS_DIVISIBLE_M: tl.constexpr,
332
+ IS_DIVISIBLE_N: tl.constexpr,
333
+ M_LEN: tl.constexpr,
334
+ N_LEN: tl.constexpr,
335
+ ):
336
+ # Calculate final pointer if strides are provided
337
+ if stride_m is not None and stride_n is not None:
338
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
339
+
340
+ # Handle all masking cases
341
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
343
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
345
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
346
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
347
+ else: # Both divisible
348
+ return tl.load(ptr)
349
+
350
+
351
+ # Common Imports
352
+ @triton.jit
353
+ def forward_block_mn(
354
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
355
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
356
+ # accumulated values
357
+ acc, l_i, m_i,
358
+ # Offsets
359
+ off_z, off_h, offs_m, offs_n,
360
+ # Offsets needed for TMA loads
361
+ kv_start,
362
+ kv_offset,
363
+ MATMUL_PRECISION, RCP_LN2,
364
+ # Strides for K and V
365
+ stride_kk, stride_kn, stride_vn, stride_vk,
366
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
367
+
368
+ ):
369
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
370
+ PRESCALE_QK : tl.constexpr = False
371
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
372
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
373
+ WRITE_DQ : tl.constexpr = True
374
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
375
+ OUTPUT_MAX : tl.constexpr = False
376
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
377
+ IS_DIVISIBLE : tl.constexpr = True
378
+ SM_SCALE : tl.constexpr = 0.08838834764831843
379
+ GQA_SHARED_HEADS : tl.constexpr = 4
380
+ HAS_FULL_BLOCKS : tl.constexpr = True
381
+ QK_HEAD_DIM : tl.constexpr = 128
382
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ V_HEAD_DIM : tl.constexpr = 128
384
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
385
+ SAFE_HEAD_DIM : tl.constexpr = True
386
+ USE_TMA : tl.constexpr = False
387
+ BLOCK_M : tl.constexpr = 128
388
+ BLOCK_N : tl.constexpr = 64
389
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
390
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
391
+ INDEX_DTYPE : tl.constexpr = tl.int32
392
+
393
+
394
+ # -- load k --
395
+ # NB reversed order to since K is transposed
396
+ kv_base_offset = kv_start + kv_offset
397
+
398
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
399
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
400
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
401
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
402
+
403
+ k = tl.trans(k)
404
+ # -- compute qk ---
405
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
406
+ if not PRESCALE_QK:
407
+ qk *= SM_SCALE
408
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
409
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
410
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
411
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
412
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
413
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
414
+
415
+ tmp0 = (qk)
416
+ post_mod_scores = tmp0
417
+
418
+
419
+ if CHECK_BLOCK_BOUNDARY:
420
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
421
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
422
+
423
+ if not IS_FULL_BLOCKS:
424
+ tmp1 = tl.full([1], False, tl.int1)
425
+ tmp2 = (m)
426
+ tmp3 = (n)
427
+ tmp4 = tmp2 >= tmp3
428
+ tmp5 = tmp3.to(tl.int64)
429
+ tmp6 = (off_z)
430
+ tmp7 = tl.load(in_ptr9 + tmp6)
431
+ tmp8 = tmp5 < tmp7
432
+ tmp9 = tmp2.to(tl.int64)
433
+ tmp10 = tmp9 < tmp7
434
+ tmp11 = tmp8 & tmp10
435
+ tmp12 = tmp4 & tmp11
436
+ tmp13 = tmp1 | tmp12
437
+ tmp14 = tl.full([1], 2048, tl.int32)
438
+ tmp15 = tmp3 >= tmp14
439
+ tmp16 = (tmp3 % tmp14)
440
+ tmp17 = tl.full([1], 0, tl.int32)
441
+ tmp18 = tmp16 != tmp17
442
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
443
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
444
+ tmp21 = tmp19 != tmp20
445
+ tmp22 = tmp18 & tmp21
446
+ tmp23 = tmp16 + tmp14
447
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
448
+ tmp25 = tmp24.to(tl.int64)
449
+ tmp26 = tmp25 < tmp7
450
+ tmp27 = tmp15 & tmp26
451
+ tmp28 = tmp3 - tmp2
452
+ tmp29 = (tmp28 % tmp14)
453
+ tmp30 = tmp29 != tmp17
454
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
455
+ tmp32 = tmp31 != tmp20
456
+ tmp33 = tmp30 & tmp32
457
+ tmp34 = tmp29 + tmp14
458
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
459
+ tmp36 = tmp35 == tmp17
460
+ tmp37 = tmp27 & tmp36
461
+ tmp38 = tmp13 | tmp37
462
+ mask_mod_output = tmp38
463
+
464
+
465
+ if CHECK_BLOCK_BOUNDARY:
466
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
467
+ # apply mask for partially unmasked blocks
468
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
469
+
470
+ if not PRESCALE_QK:
471
+ post_mod_scores *= RCP_LN2
472
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
473
+
474
+ # -- compute scaling constant ---
475
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
476
+ if not ROWS_GUARANTEED_SAFE:
477
+ masked_out_rows = (m_ij == float("-inf"))
478
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
479
+ else:
480
+ m_ij_masked = m_ij
481
+
482
+ alpha = tl.math.exp2(m_i - m_ij_masked)
483
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
484
+
485
+ # NB: l_i update is pulled up here since it's a bit faster
486
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
487
+ # m_ij
488
+ l_i = l_i * alpha + tl.sum(p, 1)
489
+ # # -- scale and update acc --
490
+ acc = acc * alpha[:, None]
491
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
492
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
493
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
494
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
495
+
496
+ # -- update m_i
497
+ m_i = m_ij
498
+
499
+ return acc, l_i, m_i
500
+
501
+ @triton.jit
502
+ def forward_inner(
503
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
504
+ q, K, V,
505
+ desc_k, desc_v, Q_LEN, KV_LEN,
506
+ # accumulated values
507
+ acc, l_i, m_i,
508
+ # Offsets used as inputs to score_mod & mask_mod
509
+ # of size [BLOCK_M, BLOCK_N] or scalar.
510
+ off_z, off_h, offs_m, offs_n,
511
+ # Offsets needed for TMA loads
512
+ kv_start,
513
+ # blocksparse data
514
+ kv_indices, kv_num_blocks,
515
+ # start kv and end kv block
516
+ block_n_start, block_n_end,
517
+ MATMUL_PRECISION,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS,
521
+ ):
522
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
523
+ PRESCALE_QK : tl.constexpr = False
524
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
525
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
526
+ WRITE_DQ : tl.constexpr = True
527
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
528
+ OUTPUT_MAX : tl.constexpr = False
529
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
530
+ IS_DIVISIBLE : tl.constexpr = True
531
+ SM_SCALE : tl.constexpr = 0.08838834764831843
532
+ GQA_SHARED_HEADS : tl.constexpr = 4
533
+ HAS_FULL_BLOCKS : tl.constexpr = True
534
+ QK_HEAD_DIM : tl.constexpr = 128
535
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
536
+ V_HEAD_DIM : tl.constexpr = 128
537
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
538
+ SAFE_HEAD_DIM : tl.constexpr = True
539
+ USE_TMA : tl.constexpr = False
540
+ BLOCK_M : tl.constexpr = 128
541
+ BLOCK_N : tl.constexpr = 64
542
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
543
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
544
+ INDEX_DTYPE : tl.constexpr = tl.int32
545
+
546
+
547
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
548
+ RCP_LN2: tl.constexpr = 1.44269504
549
+
550
+ if PRESCALE_QK:
551
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
552
+
553
+ kv_offset = 0
554
+
555
+ # loop over k, v and update accumulator until block_n_end
556
+ for start_n in range(block_n_start, block_n_end):
557
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
558
+ if IS_DIVISIBLE:
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS,
573
+ )
574
+ else:
575
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
576
+ # it's on par or slightly faster than only applying to the last block in fwd.
577
+ # However, we choose different strategy for bwd, where we only apply mod & mask
578
+ # to the last block because it's faster a lot.
579
+ acc, l_i, m_i = forward_block_mn(
580
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
581
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
582
+ # accumulated values
583
+ acc, l_i, m_i,
584
+ # Offsets
585
+ off_z, off_h, offs_m, offs_n,
586
+ # Offsets needed for TMA loads
587
+ kv_start,
588
+ kv_offset,
589
+ MATMUL_PRECISION, RCP_LN2,
590
+ # Strides for K and V
591
+ stride_kk, stride_kn, stride_vn, stride_vk,
592
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
593
+ )
594
+
595
+
596
+
597
+ offset = get_offset_for_next_block(
598
+ start_n, kv_indices, kv_num_blocks,
599
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
600
+ )
601
+
602
+ offs_n = offs_n + offset
603
+ kv_offset += offset
604
+
605
+
606
+ return acc, l_i, m_i
607
+ ''', device_str='cuda')
608
+
609
+
610
+ async_compile.wait(globals())
611
+ del async_compile
612
+
613
+ class Runner:
614
+ def __init__(self, partitions):
615
+ self.partitions = partitions
616
+
617
+ def recursively_apply_fns(self, fns):
618
+ new_callables = []
619
+ for fn, c in zip(fns, self.partitions):
620
+ new_callables.append(fn(c))
621
+ self.partitions = new_callables
622
+
623
+ def call(self, args):
624
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args
625
+ args.clear()
626
+ assert_size_stride(primals_1, (8, 32, 2048, 128), (8388608, 128, 4096, 1))
627
+ assert_size_stride(primals_2, (8, 8, 2048, 128), (2097152, 262144, 128, 1))
628
+ assert_size_stride(primals_3, (8, 8, 2048, 128), (2097152, 262144, 128, 1))
629
+ assert_size_stride(primals_4, (8, 1, 16, 16), (256, 256, 16, 1))
630
+ assert_size_stride(primals_5, (8, 1, 16), (16, 16, 1))
631
+ assert_size_stride(primals_6, (8, ), (1, ))
632
+ assert_size_stride(primals_7, (8, 1, 16), (16, 16, 1))
633
+ assert_size_stride(primals_8, (8, 1, 16, 16), (256, 256, 16, 1))
634
+ assert_size_stride(primals_9, (8, 1, 16), (16, 16, 1))
635
+ assert_size_stride(primals_10, (8, 1, 16, 16), (256, 256, 16, 1))
636
+ assert_size_stride(primals_11, (8, 1, 16), (16, 16, 1))
637
+ assert_size_stride(primals_12, (8, 1, 16, 16), (256, 256, 16, 1))
638
+ with torch.cuda._DeviceGuard(5):
639
+ torch.cuda.set_device(5)
640
+ buf0 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32)
641
+ buf1 = empty_strided_cuda((8, 32, 2048), (65536, 2048, 1), torch.float32)
642
+ buf2 = empty_strided_cuda((8, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16)
643
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
644
+ stream5 = get_raw_stream(5)
645
+ triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 8, 32, stream=stream5)
646
+ del buf1
647
+ return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, )
648
+
649
+ runner = Runner(partitions=[])
650
+ call = runner.call
651
+ recursively_apply_fns = runner.recursively_apply_fns
652
+
653
+
654
+ def benchmark_compiled_module(times=10, repeat=10):
655
+ from torch._dynamo.testing import rand_strided
656
+ from torch._inductor.utils import print_performance
657
+ primals_1 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:5', dtype=torch.bfloat16)
658
+ primals_2 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16)
659
+ primals_3 = rand_strided((8, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:5', dtype=torch.bfloat16)
660
+ primals_4 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32)
661
+ primals_5 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32)
662
+ primals_6 = rand_strided((8, ), (1, ), device='cuda:5', dtype=torch.int64)
663
+ primals_7 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32)
664
+ primals_8 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32)
665
+ primals_9 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32)
666
+ primals_10 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32)
667
+ primals_11 = rand_strided((8, 1, 16), (16, 16, 1), device='cuda:5', dtype=torch.int32)
668
+ primals_12 = rand_strided((8, 1, 16, 16), (256, 256, 16, 1), device='cuda:5', dtype=torch.int32)
669
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12])
670
+ return print_performance(fn, times=times, repeat=repeat)
671
+
672
+
673
+ if __name__ == "__main__":
674
+ from torch._inductor.wrapper_benchmark import compiled_module_main
675
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/as/9f962df2938e79169dbf28adc9c67d12118719f3425569a98b11309d3108a638.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 65, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"}
SpecForge-ext/cache/compiled_kernels/as/casevqrknafvhxbpwjozemzmdw3n2vgrctm4s4zdjzqp52cqs6kd.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['9_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
17
+ import triton
18
+ import triton.language as tl
19
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
20
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
21
+
22
+ aten = torch.ops.aten
23
+ inductor_ops = torch.ops.inductor
24
+ _quantized = torch.ops._quantized
25
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
26
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
27
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
28
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
29
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
30
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
31
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
32
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
33
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
34
+ async_compile = AsyncCompile()
35
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
36
+
37
+
38
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/op/cop75xk6fpjjvnvvcusccw4eu3b3i2silh5jxkjylbibzjctamxl.py
39
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
40
+ # Source node to ATen node mapping:
41
+ # flex_attention => flex_attention
42
+ # Graph fragment:
43
+ # %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:2" = PlaceHolder[target=primals_1]
44
+ # %primals_3 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:2" = PlaceHolder[target=primals_3]
45
+ # %primals_5 : Tensor "bf16[2, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:2" = PlaceHolder[target=primals_5]
46
+ # %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=getitem_1]
47
+ # %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:2" = PlaceHolder[target=buf1]
48
+ # %primals_9 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_9]
49
+ # %primals_7 : Tensor "i32[2, 1, 16, s72][16*s72, 16*s72, s72, 1]cuda:2" = PlaceHolder[target=primals_7]
50
+ # %primals_11 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=primals_11]
51
+ # %primals_13 : Tensor "i32[2, 1, 16, s4][16*s4, 16*s4, s4, 1]cuda:2" = PlaceHolder[target=primals_13]
52
+ # %primals_10 : Tensor "i64[2][1]cuda:2" = PlaceHolder[target=primals_10]
53
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_3, %primals_5, %sdpa_score0, (2048, %primals_8, %primals_9, %primals_7, %primals_11, %primals_13, %primals_15, %primals_17, %primals_19, %primals_21, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_10,)), kwargs = {})
54
+ # return %getitem
55
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
56
+ import triton
57
+ import triton.language as tl
58
+
59
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
60
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
61
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
62
+
63
+ @triton_heuristics.template(
64
+
65
+ num_stages=3,
66
+ num_warps=8,
67
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
68
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
69
+
70
+ )
71
+ @triton.jit
72
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1):
73
+ PRESCALE_QK : tl.constexpr = False
74
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
75
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
76
+ WRITE_DQ : tl.constexpr = True
77
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
78
+ OUTPUT_MAX : tl.constexpr = False
79
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
80
+ IS_DIVISIBLE : tl.constexpr = False
81
+ SM_SCALE : tl.constexpr = 0.08838834764831843
82
+ GQA_SHARED_HEADS : tl.constexpr = 4
83
+ HAS_FULL_BLOCKS : tl.constexpr = True
84
+ QK_HEAD_DIM : tl.constexpr = 128
85
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ V_HEAD_DIM : tl.constexpr = 128
87
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
88
+ SAFE_HEAD_DIM : tl.constexpr = True
89
+ USE_TMA : tl.constexpr = False
90
+ BLOCK_M : tl.constexpr = 128
91
+ BLOCK_N : tl.constexpr = 64
92
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
93
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
94
+ INDEX_DTYPE : tl.constexpr = tl.int32
95
+ Q = arg_Q
96
+ K = arg_K
97
+ V = arg_V
98
+ LSE = arg_LSE
99
+ MAX = arg_MAX
100
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
101
+ KV_IDX = arg_KV_IDX
102
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
103
+ FULL_KV_IDX = arg_FULL_KV_IDX
104
+
105
+ # Sub notation for this kernel:
106
+ #
107
+ # Q: Query, K: Key, V: Value
108
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
109
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
110
+ # V_HEAD_DIM: The dimension of the value embeddings
111
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
112
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
113
+ #
114
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
115
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
116
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
117
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
118
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
119
+ #
120
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
121
+ #
122
+ # (Modifiable) Performance tuning options
123
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
124
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
125
+
126
+ # The below are kernel options that can be applied for certain score_mods,
127
+ # or involve a numerics vs. perf tradeoff
128
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
129
+ # about 20% more numerical error, but slightly faster.
130
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
131
+ # is not masked out? If so, we can skip an extra safety check
132
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
133
+ # contiguous? If so, we don't need to do an indirect jump for every block
134
+
135
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
136
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
137
+
138
+ # Define strides of inputs
139
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
140
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks0, 128*ks0, 128, 1
141
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks0, 128*ks0, 128, 1
142
+
143
+ ZQ = 2
144
+ HQ = 32
145
+ Q_LEN = 2048
146
+ ZKV = 2
147
+ KV_LEN = ks0
148
+
149
+ MATMUL_PRECISION = Q.dtype.element_ty
150
+
151
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
152
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
153
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
154
+
155
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
156
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
157
+ off_zkv = off_zq % ZKV
158
+ off_hkv = off_hq // GQA_SHARED_HEADS
159
+ off_g = off_hq % GQA_SHARED_HEADS
160
+
161
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
162
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
163
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
164
+
165
+ Q = Q + q_offset
166
+ K = K + k_offset
167
+ V = V + v_offset
168
+
169
+ # Setting up the TMA descriptors for Q, K, V
170
+ desc_q = None
171
+ desc_k = None
172
+ desc_v = None
173
+
174
+ SPARSE_Z = 2
175
+ SPARSE_HQ = 1
176
+
177
+ sparse_idx_z = off_zq % SPARSE_Z
178
+ sparse_idx_hq = off_hq % SPARSE_HQ
179
+
180
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
181
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
182
+
183
+ stride_kv_num_blks_h = 16
184
+ stride_kv_idx_h = 16*ks1
185
+ stride_kv_idx_m = ks1
186
+
187
+ # initialize pointer to m and l
188
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
189
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
190
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
191
+
192
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
193
+
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
196
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
197
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
198
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
199
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
200
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
201
+
202
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
203
+ # We don't know anything "special" about these blocks, so we need to apply
204
+ # both score_mod and mask_mod to it
205
+ kv_indices = KV_IDX + sparse_kv_idx_offset
206
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
207
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
208
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
209
+
210
+
211
+ # K and V pointers will be passed directly to forward_inner
212
+
213
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
214
+
215
+
216
+ acc, l_i, m_i = forward_inner(
217
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
218
+ q, K, V,
219
+ desc_k, desc_v, Q_LEN, KV_LEN,
220
+ acc, l_i, m_i,
221
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
222
+ kv_start,
223
+ kv_indices, kv_num_blocks,
224
+ 0, block_n_end,
225
+ MATMUL_PRECISION,
226
+ stride_kk, stride_kn, stride_vn, stride_vk,
227
+ IS_FULL_BLOCKS=False,
228
+ )
229
+
230
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
231
+ # We know these blocks are guaranteed to be "full", so we don't need to
232
+ # apply mask_mod to them - only score_mod
233
+ if HAS_FULL_BLOCKS:
234
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
235
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
236
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
237
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
238
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
239
+ # K and V pointers will be passed directly to forward_inner
240
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
241
+
242
+ acc, l_i, m_i = forward_inner(
243
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
244
+ q, K, V,
245
+ desc_k, desc_v, Q_LEN, KV_LEN,
246
+ acc, l_i, m_i,
247
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
248
+ kv_start,
249
+ kv_indices, kv_num_blocks,
250
+ 0, block_n_end,
251
+ MATMUL_PRECISION,
252
+ stride_kk, stride_kn, stride_vn, stride_vk,
253
+ IS_FULL_BLOCKS=True,
254
+ )
255
+
256
+
257
+ # [Note] Handle fully masked out rows:
258
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
259
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
260
+ l_i = tl.where(l_i == 0.0, 1, l_i)
261
+
262
+ acc = acc / l_i[:, None]
263
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
264
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
265
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
266
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
267
+
268
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
269
+
270
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
271
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
272
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
273
+
274
+ if OUTPUT_LOGSUMEXP:
275
+ off_hz = off_zq * HQ + off_hq
276
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
277
+ lse = m_i + tl.math.log2(l_i)
278
+ if IS_DIVISIBLE:
279
+ tl.store(l_ptrs, lse)
280
+ else:
281
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
282
+
283
+ if OUTPUT_MAX:
284
+ off_hz = off_zq * HQ + off_hq
285
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
286
+ if IS_DIVISIBLE:
287
+ tl.store(max_ptrs, m_i)
288
+ else:
289
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
290
+
291
+
292
+ # Utility triton funcs
293
+ @triton.jit
294
+ def get_offset_for_next_block(
295
+ loop_iter, col_indices, total_blocks,
296
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
297
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
298
+ ):
299
+ if BLOCKS_ARE_CONTIGUOUS:
300
+ return BLOCK
301
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
302
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
303
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
304
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
305
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
306
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
307
+ return offset
308
+
309
+ @triton.jit
310
+ def get_bounded_indices(indices, max_len=None):
311
+ return indices % max_len if max_len is not None else indices
312
+
313
+ @triton.jit
314
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
315
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr)
317
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
319
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
320
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
321
+ else:
322
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
323
+
324
+ @triton.jit
325
+ def load_checked_2d(
326
+ ptr,
327
+ offs_m,
328
+ offs_n,
329
+ stride_m,
330
+ stride_n,
331
+ IS_DIVISIBLE_M: tl.constexpr,
332
+ IS_DIVISIBLE_N: tl.constexpr,
333
+ M_LEN: tl.constexpr,
334
+ N_LEN: tl.constexpr,
335
+ ):
336
+ # Calculate final pointer if strides are provided
337
+ if stride_m is not None and stride_n is not None:
338
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
339
+
340
+ # Handle all masking cases
341
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
343
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
345
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
346
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
347
+ else: # Both divisible
348
+ return tl.load(ptr)
349
+
350
+
351
+ # Common Imports
352
+ @triton.jit
353
+ def forward_block_mn(
354
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
355
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
356
+ # accumulated values
357
+ acc, l_i, m_i,
358
+ # Offsets
359
+ off_z, off_h, offs_m, offs_n,
360
+ # Offsets needed for TMA loads
361
+ kv_start,
362
+ kv_offset,
363
+ MATMUL_PRECISION, RCP_LN2,
364
+ # Strides for K and V
365
+ stride_kk, stride_kn, stride_vn, stride_vk,
366
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
367
+
368
+ ):
369
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
370
+ PRESCALE_QK : tl.constexpr = False
371
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
372
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
373
+ WRITE_DQ : tl.constexpr = True
374
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
375
+ OUTPUT_MAX : tl.constexpr = False
376
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
377
+ IS_DIVISIBLE : tl.constexpr = False
378
+ SM_SCALE : tl.constexpr = 0.08838834764831843
379
+ GQA_SHARED_HEADS : tl.constexpr = 4
380
+ HAS_FULL_BLOCKS : tl.constexpr = True
381
+ QK_HEAD_DIM : tl.constexpr = 128
382
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ V_HEAD_DIM : tl.constexpr = 128
384
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
385
+ SAFE_HEAD_DIM : tl.constexpr = True
386
+ USE_TMA : tl.constexpr = False
387
+ BLOCK_M : tl.constexpr = 128
388
+ BLOCK_N : tl.constexpr = 64
389
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
390
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
391
+ INDEX_DTYPE : tl.constexpr = tl.int32
392
+
393
+
394
+ # -- load k --
395
+ # NB reversed order to since K is transposed
396
+ kv_base_offset = kv_start + kv_offset
397
+
398
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
399
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
400
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
401
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
402
+
403
+ k = tl.trans(k)
404
+ # -- compute qk ---
405
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
406
+ if not PRESCALE_QK:
407
+ qk *= SM_SCALE
408
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
409
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
410
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
411
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
412
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
413
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
414
+
415
+ tmp0 = (qk)
416
+ post_mod_scores = tmp0
417
+
418
+
419
+ if CHECK_BLOCK_BOUNDARY:
420
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
421
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
422
+
423
+ if not IS_FULL_BLOCKS:
424
+ tmp1 = tl.full([1], False, tl.int1)
425
+ tmp2 = (m)
426
+ tmp3 = (n)
427
+ tmp4 = tmp2 >= tmp3
428
+ tmp5 = tmp3.to(tl.int64)
429
+ tmp6 = (off_z)
430
+ tmp7 = tl.load(in_ptr9 + tmp6)
431
+ tmp8 = tmp5 < tmp7
432
+ tmp9 = tmp2.to(tl.int64)
433
+ tmp10 = tmp9 < tmp7
434
+ tmp11 = tmp8 & tmp10
435
+ tmp12 = tmp4 & tmp11
436
+ tmp13 = tmp1 | tmp12
437
+ tmp14 = tl.full([1], 2048, tl.int32)
438
+ tmp15 = tmp3 >= tmp14
439
+ tmp16 = (tmp3 % tmp14)
440
+ tmp17 = tl.full([1], 0, tl.int32)
441
+ tmp18 = tmp16 != tmp17
442
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
443
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
444
+ tmp21 = tmp19 != tmp20
445
+ tmp22 = tmp18 & tmp21
446
+ tmp23 = tmp16 + tmp14
447
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
448
+ tmp25 = tmp24.to(tl.int64)
449
+ tmp26 = tmp25 < tmp7
450
+ tmp27 = tmp15 & tmp26
451
+ tmp28 = tmp3 - tmp2
452
+ tmp29 = (tmp28 % tmp14)
453
+ tmp30 = tmp29 != tmp17
454
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
455
+ tmp32 = tmp31 != tmp20
456
+ tmp33 = tmp30 & tmp32
457
+ tmp34 = tmp29 + tmp14
458
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
459
+ tmp36 = tmp35 == tmp17
460
+ tmp37 = tmp27 & tmp36
461
+ tmp38 = tmp13 | tmp37
462
+ mask_mod_output = tmp38
463
+
464
+
465
+ if CHECK_BLOCK_BOUNDARY:
466
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
467
+ # apply mask for partially unmasked blocks
468
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
469
+
470
+ if not PRESCALE_QK:
471
+ post_mod_scores *= RCP_LN2
472
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
473
+
474
+ # -- compute scaling constant ---
475
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
476
+ if not ROWS_GUARANTEED_SAFE:
477
+ masked_out_rows = (m_ij == float("-inf"))
478
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
479
+ else:
480
+ m_ij_masked = m_ij
481
+
482
+ alpha = tl.math.exp2(m_i - m_ij_masked)
483
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
484
+
485
+ # NB: l_i update is pulled up here since it's a bit faster
486
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
487
+ # m_ij
488
+ l_i = l_i * alpha + tl.sum(p, 1)
489
+ # # -- scale and update acc --
490
+ acc = acc * alpha[:, None]
491
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
492
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
493
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
494
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
495
+
496
+ # -- update m_i
497
+ m_i = m_ij
498
+
499
+ return acc, l_i, m_i
500
+
501
+ @triton.jit
502
+ def forward_inner(
503
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
504
+ q, K, V,
505
+ desc_k, desc_v, Q_LEN, KV_LEN,
506
+ # accumulated values
507
+ acc, l_i, m_i,
508
+ # Offsets used as inputs to score_mod & mask_mod
509
+ # of size [BLOCK_M, BLOCK_N] or scalar.
510
+ off_z, off_h, offs_m, offs_n,
511
+ # Offsets needed for TMA loads
512
+ kv_start,
513
+ # blocksparse data
514
+ kv_indices, kv_num_blocks,
515
+ # start kv and end kv block
516
+ block_n_start, block_n_end,
517
+ MATMUL_PRECISION,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS,
521
+ ):
522
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
523
+ PRESCALE_QK : tl.constexpr = False
524
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
525
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
526
+ WRITE_DQ : tl.constexpr = True
527
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
528
+ OUTPUT_MAX : tl.constexpr = False
529
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
530
+ IS_DIVISIBLE : tl.constexpr = False
531
+ SM_SCALE : tl.constexpr = 0.08838834764831843
532
+ GQA_SHARED_HEADS : tl.constexpr = 4
533
+ HAS_FULL_BLOCKS : tl.constexpr = True
534
+ QK_HEAD_DIM : tl.constexpr = 128
535
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
536
+ V_HEAD_DIM : tl.constexpr = 128
537
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
538
+ SAFE_HEAD_DIM : tl.constexpr = True
539
+ USE_TMA : tl.constexpr = False
540
+ BLOCK_M : tl.constexpr = 128
541
+ BLOCK_N : tl.constexpr = 64
542
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
543
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
544
+ INDEX_DTYPE : tl.constexpr = tl.int32
545
+
546
+
547
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
548
+ RCP_LN2: tl.constexpr = 1.44269504
549
+
550
+ if PRESCALE_QK:
551
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
552
+
553
+ kv_offset = 0
554
+
555
+ # loop over k, v and update accumulator until block_n_end
556
+ for start_n in range(block_n_start, block_n_end):
557
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
558
+ if IS_DIVISIBLE:
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS,
573
+ )
574
+ else:
575
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
576
+ # it's on par or slightly faster than only applying to the last block in fwd.
577
+ # However, we choose different strategy for bwd, where we only apply mod & mask
578
+ # to the last block because it's faster a lot.
579
+ acc, l_i, m_i = forward_block_mn(
580
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1,
581
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
582
+ # accumulated values
583
+ acc, l_i, m_i,
584
+ # Offsets
585
+ off_z, off_h, offs_m, offs_n,
586
+ # Offsets needed for TMA loads
587
+ kv_start,
588
+ kv_offset,
589
+ MATMUL_PRECISION, RCP_LN2,
590
+ # Strides for K and V
591
+ stride_kk, stride_kn, stride_vn, stride_vk,
592
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
593
+ )
594
+
595
+
596
+
597
+ offset = get_offset_for_next_block(
598
+ start_n, kv_indices, kv_num_blocks,
599
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
600
+ )
601
+
602
+ offs_n = offs_n + offset
603
+ kv_offset += offset
604
+
605
+
606
+ return acc, l_i, m_i
607
+ ''', device_str='cuda')
608
+
609
+
610
+ async_compile.wait(globals())
611
+ del async_compile
612
+
613
+ class Runner:
614
+ def __init__(self, partitions):
615
+ self.partitions = partitions
616
+
617
+ def recursively_apply_fns(self, fns):
618
+ new_callables = []
619
+ for fn, c in zip(fns, self.partitions):
620
+ new_callables.append(fn(c))
621
+ self.partitions = new_callables
622
+
623
+ def call(self, args):
624
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21 = args
625
+ args.clear()
626
+ s0 = primals_2
627
+ s43 = primals_4
628
+ s72 = primals_6
629
+ s71 = primals_8
630
+ s4 = primals_12
631
+ s56 = primals_14
632
+ s84 = primals_16
633
+ s99 = primals_18
634
+ s6 = primals_20
635
+ assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1))
636
+ assert_size_stride(primals_3, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
637
+ assert_size_stride(primals_5, (2, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
638
+ assert_size_stride(primals_7, (2, 1, 16, s72), (16*s72, 16*s72, s72, 1))
639
+ assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1))
640
+ assert_size_stride(primals_10, (2, ), (1, ))
641
+ assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1))
642
+ assert_size_stride(primals_13, (2, 1, 16, s4), (16*s4, 16*s4, s4, 1))
643
+ assert_size_stride(primals_15, (2, 1, s56), (s56, s56, 1))
644
+ assert_size_stride(primals_17, (2, 1, s84, 16), (16*s84, 16*s84, 16, 1))
645
+ assert_size_stride(primals_19, (2, 1, s99), (s99, s99, 1))
646
+ assert_size_stride(primals_21, (2, 1, s6, 16), (16*s6, 16*s6, 16, 1))
647
+ with torch.cuda._DeviceGuard(2):
648
+ torch.cuda.set_device(2)
649
+ buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32)
650
+ buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32)
651
+ buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16)
652
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
653
+ stream2 = get_raw_stream(2)
654
+ triton_tem_fused_0.run(primals_1, primals_3, primals_5, buf0, buf1, primals_9, primals_7, primals_11, primals_13, primals_10, buf2, s0, s72, 16, 2, 32, stream=stream2)
655
+ del buf1
656
+ return (buf2, primals_1, primals_3, primals_5, primals_7, primals_9, primals_10, primals_11, primals_13, primals_15, primals_17, primals_19, primals_21, buf2, buf0, s0, s72, s4, s56, s84, s99, s6, )
657
+
658
+ runner = Runner(partitions=[])
659
+ call = runner.call
660
+ recursively_apply_fns = runner.recursively_apply_fns
661
+
662
+
663
+ def benchmark_compiled_module(times=10, repeat=10):
664
+ from torch._dynamo.testing import rand_strided
665
+ from torch._inductor.utils import print_performance
666
+ primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:2', dtype=torch.bfloat16)
667
+ primals_2 = 4096
668
+ primals_3 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:2', dtype=torch.bfloat16)
669
+ primals_4 = 4096
670
+ primals_5 = rand_strided((2, 8, 4096, 128), (4194304, 524288, 128, 1), device='cuda:2', dtype=torch.bfloat16)
671
+ primals_6 = 32
672
+ primals_7 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:2', dtype=torch.int32)
673
+ primals_8 = 4096
674
+ primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32)
675
+ primals_10 = rand_strided((2, ), (1, ), device='cuda:2', dtype=torch.int64)
676
+ primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:2', dtype=torch.int32)
677
+ primals_12 = 32
678
+ primals_13 = rand_strided((2, 1, 16, 32), (512, 512, 32, 1), device='cuda:2', dtype=torch.int32)
679
+ primals_14 = 32
680
+ primals_15 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:2', dtype=torch.int32)
681
+ primals_16 = 32
682
+ primals_17 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:2', dtype=torch.int32)
683
+ primals_18 = 32
684
+ primals_19 = rand_strided((2, 1, 32), (32, 32, 1), device='cuda:2', dtype=torch.int32)
685
+ primals_20 = 32
686
+ primals_21 = rand_strided((2, 1, 32, 16), (512, 512, 16, 1), device='cuda:2', dtype=torch.int32)
687
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21])
688
+ return print_performance(fn, times=times, repeat=repeat)
689
+
690
+
691
+ if __name__ == "__main__":
692
+ from torch._inductor.wrapper_benchmark import compiled_module_main
693
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/as/casmcbz6icqn6mp2r7jahugidys5xwty64z2p3tfw4s7vlsj2oz2.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 67108864},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = (xindex % ks0)
23
+ x3 = xindex
24
+ x1 = ((xindex // ks0) % ks1)
25
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
26
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
27
+ tmp0 = x0
28
+ tmp1 = ks0 // 2
29
+ tmp2 = tmp0 >= tmp1
30
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
31
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
32
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
33
+ tmp6 = tmp4 + tmp5
34
+ tmp7 = tmp4 < 0
35
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
36
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
37
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
38
+ tmp11 = tmp3 * tmp10
39
+ tmp12 = -tmp11
40
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
41
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
42
+ tmp15 = 0.0
43
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
44
+ tmp17 = tmp0 < tmp1
45
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
46
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
47
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
48
+ tmp21 = tmp19 + tmp20
49
+ tmp22 = tmp19 < 0
50
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
51
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
52
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
53
+ tmp26 = tmp18 * tmp25
54
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
55
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
56
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
57
+ tmp30 = tmp16 + tmp29
58
+ tmp33 = ks3
59
+ tmp34 = tmp32 + tmp33
60
+ tmp35 = tmp32 < 0
61
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
62
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
63
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
64
+ tmp39 = tmp31 * tmp38
65
+ tmp40 = tmp30 + tmp39
66
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
SpecForge-ext/cache/compiled_kernels/c2/363ecfeae02cf0bc03b4070f8b6a6ac6bcf543c1a19d1c4a53122d5722a2b3dd.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "b6ac5ef64fddcad8fc8d2c05fa12424871fd9baa5a4158ff38ecebbafb55a4b1", "found_by_coordesc": false, "time_taken_ms": 40, "triton_cache_hash": "MMGM2ESHRXPRFAROBBDYKTZUJ2JVVKU2TB5DVA3EL4OF2SOELPMQ"}
SpecForge-ext/cache/compiled_kernels/c2/cc2qlkbbemfommyywsdbow3sqg7jqf5x5tfkbqjzo2qy6lt36yjr.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 128, 'r0_': 16},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ xnumel = 128
20
+ r0_numel = 16
21
+ R0_BLOCK: tl.constexpr = 16
22
+ rnumel = r0_numel
23
+ RBLOCK: tl.constexpr = R0_BLOCK
24
+ xoffset = tl.program_id(0) * XBLOCK
25
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
26
+ xmask = xindex < xnumel
27
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
28
+ r0_offset = 0
29
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
30
+ roffset = r0_offset
31
+ rindex = r0_index
32
+ r0_1 = r0_index
33
+ x0 = xindex
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0)
35
+ tmp1 = tl.full([1, 1], 0, tl.int64)
36
+ tmp2 = tmp0 > tmp1
37
+ tmp3 = tl.full([1, 1], 16384, tl.int64)
38
+ tmp4 = tmp0 < tmp3
39
+ tmp5 = tmp2 & tmp4
40
+ tmp6 = tmp5.to(tl.int8)
41
+ tmp7 = tmp6.to(tl.int32)
42
+ tmp8 = r0_1
43
+ tmp9 = tmp8.to(tl.int16)
44
+ tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
45
+ tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
46
+ tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True)
47
+ tmp14 = tmp0 == tmp3
48
+ tmp15 = tmp14.to(tl.int8)
49
+ tmp16 = tmp15.to(tl.int32)
50
+ tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
51
+ tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True)
52
+ tmp20 = tmp7.to(tl.int64)
53
+ tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK])
54
+ tmp23 = tl.where(xmask, tmp21, 0)
55
+ tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64)
56
+ tmp25 = tmp16.to(tl.int64)
57
+ tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
58
+ tmp28 = tl.where(xmask, tmp26, 0)
59
+ tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64)
60
+ tmp30 = tmp24.to(tl.int32)
61
+ tmp31 = tmp29.to(tl.int32)
62
+ tmp32 = tmp13.to(tl.int64)
63
+ tmp33 = tmp32.to(tl.int32)
64
+ tmp34 = tmp8 < tmp30
65
+ tmp35 = tl.full([1, 1], 16, tl.int32)
66
+ tmp36 = tl.where(tmp34, tmp33, tmp35)
67
+ tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32)
68
+ tmp38 = tmp36 + tmp37
69
+ tmp39 = tmp36 < 0
70
+ tmp40 = tl.where(tmp39, tmp38, tmp36)
71
+ tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17")
72
+ tmp42 = tl.full([1, 1], 1, tl.int32)
73
+ tmp43 = tmp19.to(tl.int64)
74
+ tmp44 = tmp43.to(tl.int32)
75
+ tmp45 = tmp8 < tmp31
76
+ tmp46 = tl.where(tmp45, tmp44, tmp35)
77
+ tmp47 = tmp46 + tmp37
78
+ tmp48 = tmp46 < 0
79
+ tmp49 = tl.where(tmp48, tmp47, tmp46)
80
+ tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17")
81
+ tl.store(out_ptr4 + (x0), tmp30, xmask)
82
+ tl.store(out_ptr5 + (x0), tmp31, xmask)
83
+ tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask)
84
+ tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
85
+ tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask)
86
+ tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
SpecForge-ext/cache/compiled_kernels/dm/cdma2uevipbm2dd462ztkubtq5uanau5l3oglcw7lhpt4uovlqya.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = True
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ BLOCK_M1 : tl.constexpr = 64
36
+ BLOCK_N1 : tl.constexpr = 128
37
+ BLOCK_M2 : tl.constexpr = 128
38
+ BLOCK_N2 : tl.constexpr = 64
39
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
40
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
41
+ INDEX_DTYPE : tl.constexpr = tl.int32
42
+ Q = arg_Q
43
+ K = arg_K
44
+ V = arg_V
45
+ LSE = arg_LSE
46
+ DELTA = arg_DELTA
47
+ DO = arg_DO
48
+ DQ = arg_DQ
49
+ DV = arg_DV
50
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
51
+ KV_IDX = arg_KV_IDX
52
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
53
+ Q_IDX = arg_Q_IDX
54
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
55
+ FULL_KV_IDX = arg_FULL_KV_IDX
56
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
57
+ FULL_Q_IDX = arg_FULL_Q_IDX
58
+
59
+ # Sub notation for this kernel:
60
+ #
61
+ # Q: Query, K: Key, V: Value
62
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
63
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
64
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
65
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
66
+ # inductor codegen
67
+ # M: Number of queries, N: Number of keys/values
68
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
69
+ # V_HEAD_DIM: The dimension of the value embeddings
70
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
71
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
72
+ # (Modifiable) Performance tuning options
73
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
74
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
75
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
76
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
77
+ #
78
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
79
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
80
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
81
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
82
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
83
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
84
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
85
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
86
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
87
+
88
+ # The below are kernel options that can be applied for certain score_mods,
89
+ # or involve a numerics vs. perf tradeoff
90
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
91
+ # about 20% more numerical error, but slightly faster.
92
+
93
+ # Define strides of inputs
94
+ stride_qz, stride_qh, stride_qm, stride_qd = 8388608, 128, 4096, 1
95
+ stride_kz, stride_kh, stride_kn, stride_kd = 2097152, 262144, 128, 1
96
+ stride_vz, stride_vh, stride_vn, stride_vd = 2097152, 262144, 128, 1
97
+ stride_doz, stride_doh, stride_dom, stride_dod = 8388608, 262144, 128, 1
98
+
99
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 8388608, 128, 4096, 1
100
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 2097152, 262144, 128, 1
101
+
102
+ ZQ = 8
103
+ HQ = 32
104
+ HKV = 8
105
+ Q_LEN = 2048
106
+ ZKV = 8
107
+ KV_LEN = 2048
108
+
109
+ MATMUL_PRECISION = Q.dtype.element_ty
110
+
111
+ pid = tl.program_id(0).to(INDEX_DTYPE)
112
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
113
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
114
+
115
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
116
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
117
+ off_zkv = off_zq % ZKV # kv batch idx
118
+
119
+ SPARSE_Z = 8
120
+ SPARSE_HQ = 1
121
+
122
+ sparse_idx_z = off_zq % SPARSE_Z
123
+
124
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
125
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
126
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
127
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
128
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
129
+
130
+ # offset K, V, DV pointers for batch/kv-head
131
+ K += k_adj
132
+ V += v_adj
133
+ DV += dv_adj
134
+
135
+ RCP_LN2 = 1.44269504
136
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
137
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
138
+
139
+ if pid >= NUM_KV_BLOCKS:
140
+ off_pid = pid - NUM_KV_BLOCKS
141
+ # THIS BLOCK DOES DQ
142
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
143
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
144
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
145
+ start_m2_block = off_pid % NUM_Q_BLOCKS
146
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
147
+ stride_kv_num_blks_h = 16
148
+ stride_kv_idx_h = 256
149
+ stride_kv_idx_m = 16
150
+
151
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
152
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
153
+
154
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
155
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
156
+
157
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
158
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
159
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
160
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
161
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
162
+
163
+ Q2 = Q + q_adj2
164
+ DO2 = DO + do_adj2
165
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
166
+ # if Q is broadcasted)
167
+ DQ2 = DQ + dq_adj2
168
+ LSE2 = LSE + off_chz2
169
+ DELTA2 = DELTA + off_chz2
170
+
171
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
172
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
173
+
174
+ start_m2 = start_m2_block * BLOCK_M2
175
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
176
+
177
+ # load Q and do: they stay in SRAM throughout the inner loop.
178
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
179
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
180
+
181
+ if PRESCALE_QK:
182
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
183
+
184
+ if IS_DIVISIBLE:
185
+ Di = tl.load(DELTA2 + offs_m2)
186
+ lse = tl.load(LSE2 + offs_m2)
187
+ else:
188
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
189
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
190
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
191
+ lse = lse[:, None]
192
+
193
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ kv_indices = KV_IDX + sparse_kv_idx_offset
196
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
197
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
198
+
199
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
200
+ dq = bwd_dq_inner(
201
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
202
+ K, V,
203
+ dq, q, do, Di, lse,
204
+ off_zq, off_hq2, offs_m2, offs_n2,
205
+ stride_kn, stride_kd, stride_vn, stride_vd,
206
+ kv_indices, sparse_kv_num_blocks,
207
+ MATMUL_PRECISION,
208
+ IS_FULL_BLOCKS=False,
209
+ )
210
+
211
+ if HAS_FULL_BLOCKS:
212
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
214
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
215
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
216
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
217
+
218
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
219
+ dq = bwd_dq_inner(
220
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
221
+ K, V,
222
+ dq, q, do, Di, lse,
223
+ off_zq, off_hq2, offs_m2, offs_n2,
224
+ stride_kn, stride_kd, stride_vn, stride_vd,
225
+ kv_indices, sparse_kv_num_blocks,
226
+ MATMUL_PRECISION,
227
+ IS_FULL_BLOCKS=True,
228
+ )
229
+
230
+ # Write back dQ.
231
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
232
+ dq *= SM_SCALE
233
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
234
+ tl.store(dq_ptrs, dq)
235
+ else:
236
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
237
+ else:
238
+ # THIS BLOCK DOES DK & DV
239
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
240
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
241
+
242
+ pid_mask = pid // SPARSE_KV_MULTIPLE
243
+
244
+ stride_q_num_blks_h = 16
245
+ stride_q_idx_h = 256
246
+ stride_q_idx_n = 16
247
+
248
+
249
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
250
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
251
+
252
+ start_n1 = pid * BLOCK_N1
253
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
254
+
255
+ # load K and V: they stay in SRAM throughout the inner loop.
256
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
257
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
258
+
259
+ if PRESCALE_QK:
260
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
261
+
262
+ for off_g in range(0, GQA_SHARED_HEADS):
263
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
264
+
265
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
266
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
267
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
268
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
269
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
270
+
271
+ Q1 = Q + q_adj1
272
+ DO1 = DO + do_adj1
273
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
274
+ # if Q is broadcasted)
275
+ LSE1 = LSE + off_chz1
276
+ DELTA1 = DELTA + off_chz1
277
+
278
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
279
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
280
+
281
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
282
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
283
+
284
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
285
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
286
+ q_indices = Q_IDX + sparse_q_idx_offset
287
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
288
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
289
+
290
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
291
+ dk, dv = bwd_dkdv_inner(
292
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
293
+ Q1, DO1, DELTA1, LSE1,
294
+ dk, dv, k, v,
295
+ off_zq, off_hq1, offs_n1, offs_m1,
296
+ stride_qm, stride_qd, stride_dom, stride_dod,
297
+ q_indices, sparse_q_num_blocks,
298
+ MATMUL_PRECISION,
299
+ IS_FULL_BLOCKS=False,
300
+ )
301
+
302
+
303
+ if HAS_FULL_BLOCKS:
304
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
305
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
306
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
307
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
308
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
309
+
310
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
311
+ dk, dv = bwd_dkdv_inner(
312
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
313
+ Q1, DO1, DELTA1, LSE1,
314
+ dk, dv, k, v,
315
+ off_zq, off_hq1, offs_n1, offs_m1,
316
+ stride_qm, stride_qd, stride_dom, stride_dod,
317
+ q_indices, sparse_q_num_blocks,
318
+ MATMUL_PRECISION,
319
+ IS_FULL_BLOCKS=True,
320
+ )
321
+
322
+ # Write back dV and dK.
323
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
324
+
325
+ index_n = offs_n1[:, None]
326
+ index_k = offs_k[None, :]
327
+ index_v = offs_v[None, :]
328
+
329
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
330
+ tl.store(dv_ptrs, dv)
331
+ else:
332
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
333
+
334
+ dk *= SM_SCALE
335
+
336
+ if SAFE_HEAD_DIM:
337
+ mask = index_n < KV_LEN
338
+ else:
339
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
340
+
341
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
342
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
343
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
344
+ xindex = index_k + 128*index_n + 262144*off_hkv + 2097152*off_zq
345
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
346
+
347
+ @triton.jit
348
+ def bwd_dq_inner(
349
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
350
+ K, V, # pointers
351
+ dq, q, do, Di, lse,
352
+ off_z, off_hq, offs_m2, offs_n2,
353
+ stride_kn, stride_kd, stride_vn, stride_vd,
354
+ kv_indices, sparse_kv_num_blocks,
355
+ MATMUL_PRECISION,
356
+ IS_FULL_BLOCKS,
357
+ ):
358
+ PRESCALE_QK : tl.constexpr = False
359
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
360
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
361
+ WRITE_DQ : tl.constexpr = True
362
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
363
+ OUTPUT_MAX : tl.constexpr = False
364
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
365
+ IS_DIVISIBLE : tl.constexpr = True
366
+ SM_SCALE : tl.constexpr = 0.08838834764831843
367
+ GQA_SHARED_HEADS : tl.constexpr = 4
368
+ HAS_FULL_BLOCKS : tl.constexpr = True
369
+ QK_HEAD_DIM : tl.constexpr = 128
370
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
371
+ V_HEAD_DIM : tl.constexpr = 128
372
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
373
+ SAFE_HEAD_DIM : tl.constexpr = True
374
+ BLOCK_M1 : tl.constexpr = 64
375
+ BLOCK_N1 : tl.constexpr = 128
376
+ BLOCK_M2 : tl.constexpr = 128
377
+ BLOCK_N2 : tl.constexpr = 64
378
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
379
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
380
+ INDEX_DTYPE : tl.constexpr = tl.int32
381
+
382
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
383
+ RCP_LN2: tl.constexpr = 1.44269504
384
+ Q_LEN = 2048
385
+ KV_LEN = 2048
386
+
387
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
388
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
389
+
390
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
391
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
392
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
393
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
394
+
395
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
396
+
397
+ for start_n in range(0, hi):
398
+ dq = bwd_dq_block_mn(
399
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
400
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
401
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
402
+ stride_kn, stride_kd, stride_vn, stride_vd,
403
+ kv_indices, sparse_kv_num_blocks,
404
+ MATMUL_PRECISION, RCP_LN2,
405
+ IS_FULL_BLOCKS,
406
+ )
407
+
408
+ # Increment pointers.
409
+ offset = get_offset_for_next_block(
410
+ start_n, kv_indices, sparse_kv_num_blocks,
411
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
412
+ )
413
+
414
+ kT_ptrs += offset * stride_kn
415
+ vT_ptrs += offset * stride_vn
416
+
417
+ offs_n2 += offset
418
+
419
+ return dq
420
+
421
+
422
+ @triton.jit
423
+ def bwd_dq_block_mn(
424
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
425
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
426
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
427
+ stride_kn, stride_kd, stride_vn, stride_vd,
428
+ kv_indices, sparse_kv_num_blocks,
429
+ MATMUL_PRECISION, RCP_LN2,
430
+ IS_FULL_BLOCKS,
431
+ ):
432
+ PRESCALE_QK : tl.constexpr = False
433
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
434
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
435
+ WRITE_DQ : tl.constexpr = True
436
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
437
+ OUTPUT_MAX : tl.constexpr = False
438
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
439
+ IS_DIVISIBLE : tl.constexpr = True
440
+ SM_SCALE : tl.constexpr = 0.08838834764831843
441
+ GQA_SHARED_HEADS : tl.constexpr = 4
442
+ HAS_FULL_BLOCKS : tl.constexpr = True
443
+ QK_HEAD_DIM : tl.constexpr = 128
444
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
445
+ V_HEAD_DIM : tl.constexpr = 128
446
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
447
+ SAFE_HEAD_DIM : tl.constexpr = True
448
+ BLOCK_M1 : tl.constexpr = 64
449
+ BLOCK_N1 : tl.constexpr = 128
450
+ BLOCK_M2 : tl.constexpr = 128
451
+ BLOCK_N2 : tl.constexpr = 64
452
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
453
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
454
+ INDEX_DTYPE : tl.constexpr = tl.int32
455
+
456
+
457
+ # NB reversed order to since K is transposed
458
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
459
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
460
+ if not PRESCALE_QK:
461
+ qk *= SM_SCALE
462
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
463
+ pre_mod_scores = qk
464
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
465
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
466
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
467
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
468
+
469
+ tmp0 = (qk)
470
+ post_mod_scores = tmp0
471
+
472
+
473
+
474
+
475
+ if not IS_DIVISIBLE:
476
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
477
+
478
+ if not IS_FULL_BLOCKS:
479
+ tmp1 = tl.full([1], False, tl.int1)
480
+ tmp2 = (m)
481
+ tmp3 = (n)
482
+ tmp4 = tmp2 >= tmp3
483
+ tmp5 = tmp3.to(tl.int64)
484
+ tmp6 = (off_z)
485
+ tmp7 = tl.load(in_ptr16 + tmp6)
486
+ tmp8 = tmp5 < tmp7
487
+ tmp9 = tmp2.to(tl.int64)
488
+ tmp10 = tmp9 < tmp7
489
+ tmp11 = tmp8 & tmp10
490
+ tmp12 = tmp4 & tmp11
491
+ tmp13 = tmp1 | tmp12
492
+ tmp14 = tl.full([1], 2048, tl.int32)
493
+ tmp15 = tmp3 >= tmp14
494
+ tmp16 = (tmp3 % tmp14)
495
+ tmp17 = tl.full([1], 0, tl.int32)
496
+ tmp18 = tmp16 != tmp17
497
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
498
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
499
+ tmp21 = tmp19 != tmp20
500
+ tmp22 = tmp18 & tmp21
501
+ tmp23 = tmp16 + tmp14
502
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
503
+ tmp25 = tmp24.to(tl.int64)
504
+ tmp26 = tmp25 < tmp7
505
+ tmp27 = tmp15 & tmp26
506
+ tmp28 = tmp3 - tmp2
507
+ tmp29 = (tmp28 % tmp14)
508
+ tmp30 = tmp29 != tmp17
509
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
510
+ tmp32 = tmp31 != tmp20
511
+ tmp33 = tmp30 & tmp32
512
+ tmp34 = tmp29 + tmp14
513
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
514
+ tmp36 = tmp35 == tmp17
515
+ tmp37 = tmp27 & tmp36
516
+ tmp38 = tmp13 | tmp37
517
+ mask_mod_output = tmp38
518
+
519
+
520
+ # apply mask for partial masked block
521
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
522
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
523
+ if not PRESCALE_QK:
524
+ post_mod_scores *= RCP_LN2
525
+ p = tl.math.exp2(post_mod_scores - lse)
526
+ # Compute dP and dS.
527
+ # NB reversed order to since V is transposed
528
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
529
+
530
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
531
+ ds = p * (dp - Di[:, None])
532
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
533
+ tmp39 = (ds)
534
+ grad_scores = tmp39
535
+
536
+
537
+ if not IS_DIVISIBLE:
538
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
539
+
540
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
541
+ if WRITE_DQ:
542
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
543
+
544
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
545
+ ds = grad_scores
546
+
547
+ if not IS_FULL_BLOCKS:
548
+ # (grads) apply mask for partially unmasked block
549
+ ds = tl.where(mask_mod_output, ds, 0.0)
550
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
551
+ ds = ds.to(MATMUL_PRECISION)
552
+ # Compute dQ.
553
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
554
+
555
+ return dq
556
+
557
+
558
+ @triton.jit
559
+ def bwd_dkdv_inner(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
561
+ Q, DO, DELTA, LSE, # pointers
562
+ dk, dv, k, v,
563
+ off_z, off_hq, offs_n1, offs_m1,
564
+ stride_qm, stride_qd, stride_dom, stride_dod,
565
+ q_indices, sparse_q_num_blocks,
566
+ MATMUL_PRECISION,
567
+ IS_FULL_BLOCKS,
568
+ ):
569
+ PRESCALE_QK : tl.constexpr = False
570
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
571
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
572
+ WRITE_DQ : tl.constexpr = True
573
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
574
+ OUTPUT_MAX : tl.constexpr = False
575
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
576
+ IS_DIVISIBLE : tl.constexpr = True
577
+ SM_SCALE : tl.constexpr = 0.08838834764831843
578
+ GQA_SHARED_HEADS : tl.constexpr = 4
579
+ HAS_FULL_BLOCKS : tl.constexpr = True
580
+ QK_HEAD_DIM : tl.constexpr = 128
581
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
582
+ V_HEAD_DIM : tl.constexpr = 128
583
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
584
+ SAFE_HEAD_DIM : tl.constexpr = True
585
+ BLOCK_M1 : tl.constexpr = 64
586
+ BLOCK_N1 : tl.constexpr = 128
587
+ BLOCK_M2 : tl.constexpr = 128
588
+ BLOCK_N2 : tl.constexpr = 64
589
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
590
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
591
+ INDEX_DTYPE : tl.constexpr = tl.int32
592
+
593
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
594
+ RCP_LN2: tl.constexpr = 1.44269504
595
+ Q_LEN = 2048
596
+ KV_LEN = 2048
597
+
598
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
599
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
600
+
601
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
602
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
603
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
604
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
605
+
606
+ # The minimum is needed to handle the case where we run with a super large
607
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
608
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
609
+
610
+ for start_m in range(0, hi):
611
+ dk, dv = bwd_dkdv_block_mn(
612
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
613
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
614
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
615
+ stride_qm, stride_qd, stride_dom, stride_dod,
616
+ q_indices, sparse_q_num_blocks,
617
+ MATMUL_PRECISION, RCP_LN2,
618
+ IS_FULL_BLOCKS,
619
+ )
620
+ # Increment pointers.
621
+ offset = get_offset_for_next_block(
622
+ start_m, q_indices, sparse_q_num_blocks,
623
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
624
+ )
625
+
626
+ qT_ptrs += offset * stride_qm
627
+ do_ptrs += offset * stride_dom
628
+ offs_m1 += offset
629
+
630
+ return dk, dv
631
+
632
+
633
+ @triton.jit
634
+ def bwd_dkdv_block_mn(
635
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
636
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
637
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
638
+ stride_qm, stride_qd, stride_dom, stride_dod,
639
+ q_indices, sparse_q_num_blocks,
640
+ MATMUL_PRECISION, RCP_LN2,
641
+ IS_FULL_BLOCKS,
642
+ ):
643
+ PRESCALE_QK : tl.constexpr = False
644
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
645
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
646
+ WRITE_DQ : tl.constexpr = True
647
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
648
+ OUTPUT_MAX : tl.constexpr = False
649
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
650
+ IS_DIVISIBLE : tl.constexpr = True
651
+ SM_SCALE : tl.constexpr = 0.08838834764831843
652
+ GQA_SHARED_HEADS : tl.constexpr = 4
653
+ HAS_FULL_BLOCKS : tl.constexpr = True
654
+ QK_HEAD_DIM : tl.constexpr = 128
655
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
656
+ V_HEAD_DIM : tl.constexpr = 128
657
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
658
+ SAFE_HEAD_DIM : tl.constexpr = True
659
+ BLOCK_M1 : tl.constexpr = 64
660
+ BLOCK_N1 : tl.constexpr = 128
661
+ BLOCK_M2 : tl.constexpr = 128
662
+ BLOCK_N2 : tl.constexpr = 64
663
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
664
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
665
+ INDEX_DTYPE : tl.constexpr = tl.int32
666
+
667
+
668
+ # NB reversed order since Q is transposed
669
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
670
+ # Load LSE before computing qk to reduce pipeline stall.
671
+ if IS_DIVISIBLE:
672
+ lse = tl.load(LSE + offs_m1)
673
+ else:
674
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
675
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
676
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
677
+ if not PRESCALE_QK:
678
+ qkT *= SM_SCALE
679
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
680
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
681
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
682
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
683
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
684
+
685
+ pre_mod_scores = qkT
686
+ tmp40 = (qkT)
687
+ post_mod_scores = tmp40
688
+
689
+
690
+
691
+ if not IS_DIVISIBLE:
692
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
693
+
694
+ if not IS_FULL_BLOCKS:
695
+ tmp41 = tl.full([1], False, tl.int1)
696
+ tmp42 = (m)
697
+ tmp43 = (n)
698
+ tmp44 = tmp42 >= tmp43
699
+ tmp45 = tmp43.to(tl.int64)
700
+ tmp46 = (off_z)
701
+ tmp47 = tl.load(in_ptr16 + tmp46)
702
+ tmp48 = tmp45 < tmp47
703
+ tmp49 = tmp42.to(tl.int64)
704
+ tmp50 = tmp49 < tmp47
705
+ tmp51 = tmp48 & tmp50
706
+ tmp52 = tmp44 & tmp51
707
+ tmp53 = tmp41 | tmp52
708
+ tmp54 = tl.full([1], 2048, tl.int32)
709
+ tmp55 = tmp43 >= tmp54
710
+ tmp56 = (tmp43 % tmp54)
711
+ tmp57 = tl.full([1], 0, tl.int32)
712
+ tmp58 = tmp56 != tmp57
713
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
714
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
715
+ tmp61 = tmp59 != tmp60
716
+ tmp62 = tmp58 & tmp61
717
+ tmp63 = tmp56 + tmp54
718
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
719
+ tmp65 = tmp64.to(tl.int64)
720
+ tmp66 = tmp65 < tmp47
721
+ tmp67 = tmp55 & tmp66
722
+ tmp68 = tmp43 - tmp42
723
+ tmp69 = (tmp68 % tmp54)
724
+ tmp70 = tmp69 != tmp57
725
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
726
+ tmp72 = tmp71 != tmp60
727
+ tmp73 = tmp70 & tmp72
728
+ tmp74 = tmp69 + tmp54
729
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
730
+ tmp76 = tmp75 == tmp57
731
+ tmp77 = tmp67 & tmp76
732
+ tmp78 = tmp53 | tmp77
733
+ mask_mod_output = tmp78
734
+
735
+ # (grads) apply mask for fully masked block
736
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
737
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
738
+ if not PRESCALE_QK:
739
+ post_mod_scores *= RCP_LN2
740
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
741
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
742
+ # Compute dV.
743
+ ppT = pT
744
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
745
+ if IS_DIVISIBLE:
746
+ Di = tl.load(DELTA + offs_m1)
747
+ else:
748
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
749
+ # Compute dP and dS.
750
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
751
+ dsT = pT * (dpT - Di[None, :])
752
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
753
+ tmp79 = (dsT)
754
+ grad_scores = tmp79
755
+
756
+
757
+
758
+ if not IS_DIVISIBLE:
759
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
760
+
761
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
762
+ if not WRITE_DQ:
763
+ idx_b = off_z
764
+ idx_h = off_hq
765
+ idx_m = m
766
+ idx_n = n
767
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
768
+
769
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
770
+ dsT = grad_scores
771
+ if not IS_FULL_BLOCKS:
772
+ # (grads) apply mask for partially unmasked block
773
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
774
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
775
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
776
+
777
+ return dk, dv
778
+
779
+ # Utility triton funcs
780
+ @triton.jit
781
+ def get_offset_for_next_block(
782
+ loop_iter, col_indices, total_blocks,
783
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
784
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
785
+ ):
786
+ if BLOCKS_ARE_CONTIGUOUS:
787
+ return BLOCK
788
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
789
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
790
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
791
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
792
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
793
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
794
+ return offset
795
+
796
+ @triton.jit
797
+ def get_bounded_indices(indices, max_len=None):
798
+ return indices % max_len if max_len is not None else indices
799
+
800
+ @triton.jit
801
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
802
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
803
+ return tl.load(block_ptr)
804
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
805
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
806
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
807
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
808
+ else:
809
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
810
+
811
+ @triton.jit
812
+ def load_checked_2d(
813
+ ptr,
814
+ offs_m,
815
+ offs_n,
816
+ stride_m,
817
+ stride_n,
818
+ IS_DIVISIBLE_M: tl.constexpr,
819
+ IS_DIVISIBLE_N: tl.constexpr,
820
+ M_LEN: tl.constexpr,
821
+ N_LEN: tl.constexpr,
822
+ ):
823
+ # Calculate final pointer if strides are provided
824
+ if stride_m is not None and stride_n is not None:
825
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
826
+
827
+ # Handle all masking cases
828
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
829
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
830
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
831
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
832
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
833
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
834
+ else: # Both divisible
835
+ return tl.load(ptr)
SpecForge-ext/cache/compiled_kernels/dm/cdmv6ytwvbipl4lagbifdkedszdjny3opgqlnricedg4hfpkxbdo.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(1,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ r0_numel = 32000
20
+ rnumel = r0_numel
21
+ RBLOCK: tl.constexpr = R0_BLOCK
22
+ xoffset = tl.program_id(0) * XBLOCK
23
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
24
+ xmask = xindex < xnumel
25
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
26
+ rbase = r0_base
27
+ x0 = (xindex % ks0)
28
+ x1 = xindex // ks0
29
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
30
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
31
+ x3 = xindex
32
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
33
+ r0_index = r0_offset + r0_base
34
+ r0_mask = r0_index < r0_numel
35
+ roffset = r0_offset
36
+ rindex = r0_index
37
+ r0_2 = r0_index
38
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + ks1*x1), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
39
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
40
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
41
+ _tmp2, _tmp2_index, tmp1, rindex
42
+ )
43
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
44
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
45
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
46
+ tmp2 = tmp2_idx[:, None]
47
+ tl.store(out_ptr0 + (x3), tmp2, xmask)
SpecForge-ext/cache/compiled_kernels/du/cduexexwzoejgfo3kafnuhcdb2jpdj5mqnwnijlnqydzf2tfuyoh.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['12_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/xv/cxvkzhyhna2alntgjzwfekekacjtshos257zi4b5b75eycps5xaj.py
38
+ # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros]
39
+ # Source node to ATen node mapping:
40
+ # dense_mask_2 => full_default_1
41
+ # Graph fragment:
42
+ # %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False})
43
+ # return %index_put
44
+ triton_poi_fused_new_zeros_0 = async_compile.triton('triton_poi_fused_new_zeros_0', '''
45
+ import triton
46
+ import triton.language as tl
47
+
48
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
49
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
50
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
51
+ triton_helpers.set_driver_to_gpu()
52
+
53
+ @triton_heuristics.pointwise(
54
+ size_hints={'x': 512},
55
+ filename=__file__,
56
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]]}]},
57
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
58
+ min_elem_per_thread=0
59
+ )
60
+ @triton.jit
61
+ def triton_poi_fused_new_zeros_0(out_ptr0, xnumel, XBLOCK : tl.constexpr):
62
+ xoffset = tl.program_id(0) * XBLOCK
63
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
64
+ xmask = xindex < xnumel
65
+ x0 = xindex
66
+ tmp0 = tl.full([1], 0, tl.int32)
67
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
68
+ ''', device_str='cuda')
69
+
70
+
71
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hk/chkcqajju5cxlzspkm5ffg5s3lyxlimjbubz6ljwtrx23yf6fnkc.py
72
+ # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy]
73
+ # Source node to ATen node mapping:
74
+ # and_2 => bitwise_and_1
75
+ # and_3 => bitwise_and_2
76
+ # and_4 => bitwise_and_3, view_8
77
+ # b => iota
78
+ # batched_outputs_2 => view_9
79
+ # causal_mask => ge_2, view
80
+ # dense_mask => convert_element_type_2
81
+ # dense_mask_1 => convert_element_type_5
82
+ # diagnol_mask => eq_24
83
+ # full_blocks => eq_45
84
+ # full_blocks_1 => convert_element_type_1
85
+ # gt => gt
86
+ # index => index
87
+ # index_1 => index_1
88
+ # index_2 => index_2
89
+ # lt => lt, view_1
90
+ # lt_1 => lt_1, view_2
91
+ # lt_3 => lt_3
92
+ # m => iota_2
93
+ # mask_1 => constant_pad_nd
94
+ # mask_2 => view_10
95
+ # mask_3 => permute
96
+ # mask_block_sum => sum_1
97
+ # n => iota_3
98
+ # padding_mask => bitwise_and, view_3, view_4
99
+ # padding_mask_1 => lt_2, view_6
100
+ # partial_blocks => bitwise_and_4
101
+ # partial_blocks_1 => convert_element_type
102
+ # remainder => remainder
103
+ # remainder_1 => remainder_1
104
+ # result_1 => bitwise_or, full_default
105
+ # result_2 => bitwise_or_1
106
+ # sub => sub_24, view_7
107
+ # suffix_mask => ge_3
108
+ # Graph fragment:
109
+ # %arg2_1 : Tensor "i64[2][1]cuda:5" = PlaceHolder[target=arg2_1]
110
+ # %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:5" = PlaceHolder[target=sum_1]
111
+ # %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:5, pin_memory: False})
112
+ # %iota_2 : Tensor "i64[s12][1]cuda:5"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (%arg0_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False})
113
+ # %view : Tensor "i64[s12, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {})
114
+ # %iota_3 : Tensor "i64[s37][1]cuda:5"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (%arg1_1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False})
115
+ # %ge_2 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {})
116
+ # %iota : Tensor "i64[2][1]cuda:5"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False})
117
+ # %index : Tensor "i64[2][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {})
118
+ # %view_1 : Tensor "i64[2, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {})
119
+ # %lt : Tensor "b8[2, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {})
120
+ # %view_4 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, %arg1_1]), kwargs = {})
121
+ # %index_1 : Tensor "i64[2][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {})
122
+ # %view_2 : Tensor "i64[2, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {})
123
+ # %lt_1 : Tensor "b8[2, s12][Max(1, s12), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {})
124
+ # %view_3 : Tensor "b8[2, s12, 1][Max(1, s12), 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, %arg0_1, 1]), kwargs = {})
125
+ # %bitwise_and : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {})
126
+ # %bitwise_and_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_2, %bitwise_and), kwargs = {})
127
+ # %bitwise_or : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {})
128
+ # %ge_3 : Tensor "b8[s37][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, %arg3_1), kwargs = {})
129
+ # %remainder : Tensor "i64[s37][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, %arg3_1), kwargs = {})
130
+ # %index_2 : Tensor "i64[2][1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%iota]), kwargs = {})
131
+ # %view_6 : Tensor "i64[2, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {})
132
+ # %lt_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {})
133
+ # %bitwise_and_2 : Tensor "b8[2, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_3, %lt_2), kwargs = {})
134
+ # %view_8 : Tensor "b8[2, 1, s37][Max(1, s37), s37, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, %arg1_1]), kwargs = {})
135
+ # %view_7 : Tensor "i64[s12, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [%arg0_1, 1]), kwargs = {})
136
+ # %sub_24 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {})
137
+ # %remainder_1 : Tensor "i64[s12, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub_24, %arg3_1), kwargs = {})
138
+ # %eq_24 : Tensor "b8[s12, s37][Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {})
139
+ # %bitwise_and_3 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq_24), kwargs = {})
140
+ # %bitwise_or_1 : Tensor "b8[2, s12, s37][Max(1, s12)*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {})
141
+ # %view_9 : Tensor "b8[2, 1, s12, s37][Max(1, s12)*Max(1, s37), s12*Max(1, s37), Max(1, s37), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, %arg0_1, %arg1_1]), kwargs = {})
142
+ # %constant_pad_nd : Tensor "b8[2, 1, 128*(((s12 + 127)//128)), 128*(((s37 + 127)//128))][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.constant_pad_nd.default](args = (%expand, [0, %sub_42, 0, %sub_44], 0.0), kwargs = {})
143
+ # %view_10 : Tensor "b8[2, 1, ((s12 + 127)//128), 128, ((s37 + 127)//128), 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s37 + 127)//128))), 128, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%constant_pad_nd, [2, 1, %floordiv_3, 128, %floordiv_2, 128]), kwargs = {})
144
+ # %permute : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128), 128, 128][Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), Max(1, 128*(((s12 + 127)//128)))*Max(1, 128*(((s37 + 127)//128))), 128*Max(1, 128*(((s37 + 127)//128))), 128, Max(1, 128*(((s37 + 127)//128))), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {})
145
+ # %sum_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {})
146
+ # %gt : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
147
+ # %lt_3 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {})
148
+ # %bitwise_and_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {})
149
+ # %convert_element_type : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {})
150
+ # %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {})
151
+ # %eq_45 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {})
152
+ # %convert_element_type_1 : Tensor "i8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_45, torch.int8), kwargs = {})
153
+ # %convert_element_type_5 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {})
154
+ # return %sum_1,%convert_element_type_2,%convert_element_type_5
155
+ triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1 = async_compile.triton('triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', '''
156
+ import triton
157
+ import triton.language as tl
158
+
159
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
160
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
161
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
162
+ triton_helpers.set_driver_to_gpu()
163
+
164
+ @triton_heuristics.reduction(
165
+ size_hints={'x': 512, 'r0_': 16384},
166
+ reduction_hint=ReductionHint.INNER,
167
+ filename=__file__,
168
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr1': '*i32', 'out_ptr2': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'ks5': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
169
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
170
+ )
171
+ @triton.jit
172
+ def triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1(in_ptr0, out_ptr1, out_ptr2, ks0, ks1, ks2, ks3, ks4, ks5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
173
+ r0_numel = 16384
174
+ rnumel = r0_numel
175
+ RBLOCK: tl.constexpr = R0_BLOCK
176
+ xoffset = tl.program_id(0) * XBLOCK
177
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
178
+ xmask = xindex < xnumel
179
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
180
+ rbase = r0_base
181
+ x1 = ((xindex // ks0) % ks1)
182
+ x0 = (xindex % ks0)
183
+ x2 = xindex // ks4
184
+ _tmp46 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
185
+ x5 = xindex
186
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
187
+ r0_index = r0_offset + r0_base
188
+ r0_mask = r0_index < r0_numel
189
+ roffset = r0_offset
190
+ rindex = r0_index
191
+ r0_4 = r0_index // 128
192
+ r0_3 = (r0_index % 128)
193
+ tmp0 = r0_4 + 128*x1
194
+ tmp1 = ks2
195
+ tmp2 = tmp0 < tmp1
196
+ tmp3 = r0_3 + 128*x0
197
+ tmp4 = ks3
198
+ tmp5 = tmp3 < tmp4
199
+ tmp6 = tmp2 & tmp5
200
+ tmp7 = r0_4 + 128*x1
201
+ tmp8 = r0_3 + 128*x0
202
+ tmp9 = tmp7 >= tmp8
203
+ tmp10 = tl.load(in_ptr0 + (tl.broadcast_to(x2, [XBLOCK, R0_BLOCK])), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0)
204
+ tmp11 = tmp8 < tmp10
205
+ tmp12 = tmp7 < tmp10
206
+ tmp13 = tmp11 & tmp12
207
+ tmp14 = tmp9 & tmp13
208
+ tmp15 = tl.full([1, 1], False, tl.int1)
209
+ tmp16 = tmp15 | tmp14
210
+ tmp17 = tl.broadcast_to(ks5, [XBLOCK, R0_BLOCK])
211
+ tmp18 = tmp8 >= tmp17
212
+ tmp19 = (tmp8 % tmp17)
213
+ tmp20 = tl.full([1, 1], 0, tl.int32)
214
+ tmp21 = tmp19 != tmp20
215
+ tmp22 = (libdevice.signbit(tmp19) != 0) if (tmp19).dtype is tl.float32 else tmp19 < 0
216
+ tmp23 = (libdevice.signbit(tmp17) != 0) if (tmp17).dtype is tl.float32 else tmp17 < 0
217
+ tmp24 = tmp22 != tmp23
218
+ tmp25 = tmp21 & tmp24
219
+ tmp26 = tmp19 + tmp17
220
+ tmp27 = tl.where(tmp25, tmp26, tmp19)
221
+ tmp28 = tmp27 < tmp10
222
+ tmp29 = tmp18 & tmp28
223
+ tmp30 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
224
+ tmp31 = (tmp30 % tmp17)
225
+ tmp32 = tmp31 != tmp20
226
+ tmp33 = (libdevice.signbit(tmp31) != 0) if (tmp31).dtype is tl.float32 else tmp31 < 0
227
+ tmp34 = tmp33 != tmp23
228
+ tmp35 = tmp32 & tmp34
229
+ tmp36 = tmp31 + tmp17
230
+ tmp37 = tl.where(tmp35, tmp36, tmp31)
231
+ tmp38 = tl.full([1, 1], 0, tl.int64)
232
+ tmp39 = tmp37 == tmp38
233
+ tmp40 = tmp29 & tmp39
234
+ tmp41 = tmp16 | tmp40
235
+ tmp42 = tl.full(tmp41.shape, False, tmp41.dtype)
236
+ tmp43 = tl.where(tmp6, tmp41, tmp42)
237
+ tmp44 = tmp43.to(tl.int64)
238
+ tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK])
239
+ tmp47 = _tmp46 + tmp45
240
+ _tmp46 = tl.where(r0_mask & xmask, tmp47, _tmp46)
241
+ tmp46 = tl.sum(_tmp46, 1)[:, None]
242
+ tmp48 = tl.full([1, 1], 0, tl.int64)
243
+ tmp49 = tmp46 > tmp48
244
+ tmp50 = tl.full([1, 1], 16384, tl.int64)
245
+ tmp51 = tmp46 < tmp50
246
+ tmp52 = tmp49 & tmp51
247
+ tmp53 = tmp52.to(tl.int8)
248
+ tmp54 = tmp53.to(tl.int32)
249
+ tmp55 = tmp46 == tmp50
250
+ tmp56 = tmp55.to(tl.int8)
251
+ tmp57 = tmp56.to(tl.int32)
252
+ tl.store(out_ptr1 + (x5), tmp54, xmask)
253
+ tl.store(out_ptr2 + (x5), tmp57, xmask)
254
+ ''', device_str='cuda')
255
+
256
+
257
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/34/c34af36gfqnn2ovywuaultc2pol4jyn6io3szgjeuv3uxfzcf3nv.py
258
+ # Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy]
259
+ # Source node to ATen node mapping:
260
+ # child_3 => convert_element_type_3
261
+ # num_blocks_in_row => sum_2
262
+ # Graph fragment:
263
+ # %convert_element_type_2 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 2*(((s12 + 127)//128))*(((s37 + 127)//128)), ((s37 + 127)//128), 1]cuda:5" = PlaceHolder[target=convert_element_type_2]
264
+ # %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][((s12 + 127)//128), 2*(((s12 + 127)//128)), 1]cuda:5" = PlaceHolder[target=sum_2]
265
+ # %sum_2 : Tensor "i64[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {})
266
+ # %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {})
267
+ # return %sum_2,%convert_element_type_3
268
+ triton_red_fused__to_copy_sum_2 = async_compile.triton('triton_red_fused__to_copy_sum_2', '''
269
+ import triton
270
+ import triton.language as tl
271
+
272
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
273
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
274
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
275
+ triton_helpers.set_driver_to_gpu()
276
+
277
+ @triton_heuristics.reduction(
278
+ size_hints={'x': 32, 'r0_': 16},
279
+ reduction_hint=ReductionHint.INNER,
280
+ filename=__file__,
281
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
282
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
283
+ )
284
+ @triton.jit
285
+ def triton_red_fused__to_copy_sum_2(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
286
+ rnumel = r0_numel
287
+ RBLOCK: tl.constexpr = R0_BLOCK
288
+ xoffset = tl.program_id(0) * XBLOCK
289
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
290
+ xmask = xindex < xnumel
291
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
292
+ rbase = r0_base
293
+ x0 = xindex
294
+ _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
295
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
296
+ r0_index = r0_offset + r0_base
297
+ r0_mask = r0_index < r0_numel
298
+ roffset = r0_offset
299
+ rindex = r0_index
300
+ r0_1 = r0_index
301
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
302
+ tmp1 = tmp0.to(tl.int64)
303
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
304
+ tmp4 = _tmp3 + tmp2
305
+ _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3)
306
+ tmp3 = tl.sum(_tmp3, 1)[:, None]
307
+ x2 = (xindex % ks1)
308
+ x3 = xindex // ks1
309
+ tmp5 = tmp3.to(tl.int32)
310
+ tl.store(out_ptr1 + (x2 + x3*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp5, xmask)
311
+ ''', device_str='cuda')
312
+
313
+
314
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/7h/c7hhwzbu42q2ic55mujfpppabs5ube44ahuppbgjh35eanxqzare.py
315
+ # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
316
+ # Source node to ATen node mapping:
317
+ # arange_4 => iota_4
318
+ # child_4 => convert_element_type_4
319
+ # col_range => iota_5
320
+ # dense_mask_2 => full_default_1
321
+ # index_mask => lt_4
322
+ # row_indices => unsqueeze
323
+ # setitem => full_default_2, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6
324
+ # unsqueeze_1 => unsqueeze_1
325
+ # valid_indices => scalar_tensor, where
326
+ # Graph fragment:
327
+ # %getitem_1 : Tensor "i64[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 2*Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5" = PlaceHolder[target=getitem_1]
328
+ # %convert_element_type_3 : Tensor "i32[2, 1, ((s12 + 127)//128)][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:5" = PlaceHolder[target=convert_element_type_3]
329
+ # %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5" = PlaceHolder[target=convert_element_type_4]
330
+ # %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:5" = PlaceHolder[target=index_put]
331
+ # %full_default_1 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, %floordiv_3, %add_201], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False})
332
+ # %iota_7 : Tensor "i64[2][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False})
333
+ # %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {})
334
+ # %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {})
335
+ # %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {})
336
+ # %iota_6 : Tensor "i64[1][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:5, requires_grad: False})
337
+ # %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {})
338
+ # %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {})
339
+ # %iota_4 : Tensor "i32[((s12 + 127)//128)][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_3,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:5, requires_grad: False})
340
+ # %unsqueeze : Tensor "i32[((s12 + 127)//128), 1][1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {})
341
+ # %iota_5 : Tensor "i32[((s37 + 127)//128)][1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (%floordiv_2,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:5, requires_grad: False})
342
+ # %unsqueeze_1 : Tensor "i32[2, 1, ((s12 + 127)//128), 1][Max(1, ((s12 + 127)//128)), Max(1, ((s12 + 127)//128)), 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {})
343
+ # %lt_4 : Tensor "b8[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {})
344
+ # %convert_element_type_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {})
345
+ # %scalar_tensor : Tensor "i32[][]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%floordiv_2,), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5})
346
+ # %where : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %scalar_tensor), kwargs = {})
347
+ # %full_default_2 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:5, pin_memory: False})
348
+ # %index_put : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_2), kwargs = {})
349
+ # return %convert_element_type_4,%buf13
350
+ triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3 = async_compile.triton('triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', '''
351
+ import triton
352
+ import triton.language as tl
353
+
354
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
355
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
356
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
357
+ triton_helpers.set_driver_to_gpu()
358
+
359
+ @triton_heuristics.pointwise(
360
+ size_hints={'x': 512},
361
+ filename=__file__,
362
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i32', 'out_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
363
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
364
+ min_elem_per_thread=0
365
+ )
366
+ @triton.jit
367
+ def triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3(in_ptr0, in_ptr1, out_ptr0, out_ptr1, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
368
+ xoffset = tl.program_id(0) * XBLOCK
369
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
370
+ xmask = xindex < xnumel
371
+ x0 = (xindex % ks0)
372
+ x1 = ((xindex // ks0) % ks1)
373
+ x2 = xindex // ks2
374
+ x3 = xindex // ks0
375
+ tmp0 = tl.load(in_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last')
376
+ tmp2 = tl.load(in_ptr1 + (x3), xmask, eviction_policy='evict_last')
377
+ tmp1 = tmp0.to(tl.int32)
378
+ tmp3 = x0
379
+ tmp4 = tmp3 < tmp2
380
+ tmp5 = ks0
381
+ tmp6 = tl.where(tmp4, tmp1, tmp5)
382
+ tmp7 = 1 + ks0
383
+ tmp8 = tmp6 + tmp7
384
+ tmp9 = tmp6 < 0
385
+ tmp10 = tl.where(tmp9, tmp8, tmp6)
386
+ tl.device_assert(((0 <= tmp10) & (tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128)))) | ~(xmask), "index out of bounds: 0 <= tmp10 < 1 + (triton_helpers.div_floor_integer(127 + ks3, 128))")
387
+ tmp12 = tl.full([1], 1, tl.int32)
388
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask)
389
+ tl.store(out_ptr1 + (tmp10 + x3 + ks0*x3), tmp12, xmask)
390
+ ''', device_str='cuda')
391
+
392
+
393
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/is/cisbwn452kdvm56u75a2mwmrdzns6w4vxzuweva24qshuv4gksv2.py
394
+ # Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone]
395
+ # Source node to ATen node mapping:
396
+ # batched_outputs_3 => clone_4, slice_4
397
+ # Graph fragment:
398
+ # %buf13 : Tensor "i32[2, 1, ((s12 + 127)//128), (((s37 + 127)//128)) + 1][((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), ((((s37 + 127)//128)) + 1)*(((s12 + 127)//128)), (((s37 + 127)//128)) + 1, 1]cuda:5" = PlaceHolder[target=buf13]
399
+ # %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {})
400
+ # %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format})
401
+ # return %clone_4
402
+ triton_poi_fused_clone_slice_4 = async_compile.triton('triton_poi_fused_clone_slice_4', '''
403
+ import triton
404
+ import triton.language as tl
405
+
406
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
407
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
408
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
409
+ triton_helpers.set_driver_to_gpu()
410
+
411
+ @triton_heuristics.pointwise(
412
+ size_hints={'x': 512},
413
+ filename=__file__,
414
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
415
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
416
+ min_elem_per_thread=0
417
+ )
418
+ @triton.jit
419
+ def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
420
+ xoffset = tl.program_id(0) * XBLOCK
421
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
422
+ xmask = xindex < xnumel
423
+ x0 = (xindex % ks0)
424
+ x1 = xindex // ks0
425
+ x2 = xindex
426
+ tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last')
427
+ tl.store(out_ptr0 + (x2), tmp0, xmask)
428
+ ''', device_str='cuda')
429
+
430
+
431
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hr/chrxoz3s6dcccbxa4bhegahvxtofkt5hvfz7hdrybtpjo4ffso64.py
432
+ # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy]
433
+ # Source node to ATen node mapping:
434
+ # batched_outputs_3 => clone_4, slice_4
435
+ # num_blocks_in_row_2 => sum_4
436
+ # q_num_blocks => convert_element_type_8
437
+ # transpose => permute_1
438
+ # Graph fragment:
439
+ # %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][(((s12 + 127)//128))*(((s37 + 127)//128)), 1, ((s37 + 127)//128), 1]cuda:5" = PlaceHolder[target=clone_4]
440
+ # %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][((s37 + 127)//128), 2*(((s37 + 127)//128)), 1]cuda:5" = PlaceHolder[target=sum_4]
441
+ # %slice_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1)*Max(1, ((s12 + 127)//128)), Max(1, (((s37 + 127)//128)) + 1), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, %floordiv_2), kwargs = {})
442
+ # %clone_4 : Tensor "i32[2, 1, ((s12 + 127)//128), ((s37 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {memory_format: torch.contiguous_format})
443
+ # %permute_1 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:5"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {})
444
+ # %sum_4 : Tensor "i64[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {})
445
+ # %convert_element_type_8 : Tensor "i32[2, 1, ((s37 + 127)//128)][Max(1, ((s37 + 127)//128)), Max(1, ((s37 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {})
446
+ # return %sum_4,%convert_element_type_8
447
+ triton_red_fused__to_copy_clone_slice_sum_transpose_5 = async_compile.triton('triton_red_fused__to_copy_clone_slice_sum_transpose_5', '''
448
+ import triton
449
+ import triton.language as tl
450
+
451
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
452
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
453
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
454
+ triton_helpers.set_driver_to_gpu()
455
+
456
+ @triton_heuristics.reduction(
457
+ size_hints={'x': 32, 'r0_': 16},
458
+ reduction_hint=ReductionHint.DEFAULT,
459
+ filename=__file__,
460
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr1': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
461
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_clone_slice_sum_transpose_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
462
+ )
463
+ @triton.jit
464
+ def triton_red_fused__to_copy_clone_slice_sum_transpose_5(in_ptr0, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
465
+ rnumel = r0_numel
466
+ RBLOCK: tl.constexpr = R0_BLOCK
467
+ xoffset = tl.program_id(0) * XBLOCK
468
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
469
+ xmask = xindex < xnumel
470
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
471
+ rbase = r0_base
472
+ x0 = (xindex % ks0)
473
+ x1 = xindex // ks0
474
+ _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
475
+ x3 = xindex
476
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
477
+ r0_index = r0_offset + r0_base
478
+ r0_mask = r0_index < r0_numel
479
+ roffset = r0_offset
480
+ rindex = r0_index
481
+ r0_2 = r0_index
482
+ tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*ks1*x1), r0_mask & xmask, eviction_policy='evict_last', other=0.0)
483
+ tmp1 = tmp0.to(tl.int64)
484
+ tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
485
+ tmp4 = _tmp3 + tmp2
486
+ _tmp3 = tl.where(r0_mask & xmask, tmp4, _tmp3)
487
+ tmp3 = tl.sum(_tmp3, 1)[:, None]
488
+ tmp5 = tmp3.to(tl.int32)
489
+ tl.store(out_ptr1 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), tmp5, xmask)
490
+ ''', device_str='cuda')
491
+
492
+
493
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/g2/cg2ims2kqlrojffd3to6cjqkeah4zdvhgfoxnfrkinus4ddrvhhe.py
494
+ # Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy]
495
+ # Source node to ATen node mapping:
496
+ # q_indices => clone_6, convert_element_type_9
497
+ # Graph fragment:
498
+ # %getitem_5 : Tensor "i64[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:5" = PlaceHolder[target=getitem_5]
499
+ # %convert_element_type_9 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), 1, Max(1, ((s37 + 127)//128))]cuda:5"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {})
500
+ # %clone_6 : Tensor "i32[2, 1, ((s37 + 127)//128), ((s12 + 127)//128)][Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128))*Max(1, ((s37 + 127)//128)), Max(1, ((s12 + 127)//128)), 1]cuda:5"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format})
501
+ # return %clone_6
502
+ triton_poi_fused__to_copy_6 = async_compile.triton('triton_poi_fused__to_copy_6', '''
503
+ import triton
504
+ import triton.language as tl
505
+
506
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
507
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
508
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
509
+ triton_helpers.set_driver_to_gpu()
510
+
511
+ @triton_heuristics.pointwise(
512
+ size_hints={'x': 512},
513
+ filename=__file__,
514
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i32', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
515
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
516
+ min_elem_per_thread=0
517
+ )
518
+ @triton.jit
519
+ def triton_poi_fused__to_copy_6(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
520
+ xoffset = tl.program_id(0) * XBLOCK
521
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
522
+ xmask = xindex < xnumel
523
+ x0 = (xindex % ks0)
524
+ x1 = ((xindex // ks0) % ks1)
525
+ x2 = xindex // ks2
526
+ tmp0 = tl.load(in_ptr0 + (x1 + x0*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), xmask, eviction_policy='evict_last')
527
+ tmp1 = tmp0.to(tl.int32)
528
+ tl.store(out_ptr0 + (x0 + x1*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))) + x2*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))*((1) * ((1) >= (ks1)) + (ks1) * ((ks1) > (1)))), tmp1, xmask)
529
+ ''', device_str='cuda')
530
+
531
+
532
+ async_compile.wait(globals())
533
+ del async_compile
534
+
535
+ class Runner:
536
+ def __init__(self, partitions):
537
+ self.partitions = partitions
538
+
539
+ def recursively_apply_fns(self, fns):
540
+ new_callables = []
541
+ for fn, c in zip(fns, self.partitions):
542
+ new_callables.append(fn(c))
543
+ self.partitions = new_callables
544
+
545
+ def call(self, args):
546
+ arg0_1, arg1_1, arg2_1, arg3_1 = args
547
+ args.clear()
548
+ s12 = arg0_1
549
+ s37 = arg1_1
550
+ s21 = arg3_1
551
+ assert_size_stride(arg2_1, (2, ), (1, ))
552
+ with torch.cuda._DeviceGuard(5):
553
+ torch.cuda.set_device(5)
554
+ buf12 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32)
555
+ # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros]
556
+ triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128)
557
+ stream5 = get_raw_stream(5)
558
+ triton_poi_fused_new_zeros_0.run(buf12, triton_poi_fused_new_zeros_0_xnumel, stream=stream5)
559
+ buf21 = empty_strided_cuda((2, 1, (127 + s12) // 128, 1 + ((127 + s37) // 128)), (((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), ((127 + s12) // 128)*((127 + s37) // 128) + ((127 + s12) // 128), 1 + ((127 + s37) // 128), 1), torch.int32)
560
+ # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros]
561
+ triton_poi_fused_new_zeros_0_xnumel = 2*((127 + s12) // 128) + 2*((127 + s12) // 128)*((127 + s37) // 128)
562
+ stream5 = get_raw_stream(5)
563
+ triton_poi_fused_new_zeros_0.run(buf21, triton_poi_fused_new_zeros_0_xnumel, stream=stream5)
564
+ ps0 = (127 + s37) // 128
565
+ ps1 = (127 + s12) // 128
566
+ ps2 = ((127 + s12) // 128)*((127 + s37) // 128)
567
+ buf1 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32)
568
+ buf5 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 2*((127 + s12) // 128)*((127 + s37) // 128), (127 + s37) // 128, 1), torch.int32)
569
+ # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_1, mask_2, mask_3, mask_block_sum, gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, full_blocks, full_blocks_1, dense_mask_1], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.constant_pad_nd, aten.permute, aten.sum, aten.gt, aten._to_copy]
570
+ triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128)
571
+ stream5 = get_raw_stream(5)
572
+ triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1.run(arg2_1, buf1, buf5, ps0, ps1, s12, s37, ps2, s21, triton_red_fused__to_copy_arange_bitwise_and_bitwise_or_constant_pad_nd_eq_ge_gt_index_lt_permute_remainder_sub_sum_view_1_xnumel, 16384, stream=stream5)
573
+ del arg2_1
574
+ buf10 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32)
575
+ # Topologically Sorted Source Nodes: [num_blocks_in_row, child_3], Original ATen: [aten.sum, aten._to_copy]
576
+ triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128)
577
+ triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128
578
+ stream5 = get_raw_stream(5)
579
+ triton_red_fused__to_copy_sum_2.run(buf1, buf10, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream5)
580
+ buf19 = empty_strided_cuda((2, 1, (127 + s12) // 128), (max(1, (127 + s12) // 128), max(1, (127 + s12) // 128), 1), torch.int32)
581
+ # Topologically Sorted Source Nodes: [num_blocks_in_row_1, child_7], Original ATen: [aten.sum, aten._to_copy]
582
+ triton_red_fused__to_copy_sum_2_xnumel = 2*((127 + s12) // 128)
583
+ triton_red_fused__to_copy_sum_2_r0_numel = (127 + s37) // 128
584
+ stream5 = get_raw_stream(5)
585
+ triton_red_fused__to_copy_sum_2.run(buf5, buf19, ps0, ps1, triton_red_fused__to_copy_sum_2_xnumel, triton_red_fused__to_copy_sum_2_r0_numel, stream=stream5)
586
+ # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort]
587
+ buf2 = torch.ops.aten.sort.stable(buf1, stable=True, dim=3, descending=True)
588
+ del buf1
589
+ buf4 = buf2[1]
590
+ assert_size_stride(buf4, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable')
591
+ assert_alignment(buf4, 16, 'torch.ops.aten.sort.stable')
592
+ del buf2
593
+ buf11 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32)
594
+ # Topologically Sorted Source Nodes: [dense_mask_2, setitem, arange_4, row_indices, col_range, unsqueeze_1, index_mask, child_4, valid_indices], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
595
+ triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128)
596
+ stream5 = get_raw_stream(5)
597
+ triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf4, buf10, buf11, buf12, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream5)
598
+ del buf4
599
+ buf14 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32)
600
+ # Topologically Sorted Source Nodes: [batched_outputs_3], Original ATen: [aten.slice, aten.clone]
601
+ triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128)
602
+ stream5 = get_raw_stream(5)
603
+ triton_poi_fused_clone_slice_4.run(buf12, buf14, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream5)
604
+ del buf12
605
+ buf32 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32)
606
+ # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy]
607
+ triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128)
608
+ triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128
609
+ stream5 = get_raw_stream(5)
610
+ triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf14, buf32, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream5)
611
+ # Topologically Sorted Source Nodes: [full_blocks, full_blocks_1, dense_mask_1, col_indices_1], Original ATen: [aten.eq, aten._to_copy, aten.sort]
612
+ buf6 = torch.ops.aten.sort.stable(buf5, stable=True, dim=3, descending=True)
613
+ del buf5
614
+ buf8 = buf6[1]
615
+ assert_size_stride(buf8, (2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 2*max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), 'torch.ops.aten.sort.stable')
616
+ assert_alignment(buf8, 16, 'torch.ops.aten.sort.stable')
617
+ del buf6
618
+ buf20 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32)
619
+ # Topologically Sorted Source Nodes: [dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.new_zeros, aten.arange, aten.unsqueeze, aten.lt, aten._to_copy, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
620
+ triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128)
621
+ stream5 = get_raw_stream(5)
622
+ triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3.run(buf8, buf19, buf20, buf21, ps0, ps1, ps2, s37, triton_poi_fused__to_copy_arange_index_put_lt_new_zeros_scalar_tensor_unsqueeze_view_where_3_xnumel, stream=stream5)
623
+ del buf8
624
+ buf23 = empty_strided_cuda((2, 1, (127 + s12) // 128, (127 + s37) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 1, (127 + s37) // 128, 1), torch.int32)
625
+ # Topologically Sorted Source Nodes: [batched_outputs_5], Original ATen: [aten.slice, aten.clone]
626
+ triton_poi_fused_clone_slice_4_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128)
627
+ stream5 = get_raw_stream(5)
628
+ triton_poi_fused_clone_slice_4.run(buf21, buf23, ps0, triton_poi_fused_clone_slice_4_xnumel, stream=stream5)
629
+ del buf21
630
+ buf29 = empty_strided_cuda((2, 1, (127 + s37) // 128), (max(1, (127 + s37) // 128), max(1, (127 + s37) // 128), 1), torch.int32)
631
+ # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sum, aten._to_copy]
632
+ triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel = 2*((127 + s37) // 128)
633
+ triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel = (127 + s12) // 128
634
+ stream5 = get_raw_stream(5)
635
+ triton_red_fused__to_copy_clone_slice_sum_transpose_5.run(buf23, buf29, ps0, ps1, triton_red_fused__to_copy_clone_slice_sum_transpose_5_xnumel, triton_red_fused__to_copy_clone_slice_sum_transpose_5_r0_numel, stream=stream5)
636
+ # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort]
637
+ buf15 = torch.ops.aten.sort.stable(reinterpret_tensor(buf14, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True)
638
+ del buf14
639
+ buf17 = buf15[1]
640
+ assert_size_stride(buf17, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable')
641
+ assert_alignment(buf17, 16, 'torch.ops.aten.sort.stable')
642
+ del buf15
643
+ buf30 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32)
644
+ # Topologically Sorted Source Nodes: [q_indices], Original ATen: [aten._to_copy]
645
+ triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128)
646
+ stream5 = get_raw_stream(5)
647
+ triton_poi_fused__to_copy_6.run(buf17, buf30, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream5)
648
+ del buf17
649
+ # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort]
650
+ buf24 = torch.ops.aten.sort.stable(reinterpret_tensor(buf23, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (((127 + s12) // 128)*((127 + s37) // 128), 0, 1, (127 + s37) // 128), 0), stable=True, dim=3, descending=True)
651
+ del buf23
652
+ buf26 = buf24[1]
653
+ assert_size_stride(buf26, (2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), 1, max(1, (127 + s37) // 128)), 'torch.ops.aten.sort.stable')
654
+ assert_alignment(buf26, 16, 'torch.ops.aten.sort.stable')
655
+ del buf24
656
+ buf27 = empty_strided_cuda((2, 1, (127 + s37) // 128, (127 + s12) // 128), (max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128)*max(1, (127 + s37) // 128), max(1, (127 + s12) // 128), 1), torch.int32)
657
+ # Topologically Sorted Source Nodes: [full_q_indices], Original ATen: [aten._to_copy]
658
+ triton_poi_fused__to_copy_6_xnumel = 2*((127 + s12) // 128)*((127 + s37) // 128)
659
+ stream5 = get_raw_stream(5)
660
+ triton_poi_fused__to_copy_6.run(buf26, buf27, ps1, ps0, ps2, triton_poi_fused__to_copy_6_xnumel, stream=stream5)
661
+ del buf26
662
+ return (buf27, buf29, buf30, buf32, buf20, buf19, buf11, buf10, )
663
+
664
+ runner = Runner(partitions=[])
665
+ call = runner.call
666
+ recursively_apply_fns = runner.recursively_apply_fns
667
+
668
+
669
+ def benchmark_compiled_module(times=10, repeat=10):
670
+ from torch._dynamo.testing import rand_strided
671
+ from torch._inductor.utils import print_performance
672
+ arg0_1 = 1569
673
+ arg1_1 = 1569
674
+ arg2_1 = rand_strided((2, ), (1, ), device='cuda:5', dtype=torch.int64)
675
+ arg3_1 = 1569
676
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1])
677
+ return print_performance(fn, times=times, repeat=repeat)
678
+
679
+
680
+ if __name__ == "__main__":
681
+ from torch._inductor.wrapper_benchmark import compiled_module_main
682
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/dz/cdz3io7w5uyfrmfqvmg2kt2ay66qv4ckwtyurhik3frq7fqnk7gm.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 16777216},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = (xindex % ks0)
23
+ x3 = xindex
24
+ x1 = ((xindex // ks0) % ks1)
25
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
26
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
27
+ tmp0 = x0
28
+ tmp1 = ks0 // 2
29
+ tmp2 = tmp0 >= tmp1
30
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
31
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
32
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
33
+ tmp6 = tmp4 + tmp5
34
+ tmp7 = tmp4 < 0
35
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
36
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
37
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
38
+ tmp11 = tmp3 * tmp10
39
+ tmp12 = -tmp11
40
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
41
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
42
+ tmp15 = 0.0
43
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
44
+ tmp17 = tmp0 < tmp1
45
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
46
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
47
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
48
+ tmp21 = tmp19 + tmp20
49
+ tmp22 = tmp19 < 0
50
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
51
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
52
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
53
+ tmp26 = tmp18 * tmp25
54
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
55
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
56
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
57
+ tmp30 = tmp16 + tmp29
58
+ tmp33 = ks3
59
+ tmp34 = tmp32 + tmp33
60
+ tmp35 = tmp32 < 0
61
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
62
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
63
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
64
+ tmp39 = tmp31 * tmp38
65
+ tmp40 = tmp30 + tmp39
66
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
SpecForge-ext/cache/compiled_kernels/dz/f7d5f2184a6f349e4531c61cf67ffbd51fe751bb6902c7e014986bad1a4a9b8f.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 512, "num_warps": 8, "num_stages": 1, "configs_hash": "3ca5c3e34d35093f3c9ab2829a9faeebad5e61c4ca13d5ed6053d7b71ce60d5a", "found_by_coordesc": false, "time_taken_ms": 53, "triton_cache_hash": "UQSFYICF6CFQWZOBHCGZ7JZ457GHWVO6RMPN5ABNWOATFMKI6GQA"}
SpecForge-ext/cache/compiled_kernels/fa/cfac6ze2ka7xqvmyxx4ehmqqczd7mi63mu366jgrbaebsyxjcuna.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/27/c274gnr6pjrqx44o2l7ymaeh7yrigwgf3ninh5xcv6vd5wswoduy.py
38
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
39
+ # Source node to ATen node mapping:
40
+ # cat => cat
41
+ # cos => squeeze_1
42
+ # cos_1 => unsqueeze
43
+ # getitem => index
44
+ # getitem_1 => index_1
45
+ # mul => mul_24
46
+ # mul_1 => mul_45
47
+ # neg => neg
48
+ # q_embed => add_54
49
+ # sin => squeeze_3
50
+ # sin_1 => unsqueeze_1
51
+ # squeeze => squeeze
52
+ # squeeze_2 => squeeze_2
53
+ # x1 => slice_1
54
+ # x2 => slice_2
55
+ # Graph fragment:
56
+ # %primals_12 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:1" = PlaceHolder[target=primals_12]
57
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:1" = PlaceHolder[target=primals_8]
58
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_4]
59
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_6]
60
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
61
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
62
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
63
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
64
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
65
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
66
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
67
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
68
+ # %mul_24 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_12, %unsqueeze), kwargs = {})
69
+ # %slice_1 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, 0, %floordiv), kwargs = {})
70
+ # %slice_2 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_12, 3, %floordiv, 9223372036854775807), kwargs = {})
71
+ # %neg : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s34*Max(1, s24 - ((s24//2))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {})
72
+ # %cat : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {})
73
+ # %mul_45 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %unsqueeze_1), kwargs = {})
74
+ # %add_54 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24, s24*s34, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_24, %mul_45), kwargs = {})
75
+ # return %add_54
76
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', '''
77
+ import triton
78
+ import triton.language as tl
79
+
80
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
81
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
82
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
83
+ triton_helpers.set_driver_to_gpu()
84
+
85
+ @triton_heuristics.pointwise(
86
+ size_hints={'x': 67108864},
87
+ filename=__file__,
88
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
89
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
90
+ min_elem_per_thread=0
91
+ )
92
+ @triton.jit
93
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
94
+ xoffset = tl.program_id(0) * XBLOCK
95
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
96
+ xmask = xindex < xnumel
97
+ x4 = xindex
98
+ x2 = ((xindex // ks0) % ks1)
99
+ x0 = (xindex % ks3)
100
+ x5 = xindex // ks3
101
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
102
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
103
+ tmp2 = ks2
104
+ tmp3 = tmp1 + tmp2
105
+ tmp4 = tmp1 < 0
106
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
107
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
108
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
109
+ tmp8 = tmp0 * tmp7
110
+ tmp9 = x0
111
+ tmp10 = tl.full([1], 0, tl.int64)
112
+ tmp11 = tmp9 >= tmp10
113
+ tmp12 = ks3 + (-1)*(ks3 // 2)
114
+ tmp13 = tmp9 < tmp12
115
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
116
+ tmp15 = -tmp14
117
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
118
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
119
+ tmp18 = tmp9 >= tmp12
120
+ tmp19 = ks3
121
+ tmp20 = tmp9 < tmp19
122
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
123
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
124
+ tmp23 = ks4
125
+ tmp24 = tmp1 + tmp23
126
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
127
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
128
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
129
+ tmp28 = tmp22 * tmp27
130
+ tmp29 = tmp8 + tmp28
131
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
132
+ ''', device_str='cuda')
133
+
134
+
135
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/qt/cqtv2hjbuijyx7awch534sanohmqs6reawit6ksar4ud36qn7xhy.py
136
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
137
+ # Source node to ATen node mapping:
138
+ # cat_1 => cat_1
139
+ # cos => squeeze_1
140
+ # cos_1 => unsqueeze
141
+ # getitem => index
142
+ # getitem_1 => index_1
143
+ # k_embed => add_90
144
+ # mul_2 => mul_54
145
+ # mul_3 => mul_75
146
+ # neg_1 => neg_1
147
+ # sin => squeeze_3
148
+ # sin_1 => unsqueeze_1
149
+ # squeeze => squeeze
150
+ # squeeze_2 => squeeze_2
151
+ # x1_1 => slice_3
152
+ # x2_1 => slice_4
153
+ # Graph fragment:
154
+ # %primals_13 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:1" = PlaceHolder[target=primals_13]
155
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:1" = PlaceHolder[target=primals_8]
156
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_4]
157
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:1" = PlaceHolder[target=primals_6]
158
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
159
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
160
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
161
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
162
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
163
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
164
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
165
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
166
+ # %mul_54 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_13, %unsqueeze), kwargs = {})
167
+ # %slice_3 : Tensor "bf16[s48, s48, s9, (s24//2)][s24*s48*s9, s24, s24*s48, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, 0, %floordiv), kwargs = {})
168
+ # %slice_4 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s24*s48*s9, s24, s24*s48, 1]cuda:1"[num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%primals_13, 3, %floordiv, 9223372036854775807), kwargs = {})
169
+ # %neg_1 : Tensor "bf16[s48, s48, s9, s24 - ((s24//2))][s48*s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), s48*Max(1, s24 - ((s24//2))), 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {})
170
+ # %cat_1 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {})
171
+ # %mul_75 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24*s9, s24, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_1), kwargs = {})
172
+ # %add_90 : Tensor "bf16[s48, s48, s9, s24][s24*s48*s9, s24, s24*s48, 1]cuda:1"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_54, %mul_75), kwargs = {})
173
+ # return %add_90
174
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', '''
175
+ import triton
176
+ import triton.language as tl
177
+
178
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
179
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
180
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
181
+ triton_helpers.set_driver_to_gpu()
182
+
183
+ @triton_heuristics.pointwise(
184
+ size_hints={'x': 16777216},
185
+ filename=__file__,
186
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
187
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
188
+ min_elem_per_thread=0
189
+ )
190
+ @triton.jit
191
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
192
+ xoffset = tl.program_id(0) * XBLOCK
193
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
194
+ xmask = xindex < xnumel
195
+ x4 = xindex
196
+ x2 = ((xindex // ks0) % ks1)
197
+ x0 = (xindex % ks3)
198
+ x5 = xindex // ks3
199
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
200
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
201
+ tmp2 = ks2
202
+ tmp3 = tmp1 + tmp2
203
+ tmp4 = tmp1 < 0
204
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
205
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
206
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
207
+ tmp8 = tmp0 * tmp7
208
+ tmp9 = x0
209
+ tmp10 = tl.full([1], 0, tl.int64)
210
+ tmp11 = tmp9 >= tmp10
211
+ tmp12 = ks3 + (-1)*(ks3 // 2)
212
+ tmp13 = tmp9 < tmp12
213
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
214
+ tmp15 = -tmp14
215
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
216
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
217
+ tmp18 = tmp9 >= tmp12
218
+ tmp19 = ks3
219
+ tmp20 = tmp9 < tmp19
220
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
221
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
222
+ tmp23 = ks4
223
+ tmp24 = tmp1 + tmp23
224
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
225
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
226
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
227
+ tmp28 = tmp22 * tmp27
228
+ tmp29 = tmp8 + tmp28
229
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
230
+ ''', device_str='cuda')
231
+
232
+
233
+ async_compile.wait(globals())
234
+ del async_compile
235
+
236
+ class Runner:
237
+ def __init__(self, partitions):
238
+ self.partitions = partitions
239
+
240
+ def recursively_apply_fns(self, fns):
241
+ new_callables = []
242
+ for fn, c in zip(fns, self.partitions):
243
+ new_callables.append(fn(c))
244
+ self.partitions = new_callables
245
+
246
+ def call(self, args):
247
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13 = args
248
+ args.clear()
249
+ s92 = primals_1
250
+ s24 = primals_2
251
+ s96 = primals_3
252
+ s79 = primals_5
253
+ s9 = primals_7
254
+ s38 = primals_9
255
+ s48 = primals_10
256
+ s34 = primals_11
257
+ assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1))
258
+ assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1))
259
+ assert_size_stride(primals_8, (1, s9), (s9, 1))
260
+ assert_size_stride(primals_12, (s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1))
261
+ assert_size_stride(primals_13, (s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1))
262
+ with torch.cuda._DeviceGuard(1):
263
+ torch.cuda.set_device(1)
264
+ ps0 = s24*s34
265
+ buf0 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24, s24*s34, 1), torch.bfloat16)
266
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul, x1, x2, neg, cat, mul_1, q_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
267
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel = s24*s34*s48*s9
268
+ stream1 = get_raw_stream(1)
269
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0.run(primals_12, primals_8, primals_4, primals_6, buf0, ps0, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_0_xnumel, stream=stream1)
270
+ del primals_12
271
+ ps1 = s24*s48
272
+ buf1 = empty_strided_cuda((s48, s48, s9, s24), (s24*s48*s9, s24, s24*s48, 1), torch.bfloat16)
273
+ # Topologically Sorted Source Nodes: [squeeze, cos, squeeze_2, sin, getitem, cos_1, getitem_1, sin_1, mul_2, x1_1, x2_1, neg_1, cat_1, mul_3, k_embed], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.cat, aten.add]
274
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel = s24*s9*s48*s48
275
+ stream1 = get_raw_stream(1)
276
+ triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1.run(primals_13, primals_8, primals_4, primals_6, buf1, ps1, s9, s92, s24, s79, triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1_xnumel, stream=stream1)
277
+ del primals_13
278
+ return (buf0, buf1, primals_4, primals_6, primals_8, s24, s9, s48, s34, s92, s96, s79, s24 // 2, s24 + (-1)*(s24 // 2), )
279
+
280
+ runner = Runner(partitions=[])
281
+ call = runner.call
282
+ recursively_apply_fns = runner.recursively_apply_fns
283
+
284
+
285
+ def benchmark_compiled_module(times=10, repeat=10):
286
+ from torch._dynamo.testing import rand_strided
287
+ from torch._inductor.utils import print_performance
288
+ primals_1 = 2048
289
+ primals_2 = 128
290
+ primals_3 = 5245440
291
+ primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:1', dtype=torch.bfloat16)
292
+ primals_5 = 2048
293
+ primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:1', dtype=torch.bfloat16)
294
+ primals_7 = 2048
295
+ primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:1', dtype=torch.int64)
296
+ primals_9 = 1
297
+ primals_10 = 8
298
+ primals_11 = 32
299
+ primals_12 = rand_strided((8, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:1', dtype=torch.bfloat16)
300
+ primals_13 = rand_strided((8, 8, 2048, 128), (2097152, 128, 1024, 1), device='cuda:1', dtype=torch.bfloat16)
301
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13])
302
+ return print_performance(fn, times=times, repeat=repeat)
303
+
304
+
305
+ if __name__ == "__main__":
306
+ from torch._inductor.wrapper_benchmark import compiled_module_main
307
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/fa/cfai6qfroimjkp32i57fqulbbxd7ap7nwbhmtwtra7dawieplflr.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['10_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/3e/c3ecvp7qn6xg4f7m37cdxroadkir5ch6rsa6n7ihcyk3xps5ffo6.py
38
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # getitem_1 => unsqueeze
41
+ # position_mask => mul_6
42
+ # target_mask => index
43
+ # target_mask_1 => convert_element_type
44
+ # target_max_token => argmax
45
+ # Graph fragment:
46
+ # %arg1_1 : Tensor "bf16[8, s14, 151936][151936*s14, 151936, 1]cuda:0" = PlaceHolder[target=arg1_1]
47
+ # %argmax : Tensor "i64[8, s14][s14, 1]cuda:0" = PlaceHolder[target=argmax]
48
+ # %arg2_1 : Tensor "b8[151936][1]cuda:0" = PlaceHolder[target=arg2_1]
49
+ # %arg3_1 : Tensor "i64[8, s14, 1][s14, 1, 1]cuda:0" = PlaceHolder[target=arg3_1]
50
+ # %argmax : Tensor "i64[8, s14][s14, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg1_1, -1), kwargs = {})
51
+ # %index : Tensor "b8[8, s14][s14, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%argmax]), kwargs = {})
52
+ # %unsqueeze : Tensor "b8[8, s14, 1][s14, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {})
53
+ # %convert_element_type : Tensor "i32[8, s14, 1][s14, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {})
54
+ # %mul_6 : Tensor "i64[8, s14, 1][s14, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg3_1), kwargs = {})
55
+ # return %argmax,%mul_6
56
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', '''
57
+ import triton
58
+ import triton.language as tl
59
+
60
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
61
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
62
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
63
+ triton_helpers.set_driver_to_gpu()
64
+
65
+ @triton_heuristics.reduction(
66
+ size_hints={'x': 16384, 'r0_': 262144},
67
+ reduction_hint=ReductionHint.INNER,
68
+ filename=__file__,
69
+ triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
70
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
71
+ )
72
+ @triton.jit
73
+ def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
74
+ r0_numel = 151936
75
+ rnumel = r0_numel
76
+ RBLOCK: tl.constexpr = R0_BLOCK
77
+ xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
78
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64)
79
+ xmask = xindex < xnumel
80
+ r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64)
81
+ rbase = r0_base
82
+ x0 = xindex
83
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
84
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64)
85
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
86
+ r0_index = r0_offset + r0_base
87
+ r0_mask = r0_index < r0_numel
88
+ roffset = r0_offset
89
+ rindex = r0_index
90
+ r0_1 = r0_index
91
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
92
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
93
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
94
+ _tmp2, _tmp2_index, tmp1, rindex
95
+ )
96
+ _tmp2 = tl.where(r0_mask & xmask, _tmp2_next, _tmp2)
97
+ _tmp2_index = tl.where(r0_mask & xmask, _tmp2_index_next, _tmp2_index)
98
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
99
+ tmp2 = tmp2_idx[:, None]
100
+ tmp11 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
101
+ tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32)
102
+ tmp4 = tmp2 + tmp3
103
+ tmp5 = tmp2 < 0
104
+ tmp6 = tl.where(tmp5, tmp4, tmp2)
105
+ tl.device_assert(((0 <= tmp6) & (tmp6 < 151936)) | ~(xmask), "index out of bounds: 0 <= tmp6 < 151936")
106
+ tmp8 = tl.load(in_ptr1 + (tmp6), xmask, eviction_policy='evict_last').to(tl.int1)
107
+ tmp9 = tmp8.to(tl.int32)
108
+ tmp10 = tmp9.to(tl.int64)
109
+ tmp12 = tmp10 * tmp11
110
+ tl.debug_barrier()
111
+ tl.store(in_out_ptr0 + (x0), tmp12, xmask)
112
+ ''', device_str='cuda')
113
+
114
+
115
+ async_compile.wait(globals())
116
+ del async_compile
117
+
118
+ class Runner:
119
+ def __init__(self, partitions):
120
+ self.partitions = partitions
121
+
122
+ def recursively_apply_fns(self, fns):
123
+ new_callables = []
124
+ for fn, c in zip(fns, self.partitions):
125
+ new_callables.append(fn(c))
126
+ self.partitions = new_callables
127
+
128
+ def call(self, args):
129
+ arg0_1, arg1_1, arg2_1, arg3_1 = args
130
+ args.clear()
131
+ s24 = arg0_1
132
+ arg1_1_size = arg1_1.size()
133
+ s14 = arg1_1_size[1]
134
+ assert_size_stride(arg1_1, (8, s14, 151936), (151936*s14, 151936, 1))
135
+ assert_size_stride(arg2_1, (151936, ), (1, ))
136
+ assert_size_stride(arg3_1, (8, s14, 1), (s14, 1, 1))
137
+ with torch.cuda._DeviceGuard(0):
138
+ torch.cuda.set_device(0)
139
+ buf0 = empty_strided_cuda((8, s14), (s14, 1), torch.int64)
140
+ buf1 = reinterpret_tensor(buf0, (8, s14, 1), (s14, 1, 1), 0); del buf0 # reuse
141
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
142
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel = 8*s14
143
+ stream0 = get_raw_stream(0)
144
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg1_1, arg2_1, arg3_1, triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0_xnumel, 151936, stream=stream0)
145
+ del arg1_1
146
+ del arg2_1
147
+ del arg3_1
148
+ return (buf1, )
149
+
150
+ runner = Runner(partitions=[])
151
+ call = runner.call
152
+ recursively_apply_fns = runner.recursively_apply_fns
153
+
154
+
155
+ def benchmark_compiled_module(times=10, repeat=10):
156
+ from torch._dynamo.testing import rand_strided
157
+ from torch._inductor.utils import print_performance
158
+ arg0_1 = 2009
159
+ arg1_1 = rand_strided((8, 2009, 151936), (305239424, 151936, 1), device='cuda:0', dtype=torch.bfloat16)
160
+ arg2_1 = rand_strided((151936, ), (1, ), device='cuda:0', dtype=torch.bool)
161
+ arg3_1 = rand_strided((8, 2009, 1), (2009, 1, 1), device='cuda:0', dtype=torch.int64)
162
+ fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1])
163
+ return print_performance(fn, times=times, repeat=repeat)
164
+
165
+
166
+ if __name__ == "__main__":
167
+ from torch._inductor.wrapper_benchmark import compiled_module_main
168
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/fa/cfail5nyr4vuktxoags33cssvkjxk2nbmzhswhjwxszpyc4qj4wf.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['6_forward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
17
+ import triton
18
+ import triton.language as tl
19
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
20
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
21
+
22
+ aten = torch.ops.aten
23
+ inductor_ops = torch.ops.inductor
24
+ _quantized = torch.ops._quantized
25
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
26
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
27
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
28
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
29
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
30
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
31
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
32
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
33
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
34
+ async_compile = AsyncCompile()
35
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
36
+
37
+
38
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/gp/cgpqg54v7ag6awmgwhlrbbyw5jxsgjo6tuzvo3rt2xzqk6f33df2.py
39
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
40
+ # Source node to ATen node mapping:
41
+ # flex_attention => flex_attention
42
+ # Graph fragment:
43
+ # %primals_1 : Tensor "bf16[2, 32, 2048, 128][8388608, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_1]
44
+ # %primals_2 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:4" = PlaceHolder[target=primals_2]
45
+ # %primals_3 : Tensor "bf16[2, 8, 2048, 128][2097152, 262144, 128, 1]cuda:4" = PlaceHolder[target=primals_3]
46
+ # %getitem_1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:4" = PlaceHolder[target=getitem_1]
47
+ # %buf1 : Tensor "f32[2, 32, 2048][65536, 2048, 1]cuda:4" = PlaceHolder[target=buf1]
48
+ # %primals_5 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_5]
49
+ # %primals_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_4]
50
+ # %primals_7 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:4" = PlaceHolder[target=primals_7]
51
+ # %primals_8 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:4" = PlaceHolder[target=primals_8]
52
+ # %primals_6 : Tensor "i64[2][1]cuda:4" = PlaceHolder[target=primals_6]
53
+ # %flex_attention : [num_users=2] = call_function[target=torch.ops.higher_order.flex_attention](args = (%primals_1, %primals_2, %primals_3, %sdpa_score0, (2048, 2048, %primals_5, %primals_4, %primals_7, %primals_8, %primals_9, %primals_10, %primals_11, %primals_12, 128, 128, %sdpa_mask0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_6,)), kwargs = {})
54
+ # return %getitem
55
+ triton_tem_fused_0 = async_compile.triton('triton_tem_fused_0', '''
56
+ import triton
57
+ import triton.language as tl
58
+
59
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
60
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
61
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
62
+
63
+ @triton_heuristics.template(
64
+
65
+ num_stages=3,
66
+ num_warps=8,
67
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
68
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
69
+
70
+ )
71
+ @triton.jit
72
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0):
73
+ PRESCALE_QK : tl.constexpr = False
74
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
75
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
76
+ WRITE_DQ : tl.constexpr = True
77
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
78
+ OUTPUT_MAX : tl.constexpr = False
79
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
80
+ IS_DIVISIBLE : tl.constexpr = True
81
+ SM_SCALE : tl.constexpr = 0.08838834764831843
82
+ GQA_SHARED_HEADS : tl.constexpr = 4
83
+ HAS_FULL_BLOCKS : tl.constexpr = True
84
+ QK_HEAD_DIM : tl.constexpr = 128
85
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
86
+ V_HEAD_DIM : tl.constexpr = 128
87
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
88
+ SAFE_HEAD_DIM : tl.constexpr = True
89
+ USE_TMA : tl.constexpr = False
90
+ BLOCK_M : tl.constexpr = 128
91
+ BLOCK_N : tl.constexpr = 64
92
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
93
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
94
+ INDEX_DTYPE : tl.constexpr = tl.int32
95
+ Q = arg_Q
96
+ K = arg_K
97
+ V = arg_V
98
+ LSE = arg_LSE
99
+ MAX = arg_MAX
100
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
101
+ KV_IDX = arg_KV_IDX
102
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
103
+ FULL_KV_IDX = arg_FULL_KV_IDX
104
+
105
+ # Sub notation for this kernel:
106
+ #
107
+ # Q: Query, K: Key, V: Value
108
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
109
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
110
+ # V_HEAD_DIM: The dimension of the value embeddings
111
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
112
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
113
+ #
114
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
115
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
116
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
117
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
118
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
119
+ #
120
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
121
+ #
122
+ # (Modifiable) Performance tuning options
123
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
124
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
125
+
126
+ # The below are kernel options that can be applied for certain score_mods,
127
+ # or involve a numerics vs. perf tradeoff
128
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
129
+ # about 20% more numerical error, but slightly faster.
130
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
131
+ # is not masked out? If so, we can skip an extra safety check
132
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
133
+ # contiguous? If so, we don't need to do an indirect jump for every block
134
+
135
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
136
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
137
+
138
+ # Define strides of inputs
139
+ stride_qz, stride_qh, stride_qm, stride_qk = 8388608, 128, 4096, 1
140
+ stride_kz, stride_kh, stride_kn, stride_kk = 2097152, 262144, 128, 1
141
+ stride_vz, stride_vh, stride_vn, stride_vk = 2097152, 262144, 128, 1
142
+
143
+ ZQ = 2
144
+ HQ = 32
145
+ Q_LEN = 2048
146
+ ZKV = 2
147
+ KV_LEN = 2048
148
+
149
+ MATMUL_PRECISION = Q.dtype.element_ty
150
+
151
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
152
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
153
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
154
+
155
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
156
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
157
+ off_zkv = off_zq % ZKV
158
+ off_hkv = off_hq // GQA_SHARED_HEADS
159
+ off_g = off_hq % GQA_SHARED_HEADS
160
+
161
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
162
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
163
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
164
+
165
+ Q = Q + q_offset
166
+ K = K + k_offset
167
+ V = V + v_offset
168
+
169
+ # Setting up the TMA descriptors for Q, K, V
170
+ desc_q = None
171
+ desc_k = None
172
+ desc_v = None
173
+
174
+ SPARSE_Z = 2
175
+ SPARSE_HQ = 1
176
+
177
+ sparse_idx_z = off_zq % SPARSE_Z
178
+ sparse_idx_hq = off_hq % SPARSE_HQ
179
+
180
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
181
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
182
+
183
+ stride_kv_num_blks_h = 16
184
+ stride_kv_idx_h = 256
185
+ stride_kv_idx_m = 16
186
+
187
+ # initialize pointer to m and l
188
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
189
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
190
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
191
+
192
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
193
+
194
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
195
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
196
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
197
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
198
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
199
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
200
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
201
+
202
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
203
+ # We don't know anything "special" about these blocks, so we need to apply
204
+ # both score_mod and mask_mod to it
205
+ kv_indices = KV_IDX + sparse_kv_idx_offset
206
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
207
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
208
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
209
+
210
+
211
+ # K and V pointers will be passed directly to forward_inner
212
+
213
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
214
+
215
+
216
+ acc, l_i, m_i = forward_inner(
217
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
218
+ q, K, V,
219
+ desc_k, desc_v, Q_LEN, KV_LEN,
220
+ acc, l_i, m_i,
221
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
222
+ kv_start,
223
+ kv_indices, kv_num_blocks,
224
+ 0, block_n_end,
225
+ MATMUL_PRECISION,
226
+ stride_kk, stride_kn, stride_vn, stride_vk,
227
+ IS_FULL_BLOCKS=False,
228
+ )
229
+
230
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
231
+ # We know these blocks are guaranteed to be "full", so we don't need to
232
+ # apply mask_mod to them - only score_mod
233
+ if HAS_FULL_BLOCKS:
234
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
235
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
236
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
237
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
238
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
239
+ # K and V pointers will be passed directly to forward_inner
240
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
241
+
242
+ acc, l_i, m_i = forward_inner(
243
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
244
+ q, K, V,
245
+ desc_k, desc_v, Q_LEN, KV_LEN,
246
+ acc, l_i, m_i,
247
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
248
+ kv_start,
249
+ kv_indices, kv_num_blocks,
250
+ 0, block_n_end,
251
+ MATMUL_PRECISION,
252
+ stride_kk, stride_kn, stride_vn, stride_vk,
253
+ IS_FULL_BLOCKS=True,
254
+ )
255
+
256
+
257
+ # [Note] Handle fully masked out rows:
258
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
259
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
260
+ l_i = tl.where(l_i == 0.0, 1, l_i)
261
+
262
+ acc = acc / l_i[:, None]
263
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
264
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
265
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
266
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
267
+
268
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
269
+
270
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
271
+ xindex = idx_d + 128*idx_m + 262144*idx_hq + 8388608*idx_zq
272
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 8388608*idx_zq, acc.shape)), acc, mask)
273
+
274
+ if OUTPUT_LOGSUMEXP:
275
+ off_hz = off_zq * HQ + off_hq
276
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
277
+ lse = m_i + tl.math.log2(l_i)
278
+ if IS_DIVISIBLE:
279
+ tl.store(l_ptrs, lse)
280
+ else:
281
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
282
+
283
+ if OUTPUT_MAX:
284
+ off_hz = off_zq * HQ + off_hq
285
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
286
+ if IS_DIVISIBLE:
287
+ tl.store(max_ptrs, m_i)
288
+ else:
289
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
290
+
291
+
292
+ # Utility triton funcs
293
+ @triton.jit
294
+ def get_offset_for_next_block(
295
+ loop_iter, col_indices, total_blocks,
296
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
297
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
298
+ ):
299
+ if BLOCKS_ARE_CONTIGUOUS:
300
+ return BLOCK
301
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
302
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
303
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
304
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
305
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
306
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
307
+ return offset
308
+
309
+ @triton.jit
310
+ def get_bounded_indices(indices, max_len=None):
311
+ return indices % max_len if max_len is not None else indices
312
+
313
+ @triton.jit
314
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
315
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
316
+ return tl.load(block_ptr)
317
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
318
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
319
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
320
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
321
+ else:
322
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
323
+
324
+ @triton.jit
325
+ def load_checked_2d(
326
+ ptr,
327
+ offs_m,
328
+ offs_n,
329
+ stride_m,
330
+ stride_n,
331
+ IS_DIVISIBLE_M: tl.constexpr,
332
+ IS_DIVISIBLE_N: tl.constexpr,
333
+ M_LEN: tl.constexpr,
334
+ N_LEN: tl.constexpr,
335
+ ):
336
+ # Calculate final pointer if strides are provided
337
+ if stride_m is not None and stride_n is not None:
338
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
339
+
340
+ # Handle all masking cases
341
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
342
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
343
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
344
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
345
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
346
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
347
+ else: # Both divisible
348
+ return tl.load(ptr)
349
+
350
+
351
+ # Common Imports
352
+ @triton.jit
353
+ def forward_block_mn(
354
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
355
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
356
+ # accumulated values
357
+ acc, l_i, m_i,
358
+ # Offsets
359
+ off_z, off_h, offs_m, offs_n,
360
+ # Offsets needed for TMA loads
361
+ kv_start,
362
+ kv_offset,
363
+ MATMUL_PRECISION, RCP_LN2,
364
+ # Strides for K and V
365
+ stride_kk, stride_kn, stride_vn, stride_vk,
366
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
367
+
368
+ ):
369
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
370
+ PRESCALE_QK : tl.constexpr = False
371
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
372
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
373
+ WRITE_DQ : tl.constexpr = True
374
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
375
+ OUTPUT_MAX : tl.constexpr = False
376
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
377
+ IS_DIVISIBLE : tl.constexpr = True
378
+ SM_SCALE : tl.constexpr = 0.08838834764831843
379
+ GQA_SHARED_HEADS : tl.constexpr = 4
380
+ HAS_FULL_BLOCKS : tl.constexpr = True
381
+ QK_HEAD_DIM : tl.constexpr = 128
382
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
383
+ V_HEAD_DIM : tl.constexpr = 128
384
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
385
+ SAFE_HEAD_DIM : tl.constexpr = True
386
+ USE_TMA : tl.constexpr = False
387
+ BLOCK_M : tl.constexpr = 128
388
+ BLOCK_N : tl.constexpr = 64
389
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
390
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
391
+ INDEX_DTYPE : tl.constexpr = tl.int32
392
+
393
+
394
+ # -- load k --
395
+ # NB reversed order to since K is transposed
396
+ kv_base_offset = kv_start + kv_offset
397
+
398
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
399
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
400
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
401
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
402
+
403
+ k = tl.trans(k)
404
+ # -- compute qk ---
405
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
406
+ if not PRESCALE_QK:
407
+ qk *= SM_SCALE
408
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
409
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
410
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
411
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
412
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
413
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
414
+
415
+ tmp0 = (qk)
416
+ post_mod_scores = tmp0
417
+
418
+
419
+ if CHECK_BLOCK_BOUNDARY:
420
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
421
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
422
+
423
+ if not IS_FULL_BLOCKS:
424
+ tmp1 = tl.full([1], False, tl.int1)
425
+ tmp2 = (m)
426
+ tmp3 = (n)
427
+ tmp4 = tmp2 >= tmp3
428
+ tmp5 = tmp3.to(tl.int64)
429
+ tmp6 = (off_z)
430
+ tmp7 = tl.load(in_ptr9 + tmp6)
431
+ tmp8 = tmp5 < tmp7
432
+ tmp9 = tmp2.to(tl.int64)
433
+ tmp10 = tmp9 < tmp7
434
+ tmp11 = tmp8 & tmp10
435
+ tmp12 = tmp4 & tmp11
436
+ tmp13 = tmp1 | tmp12
437
+ tmp14 = tl.full([1], 2048, tl.int32)
438
+ tmp15 = tmp3 >= tmp14
439
+ tmp16 = (tmp3 % tmp14)
440
+ tmp17 = tl.full([1], 0, tl.int32)
441
+ tmp18 = tmp16 != tmp17
442
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
443
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
444
+ tmp21 = tmp19 != tmp20
445
+ tmp22 = tmp18 & tmp21
446
+ tmp23 = tmp16 + tmp14
447
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
448
+ tmp25 = tmp24.to(tl.int64)
449
+ tmp26 = tmp25 < tmp7
450
+ tmp27 = tmp15 & tmp26
451
+ tmp28 = tmp3 - tmp2
452
+ tmp29 = (tmp28 % tmp14)
453
+ tmp30 = tmp29 != tmp17
454
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
455
+ tmp32 = tmp31 != tmp20
456
+ tmp33 = tmp30 & tmp32
457
+ tmp34 = tmp29 + tmp14
458
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
459
+ tmp36 = tmp35 == tmp17
460
+ tmp37 = tmp27 & tmp36
461
+ tmp38 = tmp13 | tmp37
462
+ mask_mod_output = tmp38
463
+
464
+
465
+ if CHECK_BLOCK_BOUNDARY:
466
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
467
+ # apply mask for partially unmasked blocks
468
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
469
+
470
+ if not PRESCALE_QK:
471
+ post_mod_scores *= RCP_LN2
472
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
473
+
474
+ # -- compute scaling constant ---
475
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
476
+ if not ROWS_GUARANTEED_SAFE:
477
+ masked_out_rows = (m_ij == float("-inf"))
478
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
479
+ else:
480
+ m_ij_masked = m_ij
481
+
482
+ alpha = tl.math.exp2(m_i - m_ij_masked)
483
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
484
+
485
+ # NB: l_i update is pulled up here since it's a bit faster
486
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
487
+ # m_ij
488
+ l_i = l_i * alpha + tl.sum(p, 1)
489
+ # # -- scale and update acc --
490
+ acc = acc * alpha[:, None]
491
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
492
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
493
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
494
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
495
+
496
+ # -- update m_i
497
+ m_i = m_ij
498
+
499
+ return acc, l_i, m_i
500
+
501
+ @triton.jit
502
+ def forward_inner(
503
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
504
+ q, K, V,
505
+ desc_k, desc_v, Q_LEN, KV_LEN,
506
+ # accumulated values
507
+ acc, l_i, m_i,
508
+ # Offsets used as inputs to score_mod & mask_mod
509
+ # of size [BLOCK_M, BLOCK_N] or scalar.
510
+ off_z, off_h, offs_m, offs_n,
511
+ # Offsets needed for TMA loads
512
+ kv_start,
513
+ # blocksparse data
514
+ kv_indices, kv_num_blocks,
515
+ # start kv and end kv block
516
+ block_n_start, block_n_end,
517
+ MATMUL_PRECISION,
518
+ # Strides for K and V
519
+ stride_kk, stride_kn, stride_vn, stride_vk,
520
+ IS_FULL_BLOCKS,
521
+ ):
522
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
523
+ PRESCALE_QK : tl.constexpr = False
524
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
525
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
526
+ WRITE_DQ : tl.constexpr = True
527
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
528
+ OUTPUT_MAX : tl.constexpr = False
529
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
530
+ IS_DIVISIBLE : tl.constexpr = True
531
+ SM_SCALE : tl.constexpr = 0.08838834764831843
532
+ GQA_SHARED_HEADS : tl.constexpr = 4
533
+ HAS_FULL_BLOCKS : tl.constexpr = True
534
+ QK_HEAD_DIM : tl.constexpr = 128
535
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
536
+ V_HEAD_DIM : tl.constexpr = 128
537
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
538
+ SAFE_HEAD_DIM : tl.constexpr = True
539
+ USE_TMA : tl.constexpr = False
540
+ BLOCK_M : tl.constexpr = 128
541
+ BLOCK_N : tl.constexpr = 64
542
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
543
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
544
+ INDEX_DTYPE : tl.constexpr = tl.int32
545
+
546
+
547
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
548
+ RCP_LN2: tl.constexpr = 1.44269504
549
+
550
+ if PRESCALE_QK:
551
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
552
+
553
+ kv_offset = 0
554
+
555
+ # loop over k, v and update accumulator until block_n_end
556
+ for start_n in range(block_n_start, block_n_end):
557
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
558
+ if IS_DIVISIBLE:
559
+ acc, l_i, m_i = forward_block_mn(
560
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
561
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
562
+ # accumulated values
563
+ acc, l_i, m_i,
564
+ # Offsets
565
+ off_z, off_h, offs_m, offs_n,
566
+ # Offsets needed for TMA loads
567
+ kv_start,
568
+ kv_offset,
569
+ MATMUL_PRECISION, RCP_LN2,
570
+ # Strides for K and V
571
+ stride_kk, stride_kn, stride_vn, stride_vk,
572
+ IS_FULL_BLOCKS,
573
+ )
574
+ else:
575
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
576
+ # it's on par or slightly faster than only applying to the last block in fwd.
577
+ # However, we choose different strategy for bwd, where we only apply mod & mask
578
+ # to the last block because it's faster a lot.
579
+ acc, l_i, m_i = forward_block_mn(
580
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0,
581
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
582
+ # accumulated values
583
+ acc, l_i, m_i,
584
+ # Offsets
585
+ off_z, off_h, offs_m, offs_n,
586
+ # Offsets needed for TMA loads
587
+ kv_start,
588
+ kv_offset,
589
+ MATMUL_PRECISION, RCP_LN2,
590
+ # Strides for K and V
591
+ stride_kk, stride_kn, stride_vn, stride_vk,
592
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
593
+ )
594
+
595
+
596
+
597
+ offset = get_offset_for_next_block(
598
+ start_n, kv_indices, kv_num_blocks,
599
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
600
+ )
601
+
602
+ offs_n = offs_n + offset
603
+ kv_offset += offset
604
+
605
+
606
+ return acc, l_i, m_i
607
+ ''', device_str='cuda')
608
+
609
+
610
+ async_compile.wait(globals())
611
+ del async_compile
612
+
613
+ class Runner:
614
+ def __init__(self, partitions):
615
+ self.partitions = partitions
616
+
617
+ def recursively_apply_fns(self, fns):
618
+ new_callables = []
619
+ for fn, c in zip(fns, self.partitions):
620
+ new_callables.append(fn(c))
621
+ self.partitions = new_callables
622
+
623
+ def call(self, args):
624
+ primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12 = args
625
+ args.clear()
626
+ assert_size_stride(primals_1, (2, 32, 2048, 128), (8388608, 128, 4096, 1))
627
+ assert_size_stride(primals_2, (2, 8, 2048, 128), (2097152, 262144, 128, 1))
628
+ assert_size_stride(primals_3, (2, 8, 2048, 128), (2097152, 262144, 128, 1))
629
+ assert_size_stride(primals_4, (2, 1, 16, 16), (256, 256, 16, 1))
630
+ assert_size_stride(primals_5, (2, 1, 16), (16, 16, 1))
631
+ assert_size_stride(primals_6, (2, ), (1, ))
632
+ assert_size_stride(primals_7, (2, 1, 16), (16, 16, 1))
633
+ assert_size_stride(primals_8, (2, 1, 16, 16), (256, 256, 16, 1))
634
+ assert_size_stride(primals_9, (2, 1, 16), (16, 16, 1))
635
+ assert_size_stride(primals_10, (2, 1, 16, 16), (256, 256, 16, 1))
636
+ assert_size_stride(primals_11, (2, 1, 16), (16, 16, 1))
637
+ assert_size_stride(primals_12, (2, 1, 16, 16), (256, 256, 16, 1))
638
+ with torch.cuda._DeviceGuard(4):
639
+ torch.cuda.set_device(4)
640
+ buf0 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32)
641
+ buf1 = empty_strided_cuda((2, 32, 2048), (65536, 2048, 1), torch.float32)
642
+ buf2 = empty_strided_cuda((2, 32, 2048, 128), (8388608, 128, 4096, 1), torch.bfloat16)
643
+ # Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
644
+ stream4 = get_raw_stream(4)
645
+ triton_tem_fused_0.run(primals_1, primals_2, primals_3, buf0, buf1, primals_5, primals_4, primals_7, primals_8, primals_6, buf2, 16, 2, 32, stream=stream4)
646
+ del buf1
647
+ return (buf2, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, buf2, buf0, )
648
+
649
+ runner = Runner(partitions=[])
650
+ call = runner.call
651
+ recursively_apply_fns = runner.recursively_apply_fns
652
+
653
+
654
+ def benchmark_compiled_module(times=10, repeat=10):
655
+ from torch._dynamo.testing import rand_strided
656
+ from torch._inductor.utils import print_performance
657
+ primals_1 = rand_strided((2, 32, 2048, 128), (8388608, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16)
658
+ primals_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:4', dtype=torch.bfloat16)
659
+ primals_3 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:4', dtype=torch.bfloat16)
660
+ primals_4 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32)
661
+ primals_5 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32)
662
+ primals_6 = rand_strided((2, ), (1, ), device='cuda:4', dtype=torch.int64)
663
+ primals_7 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32)
664
+ primals_8 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32)
665
+ primals_9 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32)
666
+ primals_10 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32)
667
+ primals_11 = rand_strided((2, 1, 16), (16, 16, 1), device='cuda:4', dtype=torch.int32)
668
+ primals_12 = rand_strided((2, 1, 16, 16), (256, 256, 16, 1), device='cuda:4', dtype=torch.int32)
669
+ fn = lambda: call([primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12])
670
+ return print_performance(fn, times=times, repeat=repeat)
671
+
672
+
673
+ if __name__ == "__main__":
674
+ from torch._inductor.wrapper_benchmark import compiled_module_main
675
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/fa/cfawzdo3q32syzk5d3t3mjridjbalgrkptn5qwko7qnup25mzrum.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 16384, 'r0_': 262144},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 16384
20
+ r0_numel = 151936
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64)
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64)
27
+ rbase = r0_base
28
+ x0 = xindex
29
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
30
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64)
31
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
32
+ r0_index = r0_offset + r0_base
33
+ r0_mask = r0_index < r0_numel
34
+ roffset = r0_offset
35
+ rindex = r0_index
36
+ r0_1 = r0_index
37
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
38
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
39
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
40
+ _tmp2, _tmp2_index, tmp1, rindex
41
+ )
42
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
43
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
44
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
45
+ tmp2 = tmp2_idx[:, None]
46
+ tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
47
+ tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32)
48
+ tmp4 = tmp2 + tmp3
49
+ tmp5 = tmp2 < 0
50
+ tmp6 = tl.where(tmp5, tmp4, tmp2)
51
+ tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936")
52
+ tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1)
53
+ tmp9 = tmp8.to(tl.int32)
54
+ tmp10 = tmp9.to(tl.int64)
55
+ tmp12 = tmp10 * tmp11
56
+ tl.debug_barrier()
57
+ tl.store(in_out_ptr0 + (x0), tmp12, None)
SpecForge-ext/cache/compiled_kernels/fi/cfiplsvt2q6tbvsfjtg2dd47g7npdwtvk5m3lv4anjbxwgjigkj2.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 2048, 'r0_': 16384},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32768, 'r0_': 0}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 2048
20
+ r0_numel = 16384
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x1 = ((xindex // 16) % 16)
29
+ x0 = (xindex % 16)
30
+ x2 = xindex // 256
31
+ tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
32
+ _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
33
+ x6 = xindex
34
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
35
+ r0_index = r0_offset + r0_base
36
+ r0_mask = r0_index < r0_numel
37
+ roffset = r0_offset
38
+ rindex = r0_index
39
+ r0_4 = r0_index // 128
40
+ r0_3 = (r0_index % 128)
41
+ tmp0 = r0_4 + 128*x1
42
+ tmp1 = r0_3 + 128*x0
43
+ tmp2 = tmp0 >= tmp1
44
+ tmp4 = tmp1 < tmp3
45
+ tmp5 = tmp0 < tmp3
46
+ tmp6 = tmp4 & tmp5
47
+ tmp7 = tmp2 & tmp6
48
+ tmp8 = tl.full([1, 1], False, tl.int1)
49
+ tmp9 = tmp8 | tmp7
50
+ tmp10 = tl.full([1, 1], 2048, tl.int64)
51
+ tmp11 = tmp1 >= tmp10
52
+ tmp12 = tmp11 & tmp4
53
+ tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
54
+ tmp14 = (tmp13 % tmp10)
55
+ tmp15 = tl.full([1, 1], 0, tl.int32)
56
+ tmp16 = tmp14 != tmp15
57
+ tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
58
+ tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0
59
+ tmp19 = tmp17 != tmp18
60
+ tmp20 = tmp16 & tmp19
61
+ tmp21 = tmp14 + tmp10
62
+ tmp22 = tl.where(tmp20, tmp21, tmp14)
63
+ tmp23 = tl.full([1, 1], 0, tl.int64)
64
+ tmp24 = tmp22 == tmp23
65
+ tmp25 = tmp12 & tmp24
66
+ tmp26 = tmp9 | tmp25
67
+ tmp27 = tmp26.to(tl.int64)
68
+ tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK])
69
+ tmp30 = _tmp29 + tmp28
70
+ _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29)
71
+ tmp29 = tl.sum(_tmp29, 1)[:, None]
72
+ tl.store(out_ptr0 + (x6), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/h6/aa838d40f4d0e483f1277be61c094ff598dd757fa08fb0e455bf7c8a9b79036a.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 256, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 13, "triton_cache_hash": "6FB7I6IASCIGI3DSKLBL4Q2CXFFWPYWXW7AMHNUUDLPGKUCB3PDA"}
SpecForge-ext/cache/compiled_kernels/ic/cicti66tef7ykscmewrfizq5t5hma2a6k6njneyopvmhy4vmegql.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['5_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/zc/czc4uswzazabvj7ebt72gzrcg2fgrugi6d7lol4a4jino45fz2ua.py
38
+ # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum]
39
+ # Source node to ATen node mapping:
40
+ # and_2 => bitwise_and_1
41
+ # and_3 => bitwise_and_2
42
+ # and_4 => bitwise_and_3, view_8
43
+ # b => iota
44
+ # batched_outputs_2 => view_9
45
+ # causal_mask => ge, view
46
+ # diagnol_mask => eq
47
+ # index => index
48
+ # index_1 => index_1
49
+ # index_2 => index_2
50
+ # lt => lt, view_1
51
+ # lt_1 => lt_1, view_2
52
+ # m => iota_2
53
+ # mask_2 => view_10
54
+ # mask_3 => permute
55
+ # mask_block_sum => sum_1
56
+ # n => iota_3
57
+ # padding_mask => bitwise_and, view_3, view_4
58
+ # padding_mask_1 => lt_2, view_6
59
+ # remainder => remainder
60
+ # remainder_1 => remainder_1
61
+ # result_1 => bitwise_or, full_default
62
+ # result_2 => bitwise_or_1
63
+ # sub => sub, view_7
64
+ # suffix_mask => ge_1
65
+ # Graph fragment:
66
+ # %arg0_1 : Tensor "i64[2][1]cuda:2" = PlaceHolder[target=arg0_1]
67
+ # %full_default : Tensor "b8[2, 1, 1][1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1], False), kwargs = {dtype: torch.bool, layout: torch.strided, device: cuda:2, pin_memory: False})
68
+ # %iota_2 : Tensor "i64[2048][1]cuda:2"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False})
69
+ # %view : Tensor "i64[2048, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {})
70
+ # %iota_3 : Tensor "i64[2048][1]cuda:2"[num_users=5] = call_function[target=torch.ops.prims.iota.default](args = (2048,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False})
71
+ # %ge : Tensor "b8[2048, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.ge.Tensor](args = (%view, %iota_3), kwargs = {})
72
+ # %iota : Tensor "i64[2][1]cuda:2"[num_users=3] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False})
73
+ # %index : Tensor "i64[2][1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {})
74
+ # %view_1 : Tensor "i64[2, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index, [2, 1]), kwargs = {})
75
+ # %lt : Tensor "b8[2, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_3, %view_1), kwargs = {})
76
+ # %view_4 : Tensor "b8[2, 1, 2048][2048, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt, [2, 1, 2048]), kwargs = {})
77
+ # %index_1 : Tensor "i64[2][1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {})
78
+ # %view_2 : Tensor "i64[2, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_1, [2, 1]), kwargs = {})
79
+ # %lt_1 : Tensor "b8[2, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_2, %view_2), kwargs = {})
80
+ # %view_3 : Tensor "b8[2, 2048, 1][2048, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%lt_1, [2, 2048, 1]), kwargs = {})
81
+ # %bitwise_and : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_4, %view_3), kwargs = {})
82
+ # %bitwise_and_1 : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %bitwise_and), kwargs = {})
83
+ # %bitwise_or : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%full_default, %bitwise_and_1), kwargs = {})
84
+ # %ge_1 : Tensor "b8[2048][1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%iota_3, 2048), kwargs = {})
85
+ # %remainder : Tensor "i64[2048][1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%iota_3, 2048), kwargs = {})
86
+ # %index_2 : Tensor "i64[2][1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%iota]), kwargs = {})
87
+ # %view_6 : Tensor "i64[2, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%index_2, [2, 1]), kwargs = {})
88
+ # %lt_2 : Tensor "b8[2, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%remainder, %view_6), kwargs = {})
89
+ # %bitwise_and_2 : Tensor "b8[2, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_2), kwargs = {})
90
+ # %view_8 : Tensor "b8[2, 1, 2048][2048, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_and_2, [2, 1, 2048]), kwargs = {})
91
+ # %view_7 : Tensor "i64[2048, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%iota_2, [2048, 1]), kwargs = {})
92
+ # %sub : Tensor "i64[2048, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%iota_3, %view_7), kwargs = {})
93
+ # %remainder_1 : Tensor "i64[2048, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.remainder.Scalar](args = (%sub, 2048), kwargs = {})
94
+ # %eq : Tensor "b8[2048, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%remainder_1, 0), kwargs = {})
95
+ # %bitwise_and_3 : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%view_8, %eq), kwargs = {})
96
+ # %bitwise_or_1 : Tensor "b8[2, 2048, 2048][4194304, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_or, %bitwise_and_3), kwargs = {})
97
+ # %view_9 : Tensor "b8[2, 1, 2048, 2048][4194304, 4194304, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bitwise_or_1, [2, 1, 2048, 2048]), kwargs = {})
98
+ # %view_10 : Tensor "b8[2, 1, 16, 128, 16, 128][4194304, 4194304, 262144, 2048, 128, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%expand, [2, 1, 16, 128, 16, 128]), kwargs = {})
99
+ # %permute : Tensor "b8[2, 1, 16, 16, 128, 128][4194304, 4194304, 262144, 128, 2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 1, 2, 4, 3, 5]), kwargs = {})
100
+ # %sum_1 : Tensor "i64[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=3] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute, [-2, -1]), kwargs = {})
101
+ # return %sum_1
102
+ triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', '''
103
+ import triton
104
+ import triton.language as tl
105
+
106
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
107
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
108
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
109
+ triton_helpers.set_driver_to_gpu()
110
+
111
+ @triton_heuristics.reduction(
112
+ size_hints={'x': 512, 'r0_': 16384},
113
+ reduction_hint=ReductionHint.INNER,
114
+ filename=__file__,
115
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
116
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8192, 'r0_': 0}}
117
+ )
118
+ @triton.jit
119
+ def triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
120
+ xnumel = 512
121
+ r0_numel = 16384
122
+ rnumel = r0_numel
123
+ RBLOCK: tl.constexpr = R0_BLOCK
124
+ xoffset = tl.program_id(0) * XBLOCK
125
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
126
+ xmask = xindex < xnumel
127
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
128
+ rbase = r0_base
129
+ x1 = ((xindex // 16) % 16)
130
+ x0 = (xindex % 16)
131
+ x2 = xindex // 256
132
+ tmp3 = tl.load(in_ptr0 + (x2), xmask, eviction_policy='evict_last')
133
+ _tmp29 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
134
+ x6 = xindex
135
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
136
+ r0_index = r0_offset + r0_base
137
+ r0_mask = r0_index < r0_numel
138
+ roffset = r0_offset
139
+ rindex = r0_index
140
+ r0_4 = r0_index // 128
141
+ r0_3 = (r0_index % 128)
142
+ tmp0 = r0_4 + 128*x1
143
+ tmp1 = r0_3 + 128*x0
144
+ tmp2 = tmp0 >= tmp1
145
+ tmp4 = tmp1 < tmp3
146
+ tmp5 = tmp0 < tmp3
147
+ tmp6 = tmp4 & tmp5
148
+ tmp7 = tmp2 & tmp6
149
+ tmp8 = tl.full([1, 1], False, tl.int1)
150
+ tmp9 = tmp8 | tmp7
151
+ tmp10 = tl.full([1, 1], 2048, tl.int64)
152
+ tmp11 = tmp1 >= tmp10
153
+ tmp12 = tmp11 & tmp4
154
+ tmp13 = r0_3 + ((-1)*r0_4) + ((-128)*x1) + 128*x0
155
+ tmp14 = (tmp13 % tmp10)
156
+ tmp15 = tl.full([1, 1], 0, tl.int32)
157
+ tmp16 = tmp14 != tmp15
158
+ tmp17 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
159
+ tmp18 = (libdevice.signbit(tmp10) != 0) if (tmp10).dtype is tl.float32 else tmp10 < 0
160
+ tmp19 = tmp17 != tmp18
161
+ tmp20 = tmp16 & tmp19
162
+ tmp21 = tmp14 + tmp10
163
+ tmp22 = tl.where(tmp20, tmp21, tmp14)
164
+ tmp23 = tl.full([1, 1], 0, tl.int64)
165
+ tmp24 = tmp22 == tmp23
166
+ tmp25 = tmp12 & tmp24
167
+ tmp26 = tmp9 | tmp25
168
+ tmp27 = tmp26.to(tl.int64)
169
+ tmp28 = tl.broadcast_to(tmp27, [XBLOCK, R0_BLOCK])
170
+ tmp30 = _tmp29 + tmp28
171
+ _tmp29 = tl.where(r0_mask & xmask, tmp30, _tmp29)
172
+ tmp29 = tl.sum(_tmp29, 1)[:, None]
173
+ tl.store(out_ptr0 + (x6), tmp29, xmask)
174
+ ''', device_str='cuda')
175
+
176
+
177
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/cm/ccmqky4m65yifqjmfuu7vgvpuhwpa4ybaxffiy3mu2e6yzgecghe.py
178
+ # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros]
179
+ # Source node to ATen node mapping:
180
+ # dense_mask_4 => full_default_4
181
+ # Graph fragment:
182
+ # %full_default_4 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False})
183
+ # return %index_put_1
184
+ triton_poi_fused_new_zeros_1 = async_compile.triton('triton_poi_fused_new_zeros_1', '''
185
+ import triton
186
+ import triton.language as tl
187
+
188
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
189
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
190
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
191
+ triton_helpers.set_driver_to_gpu()
192
+
193
+ @triton_heuristics.pointwise(
194
+ size_hints={'x': 1024},
195
+ filename=__file__,
196
+ triton_meta={'signature': {'out_ptr0': '*i32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
197
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_new_zeros_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 4352}},
198
+ min_elem_per_thread=0
199
+ )
200
+ @triton.jit
201
+ def triton_poi_fused_new_zeros_1(out_ptr0, xnumel, XBLOCK : tl.constexpr):
202
+ xnumel = 544
203
+ xoffset = tl.program_id(0) * XBLOCK
204
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
205
+ xmask = xindex < xnumel
206
+ x0 = xindex
207
+ tmp0 = tl.full([1], 0, tl.int32)
208
+ tl.store(out_ptr0 + (x0), tmp0, xmask)
209
+ ''', device_str='cuda')
210
+
211
+
212
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/iw/ciwoxk7cuonocxkjitlvfvf5jppmr2duv6vgwzkwaw4xszgcaf5m.py
213
+ # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
214
+ # Source node to ATen node mapping:
215
+ # arange_4 => iota_4
216
+ # arange_6 => iota_8
217
+ # child_3 => convert_element_type_3
218
+ # child_4 => convert_element_type_4
219
+ # child_7 => convert_element_type_6
220
+ # child_8 => convert_element_type_7
221
+ # col_indices => sort
222
+ # col_indices_1 => sort_1
223
+ # col_range => iota_5
224
+ # col_range_1 => iota_9
225
+ # dense_mask => convert_element_type_2
226
+ # dense_mask_1 => convert_element_type_5
227
+ # dense_mask_2 => full_default_1
228
+ # dense_mask_4 => full_default_4
229
+ # full_blocks => eq_1
230
+ # full_blocks_1 => convert_element_type_1
231
+ # gt => gt
232
+ # index_mask => lt_4
233
+ # index_mask_1 => lt_5
234
+ # lt_3 => lt_3
235
+ # num_blocks_in_row => sum_2
236
+ # num_blocks_in_row_1 => sum_3
237
+ # partial_blocks => bitwise_and_4
238
+ # partial_blocks_1 => convert_element_type
239
+ # row_indices => unsqueeze
240
+ # row_indices_1 => unsqueeze_7
241
+ # setitem => full_default_3, index_put, iota_6, iota_7, unsqueeze_2, unsqueeze_3, unsqueeze_4, unsqueeze_5, unsqueeze_6
242
+ # setitem_1 => full_default_6, index_put_1, iota_10, iota_11, unsqueeze_10, unsqueeze_11, unsqueeze_12, unsqueeze_13, unsqueeze_9
243
+ # unsqueeze_1 => unsqueeze_1
244
+ # unsqueeze_3 => unsqueeze_8
245
+ # valid_indices => full_default_2, where
246
+ # valid_indices_1 => full_default_5, where_1
247
+ # Graph fragment:
248
+ # %sum_1 : Tensor "i64[2, 1, 16, 16][256, 512, 16, 1]cuda:2" = PlaceHolder[target=sum_1]
249
+ # %sum_2 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:2" = PlaceHolder[target=sum_2]
250
+ # %sum_3 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:2" = PlaceHolder[target=sum_3]
251
+ # %buf2 : Tensor "i16[2, 1, 16, 16][256, 512, 16, 1]cuda:2" = PlaceHolder[target=buf2]
252
+ # %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=convert_element_type_3]
253
+ # %convert_element_type_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=convert_element_type_4]
254
+ # %index_put : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2" = PlaceHolder[target=index_put]
255
+ # %buf4 : Tensor "i16[2, 1, 16, 16][256, 512, 16, 1]cuda:2" = PlaceHolder[target=buf4]
256
+ # %convert_element_type_6 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2" = PlaceHolder[target=convert_element_type_6]
257
+ # %convert_element_type_7 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2" = PlaceHolder[target=convert_element_type_7]
258
+ # %index_put_1 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2" = PlaceHolder[target=index_put_1]
259
+ # %gt : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
260
+ # %lt_3 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%sum_1, 16384), kwargs = {})
261
+ # %bitwise_and_4 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%gt, %lt_3), kwargs = {})
262
+ # %convert_element_type : Tensor "i8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%bitwise_and_4, torch.int8), kwargs = {})
263
+ # %convert_element_type_2 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type, torch.int32), kwargs = {})
264
+ # %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_2,), kwargs = {stable: True, descending: True})
265
+ # %eq_1 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%sum_1, 16384), kwargs = {})
266
+ # %convert_element_type_1 : Tensor "i8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%eq_1, torch.int8), kwargs = {})
267
+ # %convert_element_type_5 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_1, torch.int32), kwargs = {})
268
+ # %sort_1 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%convert_element_type_5,), kwargs = {stable: True, descending: True})
269
+ # %full_default_1 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False})
270
+ # %iota_7 : Tensor "i64[2][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False})
271
+ # %unsqueeze_4 : Tensor "i64[2, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_7, -1), kwargs = {})
272
+ # %unsqueeze_5 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_4, -1), kwargs = {})
273
+ # %unsqueeze_6 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_5, -1), kwargs = {})
274
+ # %iota_6 : Tensor "i64[1][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False})
275
+ # %unsqueeze_2 : Tensor "i64[1, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_6, -1), kwargs = {})
276
+ # %unsqueeze_3 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_2, -1), kwargs = {})
277
+ # %iota_4 : Tensor "i32[16][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:2, requires_grad: False})
278
+ # %unsqueeze : Tensor "i32[16, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_4, -1), kwargs = {})
279
+ # %iota_5 : Tensor "i32[16][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:2, requires_grad: False})
280
+ # %sum_2 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_2, [-1]), kwargs = {})
281
+ # %convert_element_type_3 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.int32), kwargs = {})
282
+ # %unsqueeze_1 : Tensor "i32[2, 1, 16, 1][16, 16, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_3, 3), kwargs = {})
283
+ # %lt_4 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_5, %unsqueeze_1), kwargs = {})
284
+ # %convert_element_type_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_1, torch.int32), kwargs = {})
285
+ # %full_default_2 : Tensor "i32[][]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False})
286
+ # %where : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_4, %convert_element_type_4, %full_default_2), kwargs = {})
287
+ # %full_default_3 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False})
288
+ # %index_put : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_1, [%unsqueeze_6, %unsqueeze_3, %unsqueeze, %where], %full_default_3), kwargs = {})
289
+ # %full_default_4 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 16, 17], 0), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False})
290
+ # %iota_11 : Tensor "i64[2][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (2,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False})
291
+ # %unsqueeze_11 : Tensor "i64[2, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_11, -1), kwargs = {})
292
+ # %unsqueeze_12 : Tensor "i64[2, 1, 1][1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_11, -1), kwargs = {})
293
+ # %unsqueeze_13 : Tensor "i64[2, 1, 1, 1][1, 1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_12, -1), kwargs = {})
294
+ # %iota_10 : Tensor "i64[1][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (1,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:2, requires_grad: False})
295
+ # %unsqueeze_9 : Tensor "i64[1, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_10, -1), kwargs = {})
296
+ # %unsqueeze_10 : Tensor "i64[1, 1, 1][1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_9, -1), kwargs = {})
297
+ # %iota_8 : Tensor "i32[16][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:2, requires_grad: False})
298
+ # %unsqueeze_7 : Tensor "i32[16, 1][1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%iota_8, -1), kwargs = {})
299
+ # %iota_9 : Tensor "i32[16][1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (16,), kwargs = {start: 0, step: 1, dtype: torch.int32, device: cuda:2, requires_grad: False})
300
+ # %sum_3 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%convert_element_type_5, [-1]), kwargs = {})
301
+ # %convert_element_type_6 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_3, torch.int32), kwargs = {})
302
+ # %unsqueeze_8 : Tensor "i32[2, 1, 16, 1][16, 16, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_6, 3), kwargs = {})
303
+ # %lt_5 : Tensor "b8[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.lt.Tensor](args = (%iota_9, %unsqueeze_8), kwargs = {})
304
+ # %convert_element_type_7 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_3, torch.int32), kwargs = {})
305
+ # %full_default_5 : Tensor "i32[][]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 16), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False})
306
+ # %where_1 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%lt_5, %convert_element_type_7, %full_default_5), kwargs = {})
307
+ # %full_default_6 : Tensor "i32[2, 1, 1, 1][1, 1, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 1, 1, 1], 1), kwargs = {dtype: torch.int32, layout: torch.strided, device: cuda:2, pin_memory: False})
308
+ # %index_put_1 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_4, [%unsqueeze_13, %unsqueeze_10, %unsqueeze_7, %where_1], %full_default_6), kwargs = {})
309
+ # return %buf2,%buf4,%sum_2,%sum_3,%convert_element_type_3,%convert_element_type_6,%convert_element_type_4,%buf9,%convert_element_type_7,%buf16
310
+ triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2 = async_compile.triton('triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', '''
311
+ import triton
312
+ import triton.language as tl
313
+
314
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
315
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
316
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
317
+ triton_helpers.set_driver_to_gpu()
318
+
319
+ @triton_heuristics.persistent_reduction(
320
+ size_hints={'x': 32, 'r0_': 16},
321
+ reduction_hint=ReductionHint.DEFAULT,
322
+ filename=__file__,
323
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr4': '*i32', 'out_ptr5': '*i32', 'out_ptr6': '*i32', 'out_ptr7': '*i32', 'out_ptr8': '*i32', 'out_ptr9': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
324
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2', 'mutated_arg_names': ['out_ptr7', 'out_ptr9'], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
325
+ )
326
+ @triton.jit
327
+ def triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2(in_ptr0, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, out_ptr9, xnumel, r0_numel, XBLOCK : tl.constexpr):
328
+ xnumel = 32
329
+ r0_numel = 16
330
+ R0_BLOCK: tl.constexpr = 16
331
+ rnumel = r0_numel
332
+ RBLOCK: tl.constexpr = R0_BLOCK
333
+ xoffset = tl.program_id(0) * XBLOCK
334
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
335
+ xmask = xindex < xnumel
336
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
337
+ r0_offset = 0
338
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
339
+ roffset = r0_offset
340
+ rindex = r0_index
341
+ r0_1 = r0_index
342
+ x0 = xindex
343
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 16*x0), xmask, other=0.0)
344
+ tmp1 = tl.full([1, 1], 0, tl.int64)
345
+ tmp2 = tmp0 > tmp1
346
+ tmp3 = tl.full([1, 1], 16384, tl.int64)
347
+ tmp4 = tmp0 < tmp3
348
+ tmp5 = tmp2 & tmp4
349
+ tmp6 = tmp5.to(tl.int8)
350
+ tmp7 = tmp6.to(tl.int32)
351
+ tmp8 = r0_1
352
+ tmp9 = tmp8.to(tl.int16)
353
+ tmp10 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
354
+ tmp11 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
355
+ tmp12, tmp13, = triton_helpers.sort_with_index(tmp10, tmp11, None, 1, stable=True, descending=True)
356
+ tmp14 = tmp0 == tmp3
357
+ tmp15 = tmp14.to(tl.int8)
358
+ tmp16 = tmp15.to(tl.int32)
359
+ tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
360
+ tmp18, tmp19, = triton_helpers.sort_with_index(tmp17, tmp11, None, 1, stable=True, descending=True)
361
+ tmp20 = tmp7.to(tl.int64)
362
+ tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK])
363
+ tmp23 = tl.where(xmask, tmp21, 0)
364
+ tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.int64)
365
+ tmp25 = tmp16.to(tl.int64)
366
+ tmp26 = tl.broadcast_to(tmp25, [XBLOCK, R0_BLOCK])
367
+ tmp28 = tl.where(xmask, tmp26, 0)
368
+ tmp29 = tl.sum(tmp28, 1)[:, None].to(tl.int64)
369
+ tmp30 = tmp24.to(tl.int32)
370
+ tmp31 = tmp29.to(tl.int32)
371
+ tmp32 = tmp13.to(tl.int64)
372
+ tmp33 = tmp32.to(tl.int32)
373
+ tmp34 = tmp8 < tmp30
374
+ tmp35 = tl.full([1, 1], 16, tl.int32)
375
+ tmp36 = tl.where(tmp34, tmp33, tmp35)
376
+ tmp37 = tl.full([XBLOCK, R0_BLOCK], 17, tl.int32)
377
+ tmp38 = tmp36 + tmp37
378
+ tmp39 = tmp36 < 0
379
+ tmp40 = tl.where(tmp39, tmp38, tmp36)
380
+ tl.device_assert(((0 <= tmp40) & (tmp40 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp40 < 17")
381
+ tmp42 = tl.full([1, 1], 1, tl.int32)
382
+ tmp43 = tmp19.to(tl.int64)
383
+ tmp44 = tmp43.to(tl.int32)
384
+ tmp45 = tmp8 < tmp31
385
+ tmp46 = tl.where(tmp45, tmp44, tmp35)
386
+ tmp47 = tmp46 + tmp37
387
+ tmp48 = tmp46 < 0
388
+ tmp49 = tl.where(tmp48, tmp47, tmp46)
389
+ tl.device_assert(((0 <= tmp49) & (tmp49 < 17)) | ~(xmask), "index out of bounds: 0 <= tmp49 < 17")
390
+ tl.store(out_ptr4 + (x0), tmp30, xmask)
391
+ tl.store(out_ptr5 + (x0), tmp31, xmask)
392
+ tl.store(out_ptr6 + (r0_1 + 16*x0), tmp33, xmask)
393
+ tl.store(out_ptr7 + (tl.broadcast_to(tmp40 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
394
+ tl.store(out_ptr8 + (r0_1 + 16*x0), tmp44, xmask)
395
+ tl.store(out_ptr9 + (tl.broadcast_to(tmp49 + 17*x0, [XBLOCK, R0_BLOCK])), tmp42, xmask)
396
+ ''', device_str='cuda')
397
+
398
+
399
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/bp/cbp4ofim2oujxe6hm47xzugia67k4kofgbgvt7n7d5gd3iux76li.py
400
+ # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
401
+ # Source node to ATen node mapping:
402
+ # batched_outputs_3 => clone_4, slice_2
403
+ # col_indices_2 => sort_2
404
+ # num_blocks_in_row_2 => sum_4
405
+ # q_indices => clone_6, convert_element_type_9
406
+ # q_num_blocks => convert_element_type_8
407
+ # transpose => permute_1
408
+ # Graph fragment:
409
+ # %buf9 : Tensor "i32[2, 1, 16, 17][272, 272, 17, 1]cuda:2" = PlaceHolder[target=buf9]
410
+ # %buf11 : Tensor "i16[2, 1, 16, 16][256, 512, 16, 1]cuda:2" = PlaceHolder[target=buf11]
411
+ # %sum_4 : Tensor "i64[2, 1, 16][16, 32, 1]cuda:2" = PlaceHolder[target=sum_4]
412
+ # %slice_2 : Tensor "i32[2, 1, 16, 16][272, 272, 17, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%index_put, 3, 0, 16), kwargs = {})
413
+ # %clone_4 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_2,), kwargs = {memory_format: torch.contiguous_format})
414
+ # %permute_1 : Tensor "i32[2, 1, 16, 16][256, 256, 1, 16]cuda:2"[num_users=2] = call_function[target=torch.ops.aten.permute.default](args = (%clone_4, [0, 1, 3, 2]), kwargs = {})
415
+ # %sort_2 : [num_users=1] = call_function[target=torch.ops.aten.sort.stable](args = (%permute_1,), kwargs = {stable: True, descending: True})
416
+ # %convert_element_type_9 : Tensor "i32[2, 1, 16, 16][256, 256, 1, 16]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_5, torch.int32), kwargs = {})
417
+ # %clone_6 : Tensor "i32[2, 1, 16, 16][256, 256, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_9,), kwargs = {memory_format: torch.contiguous_format})
418
+ # %sum_4 : Tensor "i64[2, 1, 16][16, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%permute_1, [-1]), kwargs = {})
419
+ # %convert_element_type_8 : Tensor "i32[2, 1, 16][16, 16, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_4, torch.int32), kwargs = {})
420
+ # return %buf11,%sum_4,%clone_6,%convert_element_type_8
421
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3 = async_compile.triton('triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', '''
422
+ import triton
423
+ import triton.language as tl
424
+
425
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
426
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
427
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
428
+ triton_helpers.set_driver_to_gpu()
429
+
430
+ @triton_heuristics.persistent_reduction(
431
+ size_hints={'x': 32, 'r0_': 16},
432
+ reduction_hint=ReductionHint.DEFAULT,
433
+ filename=__file__,
434
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
435
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 256, 'r0_': 4096}}
436
+ )
437
+ @triton.jit
438
+ def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr):
439
+ xnumel = 32
440
+ r0_numel = 16
441
+ R0_BLOCK: tl.constexpr = 16
442
+ rnumel = r0_numel
443
+ RBLOCK: tl.constexpr = R0_BLOCK
444
+ xoffset = tl.program_id(0) * XBLOCK
445
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
446
+ xmask = xindex < xnumel
447
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
448
+ r0_offset = 0
449
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
450
+ roffset = r0_offset
451
+ rindex = r0_index
452
+ r0_2 = r0_index
453
+ x0 = (xindex % 16)
454
+ x1 = xindex // 16
455
+ x3 = xindex
456
+ tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0)
457
+ tmp1 = r0_2
458
+ tmp2 = tmp1.to(tl.int16)
459
+ tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
460
+ tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
461
+ tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True)
462
+ tmp7 = tmp0.to(tl.int64)
463
+ tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
464
+ tmp10 = tl.where(xmask, tmp8, 0)
465
+ tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64)
466
+ tmp12 = tmp6.to(tl.int64)
467
+ tmp13 = tmp12.to(tl.int32)
468
+ tmp14 = tmp11.to(tl.int32)
469
+ tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask)
470
+ tl.store(out_ptr3 + (x3), tmp14, xmask)
471
+ ''', device_str='cuda')
472
+
473
+
474
+ async_compile.wait(globals())
475
+ del async_compile
476
+
477
+ class Runner:
478
+ def __init__(self, partitions):
479
+ self.partitions = partitions
480
+
481
+ def recursively_apply_fns(self, fns):
482
+ new_callables = []
483
+ for fn, c in zip(fns, self.partitions):
484
+ new_callables.append(fn(c))
485
+ self.partitions = new_callables
486
+
487
+ def call(self, args):
488
+ arg0_1, = args
489
+ args.clear()
490
+ assert_size_stride(arg0_1, (2, ), (1, ))
491
+ with torch.cuda._DeviceGuard(2):
492
+ torch.cuda.set_device(2)
493
+ buf0 = empty_strided_cuda((2, 1, 16, 16), (256, 512, 16, 1), torch.int64)
494
+ # Topologically Sorted Source Nodes: [result_1, m, causal_mask, n, b, index, lt, padding_mask, index_1, lt_1, and_2, suffix_mask, remainder, index_2, padding_mask_1, and_3, and_4, sub, remainder_1, diagnol_mask, result_2, batched_outputs_2, mask_2, mask_3, mask_block_sum], Original ATen: [aten.view, aten.arange, aten.ge, aten.index, aten.lt, aten.bitwise_and, aten.bitwise_or, aten.remainder, aten.sub, aten.eq, aten.permute, aten.sum]
495
+ stream2 = get_raw_stream(2)
496
+ triton_red_fused_arange_bitwise_and_bitwise_or_eq_ge_index_lt_permute_remainder_sub_sum_view_0.run(arg0_1, buf0, 512, 16384, stream=stream2)
497
+ del arg0_1
498
+ buf15 = empty_strided_cuda((2, 1, 16, 17), (272, 272, 17, 1), torch.int32)
499
+ # Topologically Sorted Source Nodes: [dense_mask_4], Original ATen: [aten.new_zeros]
500
+ stream2 = get_raw_stream(2)
501
+ triton_poi_fused_new_zeros_1.run(buf15, 544, stream=stream2)
502
+ buf8 = empty_strided_cuda((2, 1, 16, 17), (272, 272, 17, 1), torch.int32)
503
+ # Topologically Sorted Source Nodes: [dense_mask_2], Original ATen: [aten.new_zeros]
504
+ stream2 = get_raw_stream(2)
505
+ triton_poi_fused_new_zeros_1.run(buf8, 544, stream=stream2)
506
+ buf6 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32)
507
+ buf13 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32)
508
+ buf7 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32)
509
+ buf14 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32)
510
+ # Topologically Sorted Source Nodes: [gt, lt_3, partial_blocks, partial_blocks_1, dense_mask, col_indices, full_blocks, full_blocks_1, dense_mask_1, col_indices_1, dense_mask_2, setitem, arange_4, row_indices, col_range, num_blocks_in_row, child_3, unsqueeze_1, index_mask, child_4, valid_indices, dense_mask_4, setitem_1, arange_6, row_indices_1, col_range_1, num_blocks_in_row_1, child_7, unsqueeze_3, index_mask_1, child_8, valid_indices_1], Original ATen: [aten.gt, aten.lt, aten.bitwise_and, aten._to_copy, aten.sort, aten.eq, aten.new_zeros, aten.arange, aten.unsqueeze, aten.sum, aten.scalar_tensor, aten.where, aten.view, aten.index_put]
511
+ stream2 = get_raw_stream(2)
512
+ triton_per_fused__to_copy_arange_bitwise_and_eq_gt_index_put_lt_new_zeros_scalar_tensor_sort_sum_unsqueeze_view_where_2.run(buf0, buf6, buf13, buf7, buf8, buf14, buf15, 32, 16, stream=stream2)
513
+ del buf0
514
+ buf22 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32)
515
+ buf24 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32)
516
+ # Topologically Sorted Source Nodes: [batched_outputs_3, transpose, col_indices_2, q_indices, num_blocks_in_row_2, q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
517
+ stream2 = get_raw_stream(2)
518
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf8, buf22, buf24, 32, 16, stream=stream2)
519
+ del buf8
520
+ buf19 = empty_strided_cuda((2, 1, 16, 16), (256, 256, 16, 1), torch.int32)
521
+ buf21 = empty_strided_cuda((2, 1, 16), (16, 16, 1), torch.int32)
522
+ # Topologically Sorted Source Nodes: [batched_outputs_5, transpose_1, col_indices_3, full_q_indices, num_blocks_in_row_3, full_q_num_blocks], Original ATen: [aten.slice, aten.clone, aten.transpose, aten.sort, aten._to_copy, aten.sum]
523
+ stream2 = get_raw_stream(2)
524
+ triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3.run(buf15, buf19, buf21, 32, 16, stream=stream2)
525
+ del buf15
526
+ return (buf19, buf21, buf22, buf24, buf14, buf13, buf7, buf6, )
527
+
528
+ runner = Runner(partitions=[])
529
+ call = runner.call
530
+ recursively_apply_fns = runner.recursively_apply_fns
531
+
532
+
533
+ def benchmark_compiled_module(times=10, repeat=10):
534
+ from torch._dynamo.testing import rand_strided
535
+ from torch._inductor.utils import print_performance
536
+ arg0_1 = rand_strided((2, ), (1, ), device='cuda:2', dtype=torch.int64)
537
+ fn = lambda: call([arg0_1])
538
+ return print_performance(fn, times=times, repeat=repeat)
539
+
540
+
541
+ if __name__ == "__main__":
542
+ from torch._inductor.wrapper_benchmark import compiled_module_main
543
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/ii/ciiz7wynjvqkn6uv5csahwryt5x2d664u4o7ugmepfcsfcniut4v.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 32768},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=6, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmax_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 65536, 'r0_': 524288000}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_argmax_1(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 4096
20
+ r0_numel = 32000
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x0 = (xindex % 2048)
29
+ x1 = xindex // 2048
30
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
31
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
32
+ x3 = xindex
33
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
34
+ r0_index = r0_offset + r0_base
35
+ r0_mask = r0_index < r0_numel
36
+ roffset = r0_offset
37
+ rindex = r0_index
38
+ r0_2 = r0_index
39
+ tmp0 = tl.load(in_ptr0 + (r0_2 + 32000*x0 + 65760000*x1), r0_mask, eviction_policy='evict_first', other=0.0)
40
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
41
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
42
+ _tmp2, _tmp2_index, tmp1, rindex
43
+ )
44
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
45
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
46
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
47
+ tmp2 = tmp2_idx[:, None]
48
+ tl.store(out_ptr0 + (x3), tmp2, None)
SpecForge-ext/cache/compiled_kernels/ik/ciksm4jphopwjgs55fbipcxecpw4d643lh76mj27636ryec4e3kg.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+
9
+ @triton_heuristics.template(
10
+
11
+ num_stages=3,
12
+ num_warps=8,
13
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_MAX': '*fp32', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'in_ptr9': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'kernel_name': 'triton_tem_fused_0', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'USE_TMA': False, 'BLOCK_M': 128, 'BLOCK_N': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
15
+
16
+ )
17
+ @triton.jit
18
+ def triton_tem_fused_0(arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5):
19
+ PRESCALE_QK : tl.constexpr = False
20
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
21
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
22
+ WRITE_DQ : tl.constexpr = True
23
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
24
+ OUTPUT_MAX : tl.constexpr = False
25
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
26
+ IS_DIVISIBLE : tl.constexpr = False
27
+ SM_SCALE : tl.constexpr = 0.08838834764831843
28
+ GQA_SHARED_HEADS : tl.constexpr = 4
29
+ HAS_FULL_BLOCKS : tl.constexpr = True
30
+ QK_HEAD_DIM : tl.constexpr = 128
31
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
32
+ V_HEAD_DIM : tl.constexpr = 128
33
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
34
+ SAFE_HEAD_DIM : tl.constexpr = True
35
+ USE_TMA : tl.constexpr = False
36
+ BLOCK_M : tl.constexpr = 128
37
+ BLOCK_N : tl.constexpr = 64
38
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
39
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
40
+ INDEX_DTYPE : tl.constexpr = tl.int32
41
+ Q = arg_Q
42
+ K = arg_K
43
+ V = arg_V
44
+ LSE = arg_LSE
45
+ MAX = arg_MAX
46
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
47
+ KV_IDX = arg_KV_IDX
48
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
49
+ FULL_KV_IDX = arg_FULL_KV_IDX
50
+
51
+ # Sub notation for this kernel:
52
+ #
53
+ # Q: Query, K: Key, V: Value
54
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
55
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
56
+ # V_HEAD_DIM: The dimension of the value embeddings
57
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
61
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
62
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
63
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
64
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
65
+ #
66
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
67
+ #
68
+ # (Modifiable) Performance tuning options
69
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
70
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
71
+
72
+ # The below are kernel options that can be applied for certain score_mods,
73
+ # or involve a numerics vs. perf tradeoff
74
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
75
+ # about 20% more numerical error, but slightly faster.
76
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
77
+ # is not masked out? If so, we can skip an extra safety check
78
+ # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
79
+ # contiguous? If so, we don't need to do an indirect jump for every block
80
+
81
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
82
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
83
+
84
+ # Define strides of inputs
85
+ stride_qz, stride_qh, stride_qm, stride_qk = 4096*ks0, 128, 4096, 1
86
+ stride_kz, stride_kh, stride_kn, stride_kk = 1024*ks1, 128*ks1, 128, 1
87
+ stride_vz, stride_vh, stride_vn, stride_vk = 1024*ks1, 128*ks1, 128, 1
88
+
89
+ ZQ = 2
90
+ HQ = 32
91
+ Q_LEN = ks0
92
+ ZKV = 2
93
+ KV_LEN = ks1
94
+
95
+ MATMUL_PRECISION = Q.dtype.element_ty
96
+
97
+ q_start = tl.program_id(0).to(INDEX_DTYPE)
98
+ off_zq = tl.program_id(1).to(INDEX_DTYPE)
99
+ off_hq = tl.program_id(2).to(INDEX_DTYPE)
100
+
101
+ # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
102
+ # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
103
+ off_zkv = off_zq % ZKV
104
+ off_hkv = off_hq // GQA_SHARED_HEADS
105
+ off_g = off_hq % GQA_SHARED_HEADS
106
+
107
+ q_offset = off_zq * stride_qz + off_hq * stride_qh
108
+ k_offset = off_zkv * stride_kz + off_hkv * stride_kh
109
+ v_offset = off_zkv * stride_vz + off_hkv * stride_vh
110
+
111
+ Q = Q + q_offset
112
+ K = K + k_offset
113
+ V = V + v_offset
114
+
115
+ # Setting up the TMA descriptors for Q, K, V
116
+ desc_q = None
117
+ desc_k = None
118
+ desc_v = None
119
+
120
+ SPARSE_Z = 2
121
+ SPARSE_HQ = 1
122
+
123
+ sparse_idx_z = off_zq % SPARSE_Z
124
+ sparse_idx_hq = off_hq % SPARSE_HQ
125
+
126
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
127
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
128
+
129
+ stride_kv_num_blks_h = ks2
130
+ stride_kv_idx_h = ks3*ks4
131
+ stride_kv_idx_m = ks4
132
+
133
+ # initialize pointer to m and l
134
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
136
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
137
+
138
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
139
+
140
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
141
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
142
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
143
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
144
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
146
+ q = load_checked_2d(Q, offs_m, offs_k, stride_qm, stride_qk, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
147
+
148
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
+ # We don't know anything "special" about these blocks, so we need to apply
150
+ # both score_mod and mask_mod to it
151
+ kv_indices = KV_IDX + sparse_kv_idx_offset
152
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
153
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
154
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
155
+
156
+
157
+ # K and V pointers will be passed directly to forward_inner
158
+
159
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
160
+
161
+
162
+ acc, l_i, m_i = forward_inner(
163
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
164
+ q, K, V,
165
+ desc_k, desc_v, Q_LEN, KV_LEN,
166
+ acc, l_i, m_i,
167
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
168
+ kv_start,
169
+ kv_indices, kv_num_blocks,
170
+ 0, block_n_end,
171
+ MATMUL_PRECISION,
172
+ stride_kk, stride_kn, stride_vn, stride_vk,
173
+ IS_FULL_BLOCKS=False,
174
+ )
175
+
176
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177
+ # We know these blocks are guaranteed to be "full", so we don't need to
178
+ # apply mask_mod to them - only score_mod
179
+ if HAS_FULL_BLOCKS:
180
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
181
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
182
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
183
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
184
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
185
+ # K and V pointers will be passed directly to forward_inner
186
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
187
+
188
+ acc, l_i, m_i = forward_inner(
189
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
190
+ q, K, V,
191
+ desc_k, desc_v, Q_LEN, KV_LEN,
192
+ acc, l_i, m_i,
193
+ off_zq, off_hq, offs_m[:, None], offs_n[None, :],
194
+ kv_start,
195
+ kv_indices, kv_num_blocks,
196
+ 0, block_n_end,
197
+ MATMUL_PRECISION,
198
+ stride_kk, stride_kn, stride_vn, stride_vk,
199
+ IS_FULL_BLOCKS=True,
200
+ )
201
+
202
+
203
+ # [Note] Handle fully masked out rows:
204
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
205
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
206
+ l_i = tl.where(l_i == 0.0, 1, l_i)
207
+
208
+ acc = acc / l_i[:, None]
209
+ idx_zq = tl.program_id(1).to(INDEX_DTYPE)
210
+ idx_hq = tl.program_id(2).to(INDEX_DTYPE)
211
+ idx_m = offs_m[:, None].to(INDEX_DTYPE)
212
+ idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :].to(INDEX_DTYPE)
213
+
214
+ mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
215
+
216
+ tl.static_assert(acc.shape == [BLOCK_M, V_HEAD_DIM_ROUNDED])
217
+ xindex = idx_d + 128*idx_m + 128*idx_hq*ks0 + 4096*idx_zq*ks0
218
+ tl.store(out_ptr0 + (tl.broadcast_to(idx_d + 128*idx_hq + 4096*idx_m + 4096*idx_zq*ks0, acc.shape)), acc, mask)
219
+
220
+ if OUTPUT_LOGSUMEXP:
221
+ off_hz = off_zq * HQ + off_hq
222
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
223
+ lse = m_i + tl.math.log2(l_i)
224
+ if IS_DIVISIBLE:
225
+ tl.store(l_ptrs, lse)
226
+ else:
227
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
228
+
229
+ if OUTPUT_MAX:
230
+ off_hz = off_zq * HQ + off_hq
231
+ max_ptrs = MAX + off_hz * Q_LEN + offs_m
232
+ if IS_DIVISIBLE:
233
+ tl.store(max_ptrs, m_i)
234
+ else:
235
+ tl.store(max_ptrs, m_i, mask=offs_m < Q_LEN)
236
+
237
+
238
+ # Utility triton funcs
239
+ @triton.jit
240
+ def get_offset_for_next_block(
241
+ loop_iter, col_indices, total_blocks,
242
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
243
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
244
+ ):
245
+ if BLOCKS_ARE_CONTIGUOUS:
246
+ return BLOCK
247
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
248
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
249
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
250
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
251
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
252
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
253
+ return offset
254
+
255
+ @triton.jit
256
+ def get_bounded_indices(indices, max_len=None):
257
+ return indices % max_len if max_len is not None else indices
258
+
259
+ @triton.jit
260
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
261
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
262
+ return tl.load(block_ptr)
263
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
264
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
265
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
266
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
267
+ else:
268
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
269
+
270
+ @triton.jit
271
+ def load_checked_2d(
272
+ ptr,
273
+ offs_m,
274
+ offs_n,
275
+ stride_m,
276
+ stride_n,
277
+ IS_DIVISIBLE_M: tl.constexpr,
278
+ IS_DIVISIBLE_N: tl.constexpr,
279
+ M_LEN: tl.constexpr,
280
+ N_LEN: tl.constexpr,
281
+ ):
282
+ # Calculate final pointer if strides are provided
283
+ if stride_m is not None and stride_n is not None:
284
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
285
+
286
+ # Handle all masking cases
287
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
288
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
289
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
290
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
291
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
292
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
293
+ else: # Both divisible
294
+ return tl.load(ptr)
295
+
296
+
297
+ # Common Imports
298
+ @triton.jit
299
+ def forward_block_mn(
300
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
301
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
302
+ # accumulated values
303
+ acc, l_i, m_i,
304
+ # Offsets
305
+ off_z, off_h, offs_m, offs_n,
306
+ # Offsets needed for TMA loads
307
+ kv_start,
308
+ kv_offset,
309
+ MATMUL_PRECISION, RCP_LN2,
310
+ # Strides for K and V
311
+ stride_kk, stride_kn, stride_vn, stride_vk,
312
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
313
+
314
+ ):
315
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
316
+ PRESCALE_QK : tl.constexpr = False
317
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
318
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
319
+ WRITE_DQ : tl.constexpr = True
320
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
321
+ OUTPUT_MAX : tl.constexpr = False
322
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
323
+ IS_DIVISIBLE : tl.constexpr = False
324
+ SM_SCALE : tl.constexpr = 0.08838834764831843
325
+ GQA_SHARED_HEADS : tl.constexpr = 4
326
+ HAS_FULL_BLOCKS : tl.constexpr = True
327
+ QK_HEAD_DIM : tl.constexpr = 128
328
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
329
+ V_HEAD_DIM : tl.constexpr = 128
330
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
331
+ SAFE_HEAD_DIM : tl.constexpr = True
332
+ USE_TMA : tl.constexpr = False
333
+ BLOCK_M : tl.constexpr = 128
334
+ BLOCK_N : tl.constexpr = 64
335
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
336
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
337
+ INDEX_DTYPE : tl.constexpr = tl.int32
338
+
339
+
340
+ # -- load k --
341
+ # NB reversed order to since K is transposed
342
+ kv_base_offset = kv_start + kv_offset
343
+
344
+ # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N]
345
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
346
+ offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N)
347
+ k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
348
+
349
+ k = tl.trans(k)
350
+ # -- compute qk ---
351
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
352
+ if not PRESCALE_QK:
353
+ qk *= SM_SCALE
354
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
355
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
356
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
357
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
358
+ m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
359
+ n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
360
+
361
+ tmp0 = (qk)
362
+ post_mod_scores = tmp0
363
+
364
+
365
+ if CHECK_BLOCK_BOUNDARY:
366
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
367
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
368
+
369
+ if not IS_FULL_BLOCKS:
370
+ tmp1 = tl.full([1], False, tl.int1)
371
+ tmp2 = (m)
372
+ tmp3 = (n)
373
+ tmp4 = tmp2 >= tmp3
374
+ tmp5 = tmp3.to(tl.int64)
375
+ tmp6 = (off_z)
376
+ tmp7 = tl.load(in_ptr9 + tmp6)
377
+ tmp8 = tmp5 < tmp7
378
+ tmp9 = tmp2.to(tl.int64)
379
+ tmp10 = tmp9 < tmp7
380
+ tmp11 = tmp8 & tmp10
381
+ tmp12 = tmp4 & tmp11
382
+ tmp13 = tmp1 | tmp12
383
+ tmp14 = ks5
384
+ tmp15 = tmp3 >= tmp14
385
+ tmp16 = (tmp3 % tmp14)
386
+ tmp17 = tl.full([1], 0, tl.int32)
387
+ tmp18 = tmp16 != tmp17
388
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
389
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
390
+ tmp21 = tmp19 != tmp20
391
+ tmp22 = tmp18 & tmp21
392
+ tmp23 = tmp16 + tmp14
393
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
394
+ tmp25 = tmp24.to(tl.int64)
395
+ tmp26 = tmp25 < tmp7
396
+ tmp27 = tmp15 & tmp26
397
+ tmp28 = tmp3 - tmp2
398
+ tmp29 = (tmp28 % tmp14)
399
+ tmp30 = tmp29 != tmp17
400
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
401
+ tmp32 = tmp31 != tmp20
402
+ tmp33 = tmp30 & tmp32
403
+ tmp34 = tmp29 + tmp14
404
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
405
+ tmp36 = tmp35 == tmp17
406
+ tmp37 = tmp27 & tmp36
407
+ tmp38 = tmp13 | tmp37
408
+ mask_mod_output = tmp38
409
+
410
+
411
+ if CHECK_BLOCK_BOUNDARY:
412
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
413
+ # apply mask for partially unmasked blocks
414
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
415
+
416
+ if not PRESCALE_QK:
417
+ post_mod_scores *= RCP_LN2
418
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419
+
420
+ # -- compute scaling constant ---
421
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
422
+ if not ROWS_GUARANTEED_SAFE:
423
+ masked_out_rows = (m_ij == float("-inf"))
424
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
425
+ else:
426
+ m_ij_masked = m_ij
427
+
428
+ alpha = tl.math.exp2(m_i - m_ij_masked)
429
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
430
+
431
+ # NB: l_i update is pulled up here since it's a bit faster
432
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
433
+ # m_ij
434
+ l_i = l_i * alpha + tl.sum(p, 1)
435
+ # # -- scale and update acc --
436
+ acc = acc * alpha[:, None]
437
+ # Calculate offsets for V loading - reuse kv_base_offset from K loading
438
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
439
+ v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
440
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
441
+
442
+ # -- update m_i
443
+ m_i = m_ij
444
+
445
+ return acc, l_i, m_i
446
+
447
+ @triton.jit
448
+ def forward_inner(
449
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
450
+ q, K, V,
451
+ desc_k, desc_v, Q_LEN, KV_LEN,
452
+ # accumulated values
453
+ acc, l_i, m_i,
454
+ # Offsets used as inputs to score_mod & mask_mod
455
+ # of size [BLOCK_M, BLOCK_N] or scalar.
456
+ off_z, off_h, offs_m, offs_n,
457
+ # Offsets needed for TMA loads
458
+ kv_start,
459
+ # blocksparse data
460
+ kv_indices, kv_num_blocks,
461
+ # start kv and end kv block
462
+ block_n_start, block_n_end,
463
+ MATMUL_PRECISION,
464
+ # Strides for K and V
465
+ stride_kk, stride_kn, stride_vn, stride_vk,
466
+ IS_FULL_BLOCKS,
467
+ ):
468
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
469
+ PRESCALE_QK : tl.constexpr = False
470
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
471
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
472
+ WRITE_DQ : tl.constexpr = True
473
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
474
+ OUTPUT_MAX : tl.constexpr = False
475
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
476
+ IS_DIVISIBLE : tl.constexpr = False
477
+ SM_SCALE : tl.constexpr = 0.08838834764831843
478
+ GQA_SHARED_HEADS : tl.constexpr = 4
479
+ HAS_FULL_BLOCKS : tl.constexpr = True
480
+ QK_HEAD_DIM : tl.constexpr = 128
481
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
482
+ V_HEAD_DIM : tl.constexpr = 128
483
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
484
+ SAFE_HEAD_DIM : tl.constexpr = True
485
+ USE_TMA : tl.constexpr = False
486
+ BLOCK_M : tl.constexpr = 128
487
+ BLOCK_N : tl.constexpr = 64
488
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
489
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
490
+ INDEX_DTYPE : tl.constexpr = tl.int32
491
+
492
+
493
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
494
+ RCP_LN2: tl.constexpr = 1.44269504
495
+
496
+ if PRESCALE_QK:
497
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
498
+
499
+ kv_offset = 0
500
+
501
+ # loop over k, v and update accumulator until block_n_end
502
+ for start_n in range(block_n_start, block_n_end):
503
+ # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
504
+ if IS_DIVISIBLE:
505
+ acc, l_i, m_i = forward_block_mn(
506
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
507
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
508
+ # accumulated values
509
+ acc, l_i, m_i,
510
+ # Offsets
511
+ off_z, off_h, offs_m, offs_n,
512
+ # Offsets needed for TMA loads
513
+ kv_start,
514
+ kv_offset,
515
+ MATMUL_PRECISION, RCP_LN2,
516
+ # Strides for K and V
517
+ stride_kk, stride_kn, stride_vn, stride_vk,
518
+ IS_FULL_BLOCKS,
519
+ )
520
+ else:
521
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
522
+ # it's on par or slightly faster than only applying to the last block in fwd.
523
+ # However, we choose different strategy for bwd, where we only apply mod & mask
524
+ # to the last block because it's faster a lot.
525
+ acc, l_i, m_i = forward_block_mn(
526
+ arg_Q, arg_K, arg_V, arg_LSE, arg_MAX, arg_KV_NUM_BLKS, arg_KV_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, in_ptr9, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5,
527
+ q, K, V, desc_k, desc_v, Q_LEN, KV_LEN,
528
+ # accumulated values
529
+ acc, l_i, m_i,
530
+ # Offsets
531
+ off_z, off_h, offs_m, offs_n,
532
+ # Offsets needed for TMA loads
533
+ kv_start,
534
+ kv_offset,
535
+ MATMUL_PRECISION, RCP_LN2,
536
+ # Strides for K and V
537
+ stride_kk, stride_kn, stride_vn, stride_vk,
538
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
539
+ )
540
+
541
+
542
+
543
+ offset = get_offset_for_next_block(
544
+ start_n, kv_indices, kv_num_blocks,
545
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
546
+ )
547
+
548
+ offs_n = offs_n + offset
549
+ kv_offset += offset
550
+
551
+
552
+ return acc, l_i, m_i
SpecForge-ext/cache/compiled_kernels/is/cisbwn452kdvm56u75a2mwmrdzns6w4vxzuweva24qshuv4gksv2.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 512},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_slice_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_clone_slice_4(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = (xindex % ks0)
23
+ x1 = xindex // ks0
24
+ x2 = xindex
25
+ tmp0 = tl.load(in_ptr0 + (x0 + x1 + ks0*x1), xmask, eviction_policy='evict_last')
26
+ tl.store(out_ptr0 + (x2), tmp0, xmask)
SpecForge-ext/cache/compiled_kernels/is/d02b763bc26b4a862acff11bb1d83ee2ff669b1418d106ae0058cadf26d0f276.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 128, "num_warps": 4, "num_stages": 1, "configs_hash": "1b2cc4dbebb9680d3ce31843331593b159e4046c056f195ca1ccf2464d5b37d1", "found_by_coordesc": false, "time_taken_ms": 11, "triton_cache_hash": "CLTRXNE5MHPP3O5A5W4Z4EQTTZVYMOP5IPJT6N44O6FTBZFXLMNA"}
SpecForge-ext/cache/compiled_kernels/iy/ciy3jtwq2kqsaaylz6g2uxngpmmalnqcompyd7v6diseejxhwvzs.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 4096, 'r0_': 32},
12
+ reduction_hint=ReductionHint.OUTER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_mul_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_mul_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ r0_numel = 32
20
+ R0_BLOCK: tl.constexpr = 32
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
27
+ r0_offset = 0
28
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
29
+ roffset = r0_offset
30
+ rindex = r0_index
31
+ r0_1 = r0_index
32
+ x0 = xindex
33
+ tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), xmask, other=0.0)
34
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
35
+ tmp3 = tl.where(xmask, tmp1, 0)
36
+ tmp4 = tl.sum(tmp3, 1)[:, None].to(tl.float32)
37
+ tl.store(out_ptr0 + (x0), tmp4, xmask)
SpecForge-ext/cache/compiled_kernels/lk/clk4cgl52lrdnpqzv6ubpxawah5lw2cyfnmsbuouupfi5emjbchn.py ADDED
@@ -0,0 +1,1083 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['13_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/ef/cefh7lkkzxkkmdldjmu75mgxgh2oczofby7slgtoagmm5sd6wlvf.py
38
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
39
+ # Source node to ATen node mapping:
40
+ # Graph fragment:
41
+ # %getitem : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem]
42
+ # %tangents_1 : Tensor "bf16[8, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1]
43
+ # %buf0 : Tensor "bf16[8, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf0]
44
+ # %full_default : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:4, pin_memory: False})
45
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {})
46
+ # return %buf0,%buf1
47
+ triton_red_fused_zeros_0 = async_compile.triton('triton_red_fused_zeros_0', '''
48
+ import triton
49
+ import triton.language as tl
50
+
51
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
52
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
53
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
54
+ triton_helpers.set_driver_to_gpu()
55
+
56
+ @triton_heuristics.reduction(
57
+ size_hints={'x': 524288, 'r0_': 128},
58
+ reduction_hint=ReductionHint.DEFAULT,
59
+ filename=__file__,
60
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*fp32', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
61
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_zeros_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
62
+ )
63
+ @triton.jit
64
+ def triton_red_fused_zeros_0(in_ptr0, in_ptr1, out_ptr1, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
65
+ r0_numel = 128
66
+ rnumel = r0_numel
67
+ RBLOCK: tl.constexpr = R0_BLOCK
68
+ xoffset = tl.program_id(0) * XBLOCK
69
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
70
+ xmask = xindex < xnumel
71
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
72
+ rbase = r0_base
73
+ x0 = (xindex % ks0)
74
+ x1 = ((xindex // ks0) % 32)
75
+ x2 = xindex // ks1
76
+ x5 = triton_helpers.div_floor_integer(xindex, ks0)
77
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
78
+ x4 = xindex
79
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
80
+ r0_index = r0_offset + r0_base
81
+ r0_mask = r0_index < r0_numel
82
+ roffset = r0_offset
83
+ rindex = r0_index
84
+ r0_3 = r0_index
85
+ tmp0 = tl.load(in_ptr0 + (r0_3 + 128*x1 + 4096*x0 + 4096*ks0*x2), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
86
+ tmp1 = tl.load(in_ptr1 + (r0_3 + 128*x0 + 128*x5*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1)))), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
87
+ tmp2 = tmp0 * tmp1
88
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
89
+ tmp5 = _tmp4 + tmp3
90
+ _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
91
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
92
+ tmp6 = tmp4.to(tl.float32)
93
+ tmp7 = 0.0
94
+ tmp8 = tmp6 - tmp7
95
+ tl.store(out_ptr1 + (x4), tmp8, xmask)
96
+ ''', device_str='cuda')
97
+
98
+
99
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/hp/chpmrudjesqgjc4u7kzlnbev6u6xezu6edp3jbrg4g2q5z3yue3f.py
100
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
101
+ # Source node to ATen node mapping:
102
+ # Graph fragment:
103
+ # %primals_2 : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=primals_2]
104
+ # %primals_4 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:4" = PlaceHolder[target=primals_4]
105
+ # %primals_6 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:4" = PlaceHolder[target=primals_6]
106
+ # %getitem_1 : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4" = PlaceHolder[target=getitem_1]
107
+ # %buf1 : Tensor "f32[8, 32, s37][32*s37, s37, 1]cuda:4" = PlaceHolder[target=buf1]
108
+ # %tangents_1 : Tensor "bf16[8, 32, s37, 128][4096*Max(1, s37), 128*Max(1, s37), 128, 1]cuda:4" = PlaceHolder[target=tangents_1]
109
+ # %getitem_3 : Tensor "bf16[8, 32, s37, 128][4096*s37, 128, 4096, 1]cuda:4" = PlaceHolder[target=getitem_3]
110
+ # %getitem_5 : Tensor "bf16[8, 8, s0, 128][1024*s0, 128*s0, 128, 1]cuda:4" = PlaceHolder[target=getitem_5]
111
+ # %primals_13 : Tensor "i32[8, 1, s99][s99, s99, 1]cuda:4" = PlaceHolder[target=primals_13]
112
+ # %primals_9 : Tensor "i32[8, 1, s22, s72][s22*s72, s22*s72, s72, 1]cuda:4" = PlaceHolder[target=primals_9]
113
+ # %primals_22 : Tensor "i32[8, 1, s56][s56, s56, 1]cuda:4" = PlaceHolder[target=primals_22]
114
+ # %primals_25 : Tensor "i32[8, 1, s84, s53][s53*s84, s53*s84, s53, 1]cuda:4" = PlaceHolder[target=primals_25]
115
+ # %primals_17 : Tensor "i32[8, 1, s94][s94, s94, 1]cuda:4" = PlaceHolder[target=primals_17]
116
+ # %primals_20 : Tensor "i32[8, 1, s28, s4][s28*s4, s28*s4, s4, 1]cuda:4" = PlaceHolder[target=primals_20]
117
+ # %primals_27 : Tensor "i32[8, 1, s100][s100, s100, 1]cuda:4" = PlaceHolder[target=primals_27]
118
+ # %primals_30 : Tensor "i32[8, 1, s6, s10][s10*s6, s10*s6, s10, 1]cuda:4" = PlaceHolder[target=primals_30]
119
+ # %primals_14 : Tensor "i64[8][1]cuda:4" = PlaceHolder[target=primals_14]
120
+ # %full_default : Tensor "f32[8, 32, s37][32*Max(1, s37), Max(1, s37), 1]cuda:4"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([8, 32, %primals_10], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:4, pin_memory: False})
121
+ # %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%primals_2, %primals_4, %primals_6, %getitem, %getitem_1, %tangents_1, %full_default, %fw_graph0, %joint_graph0, (%primals_10, %primals_11, %primals_13, %primals_9, %primals_17, %primals_20, %primals_22, %primals_25, %primals_27, %primals_30, 128, 128, %mask_graph0), 0.08838834764831843, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True, OUTPUT_MAX: False}, (), (%primals_14, %primals_15)), kwargs = {})
122
+ # return %getitem_4
123
+ triton_tem_fused_zeros_1 = async_compile.triton('triton_tem_fused_zeros_1', '''
124
+ import triton
125
+ import triton.language as tl
126
+
127
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
128
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
129
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
130
+
131
+ @triton_heuristics.template(
132
+
133
+ num_stages=3,
134
+ num_warps=8,
135
+ triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'ks3': 'i32', 'ks4': 'i32', 'ks5': 'i32', 'ks6': 'i32', 'ks7': 'i32', 'ks8': 'i32'}, 'device': DeviceProperties(type='cuda', index=4, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
136
+ inductor_meta={'kernel_name': 'triton_tem_fused_zeros_1', 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False, 'FLOAT32_PRECISION': "'tf32'", 'IS_DIVISIBLE': False, 'SM_SCALE': 0.08838834764831843, 'GQA_SHARED_HEADS': 4, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}},
137
+
138
+ )
139
+ @triton.jit
140
+ def triton_tem_fused_zeros_1(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8):
141
+ PRESCALE_QK : tl.constexpr = False
142
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
143
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
144
+ WRITE_DQ : tl.constexpr = True
145
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
146
+ OUTPUT_MAX : tl.constexpr = False
147
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
148
+ IS_DIVISIBLE : tl.constexpr = False
149
+ SM_SCALE : tl.constexpr = 0.08838834764831843
150
+ GQA_SHARED_HEADS : tl.constexpr = 4
151
+ HAS_FULL_BLOCKS : tl.constexpr = True
152
+ QK_HEAD_DIM : tl.constexpr = 128
153
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
154
+ V_HEAD_DIM : tl.constexpr = 128
155
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
156
+ SAFE_HEAD_DIM : tl.constexpr = True
157
+ BLOCK_M1 : tl.constexpr = 64
158
+ BLOCK_N1 : tl.constexpr = 128
159
+ BLOCK_M2 : tl.constexpr = 128
160
+ BLOCK_N2 : tl.constexpr = 64
161
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
162
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
163
+ INDEX_DTYPE : tl.constexpr = tl.int32
164
+ Q = arg_Q
165
+ K = arg_K
166
+ V = arg_V
167
+ LSE = arg_LSE
168
+ DELTA = arg_DELTA
169
+ DO = arg_DO
170
+ DQ = arg_DQ
171
+ DV = arg_DV
172
+ KV_NUM_BLKS = arg_KV_NUM_BLKS
173
+ KV_IDX = arg_KV_IDX
174
+ Q_NUM_BLKS = arg_Q_NUM_BLKS
175
+ Q_IDX = arg_Q_IDX
176
+ FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
177
+ FULL_KV_IDX = arg_FULL_KV_IDX
178
+ FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
179
+ FULL_Q_IDX = arg_FULL_Q_IDX
180
+
181
+ # Sub notation for this kernel:
182
+ #
183
+ # Q: Query, K: Key, V: Value
184
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
185
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
186
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
187
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
188
+ # inductor codegen
189
+ # M: Number of queries, N: Number of keys/values
190
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
191
+ # V_HEAD_DIM: The dimension of the value embeddings
192
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
193
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
194
+ # (Modifiable) Performance tuning options
195
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
196
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
197
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
198
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
199
+ #
200
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
201
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
202
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
203
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
204
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
205
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
206
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
207
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
208
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
209
+
210
+ # The below are kernel options that can be applied for certain score_mods,
211
+ # or involve a numerics vs. perf tradeoff
212
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
213
+ # about 20% more numerical error, but slightly faster.
214
+
215
+ # Define strides of inputs
216
+ stride_qz, stride_qh, stride_qm, stride_qd = 4096*ks0, 128, 4096, 1
217
+ stride_kz, stride_kh, stride_kn, stride_kd = 1024*ks1, 128*ks1, 128, 1
218
+ stride_vz, stride_vh, stride_vn, stride_vd = 1024*ks1, 128*ks1, 128, 1
219
+ stride_doz, stride_doh, stride_dom, stride_dod = 4096*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128*((1) * ((1) >= (ks0)) + (ks0) * ((ks0) > (1))), 128, 1
220
+
221
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = 4096*ks0, 128, 4096, 1
222
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = 1024*ks1, 128*ks1, 128, 1
223
+
224
+ ZQ = 8
225
+ HQ = 32
226
+ HKV = 8
227
+ Q_LEN = ks0
228
+ ZKV = 8
229
+ KV_LEN = ks1
230
+
231
+ MATMUL_PRECISION = Q.dtype.element_ty
232
+
233
+ pid = tl.program_id(0).to(INDEX_DTYPE)
234
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
235
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
236
+
237
+ off_zq = tl.program_id(1).to(INDEX_DTYPE) # q batch idx
238
+ off_hkv = tl.program_id(2).to(INDEX_DTYPE) # kv head idx
239
+ off_zkv = off_zq % ZKV # kv batch idx
240
+
241
+ SPARSE_Z = 8
242
+ SPARSE_HQ = 1
243
+
244
+ sparse_idx_z = off_zq % SPARSE_Z
245
+
246
+ k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
247
+ v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
248
+ # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
249
+ # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
250
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
251
+
252
+ # offset K, V, DV pointers for batch/kv-head
253
+ K += k_adj
254
+ V += v_adj
255
+ DV += dv_adj
256
+
257
+ RCP_LN2 = 1.44269504
258
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
259
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
260
+
261
+ if pid >= NUM_KV_BLOCKS:
262
+ off_pid = pid - NUM_KV_BLOCKS
263
+ # THIS BLOCK DOES DQ
264
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
265
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
266
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
267
+ start_m2_block = off_pid % NUM_Q_BLOCKS
268
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
269
+ stride_kv_num_blks_h = ks2
270
+ stride_kv_idx_h = ks3*ks4
271
+ stride_kv_idx_m = ks4
272
+
273
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
274
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
275
+
276
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
277
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
278
+
279
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
280
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
281
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
282
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
283
+ off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
284
+
285
+ Q2 = Q + q_adj2
286
+ DO2 = DO + do_adj2
287
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
288
+ # if Q is broadcasted)
289
+ DQ2 = DQ + dq_adj2
290
+ LSE2 = LSE + off_chz2
291
+ DELTA2 = DELTA + off_chz2
292
+
293
+ # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
294
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
295
+
296
+ start_m2 = start_m2_block * BLOCK_M2
297
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
298
+
299
+ # load Q and do: they stay in SRAM throughout the inner loop.
300
+ q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
301
+ do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
302
+
303
+ if PRESCALE_QK:
304
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
305
+
306
+ if IS_DIVISIBLE:
307
+ Di = tl.load(DELTA2 + offs_m2)
308
+ lse = tl.load(LSE2 + offs_m2)
309
+ else:
310
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
311
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
312
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
313
+ lse = lse[:, None]
314
+
315
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
316
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
317
+ kv_indices = KV_IDX + sparse_kv_idx_offset
318
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
319
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
320
+
321
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
322
+ dq = bwd_dq_inner(
323
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
324
+ K, V,
325
+ dq, q, do, Di, lse,
326
+ off_zq, off_hq2, offs_m2, offs_n2,
327
+ stride_kn, stride_kd, stride_vn, stride_vd,
328
+ kv_indices, sparse_kv_num_blocks,
329
+ MATMUL_PRECISION,
330
+ IS_FULL_BLOCKS=False,
331
+ )
332
+
333
+ if HAS_FULL_BLOCKS:
334
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
335
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
336
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
337
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
338
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
339
+
340
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
341
+ dq = bwd_dq_inner(
342
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
343
+ K, V,
344
+ dq, q, do, Di, lse,
345
+ off_zq, off_hq2, offs_m2, offs_n2,
346
+ stride_kn, stride_kd, stride_vn, stride_vd,
347
+ kv_indices, sparse_kv_num_blocks,
348
+ MATMUL_PRECISION,
349
+ IS_FULL_BLOCKS=True,
350
+ )
351
+
352
+ # Write back dQ.
353
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
354
+ dq *= SM_SCALE
355
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
356
+ tl.store(dq_ptrs, dq)
357
+ else:
358
+ tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
359
+ else:
360
+ # THIS BLOCK DOES DK & DV
361
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
362
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
363
+
364
+ pid_mask = pid // SPARSE_KV_MULTIPLE
365
+
366
+ stride_q_num_blks_h = ks5
367
+ stride_q_idx_h = ks6*ks7
368
+ stride_q_idx_n = ks6
369
+
370
+
371
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
372
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
373
+
374
+ start_n1 = pid * BLOCK_N1
375
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
376
+
377
+ # load K and V: they stay in SRAM throughout the inner loop.
378
+ k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
379
+ v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
380
+
381
+ if PRESCALE_QK:
382
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
383
+
384
+ for off_g in range(0, GQA_SHARED_HEADS):
385
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
386
+
387
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
388
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
389
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
390
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
391
+ off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
392
+
393
+ Q1 = Q + q_adj1
394
+ DO1 = DO + do_adj1
395
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
396
+ # if Q is broadcasted)
397
+ LSE1 = LSE + off_chz1
398
+ DELTA1 = DELTA + off_chz1
399
+
400
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
401
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
402
+
403
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
404
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
405
+
406
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
407
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
408
+ q_indices = Q_IDX + sparse_q_idx_offset
409
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
410
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
411
+
412
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
413
+ dk, dv = bwd_dkdv_inner(
414
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
415
+ Q1, DO1, DELTA1, LSE1,
416
+ dk, dv, k, v,
417
+ off_zq, off_hq1, offs_n1, offs_m1,
418
+ stride_qm, stride_qd, stride_dom, stride_dod,
419
+ q_indices, sparse_q_num_blocks,
420
+ MATMUL_PRECISION,
421
+ IS_FULL_BLOCKS=False,
422
+ )
423
+
424
+
425
+ if HAS_FULL_BLOCKS:
426
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
427
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
428
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
429
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
430
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
431
+
432
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
433
+ dk, dv = bwd_dkdv_inner(
434
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
435
+ Q1, DO1, DELTA1, LSE1,
436
+ dk, dv, k, v,
437
+ off_zq, off_hq1, offs_n1, offs_m1,
438
+ stride_qm, stride_qd, stride_dom, stride_dod,
439
+ q_indices, sparse_q_num_blocks,
440
+ MATMUL_PRECISION,
441
+ IS_FULL_BLOCKS=True,
442
+ )
443
+
444
+ # Write back dV and dK.
445
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
446
+
447
+ index_n = offs_n1[:, None]
448
+ index_k = offs_k[None, :]
449
+ index_v = offs_v[None, :]
450
+
451
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
452
+ tl.store(dv_ptrs, dv)
453
+ else:
454
+ tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
455
+
456
+ dk *= SM_SCALE
457
+
458
+ if SAFE_HEAD_DIM:
459
+ mask = index_n < KV_LEN
460
+ else:
461
+ mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
462
+
463
+ # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
464
+ # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
465
+ tl.static_assert(dk.shape == [BLOCK_N1, QK_HEAD_DIM_ROUNDED])
466
+ xindex = index_k + 128*index_n + 128*off_hkv*ks1 + 1024*off_zq*ks1
467
+ tl.store(out_ptr0 + (tl.broadcast_to(xindex, dk.shape)), dk, mask)
468
+
469
+ @triton.jit
470
+ def bwd_dq_inner(
471
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
472
+ K, V, # pointers
473
+ dq, q, do, Di, lse,
474
+ off_z, off_hq, offs_m2, offs_n2,
475
+ stride_kn, stride_kd, stride_vn, stride_vd,
476
+ kv_indices, sparse_kv_num_blocks,
477
+ MATMUL_PRECISION,
478
+ IS_FULL_BLOCKS,
479
+ ):
480
+ PRESCALE_QK : tl.constexpr = False
481
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
482
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
483
+ WRITE_DQ : tl.constexpr = True
484
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
485
+ OUTPUT_MAX : tl.constexpr = False
486
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
487
+ IS_DIVISIBLE : tl.constexpr = False
488
+ SM_SCALE : tl.constexpr = 0.08838834764831843
489
+ GQA_SHARED_HEADS : tl.constexpr = 4
490
+ HAS_FULL_BLOCKS : tl.constexpr = True
491
+ QK_HEAD_DIM : tl.constexpr = 128
492
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
493
+ V_HEAD_DIM : tl.constexpr = 128
494
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
495
+ SAFE_HEAD_DIM : tl.constexpr = True
496
+ BLOCK_M1 : tl.constexpr = 64
497
+ BLOCK_N1 : tl.constexpr = 128
498
+ BLOCK_M2 : tl.constexpr = 128
499
+ BLOCK_N2 : tl.constexpr = 64
500
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
501
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
502
+ INDEX_DTYPE : tl.constexpr = tl.int32
503
+
504
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
505
+ RCP_LN2: tl.constexpr = 1.44269504
506
+ Q_LEN = ks0
507
+ KV_LEN = ks1
508
+
509
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
510
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
511
+
512
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
513
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
514
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
515
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
516
+
517
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
518
+
519
+ for start_n in range(0, hi):
520
+ dq = bwd_dq_block_mn(
521
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
522
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
523
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
524
+ stride_kn, stride_kd, stride_vn, stride_vd,
525
+ kv_indices, sparse_kv_num_blocks,
526
+ MATMUL_PRECISION, RCP_LN2,
527
+ IS_FULL_BLOCKS,
528
+ )
529
+
530
+ # Increment pointers.
531
+ offset = get_offset_for_next_block(
532
+ start_n, kv_indices, sparse_kv_num_blocks,
533
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
534
+ )
535
+
536
+ kT_ptrs += offset * stride_kn
537
+ vT_ptrs += offset * stride_vn
538
+
539
+ offs_n2 += offset
540
+
541
+ return dq
542
+
543
+
544
+ @triton.jit
545
+ def bwd_dq_block_mn(
546
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
547
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
548
+ off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
549
+ stride_kn, stride_kd, stride_vn, stride_vd,
550
+ kv_indices, sparse_kv_num_blocks,
551
+ MATMUL_PRECISION, RCP_LN2,
552
+ IS_FULL_BLOCKS,
553
+ ):
554
+ PRESCALE_QK : tl.constexpr = False
555
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
556
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
557
+ WRITE_DQ : tl.constexpr = True
558
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
559
+ OUTPUT_MAX : tl.constexpr = False
560
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
561
+ IS_DIVISIBLE : tl.constexpr = False
562
+ SM_SCALE : tl.constexpr = 0.08838834764831843
563
+ GQA_SHARED_HEADS : tl.constexpr = 4
564
+ HAS_FULL_BLOCKS : tl.constexpr = True
565
+ QK_HEAD_DIM : tl.constexpr = 128
566
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
567
+ V_HEAD_DIM : tl.constexpr = 128
568
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
569
+ SAFE_HEAD_DIM : tl.constexpr = True
570
+ BLOCK_M1 : tl.constexpr = 64
571
+ BLOCK_N1 : tl.constexpr = 128
572
+ BLOCK_M2 : tl.constexpr = 128
573
+ BLOCK_N2 : tl.constexpr = 64
574
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
575
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
576
+ INDEX_DTYPE : tl.constexpr = tl.int32
577
+
578
+
579
+ # NB reversed order to since K is transposed
580
+ kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
581
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
582
+ if not PRESCALE_QK:
583
+ qk *= SM_SCALE
584
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
585
+ pre_mod_scores = qk
586
+ n = get_bounded_indices(offs_n2[None, :], KV_LEN if not IS_DIVISIBLE else None)
587
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
588
+ # that the M reads out of bounds for the PIDS spanning the Q_LEN boundary
589
+ m = get_bounded_indices(offs_m2[:, None], Q_LEN if not IS_DIVISIBLE else None)
590
+
591
+ tmp0 = (qk)
592
+ post_mod_scores = tmp0
593
+
594
+
595
+
596
+
597
+ if not IS_DIVISIBLE:
598
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
599
+
600
+ if not IS_FULL_BLOCKS:
601
+ tmp1 = tl.full([1], False, tl.int1)
602
+ tmp2 = (m)
603
+ tmp3 = (n)
604
+ tmp4 = tmp2 >= tmp3
605
+ tmp5 = tmp3.to(tl.int64)
606
+ tmp6 = (off_z)
607
+ tmp7 = tl.load(in_ptr16 + tmp6)
608
+ tmp8 = tmp5 < tmp7
609
+ tmp9 = tmp2.to(tl.int64)
610
+ tmp10 = tmp9 < tmp7
611
+ tmp11 = tmp8 & tmp10
612
+ tmp12 = tmp4 & tmp11
613
+ tmp13 = tmp1 | tmp12
614
+ tmp14 = ks8
615
+ tmp15 = tmp3 >= tmp14
616
+ tmp16 = (tmp3 % tmp14)
617
+ tmp17 = tl.full([1], 0, tl.int32)
618
+ tmp18 = tmp16 != tmp17
619
+ tmp19 = (libdevice.signbit(tmp16) != 0) if (tmp16).dtype is tl.float32 else tmp16 < 0
620
+ tmp20 = (libdevice.signbit(tmp14) != 0) if (tmp14).dtype is tl.float32 else tmp14 < 0
621
+ tmp21 = tmp19 != tmp20
622
+ tmp22 = tmp18 & tmp21
623
+ tmp23 = tmp16 + tmp14
624
+ tmp24 = tl.where(tmp22, tmp23, tmp16)
625
+ tmp25 = tmp24.to(tl.int64)
626
+ tmp26 = tmp25 < tmp7
627
+ tmp27 = tmp15 & tmp26
628
+ tmp28 = tmp3 - tmp2
629
+ tmp29 = (tmp28 % tmp14)
630
+ tmp30 = tmp29 != tmp17
631
+ tmp31 = (libdevice.signbit(tmp29) != 0) if (tmp29).dtype is tl.float32 else tmp29 < 0
632
+ tmp32 = tmp31 != tmp20
633
+ tmp33 = tmp30 & tmp32
634
+ tmp34 = tmp29 + tmp14
635
+ tmp35 = tl.where(tmp33, tmp34, tmp29)
636
+ tmp36 = tmp35 == tmp17
637
+ tmp37 = tmp27 & tmp36
638
+ tmp38 = tmp13 | tmp37
639
+ mask_mod_output = tmp38
640
+
641
+
642
+ # apply mask for partial masked block
643
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
644
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
645
+ if not PRESCALE_QK:
646
+ post_mod_scores *= RCP_LN2
647
+ p = tl.math.exp2(post_mod_scores - lse)
648
+ # Compute dP and dS.
649
+ # NB reversed order to since V is transposed
650
+ vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
651
+
652
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
653
+ ds = p * (dp - Di[:, None])
654
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
655
+ tmp39 = (ds)
656
+ grad_scores = tmp39
657
+
658
+
659
+ if not IS_DIVISIBLE:
660
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
661
+
662
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
663
+ if WRITE_DQ:
664
+ scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
665
+
666
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
667
+ ds = grad_scores
668
+
669
+ if not IS_FULL_BLOCKS:
670
+ # (grads) apply mask for partially unmasked block
671
+ ds = tl.where(mask_mod_output, ds, 0.0)
672
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
673
+ ds = ds.to(MATMUL_PRECISION)
674
+ # Compute dQ.
675
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
676
+
677
+ return dq
678
+
679
+
680
+ @triton.jit
681
+ def bwd_dkdv_inner(
682
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
683
+ Q, DO, DELTA, LSE, # pointers
684
+ dk, dv, k, v,
685
+ off_z, off_hq, offs_n1, offs_m1,
686
+ stride_qm, stride_qd, stride_dom, stride_dod,
687
+ q_indices, sparse_q_num_blocks,
688
+ MATMUL_PRECISION,
689
+ IS_FULL_BLOCKS,
690
+ ):
691
+ PRESCALE_QK : tl.constexpr = False
692
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
693
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
694
+ WRITE_DQ : tl.constexpr = True
695
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
696
+ OUTPUT_MAX : tl.constexpr = False
697
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
698
+ IS_DIVISIBLE : tl.constexpr = False
699
+ SM_SCALE : tl.constexpr = 0.08838834764831843
700
+ GQA_SHARED_HEADS : tl.constexpr = 4
701
+ HAS_FULL_BLOCKS : tl.constexpr = True
702
+ QK_HEAD_DIM : tl.constexpr = 128
703
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
704
+ V_HEAD_DIM : tl.constexpr = 128
705
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
706
+ SAFE_HEAD_DIM : tl.constexpr = True
707
+ BLOCK_M1 : tl.constexpr = 64
708
+ BLOCK_N1 : tl.constexpr = 128
709
+ BLOCK_M2 : tl.constexpr = 128
710
+ BLOCK_N2 : tl.constexpr = 64
711
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
712
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
713
+ INDEX_DTYPE : tl.constexpr = tl.int32
714
+
715
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
716
+ RCP_LN2: tl.constexpr = 1.44269504
717
+ Q_LEN = ks0
718
+ KV_LEN = ks1
719
+
720
+ offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
721
+ offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
722
+
723
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
724
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
725
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
726
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
727
+
728
+ # The minimum is needed to handle the case where we run with a super large
729
+ # SPARSE_BLOCK_SIZE (i.e. no block-mask!)
730
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
731
+
732
+ for start_m in range(0, hi):
733
+ dk, dv = bwd_dkdv_block_mn(
734
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
735
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
736
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
737
+ stride_qm, stride_qd, stride_dom, stride_dod,
738
+ q_indices, sparse_q_num_blocks,
739
+ MATMUL_PRECISION, RCP_LN2,
740
+ IS_FULL_BLOCKS,
741
+ )
742
+ # Increment pointers.
743
+ offset = get_offset_for_next_block(
744
+ start_m, q_indices, sparse_q_num_blocks,
745
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
746
+ )
747
+
748
+ qT_ptrs += offset * stride_qm
749
+ do_ptrs += offset * stride_dom
750
+ offs_m1 += offset
751
+
752
+ return dk, dv
753
+
754
+
755
+ @triton.jit
756
+ def bwd_dkdv_block_mn(
757
+ arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, ks8,
758
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
759
+ off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
760
+ stride_qm, stride_qd, stride_dom, stride_dod,
761
+ q_indices, sparse_q_num_blocks,
762
+ MATMUL_PRECISION, RCP_LN2,
763
+ IS_FULL_BLOCKS,
764
+ ):
765
+ PRESCALE_QK : tl.constexpr = False
766
+ ROWS_GUARANTEED_SAFE : tl.constexpr = False
767
+ BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
768
+ WRITE_DQ : tl.constexpr = True
769
+ OUTPUT_LOGSUMEXP : tl.constexpr = True
770
+ OUTPUT_MAX : tl.constexpr = False
771
+ FLOAT32_PRECISION : tl.constexpr = 'tf32'
772
+ IS_DIVISIBLE : tl.constexpr = False
773
+ SM_SCALE : tl.constexpr = 0.08838834764831843
774
+ GQA_SHARED_HEADS : tl.constexpr = 4
775
+ HAS_FULL_BLOCKS : tl.constexpr = True
776
+ QK_HEAD_DIM : tl.constexpr = 128
777
+ QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
778
+ V_HEAD_DIM : tl.constexpr = 128
779
+ V_HEAD_DIM_ROUNDED : tl.constexpr = 128
780
+ SAFE_HEAD_DIM : tl.constexpr = True
781
+ BLOCK_M1 : tl.constexpr = 64
782
+ BLOCK_N1 : tl.constexpr = 128
783
+ BLOCK_M2 : tl.constexpr = 128
784
+ BLOCK_N2 : tl.constexpr = 64
785
+ SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
786
+ SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
787
+ INDEX_DTYPE : tl.constexpr = tl.int32
788
+
789
+
790
+ # NB reversed order since Q is transposed
791
+ qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
792
+ # Load LSE before computing qk to reduce pipeline stall.
793
+ if IS_DIVISIBLE:
794
+ lse = tl.load(LSE + offs_m1)
795
+ else:
796
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
797
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
798
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
799
+ if not PRESCALE_QK:
800
+ qkT *= SM_SCALE
801
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
802
+ m = get_bounded_indices(offs_m1[None, :], Q_LEN if not IS_DIVISIBLE else None)
803
+ # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
804
+ # that the n reads out of bounds for the PIDS spanning the KV_LEN boundary
805
+ n = get_bounded_indices(offs_n1[:, None], KV_LEN if not IS_DIVISIBLE else None)
806
+
807
+ pre_mod_scores = qkT
808
+ tmp40 = (qkT)
809
+ post_mod_scores = tmp40
810
+
811
+
812
+
813
+ if not IS_DIVISIBLE:
814
+ post_mod_scores = tl.where(offs_m1[None, :] < Q_LEN, post_mod_scores, float("-inf"))
815
+
816
+ if not IS_FULL_BLOCKS:
817
+ tmp41 = tl.full([1], False, tl.int1)
818
+ tmp42 = (m)
819
+ tmp43 = (n)
820
+ tmp44 = tmp42 >= tmp43
821
+ tmp45 = tmp43.to(tl.int64)
822
+ tmp46 = (off_z)
823
+ tmp47 = tl.load(in_ptr16 + tmp46)
824
+ tmp48 = tmp45 < tmp47
825
+ tmp49 = tmp42.to(tl.int64)
826
+ tmp50 = tmp49 < tmp47
827
+ tmp51 = tmp48 & tmp50
828
+ tmp52 = tmp44 & tmp51
829
+ tmp53 = tmp41 | tmp52
830
+ tmp54 = ks8
831
+ tmp55 = tmp43 >= tmp54
832
+ tmp56 = (tmp43 % tmp54)
833
+ tmp57 = tl.full([1], 0, tl.int32)
834
+ tmp58 = tmp56 != tmp57
835
+ tmp59 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
836
+ tmp60 = (libdevice.signbit(tmp54) != 0) if (tmp54).dtype is tl.float32 else tmp54 < 0
837
+ tmp61 = tmp59 != tmp60
838
+ tmp62 = tmp58 & tmp61
839
+ tmp63 = tmp56 + tmp54
840
+ tmp64 = tl.where(tmp62, tmp63, tmp56)
841
+ tmp65 = tmp64.to(tl.int64)
842
+ tmp66 = tmp65 < tmp47
843
+ tmp67 = tmp55 & tmp66
844
+ tmp68 = tmp43 - tmp42
845
+ tmp69 = (tmp68 % tmp54)
846
+ tmp70 = tmp69 != tmp57
847
+ tmp71 = (libdevice.signbit(tmp69) != 0) if (tmp69).dtype is tl.float32 else tmp69 < 0
848
+ tmp72 = tmp71 != tmp60
849
+ tmp73 = tmp70 & tmp72
850
+ tmp74 = tmp69 + tmp54
851
+ tmp75 = tl.where(tmp73, tmp74, tmp69)
852
+ tmp76 = tmp75 == tmp57
853
+ tmp77 = tmp67 & tmp76
854
+ tmp78 = tmp53 | tmp77
855
+ mask_mod_output = tmp78
856
+
857
+ # (grads) apply mask for fully masked block
858
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
859
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
860
+ if not PRESCALE_QK:
861
+ post_mod_scores *= RCP_LN2
862
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
863
+ do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
864
+ # Compute dV.
865
+ ppT = pT
866
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
867
+ if IS_DIVISIBLE:
868
+ Di = tl.load(DELTA + offs_m1)
869
+ else:
870
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
871
+ # Compute dP and dS.
872
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
873
+ dsT = pT * (dpT - Di[None, :])
874
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
875
+ tmp79 = (dsT)
876
+ grad_scores = tmp79
877
+
878
+
879
+
880
+ if not IS_DIVISIBLE:
881
+ grad_scores = tl.where(offs_m1[None, :] < Q_LEN, grad_scores, 0.0)
882
+
883
+ # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
884
+ if not WRITE_DQ:
885
+ idx_b = off_z
886
+ idx_h = off_hq
887
+ idx_m = m
888
+ idx_n = n
889
+ scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
890
+
891
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
892
+ dsT = grad_scores
893
+ if not IS_FULL_BLOCKS:
894
+ # (grads) apply mask for partially unmasked block
895
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
896
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
897
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
898
+
899
+ return dk, dv
900
+
901
+ # Utility triton funcs
902
+ @triton.jit
903
+ def get_offset_for_next_block(
904
+ loop_iter, col_indices, total_blocks,
905
+ SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
906
+ BLOCKS_ARE_CONTIGUOUS: tl.constexpr
907
+ ):
908
+ if BLOCKS_ARE_CONTIGUOUS:
909
+ return BLOCK
910
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
911
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
912
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
913
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
914
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
915
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
916
+ return offset
917
+
918
+ @triton.jit
919
+ def get_bounded_indices(indices, max_len=None):
920
+ return indices % max_len if max_len is not None else indices
921
+
922
+ @triton.jit
923
+ def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
924
+ if IS_DIVISIBLE and SAFE_HEAD_DIM:
925
+ return tl.load(block_ptr)
926
+ elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
927
+ return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
928
+ elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
929
+ return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
930
+ else:
931
+ return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
932
+
933
+ @triton.jit
934
+ def load_checked_2d(
935
+ ptr,
936
+ offs_m,
937
+ offs_n,
938
+ stride_m,
939
+ stride_n,
940
+ IS_DIVISIBLE_M: tl.constexpr,
941
+ IS_DIVISIBLE_N: tl.constexpr,
942
+ M_LEN: tl.constexpr,
943
+ N_LEN: tl.constexpr,
944
+ ):
945
+ # Calculate final pointer if strides are provided
946
+ if stride_m is not None and stride_n is not None:
947
+ ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
948
+
949
+ # Handle all masking cases
950
+ if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
951
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0)
952
+ elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
953
+ return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0)
954
+ elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
955
+ return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
956
+ else: # Both divisible
957
+ return tl.load(ptr)
958
+ ''', device_str='cuda')
959
+
960
+
961
+ async_compile.wait(globals())
962
+ del async_compile
963
+
964
+ class Runner:
965
+ def __init__(self, partitions):
966
+ self.partitions = partitions
967
+
968
+ def recursively_apply_fns(self, fns):
969
+ new_callables = []
970
+ for fn, c in zip(fns, self.partitions):
971
+ new_callables.append(fn(c))
972
+ self.partitions = new_callables
973
+
974
+ def call(self, args):
975
+ primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1 = args
976
+ args.clear()
977
+ s37 = primals_10
978
+ s0 = primals_11
979
+ s75 = primals_15
980
+ s22 = primals_7
981
+ s72 = primals_8
982
+ s99 = primals_12
983
+ s94 = primals_16
984
+ s28 = primals_18
985
+ s4 = primals_19
986
+ s56 = primals_21
987
+ s53 = primals_24
988
+ s84 = primals_23
989
+ s100 = primals_26
990
+ s10 = primals_29
991
+ s6 = primals_28
992
+ assert_size_stride(primals_2, (8, 32, s37, 128), (4096*s37, 128, 4096, 1))
993
+ assert_size_stride(primals_4, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
994
+ assert_size_stride(primals_6, (8, 8, s0, 128), (1024*s0, 128*s0, 128, 1))
995
+ assert_size_stride(primals_9, (8, 1, s22, s72), (s22*s72, s22*s72, s72, 1))
996
+ assert_size_stride(primals_13, (8, 1, s99), (s99, s99, 1))
997
+ assert_size_stride(primals_14, (8, ), (1, ))
998
+ assert_size_stride(primals_17, (8, 1, s94), (s94, s94, 1))
999
+ assert_size_stride(primals_20, (8, 1, s28, s4), (s28*s4, s28*s4, s4, 1))
1000
+ assert_size_stride(primals_22, (8, 1, s56), (s56, s56, 1))
1001
+ assert_size_stride(primals_25, (8, 1, s84, s53), (s53*s84, s53*s84, s53, 1))
1002
+ assert_size_stride(primals_27, (8, 1, s100), (s100, s100, 1))
1003
+ assert_size_stride(primals_30, (8, 1, s6, s10), (s10*s6, s10*s6, s10, 1))
1004
+ assert_size_stride(getitem, (8, 32, s37, 128), (4096*s37, 128, 4096, 1))
1005
+ assert_size_stride(getitem_1, (8, 32, s37), (32*max(1, s37), max(1, s37), 1))
1006
+ assert_size_stride(tangents_1, (8, 32, s37, 128), (4096*max(1, s37), 128*max(1, s37), 128, 1))
1007
+ with torch.cuda._DeviceGuard(4):
1008
+ torch.cuda.set_device(4)
1009
+ ps0 = 32*s37
1010
+ buf1 = empty_strided_cuda((8, 32, s37), (32*s37, s37, 1), torch.float32)
1011
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
1012
+ triton_red_fused_zeros_0_xnumel = 256*s37
1013
+ stream4 = get_raw_stream(4)
1014
+ triton_red_fused_zeros_0.run(getitem, tangents_1, buf1, s37, ps0, triton_red_fused_zeros_0_xnumel, 128, stream=stream4)
1015
+ del getitem
1016
+ buf3 = empty_strided_cuda((8, 32, s37, 128), (4096*s37, 128, 4096, 1), torch.bfloat16)
1017
+ buf4 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16)
1018
+ buf5 = empty_strided_cuda((8, 8, s0, 128), (1024*s0, 128*s0, 128, 1), torch.bfloat16)
1019
+ # Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
1020
+ stream4 = get_raw_stream(4)
1021
+ triton_tem_fused_zeros_1.run(primals_2, primals_4, primals_6, getitem_1, buf1, tangents_1, buf3, buf4, primals_13, primals_9, primals_22, primals_25, primals_17, primals_20, primals_27, primals_30, primals_14, buf5, s37, s0, s99, s22, s72, s56, s53, s84, s75, 4*((127 + s37) // 128) + ((127 + s0) // 128), 8, 8, stream=stream4)
1022
+ del buf1
1023
+ del getitem_1
1024
+ del primals_13
1025
+ del primals_14
1026
+ del primals_17
1027
+ del primals_2
1028
+ del primals_20
1029
+ del primals_22
1030
+ del primals_25
1031
+ del primals_27
1032
+ del primals_30
1033
+ del primals_4
1034
+ del primals_6
1035
+ del primals_9
1036
+ del tangents_1
1037
+ return (None, buf3, None, buf5, None, buf4, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, )
1038
+
1039
+ runner = Runner(partitions=[])
1040
+ call = runner.call
1041
+ recursively_apply_fns = runner.recursively_apply_fns
1042
+
1043
+
1044
+ def benchmark_compiled_module(times=10, repeat=10):
1045
+ from torch._dynamo.testing import rand_strided
1046
+ from torch._inductor.utils import print_performance
1047
+ primals_10 = 1896
1048
+ primals_11 = 1896
1049
+ primals_15 = 1896
1050
+ primals_7 = 15
1051
+ primals_8 = 15
1052
+ primals_12 = 15
1053
+ primals_16 = 15
1054
+ primals_18 = 15
1055
+ primals_19 = 15
1056
+ primals_21 = 15
1057
+ primals_24 = 15
1058
+ primals_23 = 15
1059
+ primals_26 = 15
1060
+ primals_29 = 15
1061
+ primals_28 = 15
1062
+ primals_2 = rand_strided((8, 32, 1896, 128), (7766016, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16)
1063
+ primals_4 = rand_strided((8, 8, 1896, 128), (1941504, 242688, 128, 1), device='cuda:4', dtype=torch.bfloat16)
1064
+ primals_6 = rand_strided((8, 8, 1896, 128), (1941504, 242688, 128, 1), device='cuda:4', dtype=torch.bfloat16)
1065
+ primals_9 = rand_strided((8, 1, 15, 15), (225, 225, 15, 1), device='cuda:4', dtype=torch.int32)
1066
+ primals_13 = rand_strided((8, 1, 15), (15, 15, 1), device='cuda:4', dtype=torch.int32)
1067
+ primals_14 = rand_strided((8, ), (1, ), device='cuda:4', dtype=torch.int64)
1068
+ primals_17 = rand_strided((8, 1, 15), (15, 15, 1), device='cuda:4', dtype=torch.int32)
1069
+ primals_20 = rand_strided((8, 1, 15, 15), (225, 225, 15, 1), device='cuda:4', dtype=torch.int32)
1070
+ primals_22 = rand_strided((8, 1, 15), (15, 15, 1), device='cuda:4', dtype=torch.int32)
1071
+ primals_25 = rand_strided((8, 1, 15, 15), (225, 225, 15, 1), device='cuda:4', dtype=torch.int32)
1072
+ primals_27 = rand_strided((8, 1, 15), (15, 15, 1), device='cuda:4', dtype=torch.int32)
1073
+ primals_30 = rand_strided((8, 1, 15, 15), (225, 225, 15, 1), device='cuda:4', dtype=torch.int32)
1074
+ getitem = rand_strided((8, 32, 1896, 128), (7766016, 128, 4096, 1), device='cuda:4', dtype=torch.bfloat16)
1075
+ getitem_1 = rand_strided((8, 32, 1896), (60672, 1896, 1), device='cuda:4', dtype=torch.float32)
1076
+ tangents_1 = rand_strided((8, 32, 1896, 128), (7766016, 242688, 128, 1), device='cuda:4', dtype=torch.bfloat16)
1077
+ fn = lambda: call([primals_10, primals_11, primals_15, primals_7, primals_8, primals_12, primals_16, primals_18, primals_19, primals_21, primals_24, primals_23, primals_26, primals_29, primals_28, primals_2, primals_4, primals_6, primals_9, primals_13, primals_14, primals_17, primals_20, primals_22, primals_25, primals_27, primals_30, getitem, getitem_1, tangents_1])
1078
+ return print_performance(fn, times=times, repeat=repeat)
1079
+
1080
+
1081
+ if __name__ == "__main__":
1082
+ from torch._inductor.wrapper_benchmark import compiled_module_main
1083
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/lp/1e661150415d7fce0f5577d7db35f128089400ce692c8dfdf5e40cb9a867cea5.best_config ADDED
@@ -0,0 +1 @@
 
 
1
+ {"XBLOCK": 1, "num_warps": 2, "num_stages": 1, "configs_hash": "6fcabd0411a839b7b5d117b5e6638bd1b5d7bc3379312c678d803859f08278a9", "found_by_coordesc": false, "time_taken_ms": 18, "triton_cache_hash": "EB4J5U2HKNQBLXRWK6B5L6ATOH55AWD3MB7P63KH5AKRGRDZER7A"}
SpecForge-ext/cache/compiled_kernels/lp/clp43olymjc72eay3ukgvj6r4apcbbbnz3xlli3tafgvidlacqsg.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['4_backward']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/pz/cpzw7g6yjflpctcqkzf5osq7m5acrctaysa6th3ox3deinxluypc.py
38
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
39
+ # Source node to ATen node mapping:
40
+ # cos => squeeze_1
41
+ # cos_1 => unsqueeze
42
+ # getitem => index
43
+ # getitem_1 => index_1
44
+ # sin => squeeze_3
45
+ # sin_1 => unsqueeze_1
46
+ # squeeze => squeeze
47
+ # squeeze_2 => squeeze_2
48
+ # Graph fragment:
49
+ # %tangents_2 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3" = PlaceHolder[target=tangents_2]
50
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:3" = PlaceHolder[target=primals_8]
51
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_6]
52
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_4]
53
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
54
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
55
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
56
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
57
+ # %mul_84 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze_1), kwargs = {})
58
+ # %slice_5 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, 0, %add_96), kwargs = {})
59
+ # %slice_6 : Tensor "bf16[s48, s25, s9, (s24//2)][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_84, 3, %sub_72, %primals_2), kwargs = {})
60
+ # %neg_2 : Tensor "bf16[s48, s25, s9, s24 - ((s24//2))][s25*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_5,), kwargs = {})
61
+ # %full_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_13, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:3, pin_memory: False})
62
+ # %slice_scatter_default : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %neg_2, 3, %floordiv, 9223372036854775807), kwargs = {})
63
+ # %slice_scatter_default_1 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default, %slice_6, 3, 0, %floordiv), kwargs = {})
64
+ # %add_100 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default, %slice_scatter_default_1), kwargs = {})
65
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
66
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
67
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
68
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
69
+ # %mul_85 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_2, %unsqueeze), kwargs = {})
70
+ # %add_101 : Tensor "bf16[s48, s25, s9, s24][s24*s25*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_100, %mul_85), kwargs = {})
71
+ # return %add_101
72
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', '''
73
+ import triton
74
+ import triton.language as tl
75
+
76
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
77
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
78
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
79
+ triton_helpers.set_driver_to_gpu()
80
+
81
+ @triton_heuristics.pointwise(
82
+ size_hints={'x': 4194304},
83
+ filename=__file__,
84
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
85
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
86
+ min_elem_per_thread=0
87
+ )
88
+ @triton.jit
89
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
90
+ xoffset = tl.program_id(0) * XBLOCK
91
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
92
+ xmask = xindex < xnumel
93
+ x0 = (xindex % ks0)
94
+ x3 = xindex
95
+ x1 = ((xindex // ks0) % ks1)
96
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
97
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
98
+ tmp0 = x0
99
+ tmp1 = ks0 // 2
100
+ tmp2 = tmp0 >= tmp1
101
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
102
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
103
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
104
+ tmp6 = tmp4 + tmp5
105
+ tmp7 = tmp4 < 0
106
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
107
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
108
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
109
+ tmp11 = tmp3 * tmp10
110
+ tmp12 = -tmp11
111
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
112
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
113
+ tmp15 = 0.0
114
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
115
+ tmp17 = tmp0 < tmp1
116
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
117
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
118
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
119
+ tmp21 = tmp19 + tmp20
120
+ tmp22 = tmp19 < 0
121
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
122
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
123
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
124
+ tmp26 = tmp18 * tmp25
125
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
126
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
127
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
128
+ tmp30 = tmp16 + tmp29
129
+ tmp33 = ks3
130
+ tmp34 = tmp32 + tmp33
131
+ tmp35 = tmp32 < 0
132
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
133
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
134
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
135
+ tmp39 = tmp31 * tmp38
136
+ tmp40 = tmp30 + tmp39
137
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
138
+ ''', device_str='cuda')
139
+
140
+
141
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/dz/cdz3io7w5uyfrmfqvmg2kt2ay66qv4ckwtyurhik3frq7fqnk7gm.py
142
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
143
+ # Source node to ATen node mapping:
144
+ # cos => squeeze_1
145
+ # cos_1 => unsqueeze
146
+ # getitem => index
147
+ # getitem_1 => index_1
148
+ # sin => squeeze_3
149
+ # sin_1 => unsqueeze_1
150
+ # squeeze => squeeze
151
+ # squeeze_2 => squeeze_2
152
+ # Graph fragment:
153
+ # %tangents_1 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3" = PlaceHolder[target=tangents_1]
154
+ # %primals_8 : Tensor "i64[1, s9][s9, 1]cuda:3" = PlaceHolder[target=primals_8]
155
+ # %primals_6 : Tensor "bf16[1, 1, s79, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_6]
156
+ # %primals_4 : Tensor "bf16[1, 1, s92, s24][s96, s96, s24, 1]cuda:3" = PlaceHolder[target=primals_4]
157
+ # %squeeze_2 : Tensor "bf16[1, s79, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_6, 1), kwargs = {})
158
+ # %squeeze_3 : Tensor "bf16[s79, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze_2, 0), kwargs = {})
159
+ # %index_1 : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_3, [%primals_8]), kwargs = {})
160
+ # %unsqueeze_1 : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index_1, 1), kwargs = {})
161
+ # %squeeze : Tensor "bf16[1, s92, s24][s96, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%primals_4, 1), kwargs = {})
162
+ # %squeeze_1 : Tensor "bf16[s92, s24][s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%squeeze, 0), kwargs = {})
163
+ # %index : Tensor "bf16[1, s9, s24][s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%squeeze_1, [%primals_8]), kwargs = {})
164
+ # %unsqueeze : Tensor "bf16[1, 1, s9, s24][s24*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 1), kwargs = {})
165
+ # %mul_86 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze_1), kwargs = {})
166
+ # %slice_7 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, 0, %sub_72), kwargs = {})
167
+ # %slice_8 : Tensor "bf16[s48, s34, s9, (s24//2)][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_86, 3, %sub_72, %primals_2), kwargs = {})
168
+ # %neg_3 : Tensor "bf16[s48, s34, s9, s24 - ((s24//2))][s34*s9*Max(1, s24 - ((s24//2))), s9*Max(1, s24 - ((s24//2))), Max(1, s24 - ((s24//2))), 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_7,), kwargs = {})
169
+ # %full_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=2] = call_function[target=torch.ops.aten.full.default](args = ([%primals_10, %primals_11, %primals_7, %primals_2], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:3, pin_memory: False})
170
+ # %slice_scatter_default_2 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %neg_3, 3, %floordiv, 9223372036854775807), kwargs = {})
171
+ # %slice_scatter_default_3 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_2, %slice_8, 3, 0, %floordiv), kwargs = {})
172
+ # %add_106 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_2, %slice_scatter_default_3), kwargs = {})
173
+ # %mul_87 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%tangents_1, %unsqueeze), kwargs = {})
174
+ # %add_107 : Tensor "bf16[s48, s34, s9, s24][s24*s34*s9, s24*s9, s24, 1]cuda:3"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_106, %mul_87), kwargs = {})
175
+ # return %add_107
176
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1 = async_compile.triton('triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', '''
177
+ import triton
178
+ import triton.language as tl
179
+
180
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
181
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
182
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
183
+ triton_helpers.set_driver_to_gpu()
184
+
185
+ @triton_heuristics.pointwise(
186
+ size_hints={'x': 16777216},
187
+ filename=__file__,
188
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=3, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
189
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
190
+ min_elem_per_thread=0
191
+ )
192
+ @triton.jit
193
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
194
+ xoffset = tl.program_id(0) * XBLOCK
195
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
196
+ xmask = xindex < xnumel
197
+ x0 = (xindex % ks0)
198
+ x3 = xindex
199
+ x1 = ((xindex // ks0) % ks1)
200
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
201
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
202
+ tmp0 = x0
203
+ tmp1 = ks0 // 2
204
+ tmp2 = tmp0 >= tmp1
205
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
206
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
207
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
208
+ tmp6 = tmp4 + tmp5
209
+ tmp7 = tmp4 < 0
210
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
211
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
212
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
213
+ tmp11 = tmp3 * tmp10
214
+ tmp12 = -tmp11
215
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
216
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
217
+ tmp15 = 0.0
218
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
219
+ tmp17 = tmp0 < tmp1
220
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
221
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
222
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
223
+ tmp21 = tmp19 + tmp20
224
+ tmp22 = tmp19 < 0
225
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
226
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
227
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
228
+ tmp26 = tmp18 * tmp25
229
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
230
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
231
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
232
+ tmp30 = tmp16 + tmp29
233
+ tmp33 = ks3
234
+ tmp34 = tmp32 + tmp33
235
+ tmp35 = tmp32 < 0
236
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
237
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
238
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
239
+ tmp39 = tmp31 * tmp38
240
+ tmp40 = tmp30 + tmp39
241
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
242
+ ''', device_str='cuda')
243
+
244
+
245
+ async_compile.wait(globals())
246
+ del async_compile
247
+
248
+ class Runner:
249
+ def __init__(self, partitions):
250
+ self.partitions = partitions
251
+
252
+ def recursively_apply_fns(self, fns):
253
+ new_callables = []
254
+ for fn, c in zip(fns, self.partitions):
255
+ new_callables.append(fn(c))
256
+ self.partitions = new_callables
257
+
258
+ def call(self, args):
259
+ primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2 = args
260
+ args.clear()
261
+ s24 = primals_2
262
+ s9 = primals_7
263
+ s48 = primals_10
264
+ s34 = primals_11
265
+ s25 = primals_13
266
+ s92 = primals_1
267
+ s96 = primals_3
268
+ s79 = primals_5
269
+ assert_size_stride(primals_4, (1, 1, s92, s24), (s96, s96, s24, 1))
270
+ assert_size_stride(primals_6, (1, 1, s79, s24), (s96, s96, s24, 1))
271
+ assert_size_stride(primals_8, (1, s9), (s9, 1))
272
+ assert_size_stride(tangents_1, (s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1))
273
+ assert_size_stride(tangents_2, (s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1))
274
+ with torch.cuda._DeviceGuard(3):
275
+ torch.cuda.set_device(3)
276
+ buf0 = empty_strided_cuda((s48, s25, s9, s24), (s24*s25*s9, s24*s9, s24, 1), torch.bfloat16)
277
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
278
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel = s24*s25*s48*s9
279
+ stream3 = get_raw_stream(3)
280
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0.run(tangents_2, primals_8, primals_6, primals_4, buf0, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0_xnumel, stream=stream3)
281
+ del tangents_2
282
+ buf1 = empty_strided_cuda((s48, s34, s9, s24), (s24*s34*s9, s24*s9, s24, 1), torch.bfloat16)
283
+ # Topologically Sorted Source Nodes: [squeeze_2, sin, getitem_1, sin_1, squeeze, cos, getitem, cos_1], Original ATen: [aten.squeeze, aten.index, aten.unsqueeze, aten.mul, aten.slice, aten.neg, aten.slice_backward, aten.add]
284
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel = s24*s34*s48*s9
285
+ stream3 = get_raw_stream(3)
286
+ triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1.run(tangents_1, primals_8, primals_6, primals_4, buf1, s24, s9, s79, s92, triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_1_xnumel, stream=stream3)
287
+ del primals_4
288
+ del primals_6
289
+ del primals_8
290
+ del tangents_1
291
+ return (None, None, None, None, None, None, None, None, None, None, None, buf1, None, buf0, )
292
+
293
+ runner = Runner(partitions=[])
294
+ call = runner.call
295
+ recursively_apply_fns = runner.recursively_apply_fns
296
+
297
+
298
+ def benchmark_compiled_module(times=10, repeat=10):
299
+ from torch._dynamo.testing import rand_strided
300
+ from torch._inductor.utils import print_performance
301
+ primals_2 = 128
302
+ primals_7 = 2048
303
+ primals_10 = 2
304
+ primals_11 = 32
305
+ primals_13 = 8
306
+ primals_1 = 2048
307
+ primals_3 = 5245440
308
+ primals_5 = 2048
309
+ floordiv = 64
310
+ add_96 = 64
311
+ primals_4 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:3', dtype=torch.bfloat16)
312
+ primals_6 = rand_strided((1, 1, 2048, 128), (5245440, 5245440, 128, 1), device='cuda:3', dtype=torch.bfloat16)
313
+ primals_8 = rand_strided((1, 2048), (2048, 1), device='cuda:3', dtype=torch.int64)
314
+ tangents_1 = rand_strided((2, 32, 2048, 128), (8388608, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16)
315
+ tangents_2 = rand_strided((2, 8, 2048, 128), (2097152, 262144, 128, 1), device='cuda:3', dtype=torch.bfloat16)
316
+ fn = lambda: call([primals_2, primals_7, primals_10, primals_11, primals_13, primals_1, primals_3, primals_5, floordiv, add_96, primals_4, primals_6, primals_8, tangents_1, tangents_2])
317
+ return print_performance(fn, times=times, repeat=repeat)
318
+
319
+
320
+ if __name__ == "__main__":
321
+ from torch._inductor.wrapper_benchmark import compiled_module_main
322
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/lp/clpt6xpoqv3wajdkyviksqw24bkxb47w4kcgihhcyrj553fxcjqs.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.persistent_reduction(
11
+ size_hints={'x': 32, 'r0_': 16},
12
+ reduction_hint=ReductionHint.DEFAULT,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i32', 'out_ptr2': '*i32', 'out_ptr3': '*i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=5, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 256, 'r0_': 4096}}
16
+ )
17
+ @triton.jit
18
+ def triton_per_fused__to_copy_clone_slice_sort_sum_transpose_3(in_ptr0, out_ptr2, out_ptr3, xnumel, r0_numel, XBLOCK : tl.constexpr):
19
+ xnumel = 32
20
+ r0_numel = 16
21
+ R0_BLOCK: tl.constexpr = 16
22
+ rnumel = r0_numel
23
+ RBLOCK: tl.constexpr = R0_BLOCK
24
+ xoffset = tl.program_id(0) * XBLOCK
25
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
26
+ xmask = xindex < xnumel
27
+ r0_index = tl.arange(0, R0_BLOCK)[None, :]
28
+ r0_offset = 0
29
+ r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
30
+ roffset = r0_offset
31
+ rindex = r0_index
32
+ r0_2 = r0_index
33
+ x0 = (xindex % 16)
34
+ x1 = xindex // 16
35
+ x3 = xindex
36
+ tmp0 = tl.load(in_ptr0 + (x0 + 17*r0_2 + 272*x1), xmask, other=0.0)
37
+ tmp1 = r0_2
38
+ tmp2 = tmp1.to(tl.int16)
39
+ tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
40
+ tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
41
+ tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, None, 1, stable=True, descending=True)
42
+ tmp7 = tmp0.to(tl.int64)
43
+ tmp8 = tl.broadcast_to(tmp7, [XBLOCK, R0_BLOCK])
44
+ tmp10 = tl.where(xmask, tmp8, 0)
45
+ tmp11 = tl.sum(tmp10, 1)[:, None].to(tl.int64)
46
+ tmp12 = tmp6.to(tl.int64)
47
+ tmp13 = tmp12.to(tl.int32)
48
+ tmp14 = tmp11.to(tl.int32)
49
+ tl.store(out_ptr2 + (r0_2 + 16*x3), tmp13, xmask)
50
+ tl.store(out_ptr3 + (x3), tmp14, xmask)
SpecForge-ext/cache/compiled_kernels/ls/cls3ju4iskgwc7wepn2m46svt5vbvf47ps3tsfw7s37earyzkzz2.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 4096},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': 'fp64', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x0 = xindex
27
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_1 = r0_index
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
35
+ tmp1 = tmp0.to(tl.float32)
36
+ tmp2 = tmp1 * tmp1
37
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
38
+ tmp5 = _tmp4 + tmp3
39
+ _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
40
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
41
+ tmp9 = in_ptr1
42
+ tmp6 = ks0
43
+ tmp7 = tmp6.to(tl.float32)
44
+ tmp8 = (tmp4 / tmp7)
45
+ tmp10 = tmp9.to(tl.float32)
46
+ tmp11 = tmp8 + tmp10
47
+ tmp12 = libdevice.rsqrt(tmp11)
48
+ tl.debug_barrier()
49
+ tl.store(in_out_ptr0 + (x0), tmp12, xmask)
50
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
51
+ r0_index = r0_offset + r0_base
52
+ r0_mask = r0_index < r0_numel
53
+ roffset = r0_offset
54
+ rindex = r0_index
55
+ r0_1 = r0_index
56
+ tmp13 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
57
+ tmp14 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
58
+ tmp15 = tmp14.to(tl.float32)
59
+ tmp16 = tmp15 * tmp12
60
+ tmp17 = tmp16.to(tl.float32)
61
+ tmp18 = tmp13 * tmp17
62
+ tl.store(out_ptr0 + (r0_1 + ks0*x0), tmp18, r0_mask & xmask)
SpecForge-ext/cache/compiled_kernels/ng/cnglvt55axgj3x37cqns4hg7zsjeu57rkczufz7vpm5o4rwbf2w7.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT ID: ['0_inference']
2
+ from ctypes import c_void_p, c_long, c_int
3
+ import torch
4
+ import math
5
+ import random
6
+ import os
7
+ import tempfile
8
+ from math import inf, nan
9
+ from cmath import nanj
10
+ from torch._inductor.hooks import run_intermediate_hooks
11
+ from torch._inductor.utils import maybe_profile
12
+ from torch._inductor.codegen.memory_planning import _align as align
13
+ from torch import device, empty_strided
14
+ from torch._inductor.async_compile import AsyncCompile
15
+ from torch._inductor.select_algorithm import extern_kernels
16
+ import triton
17
+ import triton.language as tl
18
+ from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
19
+ from torch._C import _cuda_getCurrentRawStream as get_raw_stream
20
+
21
+ aten = torch.ops.aten
22
+ inductor_ops = torch.ops.inductor
23
+ _quantized = torch.ops._quantized
24
+ assert_size_stride = torch._C._dynamo.guards.assert_size_stride
25
+ assert_alignment = torch._C._dynamo.guards.assert_alignment
26
+ empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
27
+ empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
28
+ empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
29
+ empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
30
+ empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
31
+ reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
32
+ alloc_from_pool = torch.ops.inductor._alloc_from_pool
33
+ async_compile = AsyncCompile()
34
+ empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
35
+
36
+
37
+ # kernel path: /workspace/hanrui/SpecForge-ext/cache/compiled_kernels/mw/cmw5mlntlt7o73p24outkvtp73w3ylg6pk6fbqshalpowjpvoh47.py
38
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
39
+ # Source node to ATen node mapping:
40
+ # getitem_1 => unsqueeze
41
+ # position_mask => mul
42
+ # target_mask => index
43
+ # target_mask_1 => convert_element_type
44
+ # target_max_token => argmax
45
+ # Graph fragment:
46
+ # %arg0_1 : Tensor "bf16[2, 2048, 151936][311164928, 151936, 1]cuda:2" = PlaceHolder[target=arg0_1]
47
+ # %argmax : Tensor "i64[2, 2048][2048, 1]cuda:2" = PlaceHolder[target=argmax]
48
+ # %arg1_1 : Tensor "b8[151936][1]cuda:2" = PlaceHolder[target=arg1_1]
49
+ # %arg2_1 : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:2" = PlaceHolder[target=arg2_1]
50
+ # %argmax : Tensor "i64[2, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%arg0_1, -1), kwargs = {})
51
+ # %index : Tensor "b8[2, 2048][2048, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg1_1, [%argmax]), kwargs = {})
52
+ # %unsqueeze : Tensor "b8[2, 2048, 1][2048, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%index, 2), kwargs = {})
53
+ # %convert_element_type : Tensor "i32[2, 2048, 1][2048, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze, torch.int32), kwargs = {})
54
+ # %mul : Tensor "i64[2, 2048, 1][2048, 1, 1]cuda:2"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %arg2_1), kwargs = {})
55
+ # return %argmax,%mul
56
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0 = async_compile.triton('triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', '''
57
+ import triton
58
+ import triton.language as tl
59
+
60
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
61
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
62
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
63
+ triton_helpers.set_driver_to_gpu()
64
+
65
+ @triton_heuristics.reduction(
66
+ size_hints={'x': 4096, 'r0_': 262144},
67
+ reduction_hint=ReductionHint.INNER,
68
+ filename=__file__,
69
+ triton_meta={'signature': {'in_out_ptr0': '*i64', 'in_ptr0': '*bf16', 'in_ptr1': '*i1', 'in_ptr2': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=2, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
70
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
71
+ )
72
+ @triton.jit
73
+ def triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
74
+ xnumel = 4096
75
+ r0_numel = 151936
76
+ rnumel = r0_numel
77
+ RBLOCK: tl.constexpr = R0_BLOCK
78
+ xoffset = tl.program_id(0) * XBLOCK
79
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
80
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
81
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
82
+ rbase = r0_base
83
+ x0 = xindex
84
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
85
+ _tmp2_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
86
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
87
+ r0_index = r0_offset + r0_base
88
+ r0_mask = r0_index < r0_numel
89
+ roffset = r0_offset
90
+ rindex = r0_index
91
+ r0_1 = r0_index
92
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 151936*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
93
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
94
+ _tmp2_next, _tmp2_index_next = triton_helpers.maximum_with_index(
95
+ _tmp2, _tmp2_index, tmp1, rindex
96
+ )
97
+ _tmp2 = tl.where(r0_mask, _tmp2_next, _tmp2)
98
+ _tmp2_index = tl.where(r0_mask, _tmp2_index_next, _tmp2_index)
99
+ tmp2_val, tmp2_idx = triton_helpers.max_with_index(_tmp2, _tmp2_index, 1)
100
+ tmp2 = tmp2_idx[:, None]
101
+ tmp11 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
102
+ tmp3 = tl.full([XBLOCK, 1], 151936, tl.int32)
103
+ tmp4 = tmp2 + tmp3
104
+ tmp5 = tmp2 < 0
105
+ tmp6 = tl.where(tmp5, tmp4, tmp2)
106
+ tl.device_assert((0 <= tmp6) & (tmp6 < 151936), "index out of bounds: 0 <= tmp6 < 151936")
107
+ tmp8 = tl.load(in_ptr1 + (tmp6), None, eviction_policy='evict_last').to(tl.int1)
108
+ tmp9 = tmp8.to(tl.int32)
109
+ tmp10 = tmp9.to(tl.int64)
110
+ tmp12 = tmp10 * tmp11
111
+ tl.debug_barrier()
112
+ tl.store(in_out_ptr0 + (x0), tmp12, None)
113
+ ''', device_str='cuda')
114
+
115
+
116
+ async_compile.wait(globals())
117
+ del async_compile
118
+
119
+ class Runner:
120
+ def __init__(self, partitions):
121
+ self.partitions = partitions
122
+
123
+ def recursively_apply_fns(self, fns):
124
+ new_callables = []
125
+ for fn, c in zip(fns, self.partitions):
126
+ new_callables.append(fn(c))
127
+ self.partitions = new_callables
128
+
129
+ def call(self, args):
130
+ arg0_1, arg1_1, arg2_1 = args
131
+ args.clear()
132
+ assert_size_stride(arg0_1, (2, 2048, 151936), (311164928, 151936, 1))
133
+ assert_size_stride(arg1_1, (151936, ), (1, ))
134
+ assert_size_stride(arg2_1, (2, 2048, 1), (2048, 1, 1))
135
+ with torch.cuda._DeviceGuard(2):
136
+ torch.cuda.set_device(2)
137
+ buf0 = empty_strided_cuda((2, 2048), (2048, 1), torch.int64)
138
+ buf1 = reinterpret_tensor(buf0, (2, 2048, 1), (2048, 1, 1), 0); del buf0 # reuse
139
+ # Topologically Sorted Source Nodes: [target_max_token, target_mask, getitem_1, target_mask_1, position_mask], Original ATen: [aten.argmax, aten.index, aten.unsqueeze, aten._to_copy, aten.mul]
140
+ stream2 = get_raw_stream(2)
141
+ triton_red_fused__to_copy_argmax_index_mul_unsqueeze_0.run(buf1, arg0_1, arg1_1, arg2_1, 4096, 151936, stream=stream2)
142
+ del arg0_1
143
+ del arg1_1
144
+ del arg2_1
145
+ return (buf1, )
146
+
147
+ runner = Runner(partitions=[])
148
+ call = runner.call
149
+ recursively_apply_fns = runner.recursively_apply_fns
150
+
151
+
152
+ def benchmark_compiled_module(times=10, repeat=10):
153
+ from torch._dynamo.testing import rand_strided
154
+ from torch._inductor.utils import print_performance
155
+ arg0_1 = rand_strided((2, 2048, 151936), (311164928, 151936, 1), device='cuda:2', dtype=torch.bfloat16)
156
+ arg1_1 = rand_strided((151936, ), (1, ), device='cuda:2', dtype=torch.bool)
157
+ arg2_1 = rand_strided((2, 2048, 1), (2048, 1, 1), device='cuda:2', dtype=torch.int64)
158
+ fn = lambda: call([arg0_1, arg1_1, arg2_1])
159
+ return print_performance(fn, times=times, repeat=repeat)
160
+
161
+
162
+ if __name__ == "__main__":
163
+ from torch._inductor.wrapper_benchmark import compiled_module_main
164
+ compiled_module_main('None', benchmark_compiled_module)
SpecForge-ext/cache/compiled_kernels/p6/cp66nvdwdzgxajxp2yjtqapnwidpmfnzcyyalh6z5w6f6lf3aoej.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 16777216},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'ks4': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1', 'mutated_arg_names': [], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_cat_index_mul_neg_slice_squeeze_unsqueeze_1(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, ks4, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x4 = xindex
23
+ x2 = ((xindex // ks0) % ks1)
24
+ x0 = (xindex % ks3)
25
+ x5 = xindex // ks3
26
+ tmp0 = tl.load(in_ptr0 + (x4), xmask, eviction_policy='evict_last').to(tl.float32)
27
+ tmp1 = tl.load(in_ptr1 + (x2), xmask, eviction_policy='evict_last')
28
+ tmp2 = ks2
29
+ tmp3 = tmp1 + tmp2
30
+ tmp4 = tmp1 < 0
31
+ tmp5 = tl.where(tmp4, tmp3, tmp1)
32
+ tl.device_assert(((0 <= tmp5) & (tmp5 < ks2)) | ~(xmask), "index out of bounds: 0 <= tmp5 < ks2")
33
+ tmp7 = tl.load(in_ptr2 + (x0 + ks3*tmp5), xmask, eviction_policy='evict_last').to(tl.float32)
34
+ tmp8 = tmp0 * tmp7
35
+ tmp9 = x0
36
+ tmp10 = tl.full([1], 0, tl.int64)
37
+ tmp11 = tmp9 >= tmp10
38
+ tmp12 = ks3 + (-1)*(ks3 // 2)
39
+ tmp13 = tmp9 < tmp12
40
+ tmp14 = tl.load(in_ptr0 + (ks3*x5 + (ks3 // 2) + (x0)), tmp13 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
41
+ tmp15 = -tmp14
42
+ tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
43
+ tmp17 = tl.where(tmp13, tmp15, tmp16)
44
+ tmp18 = tmp9 >= tmp12
45
+ tmp19 = ks3
46
+ tmp20 = tmp9 < tmp19
47
+ tmp21 = tl.load(in_ptr0 + (ks3*x5 + (x0 + ((-1)*ks3) + (ks3 // 2))), tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
48
+ tmp22 = tl.where(tmp13, tmp17, tmp21)
49
+ tmp23 = ks4
50
+ tmp24 = tmp1 + tmp23
51
+ tmp25 = tl.where(tmp4, tmp24, tmp1)
52
+ tl.device_assert(((0 <= tmp25) & (tmp25 < ks4)) | ~(xmask), "index out of bounds: 0 <= tmp25 < ks4")
53
+ tmp27 = tl.load(in_ptr3 + (x0 + ks3*tmp25), xmask, eviction_policy='evict_last').to(tl.float32)
54
+ tmp28 = tmp22 * tmp27
55
+ tmp29 = tmp8 + tmp28
56
+ tl.store(out_ptr0 + (x4), tmp29, xmask)
SpecForge-ext/cache/compiled_kernels/qa/cqambnamuby4hynvyzhccuoc4f5nkvwpn7yeizvaaaojnmlep42d.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 4096, 'r0_': 4096},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*bf16', 'in_ptr1': 'fp64', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mean_mul_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': False, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused__to_copy_mean_mul_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ rnumel = r0_numel
20
+ RBLOCK: tl.constexpr = R0_BLOCK
21
+ xoffset = tl.program_id(0) * XBLOCK
22
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
23
+ xmask = xindex < xnumel
24
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
25
+ rbase = r0_base
26
+ x0 = xindex
27
+ _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
28
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
29
+ r0_index = r0_offset + r0_base
30
+ r0_mask = r0_index < r0_numel
31
+ roffset = r0_offset
32
+ rindex = r0_index
33
+ r0_1 = r0_index
34
+ tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
35
+ tmp1 = tmp0.to(tl.float32)
36
+ tmp2 = tmp1 * tmp1
37
+ tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
38
+ tmp5 = _tmp4 + tmp3
39
+ _tmp4 = tl.where(r0_mask & xmask, tmp5, _tmp4)
40
+ tmp4 = tl.sum(_tmp4, 1)[:, None]
41
+ tmp9 = in_ptr1
42
+ tmp6 = ks0
43
+ tmp7 = tmp6.to(tl.float32)
44
+ tmp8 = (tmp4 / tmp7)
45
+ tmp10 = tmp9.to(tl.float32)
46
+ tmp11 = tmp8 + tmp10
47
+ tmp12 = libdevice.rsqrt(tmp11)
48
+ tl.debug_barrier()
49
+ tl.store(in_out_ptr0 + (x0), tmp12, xmask)
50
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
51
+ r0_index = r0_offset + r0_base
52
+ r0_mask = r0_index < r0_numel
53
+ roffset = r0_offset
54
+ rindex = r0_index
55
+ r0_1 = r0_index
56
+ tmp13 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
57
+ tmp14 = tl.load(in_ptr0 + (r0_1 + ks0*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
58
+ tmp15 = tmp14.to(tl.float32)
59
+ tmp16 = tmp15 * tmp12
60
+ tmp17 = tmp16.to(tl.float32)
61
+ tmp18 = tmp13 * tmp17
62
+ tl.store(out_ptr0 + (r0_1 + ks0*x0), tmp18, r0_mask & xmask)
SpecForge-ext/cache/compiled_kernels/qa/cqasclcikvb2uryr7k2gtwdnliae55wql22q6kutfmldlk5e7kks.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 2, 'r0_': 8192},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=7, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 8, 'r0_': 131072}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_sum_3(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 2
20
+ r0_numel = 8192
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = xindex < xnumel
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ x0 = xindex
29
+ _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
30
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
31
+ r0_index = r0_offset + r0_base
32
+ r0_mask = r0_index < r0_numel
33
+ roffset = r0_offset
34
+ rindex = r0_index
35
+ r0_1 = r0_index
36
+ tmp0 = tl.load(in_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
37
+ tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
38
+ tmp3 = _tmp2 + tmp1
39
+ _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
40
+ tmp2 = tl.sum(_tmp2, 1)[:, None]
41
+ tl.store(out_ptr0 + (x0), tmp2, xmask)
SpecForge-ext/cache/compiled_kernels/qd/cqd6lffrumnqrtflwfoqtqs6mvn23l4bxialovx3yvqgximtpflz.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.pointwise(
11
+ size_hints={'x': 16777216},
12
+ filename=__file__,
13
+ triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'ks2': 'i64', 'ks3': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
14
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
15
+ min_elem_per_thread=0
16
+ )
17
+ @triton.jit
18
+ def triton_poi_fused_add_index_mul_neg_slice_slice_backward_squeeze_unsqueeze_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, ks0, ks1, ks2, ks3, xnumel, XBLOCK : tl.constexpr):
19
+ xoffset = tl.program_id(0) * XBLOCK
20
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
21
+ xmask = xindex < xnumel
22
+ x0 = (xindex % ks0)
23
+ x3 = xindex
24
+ x1 = ((xindex // ks0) % ks1)
25
+ tmp31 = tl.load(in_ptr0 + (x3), xmask, eviction_policy='evict_last').to(tl.float32)
26
+ tmp32 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
27
+ tmp0 = x0
28
+ tmp1 = ks0 // 2
29
+ tmp2 = tmp0 >= tmp1
30
+ tmp3 = tl.load(in_ptr0 + (x3 + (-1)*(ks0 // 2)), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
31
+ tmp4 = tl.load(in_ptr1 + (x1), tmp2 & xmask, eviction_policy='evict_last', other=0.0)
32
+ tmp5 = tl.broadcast_to(ks2, [XBLOCK])
33
+ tmp6 = tmp4 + tmp5
34
+ tmp7 = tmp4 < 0
35
+ tmp8 = tl.where(tmp7, tmp6, tmp4)
36
+ tl.device_assert(((0 <= tl.broadcast_to(tmp8, [XBLOCK])) & (tl.broadcast_to(tmp8, [XBLOCK]) < ks2)) | ~(tmp2 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp8, [XBLOCK]) < ks2")
37
+ tmp10 = tl.load(in_ptr2 + (x0 + (-1)*(ks0 // 2) + ks0*tmp8), tmp2 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
38
+ tmp11 = tmp3 * tmp10
39
+ tmp12 = -tmp11
40
+ tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
41
+ tmp14 = tl.where(tmp2, tmp12, tmp13)
42
+ tmp15 = 0.0
43
+ tmp16 = tl.where(tmp2, tmp14, tmp15)
44
+ tmp17 = tmp0 < tmp1
45
+ tmp18 = tl.load(in_ptr0 + (ks0 + x3 + (-1)*(ks0 // 2)), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
46
+ tmp19 = tl.load(in_ptr1 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0)
47
+ tmp20 = tl.broadcast_to(ks2, [XBLOCK])
48
+ tmp21 = tmp19 + tmp20
49
+ tmp22 = tmp19 < 0
50
+ tmp23 = tl.where(tmp22, tmp21, tmp19)
51
+ tl.device_assert(((0 <= tl.broadcast_to(tmp23, [XBLOCK])) & (tl.broadcast_to(tmp23, [XBLOCK]) < ks2)) | ~(tmp17 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp23, [XBLOCK]) < ks2")
52
+ tmp25 = tl.load(in_ptr2 + (ks0 + x0 + (-1)*(ks0 // 2) + ks0*tmp23), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
53
+ tmp26 = tmp18 * tmp25
54
+ tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
55
+ tmp28 = tl.where(tmp17, tmp26, tmp27)
56
+ tmp29 = tl.where(tmp17, tmp28, tmp15)
57
+ tmp30 = tmp16 + tmp29
58
+ tmp33 = ks3
59
+ tmp34 = tmp32 + tmp33
60
+ tmp35 = tmp32 < 0
61
+ tmp36 = tl.where(tmp35, tmp34, tmp32)
62
+ tl.device_assert(((0 <= tmp36) & (tmp36 < ks3)) | ~(xmask), "index out of bounds: 0 <= tmp36 < ks3")
63
+ tmp38 = tl.load(in_ptr3 + (x0 + ks0*tmp36), xmask, eviction_policy='evict_last').to(tl.float32)
64
+ tmp39 = tmp31 * tmp38
65
+ tmp40 = tmp30 + tmp39
66
+ tl.store(out_ptr0 + (x3), tmp40, xmask)
SpecForge-ext/cache/compiled_kernels/qd/cqd7l2ktsaxhv4w2pgoiwvrihj6ya2rmzfvnjybryke4aa6nwpjp.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
6
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
7
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
8
+ triton_helpers.set_driver_to_gpu()
9
+
10
+ @triton_heuristics.reduction(
11
+ size_hints={'x': 1, 'r0_': 4096},
12
+ reduction_hint=ReductionHint.INNER,
13
+ filename=__file__,
14
+ triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*i64', 'in_ptr3': '*i64', 'out_ptr2': '*fp32', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=1, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
15
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'B0E5936CA26D1BCD1B577D0B65F13FE7553DC941EB93358973E9F1902BC212C5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'r0_': 131072}}
16
+ )
17
+ @triton.jit
18
+ def triton_red_fused_clamp_min_div_eq_mul_squeeze_sum_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
19
+ xnumel = 1
20
+ r0_numel = 4096
21
+ rnumel = r0_numel
22
+ RBLOCK: tl.constexpr = R0_BLOCK
23
+ xoffset = tl.program_id(0) * XBLOCK
24
+ xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
25
+ xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
26
+ r0_base = tl.arange(0, R0_BLOCK)[None, :]
27
+ rbase = r0_base
28
+ _tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
29
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
30
+ r0_index = r0_offset + r0_base
31
+ r0_mask = r0_index < r0_numel
32
+ roffset = r0_offset
33
+ rindex = r0_index
34
+ r0_0 = r0_index
35
+ tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
36
+ tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
37
+ tmp4 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
38
+ tmp2 = tmp0 == tmp1
39
+ tmp3 = tmp2.to(tl.int64)
40
+ tmp5 = tmp3 * tmp4
41
+ tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
42
+ tmp8 = _tmp7 + tmp6
43
+ _tmp7 = tl.where(r0_mask, tmp8, _tmp7)
44
+ tmp7 = tl.sum(_tmp7, 1)[:, None]
45
+ _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
46
+ for r0_offset in range(0, r0_numel, R0_BLOCK):
47
+ r0_index = r0_offset + r0_base
48
+ r0_mask = r0_index < r0_numel
49
+ roffset = r0_offset
50
+ rindex = r0_index
51
+ r0_0 = r0_index
52
+ tmp9 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0)
53
+ tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
54
+ tmp12 = _tmp11 + tmp10
55
+ _tmp11 = tl.where(r0_mask, tmp12, _tmp11)
56
+ tmp11 = tl.sum(_tmp11, 1)[:, None]
57
+ tmp13 = tmp7.to(tl.float32)
58
+ tmp14 = tmp11.to(tl.float32)
59
+ tmp15 = 1e-06
60
+ tmp16 = triton_helpers.maximum(tmp14, tmp15)
61
+ tmp17 = (tmp13 / tmp16)
62
+ tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp17, None)