| import math |
| import torch |
|
|
| from kernels.benchmark import Benchmark |
|
|
|
|
| def _cdiv(a, b): |
| return (a + b - 1) // b |
|
|
|
|
| def _extract_output(result): |
| if isinstance(result, tuple): |
| return result[0] |
| return result |
|
|
|
|
| def _reference_mla_decode(q, blocked_k, block_table, cache_seqlens, head_dim_v, causal=False): |
| b, s_q, h_q, d = q.size() |
| block_size = blocked_k.size(1) |
| h_kv = blocked_k.size(2) |
|
|
| out = torch.empty(b, s_q, h_q, head_dim_v, dtype=torch.float32, device=q.device) |
|
|
| for i in range(b): |
| cur_len = int(cache_seqlens[i].item()) |
| num_blocks = _cdiv(cur_len, block_size) |
| cur_blocks = block_table[i][:num_blocks] |
| kv = blocked_k[cur_blocks].reshape(-1, h_kv, d)[:cur_len] |
|
|
| query = q[i].transpose(0, 1).float() |
| key_val = kv.transpose(0, 1).float() |
|
|
| if h_kv != h_q: |
| key_val = key_val.repeat_interleave(h_q // h_kv, dim=0) |
|
|
| attn = query @ key_val.transpose(-2, -1) / math.sqrt(d) |
|
|
| s_k = key_val.size(1) |
| if causal and s_q > 1: |
| mask = torch.ones(s_q, s_k, dtype=torch.bool, device=q.device).tril( |
| diagonal=s_k - s_q |
| ) |
| attn.masked_fill_(~mask, float("-inf")) |
|
|
| attn = torch.softmax(attn, dim=-1) |
| output = attn @ key_val[..., :head_dim_v] |
| out[i] = output.transpose(0, 1) |
|
|
| return out.to(q.dtype) |
|
|
|
|
| def _varlen_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, causal=False): |
| batch_size = cu_seqlens_q.shape[0] - 1 |
| total_tokens_q = q.shape[0] |
| num_heads = q.shape[1] |
| head_dim_v = v.shape[2] |
| scale = q.shape[-1] ** (-0.5) |
|
|
| out = torch.zeros( |
| (total_tokens_q, num_heads, head_dim_v), device=q.device, dtype=q.dtype |
| ) |
|
|
| for b in range(batch_size): |
| start_q, end_q = cu_seqlens_q[b], cu_seqlens_q[b + 1] |
| start_k, end_k = cu_seqlens_k[b], cu_seqlens_k[b + 1] |
|
|
| q_b = q[start_q:end_q].transpose(0, 1).float() |
| k_b = k[start_k:end_k].transpose(0, 1).float() |
| v_b = v[start_k:end_k].transpose(0, 1).float() |
|
|
| attn = q_b @ k_b.transpose(-2, -1) * scale |
|
|
| if causal: |
| seq_q, seq_k = q_b.size(1), k_b.size(1) |
| mask = torch.ones(seq_q, seq_k, dtype=torch.bool, device=q.device).tril( |
| diagonal=seq_k - seq_q |
| ) |
| attn.masked_fill_(~mask, float("-inf")) |
|
|
| attn = torch.softmax(attn, dim=-1) |
| result = attn @ v_b |
| out[start_q:end_q] = result.transpose(0, 1).to(q.dtype) |
|
|
| return out |
|
|
|
|
| |
| _HEAD_DIM = 576 |
| _HEAD_DIM_V = 512 |
| _NUM_HEADS_K = 1 |
| _PAGE_BLOCK_SIZE = 64 |
|
|
|
|
| def _setup_mla_decode(bench, batch_size, seq_k, num_heads_q): |
| max_num_blocks = _cdiv(seq_k, _PAGE_BLOCK_SIZE) |
| total_blocks = batch_size * max_num_blocks |
|
|
| bench.q = ( |
| torch.randn( |
| batch_size, 1, num_heads_q, _HEAD_DIM, device="cuda", dtype=torch.bfloat16 |
| ) |
| / 10 |
| ) |
| bench.blocked_k = ( |
| torch.randn( |
| total_blocks, |
| _PAGE_BLOCK_SIZE, |
| _NUM_HEADS_K, |
| _HEAD_DIM, |
| device="cuda", |
| dtype=torch.bfloat16, |
| ) |
| / 10 |
| ) |
| bench.block_table = torch.arange( |
| total_blocks, device="cuda", dtype=torch.int32 |
| ).view(batch_size, max_num_blocks) |
| bench.cache_seqlens = torch.full( |
| (batch_size,), seq_k, device="cuda", dtype=torch.int32 |
| ) |
| bench.tile_scheduler_metadata, _ = bench.kernel.get_mla_metadata() |
| bench.out = torch.empty( |
| batch_size, 1, num_heads_q, _HEAD_DIM_V, device="cuda", dtype=torch.bfloat16 |
| ) |
|
|
|
|
| def _run_mla_decode(bench, causal=False): |
| out, lse = bench.kernel.flash_mla_with_kvcache( |
| q=bench.q, |
| k_cache=bench.blocked_k, |
| block_table=bench.block_table, |
| cache_seqlens=bench.cache_seqlens, |
| head_dim_v=_HEAD_DIM_V, |
| tile_scheduler_metadata=bench.tile_scheduler_metadata, |
| causal=causal, |
| ) |
| bench.out = out |
|
|
|
|
| def _verify_mla_decode(bench, causal=False): |
| return _reference_mla_decode( |
| bench.q, |
| bench.blocked_k, |
| bench.block_table, |
| bench.cache_seqlens, |
| _HEAD_DIM_V, |
| causal=causal, |
| ) |
|
|
|
|
| class FlashMLABenchmark(Benchmark): |
| seed: int = 42 |
|
|
| |
| def setup_small(self): |
| _setup_mla_decode(self, batch_size=2, seq_k=256, num_heads_q=64) |
|
|
| def benchmark_small(self): |
| _run_mla_decode(self, causal=False) |
|
|
| def verify_small(self) -> torch.Tensor: |
| return _verify_mla_decode(self, causal=False) |
|
|
| |
| def setup_medium(self): |
| _setup_mla_decode(self, batch_size=4, seq_k=1024, num_heads_q=64) |
|
|
| def benchmark_medium(self): |
| _run_mla_decode(self, causal=False) |
|
|
| def verify_medium(self) -> torch.Tensor: |
| return _verify_mla_decode(self, causal=False) |
|
|
| |
| def setup_large(self): |
| _setup_mla_decode(self, batch_size=8, seq_k=4096, num_heads_q=128) |
|
|
| def benchmark_large(self): |
| _run_mla_decode(self, causal=False) |
|
|
| def verify_large(self) -> torch.Tensor: |
| return _verify_mla_decode(self, causal=False) |
|
|
|
|
| class FlashMLACausalBenchmark(Benchmark): |
| seed: int = 42 |
|
|
| |
| def setup_small(self): |
| _setup_mla_decode(self, batch_size=2, seq_k=256, num_heads_q=64) |
|
|
| def benchmark_small(self): |
| _run_mla_decode(self, causal=True) |
|
|
| def verify_small(self) -> torch.Tensor: |
| return _verify_mla_decode(self, causal=True) |
|
|
| |
| def setup_medium(self): |
| _setup_mla_decode(self, batch_size=4, seq_k=1024, num_heads_q=64) |
|
|
| def benchmark_medium(self): |
| _run_mla_decode(self, causal=True) |
|
|
| def verify_medium(self) -> torch.Tensor: |
| return _verify_mla_decode(self, causal=True) |
|
|
| |
| def setup_large(self): |
| _setup_mla_decode(self, batch_size=8, seq_k=4096, num_heads_q=128) |
|
|
| def benchmark_large(self): |
| _run_mla_decode(self, causal=True) |
|
|
| def verify_large(self) -> torch.Tensor: |
| return _verify_mla_decode(self, causal=True) |
|
|
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|