|
|
import torch |
|
|
import numpy as np |
|
|
from typing import List, Tuple, Optional, Dict |
|
|
from pathlib import Path |
|
|
from tqdm import tqdm |
|
|
from axengine import InferenceSession |
|
|
from ml_dtypes import bfloat16 |
|
|
from transformers import AutoTokenizer, AutoConfig |
|
|
import json |
|
|
from loguru import logger |
|
|
|
|
|
|
|
|
class KVCacheTools: |
|
|
""" |
|
|
k, v cache 的本地保存和加载 |
|
|
""" |
|
|
def __init__(self, axmodel_num: int, dtype=np.float32): |
|
|
self.axmodel_num = axmodel_num |
|
|
self.dtype = dtype |
|
|
|
|
|
def save_kvcache( |
|
|
self, |
|
|
target_dir: str, |
|
|
system_prompt: str, |
|
|
precompute_len: int, |
|
|
k_caches: List[np.ndarray], |
|
|
v_caches: List[np.ndarray], |
|
|
metadata: Optional[Dict] = None |
|
|
) -> bool: |
|
|
try: |
|
|
target_path = Path(target_dir) |
|
|
target_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
for i, (k, v) in enumerate(zip(k_caches, v_caches)): |
|
|
k.astype(self.dtype).tofile(target_path / f"k_cache_{i}.bin") |
|
|
v.astype(self.dtype).tofile(target_path / f"v_cache_{i}.bin") |
|
|
|
|
|
config = { |
|
|
"precompute_len": precompute_len, |
|
|
"system_prompt": system_prompt, |
|
|
"axmodel_num": self.axmodel_num, |
|
|
"dtype": str(self.dtype), |
|
|
"metadata": metadata or {}, |
|
|
} |
|
|
with open(target_path / "config.json", "w", encoding="utf8") as f: |
|
|
json.dump(config, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Save failed: {str(e)}") |
|
|
return False |
|
|
|
|
|
def load_kvcache( |
|
|
self, |
|
|
cache_dir: str |
|
|
) -> Tuple[ |
|
|
List[np.ndarray], |
|
|
List[np.ndarray], |
|
|
str, |
|
|
int, |
|
|
Dict |
|
|
]: |
|
|
try: |
|
|
cache_path = Path(cache_dir) |
|
|
k_caches, v_caches = [], [] |
|
|
|
|
|
with open(cache_path / "config.json") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
if config["axmodel_num"] != self.axmodel_num: |
|
|
raise ValueError( |
|
|
f"Model layer mismatch: " |
|
|
f"Expected {self.axmodel_num}, got {config['axmodel_num']}" |
|
|
) |
|
|
|
|
|
for i in range(self.axmodel_num): |
|
|
k_data = np.fromfile(cache_path / f"k_cache_{i}.bin", dtype=self.dtype).reshape(1, -1, 256) |
|
|
v_data = np.fromfile(cache_path / f"v_cache_{i}.bin", dtype=self.dtype).reshape(1, -1, 256) |
|
|
k_caches.append(k_data) |
|
|
v_caches.append(v_data) |
|
|
|
|
|
return ( |
|
|
(k_caches, v_caches), |
|
|
config["system_prompt"], |
|
|
config["precompute_len"], |
|
|
config.get("metadata", {}) |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Load failed: {str(e)}") |
|
|
exit() |
|
|
|
|
|
|
|
|
class InferManager: |
|
|
def __init__(self, hf_model_path: str, axmodel_path: str): |
|
|
self.device = "cpu" |
|
|
self.hf_model_path = hf_model_path |
|
|
self.axmodel_path = axmodel_path |
|
|
|
|
|
self.hf_config = AutoConfig.from_pretrained(self.hf_model_path, trust_remote_code=True) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model_path, trust_remote_code=True, use_fast=False) |
|
|
self.system_prompt = "你的名字叫小智(allen), 你是一个人畜无害的 AI 助手. 深圳市今天(4月1日)阴天, 愚人节, 气温在 14°C 至 19°C 之间, 微风." |
|
|
self.embeds = np.load(f"{self.axmodel_path}/model.embed_tokens.weight.npy") |
|
|
|
|
|
def build_system_prompt(self): |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": self.system_prompt}, |
|
|
|
|
|
] |
|
|
text = self.tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=False |
|
|
) |
|
|
self.system_inputs = self.tokenizer([text], return_tensors="pt").to(self.device) |
|
|
self.system_input_ids = self.system_inputs.input_ids[0].cpu().numpy().tolist() |
|
|
self.system_input_embeds = np.take(self.embeds, self.system_input_ids, axis=0) |
|
|
self.system_input_ids_len = len(self.system_input_ids) |
|
|
self.model_inputs = { |
|
|
"input_ids": self.system_input_ids, |
|
|
"input_embeds": self.system_input_embeds, |
|
|
"input_ids_len": self.system_input_ids_len |
|
|
} |
|
|
self.precompute_len = self.system_input_ids_len |
|
|
|
|
|
|
|
|
def encoder_prompt(self, prompt): |
|
|
|
|
|
text = f'<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n' |
|
|
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device) |
|
|
input_ids = model_inputs.input_ids[0].cpu().numpy().tolist() |
|
|
input_embeds = np.take(self.embeds, input_ids, axis=0) |
|
|
input_ids_len = len(input_ids) |
|
|
|
|
|
|
|
|
model_inputs = { |
|
|
"message": text, |
|
|
"model_inputs": model_inputs, |
|
|
"input_ids": input_ids, |
|
|
"input_embeds": input_embeds, |
|
|
"input_ids_len": input_ids_len |
|
|
} |
|
|
return model_inputs |
|
|
|
|
|
def build_kvcache(self, kv_cache_len: int = 2559): |
|
|
|
|
|
kv_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads * self.hf_config.num_key_value_heads |
|
|
self.k_caches = [ |
|
|
np.zeros((1, kv_cache_len, kv_dim), dtype=bfloat16) |
|
|
for _ in range(self.hf_config.num_hidden_layers) |
|
|
] |
|
|
self.v_caches = [ |
|
|
np.zeros((1, kv_cache_len, kv_dim), dtype=bfloat16) |
|
|
for _ in range(self.hf_config.num_hidden_layers) |
|
|
] |
|
|
|
|
|
def get_kvcache(self): |
|
|
return [self.k_caches, self.v_caches] |
|
|
|
|
|
def update_kvcache(self, update_kv_cache): |
|
|
self.k_caches = update_kv_cache[0] |
|
|
self.v_caches = update_kv_cache[1] |
|
|
|
|
|
def get_tokenizer(self): |
|
|
return self.tokenizer |
|
|
|
|
|
def get_system_prompt(self): |
|
|
return self.system_prompt |
|
|
|
|
|
def set_system_prompt(self, prompt): |
|
|
self.system_prompt = prompt |
|
|
|
|
|
def build_infer_model(self, ): |
|
|
self.prefill_decoder_sessins = [] |
|
|
|
|
|
for i in tqdm(range(self.hf_config.num_hidden_layers), desc="Init InferenceSession"): |
|
|
session = InferenceSession( |
|
|
f"{self.axmodel_path}/qwen2_p128_l{i}_together.axmodel" |
|
|
) |
|
|
self.prefill_decoder_sessins.append(session) |
|
|
|
|
|
self.post_process_session = InferenceSession( |
|
|
f"{self.axmodel_path}/qwen2_post.axmodel" |
|
|
) |
|
|
print("The models have been loaded!") |
|
|
|
|
|
def get_infer_session(self): |
|
|
return [self.prefill_decoder_sessins, self.post_process_session] |
|
|
|
|
|
@staticmethod |
|
|
def _top_p(probs: np.ndarray, p: float) -> np.ndarray: |
|
|
sorted_indices = np.argsort(probs) |
|
|
filtered = probs.copy() |
|
|
cumulative = 0 |
|
|
for idx in sorted_indices[::-1]: |
|
|
if cumulative >= p: |
|
|
filtered[idx] = 0 |
|
|
cumulative += filtered[idx] |
|
|
return filtered / cumulative |
|
|
|
|
|
@staticmethod |
|
|
def _softmax(logits: np.ndarray) -> np.ndarray: |
|
|
logits = logits - logits.max() |
|
|
exp_logits = np.exp(logits) |
|
|
return (exp_logits / np.sum(exp_logits)).astype(np.float64) |
|
|
|
|
|
def post_process(self, logits, top_k=1, top_p=0.9, temperature=0.6): |
|
|
logits = logits.astype(np.float32).flatten() |
|
|
candidate_indices = np.argpartition(logits, -top_k)[-top_k:] |
|
|
candidate_logits = logits[candidate_indices] / temperature |
|
|
candidate_probs = self._softmax(candidate_logits) |
|
|
candidate_probs = self._top_p(candidate_probs, top_p) |
|
|
candidate_probs = candidate_probs.astype(np.float64) / candidate_probs.sum() |
|
|
chosen_idx = np.random.multinomial(1, candidate_probs).argmax() |
|
|
next_token = candidate_indices[chosen_idx] |
|
|
return next_token, candidate_indices, candidate_probs |
|
|
|
|
|
def gen_slice_indices(self, token_len, prefill=128, expand=128): |
|
|
remaining = max(0, token_len - prefill) |
|
|
extra_blocks = (remaining + expand - 1) // expand |
|
|
return list(range(extra_blocks + 1)) |
|
|
|
|
|
def prefill( |
|
|
self, |
|
|
model_inputs, |
|
|
slice_len=128, |
|
|
precompute_len=0, |
|
|
): |
|
|
""" |
|
|
Prefill step for chunked inference. |
|
|
""" |
|
|
token_ids = model_inputs["input_ids"] |
|
|
token_embeds = model_inputs["input_embeds"] |
|
|
token_len = model_inputs["input_ids_len"] |
|
|
|
|
|
seq_len = len(token_ids) |
|
|
slice_indices = [i for i in range(seq_len // slice_len + 1)] |
|
|
print(f"slice_indices: {slice_indices}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_prefill_len = slice_len * (slice_indices[-1] + 1) |
|
|
kv_mask_expand_len = 128 |
|
|
|
|
|
if total_prefill_len > 0: |
|
|
for slice_index in slice_indices: |
|
|
if slice_index == 0: |
|
|
current_slice_len = slice_len |
|
|
else: |
|
|
current_slice_len = kv_mask_expand_len |
|
|
|
|
|
indices = np.array( |
|
|
list( |
|
|
range( |
|
|
precompute_len + slice_index * slice_len, |
|
|
precompute_len + (slice_index + 1) * slice_len, |
|
|
) |
|
|
), |
|
|
np.uint32, |
|
|
).reshape((1, slice_len)) |
|
|
indices[:, min(token_len, slice_len):] = 0 |
|
|
|
|
|
mask = ( |
|
|
np.zeros((1, slice_len, current_slice_len * slice_index + slice_len)) |
|
|
- 65536 |
|
|
) |
|
|
data = np.zeros((1, slice_len, self.hf_config.hidden_size)).astype(bfloat16) |
|
|
|
|
|
for i, t in enumerate( |
|
|
range( |
|
|
slice_index * slice_len, |
|
|
(slice_index + 1) * slice_len, |
|
|
) |
|
|
): |
|
|
if t < len(token_ids): |
|
|
|
|
|
data[:, i : i + 1, :] = ( |
|
|
token_embeds[t] |
|
|
.reshape((1, 1, self.hf_config.hidden_size)) |
|
|
.astype(bfloat16) |
|
|
) |
|
|
if t < len(token_ids) + precompute_len: |
|
|
mask[:, i, 0: slice_index * slice_len + i + 1] = 0 |
|
|
|
|
|
if slice_index == slice_indices[-1]: |
|
|
curlen_procd = token_len - slice_index * slice_len |
|
|
else: |
|
|
curlen_procd = slice_len |
|
|
|
|
|
mask = mask.astype(bfloat16) |
|
|
for i in range(self.hf_config.num_hidden_layers): |
|
|
input_feed = { |
|
|
"K_cache": ( |
|
|
self.k_caches[i][:, 0: current_slice_len * slice_index, :] |
|
|
if slice_index |
|
|
else np.zeros((1, 1, self.hf_config.hidden_size), dtype=bfloat16) |
|
|
), |
|
|
"V_cache": ( |
|
|
self.v_caches[i][:, 0: current_slice_len * slice_index, :] |
|
|
if slice_index |
|
|
else np.zeros((1, 1, self.hf_config.hidden_size), dtype=bfloat16) |
|
|
), |
|
|
"indices": indices, |
|
|
"input": data, |
|
|
"mask": mask, |
|
|
} |
|
|
outputs = self.prefill_decoder_sessins[i].run(None, input_feed, shape_group=slice_index + 1) |
|
|
self.k_caches[i][ |
|
|
:, |
|
|
slice_index |
|
|
* slice_len + precompute_len : slice_index |
|
|
* slice_len + curlen_procd + precompute_len, |
|
|
:, |
|
|
] = outputs[0][:, :curlen_procd, :] |
|
|
|
|
|
self.v_caches[i][ |
|
|
:, |
|
|
slice_index |
|
|
* slice_len + precompute_len: slice_index |
|
|
* slice_len + curlen_procd + precompute_len, |
|
|
:, |
|
|
] = outputs[1][:, :curlen_procd, :] |
|
|
|
|
|
data = outputs[2] |
|
|
|
|
|
print("slice prefill done", slice_index) |
|
|
else: |
|
|
print("No prefill needed.") |
|
|
|
|
|
return (self.k_caches, self.v_caches) |
|
|
|
|
|
def decode( |
|
|
self, |
|
|
token_ids, |
|
|
prefill_len=128, |
|
|
slice_len=128 |
|
|
): |
|
|
token_len = len(token_ids) |
|
|
|
|
|
print("answer: >> ", end='', flush=True) |
|
|
kv_cache_len = 2559 |
|
|
mask = np.zeros((1, 1, kv_cache_len + 1), dtype=np.float32).astype(bfloat16) |
|
|
mask[:, :, :kv_cache_len] -= 65536 |
|
|
if prefill_len > 0: |
|
|
mask[:, :, :token_len + self.precompute_len] = 0 |
|
|
|
|
|
for start_indice in range(kv_cache_len): |
|
|
if self.precompute_len > 0 and start_indice < self.precompute_len: |
|
|
continue |
|
|
next_token = token_ids[start_indice - self.precompute_len] |
|
|
indices = np.array([start_indice], np.uint32).reshape((1, 1)) |
|
|
data = self.embeds[next_token, :].reshape((1, 1, self.hf_config.hidden_size)).astype(bfloat16) |
|
|
for i in range(self.hf_config.num_hidden_layers): |
|
|
input_feed = { |
|
|
"K_cache": self.k_caches[i], |
|
|
"V_cache": self.v_caches[i], |
|
|
"indices": indices, |
|
|
"input": data, |
|
|
"mask": mask, |
|
|
} |
|
|
outputs = self.prefill_decoder_sessins[i].run(None, input_feed, shape_group=0) |
|
|
self.k_caches[i][:, start_indice, :] = outputs[0][:, :, :] |
|
|
self.v_caches[i][:, start_indice, :] = outputs[1][:, :, :] |
|
|
data = outputs[2] |
|
|
mask[..., start_indice] = 0 |
|
|
if start_indice < token_len + self.precompute_len - 1: |
|
|
pass |
|
|
else: |
|
|
post_out = self.post_process_session.run(None, {"input": data})[0] |
|
|
next_token, posssible_tokens, possible_soft = self.post_process(post_out) |
|
|
token_ids.append(next_token) |
|
|
print(self.tokenizer.decode(next_token, skip_special_tokens=True), end='', flush=True) |
|
|
|
|
|
if next_token == self.tokenizer.eos_token_id and start_indice > token_len + self.precompute_len: |
|
|
|
|
|
break |
|
|
print("\n") |
|
|
self.precompute_len = len(token_ids) + self.precompute_len - 1 |
|
|
return self.tokenizer.decode(token_ids[self.precompute_len - 1:], skip_special_tokens=True) |
|
|
|
|
|
|