|
import torch |
|
|
|
|
|
def load_tensor(filename): |
|
"""Load tensor and determine the device from the file""" |
|
tensor = torch.load(filename) |
|
return tensor.to(tensor.device) |
|
|
|
def check_nan(output): |
|
nan_mask = output.isnan() |
|
nan_indices = torch.nonzero(nan_mask, as_tuple=True) |
|
print("NaN indices:", nan_indices) |
|
nan_values = output[nan_mask] |
|
print("NaN values:", nan_values) |
|
print("shape:", output.shape) |
|
|
|
num_seqs = 50 |
|
num_heads = 32 |
|
num_splits = 4 |
|
head_size = 128 |
|
out0 = load_tensor("output.pt") |
|
out = torch.empty_like(out0).to(out0.device) |
|
|
|
query = load_tensor("query.pt") |
|
key_cache = load_tensor("key_cache.pt") |
|
value_cache = load_tensor("value_cache.pt") |
|
block_tables = load_tensor("block_tables.pt") |
|
|
|
tmp_out = torch.empty( |
|
size=(num_seqs, num_heads, num_splits, head_size), |
|
dtype=out.dtype, |
|
device=out.device, |
|
) |
|
|
|
exp_sums = torch.empty( |
|
size=(num_seqs, num_heads, num_splits), |
|
dtype=torch.float32, |
|
device=out.device, |
|
) |
|
max_logits = torch.empty_like(exp_sums) |
|
|
|
context_lens = load_tensor("context_lens.pt") |
|
|
|
with open("num_kv_heads.txt", "r") as f: |
|
num_kv_heads = int(f.read().strip()) |
|
|
|
with open("attn_scale.txt", "r") as f: |
|
attn_scale = float(f.read().strip()) |
|
|
|
with open("kv_block_size.txt", "r") as f: |
|
kv_block_size = int(f.read().strip()) |
|
|
|
with open("max_context_len.txt", "r") as f: |
|
max_context_len = int(f.read().strip()) |
|
|
|
alibi_slope = None |
|
|
|
if 0: |
|
print("out:", out) |
|
print("query:", query) |
|
print("key_cache:", key_cache) |
|
print("value_cache:", value_cache) |
|
print("block_tables:", block_tables) |
|
print("exp_sums:", exp_sums) |
|
print("max_logits:", max_logits) |
|
print("tmp_out:", tmp_out) |
|
print("context_lens:", context_lens) |
|
print("num_kv_heads:", num_kv_heads) |
|
print("attn_scale:", attn_scale) |
|
print("kv_block_size:", kv_block_size) |
|
print("max_context_len:", max_context_len) |
|
print("alibi_slope:", alibi_slope) |
|
print("value_cache:", value_cache) |
|
check_nan(value_cache) |
|
|
|
from blade_llm import llm_ops |
|
llm_ops.paged_attention_custom( |
|
out, |
|
exp_sums, |
|
max_logits, |
|
tmp_out, |
|
query, |
|
key_cache, |
|
value_cache, |
|
num_kv_heads, |
|
attn_scale, |
|
block_tables, |
|
context_lens, |
|
kv_block_size, |
|
max_context_len, |
|
alibi_slope, |
|
"auto", |
|
1.0, |
|
1.0, |
|
) |
|
check_nan(out) |
|
|