haoyang-amd's picture
create_model (#1)
01e83ec verified
import torch
# import os
# os.chdir("/mnt/raid0/haoyanli/ws/ali/release/6302/bladellm")
def load_tensor(filename):
"""Load tensor and determine the device from the file"""
tensor = torch.load(filename)
return tensor.to(tensor.device)
def check_nan(output):
nan_mask = output.isnan()
nan_indices = torch.nonzero(nan_mask, as_tuple=True)
print("NaN indices:", nan_indices)
nan_values = output[nan_mask]
print("NaN values:", nan_values)
print("shape:", output.shape)
num_seqs = 50
num_heads = 32
num_splits = 4
head_size = 128
out0 = load_tensor("output.pt")
out = torch.empty_like(out0).to(out0.device)
query = load_tensor("query.pt")
key_cache = load_tensor("key_cache.pt")
value_cache = load_tensor("value_cache.pt")
block_tables = load_tensor("block_tables.pt")
tmp_out = torch.empty(
size=(num_seqs, num_heads, num_splits, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_splits),
dtype=torch.float32,
device=out.device,
) # [1, 64, 32]
max_logits = torch.empty_like(exp_sums) # [1, 64, 32]
context_lens = load_tensor("context_lens.pt")
with open("num_kv_heads.txt", "r") as f:
num_kv_heads = int(f.read().strip())
with open("attn_scale.txt", "r") as f:
attn_scale = float(f.read().strip())
with open("kv_block_size.txt", "r") as f:
kv_block_size = int(f.read().strip())
with open("max_context_len.txt", "r") as f:
max_context_len = int(f.read().strip())
alibi_slope = None
if 0:
print("out:", out)
print("query:", query)
print("key_cache:", key_cache)
print("value_cache:", value_cache)
print("block_tables:", block_tables)
print("exp_sums:", exp_sums)
print("max_logits:", max_logits)
print("tmp_out:", tmp_out)
print("context_lens:", context_lens)
print("num_kv_heads:", num_kv_heads)
print("attn_scale:", attn_scale)
print("kv_block_size:", kv_block_size)
print("max_context_len:", max_context_len)
print("alibi_slope:", alibi_slope)
print("value_cache:", value_cache)
check_nan(value_cache)
from blade_llm import llm_ops
llm_ops.paged_attention_custom( # 6302/bladellm/blade_llm/llm_ops.cpython-310-x86_64-linux-gnu.so'
out,
exp_sums,
max_logits,
tmp_out,
query,
key_cache,
value_cache,
num_kv_heads,
attn_scale,
block_tables,
context_lens,
kv_block_size,
max_context_len,
alibi_slope,
"auto",
1.0,
1.0,
)
check_nan(out)