YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

TriAttention β€” RTX 5080 (SM 120) Verification Plan

Goal: Confirm that TriAttention runs correctly end-to-end on an NVIDIA RTX 5080 (compute capability 12.0 / Blackwell).

Background

TriAttention compresses KV cache for long-reasoning LLMs using frequency-domain scoring. Its GPU compute surface consists of:

Layer Technology SM 120 expectation
Core scoring kernel Triton JIT (triton_scoring.py) Triton 3.2+ compiles PTX at runtime for detected GPU β€” no hardcoded arch targets
All other tensor ops PyTorch (torch.topk, gather, index_select, …) PyTorch β‰₯ 2.6 with CUDA β‰₯ 12.8 supports SM 120
Optional attention FlashAttention (flash-attn pip package) flash-attn setup.py includes sm_120 in default arch list; requires CUDA β‰₯ 12.8 to compile from source
Attention fallback sdpa / eager (pure PyTorch) No arch restriction
vLLM integration Plugin monkeypatches scheduler + worker Depends on vLLM wheel having SM 120 support; DGX Spark (SM 121, same Blackwell family) already merged

There are no custom CUDA C/C++ kernels (.cu / .cuh) in the repository. The only GPU-compiled code is the Triton scoring kernel, which uses standard tl.* operations (load, store, cos, sin, sqrt, sum, maximum) β€” no inline PTX or warp-level intrinsics.

The nearby SM 121 (DGX Spark / GB10) has been community-tested and merged into the repo as of 2026-04-14.


Phase 1 β€” Environment Validation

1.1 Confirm hardware and driver

nvidia-smi --query-gpu=name,compute_cap,driver_version --format=csv
# Expected: RTX 5080, 12.0, driver β‰₯ 570.x

1.2 Confirm CUDA toolkit

nvcc --version
# Must be β‰₯ 12.8

1.3 Confirm PyTorch sees the GPU

import torch

print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("Device name:", torch.cuda.get_device_name(0))

cap = torch.cuda.get_device_capability(0)
print(f"Compute capability: {cap[0]}.{cap[1]}")  # Expect 12.0

print("bf16 support:", torch.cuda.is_bf16_supported())

# Smoke test
a = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
b = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
c = a @ b
print("Matmul smoke test passed, output device:", c.device)

1.4 Confirm Triton compiles for SM 120

import triton
import triton.language as tl
import torch

print("Triton version:", triton.__version__)

@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n: tl.constexpr):
    offs = tl.arange(0, n)
    x = tl.load(x_ptr + offs)
    y = tl.load(y_ptr + offs)
    tl.store(out_ptr + offs, x + y)

x = torch.randn(128, device="cuda")
y = torch.randn(128, device="cuda")
out = torch.empty(128, device="cuda")
add_kernel[(1,)](x, y, out, 128)
assert torch.allclose(out, x + y, atol=1e-5)
print("Triton JIT compilation + execution on SM 120: PASSED")

Phase 2 β€” TriAttention Installation

2.1 Install the package

git clone https://github.com/WeianMao/triattention.git
cd triattention
pip install -e .

2.2 Install flash-attn (optional β€” test both paths)

pip install flash-attn --no-build-isolation 2>&1 | tee flash_attn_build.log

python3 -c "import flash_attn; print('flash-attn version:', flash_attn.__version__)" \
  && echo "FLASH_ATTN: OK" \
  || echo "FLASH_ATTN: FAILED (will use sdpa fallback)"

2.3 Verify all imports resolve

from triattention.methods.triattention import TriAttention, TriAttentionConfig, apply_triattention_patch
from triattention.methods.pruning_utils import (
    build_rotary, score_keys_for_round, compute_frequency_statistics_from_means,
    load_head_frequency_stats, build_geometric_offsets, compute_frequency_scaling,
)
from triattention.vllm.core.kernels.triton_scoring import (
    triattention_scoring, TrigTableCache, create_trig_cache,
)
from triattention.vllm.core.scoring import compute_scores
from triattention.vllm.core.compressor import TriAttentionCompressor
from triattention.vllm.core.config import TriAttentionConfig as VLLMTriAttentionConfig
print("All imports: PASSED")

Phase 3 β€” Triton Scoring Kernel Compilation & Correctness

This is the critical gate. The custom Triton kernel must JIT-compile and produce correct results on SM 120.

3.1 Kernel compiles and runs

import torch
from triattention.vllm.core.kernels.triton_scoring import triattention_scoring

batch, heads, seq_len, head_dim = 1, 8, 512, 128
freq_count = head_dim // 2

K_rot = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
q_mean_real = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_mean_imag = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_abs_mean = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32).abs()
freq_scale_sq = torch.ones(heads, freq_count, device="cuda", dtype=torch.float32)
omega = torch.randn(freq_count, device="cuda", dtype=torch.float32)
offsets = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0], device="cuda", dtype=torch.float32)

