| import torch |
|
|
| from kernels.benchmark import Benchmark |
|
|
|
|
| def ref_masked_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| scale: float, |
| ) -> torch.Tensor: |
| |
| |
| q = query.transpose(0, 1) |
| k = key.transpose(0, 1) |
| v = value.transpose(0, 1) |
|
|
| |
| attn_weights = (scale * torch.matmul(q, k.transpose(-1, -2))).float() |
| attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) |
|
|
| |
| out = torch.matmul(attn_weights, v) |
|
|
| |
| return out.transpose(0, 1) |
|
|
|
|
| def ref_paged_attention( |
| query: torch.Tensor, |
| key_cache: torch.Tensor, |
| value_cache: torch.Tensor, |
| block_tables: torch.Tensor, |
| seq_lens: torch.Tensor, |
| scale: float, |
| ) -> torch.Tensor: |
| num_seqs = query.shape[0] |
| num_heads = query.shape[1] |
| head_size = query.shape[2] |
| block_size = value_cache.shape[3] |
| max_seq_len = int(seq_lens.max().item()) |
|
|
| |
| positions = torch.arange(max_seq_len, device=query.device) |
| block_indices = positions // block_size |
| block_offsets = positions % block_size |
|
|
| |
| block_numbers = block_tables[:, block_indices.long()] |
|
|
| |
| flat_block_numbers = block_numbers.reshape(-1) |
| flat_offsets = block_offsets.repeat(num_seqs) |
|
|
| |
| |
| keys = key_cache[flat_block_numbers, :, :, flat_offsets, :] |
| keys = keys.reshape(num_seqs, max_seq_len, num_heads, head_size) |
| keys = keys.transpose(1, 2) |
|
|
| |
| values = value_cache[flat_block_numbers, :, :, flat_offsets] |
| values = values.reshape(num_seqs, max_seq_len, num_heads, head_size) |
| values = values.transpose(1, 2) |
|
|
| |
| q = query.unsqueeze(2) |
|
|
| |
| attn_weights = (scale * torch.matmul(q, keys.transpose(-1, -2))).float() |
|
|
| |
| |
| seq_mask = positions.unsqueeze(0) >= seq_lens.unsqueeze( |
| 1 |
| ) |
| seq_mask = seq_mask.unsqueeze(1).unsqueeze(2) |
| attn_weights = attn_weights.masked_fill(seq_mask, float("-inf")) |
|
|
| attn_weights = torch.softmax(attn_weights, dim=-1).to(values.dtype) |
|
|
| |
| out = torch.matmul(attn_weights, values) |
|
|
| return out.squeeze(2) |
|
|
|
|
| class PagedAttentionBenchmark(Benchmark): |
| seed: int = 42 |
|
|
| def setup(self): |
| num_seqs = 4 |
| num_heads = 8 |
| head_size = 64 |
| block_size = 16 |
| max_seq_len = 128 |
| num_blocks = 64 |
| dtype = torch.float16 |
|
|
| self.num_heads = num_heads |
| self.block_size = block_size |
| self.max_seq_len = max_seq_len |
| self.scale = 1.0 / (head_size**0.5) |
|
|
| |
| self.query = torch.randn( |
| num_seqs, num_heads, head_size, device=self.device, dtype=dtype |
| ) |
|
|
| |
| |
| x = 16 // torch.tensor([], dtype=dtype).element_size() |
| self.key_cache = torch.randn( |
| num_blocks, |
| num_heads, |
| head_size // x, |
| block_size, |
| x, |
| device=self.device, |
| dtype=dtype, |
| ) |
| self.value_cache = torch.randn( |
| num_blocks, |
| num_heads, |
| head_size, |
| block_size, |
| device=self.device, |
| dtype=dtype, |
| ) |
|
|
| |
| max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size |
| self.block_tables = torch.randint( |
| 0, |
| num_blocks, |
| (num_seqs, max_num_blocks_per_seq), |
| device=self.device, |
| dtype=torch.int32, |
| ) |
|
|
| |
| self.seq_lens = torch.tensor( |
| [64, 96, 48, 128], device=self.device, dtype=torch.int32 |
| ) |
|
|
| |
| self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device) |
| self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device) |
|
|
| |
| self.out = torch.empty_like(self.query) |
|
|
| def benchmark_base(self): |
| self.kernel.paged_attention_v1( |
| self.out, |
| self.query, |
| self.key_cache, |
| self.value_cache, |
| num_kv_heads=self.num_heads, |
| scale=self.scale, |
| block_tables=self.block_tables, |
| seq_lens=self.seq_lens, |
| block_size=self.block_size, |
| max_seq_len=self.max_seq_len, |
| alibi_slopes=None, |
| kv_cache_dtype="auto", |
| k_scale=self.k_scale, |
| v_scale=self.v_scale, |
| ) |
|
|
| def verify_base(self) -> torch.Tensor: |
| return ref_paged_attention( |
| self.query, |
| self.key_cache, |
| self.value_cache, |
| self.block_tables, |
| self.seq_lens, |
| self.scale, |
| ) |
|
|
| def setup_large(self): |
| num_seqs = 16 |
| num_heads = 32 |
| head_size = 128 |
| block_size = 16 |
| max_seq_len = 512 |
| num_blocks = 256 |
| dtype = torch.float16 |
|
|
| self.num_heads = num_heads |
| self.block_size = block_size |
| self.max_seq_len = max_seq_len |
| self.scale = 1.0 / (head_size**0.5) |
|
|
| self.query = torch.randn( |
| num_seqs, num_heads, head_size, device=self.device, dtype=dtype |
| ) |
|
|
| x = 16 // torch.tensor([], dtype=dtype).element_size() |
| self.key_cache = torch.randn( |
| num_blocks, |
| num_heads, |
| head_size // x, |
| block_size, |
| x, |
| device=self.device, |
| dtype=dtype, |
| ) |
| self.value_cache = torch.randn( |
| num_blocks, |
| num_heads, |
| head_size, |
| block_size, |
| device=self.device, |
| dtype=dtype, |
| ) |
|
|
| max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size |
| self.block_tables = torch.randint( |
| 0, |
| num_blocks, |
| (num_seqs, max_num_blocks_per_seq), |
| device=self.device, |
| dtype=torch.int32, |
| ) |
|
|
| |
| self.seq_lens = torch.randint( |
| 64, max_seq_len + 1, (num_seqs,), device=self.device, dtype=torch.int32 |
| ) |
|
|
| self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device) |
| self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device) |
|
|
| self.out = torch.empty_like(self.query) |
|
|
| def benchmark_large(self): |
| self.kernel.paged_attention_v1( |
| self.out, |
| self.query, |
| self.key_cache, |
| self.value_cache, |
| num_kv_heads=self.num_heads, |
| scale=self.scale, |
| block_tables=self.block_tables, |
| seq_lens=self.seq_lens, |
| block_size=self.block_size, |
| max_seq_len=self.max_seq_len, |
| alibi_slopes=None, |
| kv_cache_dtype="auto", |
| k_scale=self.k_scale, |
| v_scale=self.v_scale, |
| ) |
|
|
| def verify_large(self) -> torch.Tensor: |
| return ref_paged_attention( |
| self.query, |
| self.key_cache, |
| self.value_cache, |
| self.block_tables, |
| self.seq_lens, |
| self.scale, |
| ) |
|
|