Instructions to use kernels-community/punica-sgmv with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/punica-sgmv with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/punica-sgmv") - Notebooks
- Google Colab
- Kaggle
F6: chunked multi-segment dispatch produces non-deterministic silent corruption at production scale
Summary
Following up on F1 (https://huggingface.co/kernels-community/punica-sgmv/discussions/1) and F2 (https://huggingface.co/kernels-community/punica-sgmv/discussions/2): the F2 mitigation (cap each SGMV segment at MAX_SAFE_SGMV_SEGMENT = 128 tokens and chunk longer flat-batches into multiple equivalent segments) is necessary but not sufficient at d=4096 production scale. Even when every chunk is at most 128 tokens, multi-segment dispatch produces non-deterministic, data-dependent silent numerical corruption against the PEFT-factored bf16 reference.
This is a Punica F2-class regression at the multi-segment dispatch boundary, distinct from the single-segment length boundary that F2 itself describes.
Environment
- Kernel repo snapshot:
kernels-community/punica-sgmvat commitbe89a97bbd04562e2834d5c8f0e9342dc7ae7715 - Build variant resolved:
build/torch29-cxx11-cu130-x86_64-linux/ - Build sha256:
a3b5c0a1513c6721f082de7986d3525506e96d0c901d9925a94255635d8386f4 - PyTorch: 2.9.1+cu130
- CUDA runtime: 13.0
- Python: 3.11.15
kernelsruntime: 0.13.0peft: 0.19.0- GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition, SM 12.0
- Listed in the kernel
metadata.jsonbackend.archsset, so the build variant is the intended one for this target.
Repro
A flat-batch input of T=361 tokens is dispatched as 3 SGMV segments of [128, 128, 105] (or 6 segments at B=2). Each individual segment is at or below the F2-safe ceiling of 128 tokens. The dispatch wrapper allocates the SGMV segment-pointer table for the chunked layout, calls add_lora_sgmv_cutlass once on the multi-segment input, and compares against the PEFT-equivalent factored bf16 reference ((x @ A.T) @ B.T) * scaling.
Minimal repro shape: a single nn.Linear(4096, 1024) at v_proj with bf16 weights and a LoRA factor pair at rank=16, dispatched via SGMV with multi-segment input T=361 chunked at 128.
import torch
from kernels import get_kernel
kernel = get_kernel("kernels-community/punica-sgmv")
device = "cuda"
dtype = torch.bfloat16
in_features, out_features, rank = 4096, 1024, 16
scaling = 32.0 / float(rank)
T = 361
MAX_SAFE = 128
def make_factors(seed):
g = torch.Generator(device="cpu").manual_seed(seed)
A = (torch.randn(rank, in_features, generator=g) * 0.02).to(device, dtype)
B = (torch.randn(out_features, rank, generator=g) * 0.02).to(device, dtype)
return A, B
def factored_ref(x, A, B):
return ((x @ A.T) @ B.T) * scaling
def sgmv_chunked(x, A, B):
n_tokens = x.shape[0]
wa = A.unsqueeze(0).contiguous()
wb = B.transpose(0, 1).unsqueeze(0).contiguous()
y = torch.zeros(n_tokens, out_features, device=device, dtype=dtype)
starts, ends = [], []
pos = 0
while pos < n_tokens:
end = min(pos + MAX_SAFE, n_tokens)
starts.append(pos)
ends.append(end)
pos = end
n_seg = len(starts)
wa_ptr = torch.tensor([wa.data_ptr()] * n_seg, dtype=torch.int64, device=device)
wb_ptr = torch.tensor([wb.data_ptr()] * n_seg, dtype=torch.int64, device=device)
s_start = torch.tensor(starts, dtype=torch.int32, device=device)
s_end = torch.tensor(ends, dtype=torch.int32, device=device)
kernel.add_lora_sgmv_cutlass(y, x, wa_ptr, wb_ptr, s_start, s_end, 0, rank)
return y * scaling
for seed in [12345, 31415, 20260426, 271828, 1618033]:
A, B = make_factors(seed)
g = torch.Generator(device="cpu").manual_seed(seed + 1)
x = torch.randn(T, in_features, generator=g).to(device, dtype).contiguous()
y_kernel = sgmv_chunked(x, A, B)
y_ref = factored_ref(x, A, B)
err = (y_kernel.float() - y_ref.float()).abs().max().item()
print(f"seed={seed:>10} T={T} segs=3 max_abs_err={err:.4e}")
Expected behavior
Under F2's documented mitigation, every segment is at most 128 tokens, so each segment's contribution should be bit-correct. The aggregated multi-segment output should match the PEFT factored bf16 reference within the bf16 GEMM accumulation tolerance (1e-2).
Actual behavior (measured at production scale, d=4096)
A 200-iteration microbench at production Qwen3-8B target shapes ran on the SM 12.0 RTX PRO 6000 Blackwell. Of 14 prefill cells (T=361 across 7 target shapes at B=1 and B=2, all chunked to segments of 128 or fewer), 13 of 14 fail the bf16 1e-2 numerical tolerance against the PEFT factored reference. Errors range from 1.56e-2 (just over the threshold) to 6.49e-1 (about 65 percent of the output magnitude, total corruption).
| B | T | shape | segs | bf16 max abs err | verdict |
|---|---|---|---|---|---|
| 1 | 361 | q_proj | 3 | 1.56e-2 | FAIL |
| 1 | 361 | k_proj | 3 | 1.56e-2 | FAIL |
| 1 | 361 | v_proj | 3 | 3.48e-1 | FAIL |
| 1 | 361 | o_proj | 3 | 1.56e-2 | FAIL |
| 1 | 361 | gate_proj | 3 | 1.56e-2 | FAIL |
| 1 | 361 | up_proj | 3 | 1.56e-2 | FAIL |
| 1 | 361 | down_proj | 3 | 6.49e-1 | FAIL |
| 2 | 361 | q_proj | 6 | 3.33e-1 | FAIL |
| 2 | 361 | k_proj | 6 | 7.81e-3 | PASS |
| 2 | 361 | v_proj | 6 | 2.74e-1 | FAIL |
| 2 | 361 | o_proj | 6 | 1.56e-2 | FAIL |
| 2 | 361 | gate_proj | 6 | 4.04e-1 | FAIL |
| 2 | 361 | up_proj | 6 | 4.98e-1 | FAIL |
| 2 | 361 | down_proj | 6 | 4.98e-1 | FAIL |
The 21 decode cells (T=1, single-segment dispatch) all pass bf16 1e-2; errors range from 0.0 (bit-equivalent) to 7.8e-3. Single-segment dispatch at n_tokens at most 128 remains numerically clean, consistent with F2.
The corruption is non-deterministic across seeds. At v_proj, B=1, T=361, with 3 chunked segments of [128, 128, 105]:
| seed | max abs err |
|---|---|
| 12345 | 7.81e-3 (PASS) |
| 31415 | 7.81e-3 (PASS) |
| 20260426 | 4.12e-1 (FAIL) |
| 271828 | 1.81e-1 (FAIL) |
| 1618033 | 2.30e-1 (FAIL) |
3 of 5 spot-check seeds fail; 2 pass. The pattern is data-dependent on the (LoRA factor, input) draw, not categorical on the chunked layout itself.
Why this is a F2-class regression beyond F2
F2 documents that add_lora_sgmv_cutlass produces silently incorrect results when a single segment exceeds 144 tokens, and proposes the workaround MAX_SAFE_SGMV_SEGMENT = 128. F6 finds that even with that workaround applied, the multi-segment dispatch boundary itself produces silently incorrect results at d=4096 with bf16 inputs. The bug is not at the per-segment length, it is at the cross-segment accumulation when more than one segment is present in a single dispatch.
A consumer following F2's documented mitigation in good faith and verifying it on a single-seed canary would not detect F6: 2 of 5 seeds at the same shape pass cleanly. A production deployment that relies on SGMV for prefill paths would emit silently-wrong outputs at random under load.
Suggested fix direction
Investigate add_lora_sgmv_cutlass segment-descriptor handling for multi-segment inputs. Plausible failure sites in the CUTLASS dispatch:
- Segment-pointer table addressing. When the kernel reads
s_start[i]ands_end[i]for segmenti, does it correctly bound the per-segment GEMM, or is the residual from segmenti-1leaking into segmenti's output rows? The 112-byte boundary in F2's tile analysis would correspond to a partial-tile residual write. - Cross-segment accumulation in the expand kernel. SGMV's two-step (shrink to rank, expand to out_features) requires the rank-R intermediate buffer to be partitioned per segment. If the partition pointers are computed once for the kernel launch but the expand step reads from a shared rank-R slab, the second and third segments could read partially-overwritten intermediate values.
- Data-dependent shared memory bank conflicts at d=4096. The single-segment d=512 case in the F2 repro is bit-correct; the multi-segment d=4096 case here is not. The shared-memory layout for the rank-16 reduction at out_features=1024 (v_proj) and 4096 (gate_proj/up_proj/down_proj) at production dimensions may exercise a bank-conflict pattern that the d=512 staging case does not.
We can share intermediate dumps from the broken cases (chunked input, segment-pointer table contents, kernel output, factored reference) if helpful for narrowing the failure site.
Repro artifacts
- Microbench JSON (200 iterations per cell, 35 cells, all timing and numerical-equivalence data): available on request as
punica_sgmv_microbench_prod_1777207023.json. - Microbench protocol: 20 warmup iterations, 200 timed iterations interleaved across factored / sgmv / merged modes round-robin per iteration, cuda-event timing with
torch.cuda.synchronize()at trial boundaries. - Reproducibility: random seeds set, per-cell
torch.Generatorfor layer init and input draws, kernel snapshot sha256 captured.
Snapshot pinning
Reproduced against be89a97bbd04562e2834d5c8f0e9342dc7ae7715 build variant torch29-cxx11-cu130-x86_64-linux on SM 12.0. Confirmation that this is also reproducible on main would help us know when to remove the chunking workaround entirely. F6 closure plus the F2 closure together would let us drop the MAX_SAFE_SGMV_SEGMENT cap.