YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
- TriAttention β RTX 5080 (SM 120) Verification Plan
- Background
- Phase 1 β Environment Validation
- Phase 2 β TriAttention Installation
- Phase 3 β Triton Scoring Kernel Compilation & Correctness
- Phase 4 β End-to-End Inference (Non-vLLM Path)
- Phase 5 β vLLM Integration
- Phase 6 β Performance Sanity Check
- Pass / Fail Criteria
- Minimum Software Versions
- Background
TriAttention β RTX 5080 (SM 120) Verification Plan
Goal: Confirm that TriAttention runs correctly end-to-end on an NVIDIA RTX 5080 (compute capability 12.0 / Blackwell).
Background
TriAttention compresses KV cache for long-reasoning LLMs using frequency-domain scoring. Its GPU compute surface consists of:
| Layer | Technology | SM 120 expectation |
|---|---|---|
| Core scoring kernel | Triton JIT (triton_scoring.py) |
Triton 3.2+ compiles PTX at runtime for detected GPU β no hardcoded arch targets |
| All other tensor ops | PyTorch (torch.topk, gather, index_select, β¦) |
PyTorch β₯ 2.6 with CUDA β₯ 12.8 supports SM 120 |
| Optional attention | FlashAttention (flash-attn pip package) |
flash-attn setup.py includes sm_120 in default arch list; requires CUDA β₯ 12.8 to compile from source |
| Attention fallback | sdpa / eager (pure PyTorch) |
No arch restriction |
| vLLM integration | Plugin monkeypatches scheduler + worker | Depends on vLLM wheel having SM 120 support; DGX Spark (SM 121, same Blackwell family) already merged |
There are no custom CUDA C/C++ kernels (.cu / .cuh) in the repository. The only GPU-compiled code is the Triton scoring kernel, which uses standard tl.* operations (load, store, cos, sin, sqrt, sum, maximum) β no inline PTX or warp-level intrinsics.
The nearby SM 121 (DGX Spark / GB10) has been community-tested and merged into the repo as of 2026-04-14.
Phase 1 β Environment Validation
1.1 Confirm hardware and driver
nvidia-smi --query-gpu=name,compute_cap,driver_version --format=csv
# Expected: RTX 5080, 12.0, driver β₯ 570.x
1.2 Confirm CUDA toolkit
nvcc --version
# Must be β₯ 12.8
1.3 Confirm PyTorch sees the GPU
import torch
print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("Device name:", torch.cuda.get_device_name(0))
cap = torch.cuda.get_device_capability(0)
print(f"Compute capability: {cap[0]}.{cap[1]}") # Expect 12.0
print("bf16 support:", torch.cuda.is_bf16_supported())
# Smoke test
a = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
b = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
c = a @ b
print("Matmul smoke test passed, output device:", c.device)
1.4 Confirm Triton compiles for SM 120
import triton
import triton.language as tl
import torch
print("Triton version:", triton.__version__)
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n: tl.constexpr):
offs = tl.arange(0, n)
x = tl.load(x_ptr + offs)
y = tl.load(y_ptr + offs)
tl.store(out_ptr + offs, x + y)
x = torch.randn(128, device="cuda")
y = torch.randn(128, device="cuda")
out = torch.empty(128, device="cuda")
add_kernel[(1,)](x, y, out, 128)
assert torch.allclose(out, x + y, atol=1e-5)
print("Triton JIT compilation + execution on SM 120: PASSED")
Phase 2 β TriAttention Installation
2.1 Install the package
git clone https://github.com/WeianMao/triattention.git
cd triattention
pip install -e .
2.2 Install flash-attn (optional β test both paths)
pip install flash-attn --no-build-isolation 2>&1 | tee flash_attn_build.log
python3 -c "import flash_attn; print('flash-attn version:', flash_attn.__version__)" \
&& echo "FLASH_ATTN: OK" \
|| echo "FLASH_ATTN: FAILED (will use sdpa fallback)"
2.3 Verify all imports resolve
from triattention.methods.triattention import TriAttention, TriAttentionConfig, apply_triattention_patch
from triattention.methods.pruning_utils import (
build_rotary, score_keys_for_round, compute_frequency_statistics_from_means,
load_head_frequency_stats, build_geometric_offsets, compute_frequency_scaling,
)
from triattention.vllm.core.kernels.triton_scoring import (
triattention_scoring, TrigTableCache, create_trig_cache,
)
from triattention.vllm.core.scoring import compute_scores
from triattention.vllm.core.compressor import TriAttentionCompressor
from triattention.vllm.core.config import TriAttentionConfig as VLLMTriAttentionConfig
print("All imports: PASSED")
Phase 3 β Triton Scoring Kernel Compilation & Correctness
This is the critical gate. The custom Triton kernel must JIT-compile and produce correct results on SM 120.
3.1 Kernel compiles and runs
import torch
from triattention.vllm.core.kernels.triton_scoring import triattention_scoring
batch, heads, seq_len, head_dim = 1, 8, 512, 128
freq_count = head_dim // 2
K_rot = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
q_mean_real = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_mean_imag = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_abs_mean = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32).abs()
freq_scale_sq = torch.ones(heads, freq_count, device="cuda", dtype=torch.float32)
omega = torch.randn(freq_count, device="cuda", dtype=torch.float32)
offsets = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0], device="cuda", dtype=torch.float32)
scores = triattention_scoring(
K_rot=K_rot,
position_indices=None,
q_mean_real=q_mean_real,
q_mean_imag=q_mean_imag,
q_abs_mean=q_abs_mean,
freq_scale_sq=freq_scale_sq,
omega=omega,
offsets=offsets,
round_start=512,
aggregation="mean",
rope_style="half",
)
assert scores.shape == (batch, heads, seq_len), f"Bad shape: {scores.shape}"
assert scores.isfinite().all(), "Non-finite scores detected"
print(f"Triton scoring kernel: PASSED (shape={tuple(scores.shape)}, "
f"range=[{scores.min():.4f}, {scores.max():.4f}])")
3.2 Triton vs PyTorch numerical equivalence
import torch
from triattention.vllm.core.kernels.triton_scoring import triattention_scoring
from triattention.vllm.core.config import TriAttentionConfig
from triattention.vllm.core.scoring import compute_scores_pytorch
batch, heads, seq_len, head_dim = 1, 4, 256, 128
freq_count = head_dim // 2
K_rot = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
q_mean_real = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_mean_imag = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_abs_mean = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32).abs() + 0.1
freq_scale_sq = torch.ones(heads, freq_count, device="cuda", dtype=torch.float32)
omega = torch.randn(freq_count, device="cuda", dtype=torch.float32) * 0.01
offsets = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0], device="cuda", dtype=torch.float32)
# Triton path
triton_scores = triattention_scoring(
K_rot=K_rot, position_indices=None,
q_mean_real=q_mean_real, q_mean_imag=q_mean_imag,
q_abs_mean=q_abs_mean, freq_scale_sq=freq_scale_sq,
omega=omega, offsets=offsets, round_start=256,
aggregation="mean", rope_style="half",
)
# PyTorch path
config = TriAttentionConfig(
kv_budget=128, pruning_mode="per_head", score_aggregation="mean",
rope_style="half", use_triton_scoring=False, head_dim=head_dim,
num_kv_heads=heads, disable_mlr=False,
)
q_mean_complex = torch.stack([q_mean_real, q_mean_imag], dim=-1)
head_stats = {"q_mean_complex": q_mean_complex, "q_abs_mean": q_abs_mean}
pytorch_scores = compute_scores_pytorch(
key_states=K_rot, cache_positions=None,
head_stats=head_stats, omega=omega, offsets=offsets,
freq_scale_sq=freq_scale_sq, config=config, round_start=256,
)
max_diff = (triton_scores - pytorch_scores).abs().max().item()
mean_diff = (triton_scores - pytorch_scores).abs().mean().item()
print(f"Max diff: {max_diff:.6e}, Mean diff: {mean_diff:.6e}")
assert max_diff < 1e-2, f"Triton/PyTorch mismatch too large: {max_diff}"
print("Triton vs PyTorch equivalence: PASSED")
3.3 TrigTableCache (precomputed trig) path
import torch
from triattention.vllm.core.kernels.triton_scoring import (
triattention_scoring, create_trig_cache,
)
batch, heads, seq_len, head_dim = 1, 8, 1024, 128
freq_count = head_dim // 2
K_rot = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
q_mean_real = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_mean_imag = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_abs_mean = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32).abs()
freq_scale_sq = torch.ones(heads, freq_count, device="cuda", dtype=torch.float32)
omega = torch.randn(freq_count, device="cuda", dtype=torch.float32) * 0.01
offsets = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0], device="cuda", dtype=torch.float32)
trig_cache = create_trig_cache(
max_seq_len=8192, compress_interval=128,
offsets=offsets, omega=omega, device=torch.device("cuda"),
)
print(f"TrigTableCache created: {trig_cache}")
scores = triattention_scoring(
K_rot=K_rot, position_indices=None,
q_mean_real=q_mean_real, q_mean_imag=q_mean_imag,
q_abs_mean=q_abs_mean, freq_scale_sq=freq_scale_sq,
omega=omega, offsets=offsets, round_start=256,
aggregation="mean", rope_style="half", trig_cache=trig_cache,
)
assert scores.shape == (batch, heads, seq_len)
assert scores.isfinite().all()
print("TrigTableCache kernel path: PASSED")
Phase 4 β End-to-End Inference (Non-vLLM Path)
4.1 CLI run with SDPA (safest β no flash-attn dependency)
python scripts/cli.py run-one \
--model Qwen/Qwen3-8B \
--dataset aime24 \
--method triattention \
--budget 2048 \
--attn-implementation sdpa \
--max-questions 2 \
2>&1 | tee run_sdpa.log
grep -E "accuracy|score|error|Error|CUDA|Triton" run_sdpa.log
4.2 CLI run with flash_attention_2 (if Phase 2.2 succeeded)
python scripts/cli.py run-one \
--model Qwen/Qwen3-8B \
--dataset aime24 \
--method triattention \
--budget 2048 \
--attn-implementation flash_attention_2 \
--max-questions 2 \
2>&1 | tee run_flash.log
4.3 Full Attention baseline (sanity check the model itself works)
python scripts/cli.py run-one \
--model Qwen/Qwen3-8B \
--dataset aime24 \
--method full \
--attn-implementation sdpa \
--max-questions 2 \
2>&1 | tee run_full.log
Phase 5 β vLLM Integration
5.1 Install vLLM with SM 120 support
pip install vllm # or build from source if pre-built wheel lacks SM 120
python3 -c "import vllm; print('vLLM:', vllm.__version__)"
5.2 Launch vLLM server with TriAttention plugin
export TRIATTN_RUNTIME_KV_BUDGET=2048
export TRIATTN_RUNTIME_SPARSE_STATS_PATH=triattention/vllm/stats/qwen3_32b_int4_stats.pt
vllm serve Qwen/Qwen3-8B \
--dtype bfloat16 \
--max-model-len 8192 \
--enforce-eager \
--trust-remote-code \
--enable-prefix-caching false \
2>&1 | tee vllm_server.log &
sleep 60
grep "TriAttention.*Runtime.*activated" vllm_server.log \
&& echo "PLUGIN: ACTIVATED" \
|| echo "PLUGIN: NOT FOUND β check logs"
5.3 Query the server
curl -s http://localhost:8000/v1/models | python3 -m json.tool
curl -s http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen3-8B",
"messages": [{"role": "user", "content": "What is 2+2?"}],
"max_tokens": 128
}' | python3 -m json.tool
echo "vLLM serving with TriAttention: PASSED"
Phase 6 β Performance Sanity Check
6.1 Triton kernel latency
import torch
import time
from triattention.vllm.core.kernels.triton_scoring import triattention_scoring
batch, heads, seq_len, head_dim = 1, 8, 4096, 128
freq_count = head_dim // 2
K_rot = torch.randn(batch, heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
q_mean_real = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_mean_imag = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32)
q_abs_mean = torch.randn(heads, freq_count, device="cuda", dtype=torch.float32).abs()
freq_scale_sq = torch.ones(heads, freq_count, device="cuda", dtype=torch.float32)
omega = torch.randn(freq_count, device="cuda", dtype=torch.float32) * 0.01
offsets = torch.tensor(
[1., 2., 4., 8., 16., 32., 64., 128., 256., 512.,
1024., 2048., 4096., 8192., 16384., 32768., 65536.],
device="cuda",
)
# Warmup (includes JIT compile)
for _ in range(3):
triattention_scoring(
K_rot=K_rot, position_indices=None, q_mean_real=q_mean_real,
q_mean_imag=q_mean_imag, q_abs_mean=q_abs_mean, freq_scale_sq=freq_scale_sq,
omega=omega, offsets=offsets, round_start=4096,
aggregation="mean", rope_style="half",
)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(100):
triattention_scoring(
K_rot=K_rot, position_indices=None, q_mean_real=q_mean_real,
q_mean_imag=q_mean_imag, q_abs_mean=q_abs_mean, freq_scale_sq=freq_scale_sq,
omega=omega, offsets=offsets, round_start=4096,
aggregation="mean", rope_style="half",
)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - t0) / 100 * 1000
print(f"Triton scoring kernel: {elapsed:.2f} ms/call (seq_len={seq_len}, heads={heads})")
print("(Compare with published ~0.1-0.5ms on A100 at similar config)")
Pass / Fail Criteria
| Test | Gate | Acceptable Failure |
|---|---|---|
1.3 β PyTorch sees RTX 5080 as 12.0 |
HARD | β |
| 1.4 β Triton smoke test | HARD | β |
| 3.1 β Triton scoring kernel compiles | HARD | β |
| 3.2 β Triton / PyTorch max diff < 1e-2 | HARD | β |
4.1 β CLI run with sdpa produces answers |
HARD | β |
| 2.2 β flash-attn builds | SOFT | Use sdpa fallback |
4.2 β CLI run with flash_attention_2 |
SOFT | Use sdpa fallback |
| 5.2 β vLLM plugin activates | SOFT | Depends on vLLM SM 120 wheel availability |
| 6.1 β Kernel latency reasonable | SOFT | Perf regression is a bug report, not a blocker |
Phases 1β3 are the minimum viable check (15 min).
Phase 4 adds model-level confidence (30 min + model download).
Phases 5β6 cover the production serving path.
Minimum Software Versions
| Requirement | Minimum |
|---|---|
| CUDA Toolkit | β₯ 12.8 |
| PyTorch | β₯ 2.6 with cu128 or cu130 wheel |
| Triton | β₯ 3.2 (ships with PyTorch 2.6+; current 3.6) |
| flash-attn (optional) | β₯ 2.7, compiled from source with CUDA β₯ 12.8 |
| vLLM (if serving) | β₯ 0.19 with Blackwell support |