scores = triattention_scoring(
    K_rot=K_rot,
    position_indices=None,
    q_mean_real=q_mean_real,
    q_mean_imag=q_mean_imag,
    q_abs_mean=q_abs_mean,
    freq_scale_sq=freq_scale_sq,
    omega=omega,
    offsets=offsets,
    round_start=512,
    aggregation="mean",
    rope_style="half",
)
assert scores.shape == (batch, heads, seq_len), f"Bad shape: {scores.shape}"
assert scores.isfinite().all(), "Non-finite scores detected"
print(f"Triton scoring kernel: PASSED (shape={tuple(scores.shape)}, "
      f"range=[{scores.min():.4f}, {scores.max():.4f}])")

3.2 Triton vs PyTorch numerical equivalence

import torch
from triattention.vllm.core.kernels.triton_scoring import triattention_scoring
from triattention.vllm.core.config import TriAttentionConfig
from triattention.vllm.core.scoring import compute_scores_pytorch

batch, heads, seq_len, head_dim = 1, 4, 256, 128
freq_count = head_dim // 2

K_rot = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
q_mean_real = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_mean_imag = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_abs_mean = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32).abs() + 0.1
freq_scale_sq = torch.ones(heads, freq_count, device="cuda", dtype=torch.float32)
omega = torch.randn(freq_count, device="cuda", dtype=torch.float32) * 0.01
offsets = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0], device="cuda", dtype=torch.float32)

# Triton path
triton_scores = triattention_scoring(
    K_rot=K_rot, position_indices=None,
    q_mean_real=q_mean_real, q_mean_imag=q_mean_imag,
    q_abs_mean=q_abs_mean, freq_scale_sq=freq_scale_sq,
    omega=omega, offsets=offsets, round_start=256,
    aggregation="mean", rope_style="half",
)

# PyTorch path
config = TriAttentionConfig(
    kv_budget=128, pruning_mode="per_head", score_aggregation="mean",
    rope_style="half", use_triton_scoring=False, head_dim=head_dim,
    num_kv_heads=heads, disable_mlr=False,
)
q_mean_complex = torch.stack([q_mean_real, q_mean_imag], dim=-1)
head_stats = {"q_mean_complex": q_mean_complex, "q_abs_mean": q_abs_mean}
pytorch_scores = compute_scores_pytorch(
    key_states=K_rot, cache_positions=None,
    head_stats=head_stats, omega=omega, offsets=offsets,
    freq_scale_sq=freq_scale_sq, config=config, round_start=256,
)

max_diff = (triton_scores - pytorch_scores).abs().max().item()
mean_diff = (triton_scores - pytorch_scores).abs().mean().item()
print(f"Max diff: {max_diff:.6e}, Mean diff: {mean_diff:.6e}")
assert max_diff < 1e-2, f"Triton/PyTorch mismatch too large: {max_diff}"
print("Triton vs PyTorch equivalence: PASSED")

3.3 TrigTableCache (precomputed trig) path

import torch
from triattention.vllm.core.kernels.triton_scoring import (
    triattention_scoring, create_trig_cache,
)

batch, heads, seq_len, head_dim = 1, 8, 1024, 128
freq_count = head_dim // 2

K_rot = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
q_mean_real = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_mean_imag = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_abs_mean = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32).abs()
freq_scale_sq = torch.ones(heads, freq_count, device="cuda", dtype=torch.float32)
omega = torch.randn(freq_count, device="cuda", dtype=torch.float32) * 0.01
offsets = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0], device="cuda", dtype=torch.float32)

trig_cache = create_trig_cache(
    max_seq_len=8192, compress_interval=128,
    offsets=offsets, omega=omega, device=torch.device("cuda"),
)
print(f"TrigTableCache created: {trig_cache}")

scores = triattention_scoring(
    K_rot=K_rot, position_indices=None,
    q_mean_real=q_mean_real, q_mean_imag=q_mean_imag,
    q_abs_mean=q_abs_mean, freq_scale_sq=freq_scale_sq,
    omega=omega, offsets=offsets, round_start=256,
    aggregation="mean", rope_style="half", trig_cache=trig_cache,
)
assert scores.shape == (batch, heads, seq_len)
assert scores.isfinite().all()
print("TrigTableCache kernel path: PASSED")

Phase 4 β€” End-to-End Inference (Non-vLLM Path)

4.1 CLI run with SDPA (safest β€” no flash-attn dependency)

python scripts/cli.py run-one \
    --model Qwen/Qwen3-8B \
    --dataset aime24 \
    --method triattention \
    --budget 2048 \
    --attn-implementation sdpa \
    --max-questions 2 \
    2>&1 | tee run_sdpa.log

grep -E "accuracy|score|error|Error|CUDA|Triton" run_sdpa.log

4.2 CLI run with flash_attention_2 (if Phase 2.2 succeeded)

python scripts/cli.py run-one \
    --model Qwen/Qwen3-8B \
    --dataset aime24 \
    --method triattention \
    --budget 2048 \
    --attn-implementation flash_attention_2 \
    --max-questions 2 \
    2>&1 | tee run_flash.log

4.3 Full Attention baseline (sanity check the model itself works)

python scripts/cli.py run-one \
    --model Qwen/Qwen3-8B \
    --dataset aime24 \
    --method full \
    --attn-implementation sdpa \
    --max-questions 2 \
    2>&1 | tee run_full.log

Phase 5 β€” vLLM Integration

5.1 Install vLLM with SM 120 support

pip install vllm  # or build from source if pre-built wheel lacks SM 120
python3 -c "import vllm; print('vLLM:', vllm.__version__)"

5.2 Launch vLLM server with TriAttention plugin

export TRIATTN_RUNTIME_KV_BUDGET=2048
export TRIATTN_RUNTIME_SPARSE_STATS_PATH=triattention/vllm/stats/qwen3_32b_int4_stats.pt

vllm serve Qwen/Qwen3-8B \
    --dtype bfloat16 \
    --max-model-len 8192 \
    --enforce-eager \
    --trust-remote-code \
    --enable-prefix-caching false \
    2>&1 | tee vllm_server.log &

sleep 60

grep "TriAttention.*Runtime.*activated" vllm_server.log \
  && echo "PLUGIN: ACTIVATED" \
  || echo "PLUGIN: NOT FOUND β€” check logs"

5.3 Query the server

curl -s http://localhost:8000/v1/models | python3 -m json.tool

curl -s http://localhost:8000/v1/chat/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "Qwen/Qwen3-8B",
        "messages": [{"role": "user", "content": "What is 2+2?"}],
        "max_tokens": 128
    }' | python3 -m json.tool

echo "vLLM serving with TriAttention: PASSED"

Phase 6 β€” Performance Sanity Check

6.1 Triton kernel latency

import torch
import time
from triattention.vllm.core.kernels.triton_scoring import triattention_scoring

batch, heads, seq_len, head_dim = 1, 8, 4096, 128
freq_count = head_dim // 2

K_rot = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
q_mean_real = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_mean_imag = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_abs_mean = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32).abs()
freq_scale_sq = torch.ones(heads, freq_count, device="cuda", dtype=torch.float32)
omega = torch.randn(freq_count, device="cuda", dtype=torch.float32) * 0.01
offsets = torch.tensor(
    [1., 2., 4., 8., 16., 32., 64., 128., 256., 512.,
     1024., 2048., 4096., 8192., 16384., 32768., 65536.],
    device="cuda",
)

# Warmup (includes JIT compile)
for _ in range(3):
    triattention_scoring(
        K_rot=K_rot, position_indices=None, q_mean_real=q_mean_real,
        q_mean_imag=q_mean_imag, q_abs_mean=q_abs_mean, freq_scale_sq=freq_scale_sq,
        omega=omega, offsets=offsets, round_start=4096,
        aggregation="mean", rope_style="half",
    )

torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(100):
    triattention_scoring(
        K_rot=K_rot, position_indices=None, q_mean_real=q_mean_real,
        q_mean_imag=q_mean_imag, q_abs_mean=q_abs_mean, freq_scale_sq=freq_scale_sq,
        omega=omega, offsets=offsets, round_start=4096,
        aggregation="mean", rope_style="half",
    )
torch.cuda.synchronize()
elapsed = (time.perf_counter() - t0) / 100 * 1000

print(f"Triton scoring kernel: {elapsed:.2f} ms/call (seq_len={seq_len}, heads={heads})")
print("(Compare with published ~0.1-0.5ms on A100 at similar config)")

Pass / Fail Criteria

Test Gate Acceptable Failure
1.3 β€” PyTorch sees RTX 5080 as 12.0 HARD β€”
1.4 β€” Triton smoke test HARD β€”
3.1 β€” Triton scoring kernel compiles HARD β€”
3.2 β€” Triton / PyTorch max diff < 1e-2 HARD β€”
4.1 β€” CLI run with sdpa produces answers HARD β€”
2.2 β€” flash-attn builds SOFT Use sdpa fallback
4.2 β€” CLI run with flash_attention_2 SOFT Use sdpa fallback
5.2 β€” vLLM plugin activates SOFT Depends on vLLM SM 120 wheel availability
6.1 β€” Kernel latency reasonable SOFT Perf regression is a bug report, not a blocker

Phases 1–3 are the minimum viable check (15 min). Phase 4 adds model-level confidence (30 min + model download). Phases 5–6 cover the production serving path.


Minimum Software Versions

Requirement Minimum
CUDA Toolkit β‰₯ 12.8
PyTorch β‰₯ 2.6 with cu128 or cu130 wheel
Triton β‰₯ 3.2 (ships with PyTorch 2.6+; current 3.6)
flash-attn (optional) β‰₯ 2.7, compiled from source with CUDA β‰₯ 12.8
vLLM (if serving) β‰₯ 0.19 with Blackwell support
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support