wli1995 commited on
Commit
29211a0
·
verified ·
1 Parent(s): c8a8be4

Upload folder using huggingface_hub

Browse files
utils/__pycache__/infer_func.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
utils/infer_func.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import List, Tuple, Optional, Dict
4
+ from pathlib import Path
5
+ from tqdm import tqdm
6
+ from axengine import InferenceSession
7
+ from ml_dtypes import bfloat16
8
+ from transformers import AutoTokenizer, AutoConfig
9
+ import json
10
+ from loguru import logger
11
+
12
+
13
+ class KVCacheTools:
14
+ """
15
+ k, v cache 的本地保存和加载
16
+ """
17
+ def __init__(self, axmodel_num: int, dtype=np.float32):
18
+ self.axmodel_num = axmodel_num
19
+ self.dtype = dtype
20
+
21
+ def save_kvcache(
22
+ self,
23
+ target_dir: str,
24
+ system_prompt: str,
25
+ precompute_len: int,
26
+ k_caches: List[np.ndarray],
27
+ v_caches: List[np.ndarray],
28
+ metadata: Optional[Dict] = None
29
+ ) -> bool:
30
+ try:
31
+ target_path = Path(target_dir)
32
+ target_path.mkdir(parents=True, exist_ok=True)
33
+
34
+ for i, (k, v) in enumerate(zip(k_caches, v_caches)):
35
+ k.astype(self.dtype).tofile(target_path / f"k_cache_{i}.bin")
36
+ v.astype(self.dtype).tofile(target_path / f"v_cache_{i}.bin")
37
+
38
+ config = {
39
+ "precompute_len": precompute_len,
40
+ "system_prompt": system_prompt,
41
+ "axmodel_num": self.axmodel_num,
42
+ "dtype": str(self.dtype),
43
+ "metadata": metadata or {},
44
+ }
45
+ with open(target_path / "config.json", "w", encoding="utf8") as f:
46
+ json.dump(config, f, indent=2, ensure_ascii=False)
47
+
48
+ return True
49
+ except Exception as e:
50
+ print(f"Save failed: {str(e)}")
51
+ return False
52
+
53
+ def load_kvcache(
54
+ self,
55
+ cache_dir: str
56
+ ) -> Tuple[
57
+ List[np.ndarray],
58
+ List[np.ndarray],
59
+ str,
60
+ int,
61
+ Dict
62
+ ]:
63
+ try:
64
+ cache_path = Path(cache_dir)
65
+ k_caches, v_caches = [], []
66
+
67
+ with open(cache_path / "config.json") as f:
68
+ config = json.load(f)
69
+
70
+ if config["axmodel_num"] != self.axmodel_num:
71
+ raise ValueError(
72
+ f"Model layer mismatch: "
73
+ f"Expected {self.axmodel_num}, got {config['axmodel_num']}"
74
+ )
75
+
76
+ for i in range(self.axmodel_num):
77
+ k_data = np.fromfile(cache_path / f"k_cache_{i}.bin", dtype=self.dtype).reshape(1, -1, 256)
78
+ v_data = np.fromfile(cache_path / f"v_cache_{i}.bin", dtype=self.dtype).reshape(1, -1, 256)
79
+ k_caches.append(k_data)
80
+ v_caches.append(v_data)
81
+
82
+ return (
83
+ (k_caches, v_caches),
84
+ config["system_prompt"],
85
+ config["precompute_len"],
86
+ config.get("metadata", {})
87
+ )
88
+ except Exception as e:
89
+ print(f"Load failed: {str(e)}")
90
+ exit()
91
+
92
+
93
+ class InferManager:
94
+ def __init__(self, hf_model_path: str, axmodel_path: str):
95
+ self.device = "cpu"
96
+ self.hf_model_path = hf_model_path
97
+ self.axmodel_path = axmodel_path
98
+
99
+ self.hf_config = AutoConfig.from_pretrained(self.hf_model_path, trust_remote_code=True)
100
+ self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model_path, trust_remote_code=True, use_fast=False)
101
+ self.system_prompt = "你的名字叫小智(allen), 你是一个人畜无害的 AI 助手. 深圳市今天(4月1日)阴天, 愚人节, 气温在 14°C 至 19°C 之间, 微风."
102
+ self.embeds = np.load(f"{self.axmodel_path}/model.embed_tokens.weight.npy")
103
+
104
+ def build_system_prompt(self):
105
+
106
+ messages = [
107
+ {"role": "system", "content": self.system_prompt},
108
+ # {"role": "user", "content": prompt}
109
+ ]
110
+ text = self.tokenizer.apply_chat_template(
111
+ messages,
112
+ tokenize=False,
113
+ add_generation_prompt=False
114
+ )
115
+ self.system_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
116
+ self.system_input_ids = self.system_inputs.input_ids[0].cpu().numpy().tolist()
117
+ self.system_input_embeds = np.take(self.embeds, self.system_input_ids, axis=0)
118
+ self.system_input_ids_len = len(self.system_input_ids)
119
+ self.model_inputs = {
120
+ "input_ids": self.system_input_ids,
121
+ "input_embeds": self.system_input_embeds,
122
+ "input_ids_len": self.system_input_ids_len
123
+ }
124
+ self.precompute_len = self.system_input_ids_len
125
+ # logger.info(f"system prompt prompt ids len: {self.system_input_ids_len}")
126
+
127
+ def encoder_prompt(self, prompt):
128
+
129
+ text = f'<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n'
130
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
131
+ input_ids = model_inputs.input_ids[0].cpu().numpy().tolist()
132
+ input_embeds = np.take(self.embeds, input_ids, axis=0)
133
+ input_ids_len = len(input_ids)
134
+ # logger.info(f"user prompt token_len: {input_ids_len}")
135
+
136
+ model_inputs = {
137
+ "message": text,
138
+ "model_inputs": model_inputs,
139
+ "input_ids": input_ids,
140
+ "input_embeds": input_embeds,
141
+ "input_ids_len": input_ids_len
142
+ }
143
+ return model_inputs
144
+
145
+ def build_kvcache(self, kv_cache_len: int = 2559):
146
+
147
+ kv_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads * self.hf_config.num_key_value_heads
148
+ self.k_caches = [
149
+ np.zeros((1, kv_cache_len, kv_dim), dtype=bfloat16)
150
+ for _ in range(self.hf_config.num_hidden_layers)
151
+ ]
152
+ self.v_caches = [
153
+ np.zeros((1, kv_cache_len, kv_dim), dtype=bfloat16)
154
+ for _ in range(self.hf_config.num_hidden_layers)
155
+ ]
156
+
157
+ def get_kvcache(self):
158
+ return [self.k_caches, self.v_caches]
159
+
160
+ def update_kvcache(self, update_kv_cache):
161
+ self.k_caches = update_kv_cache[0]
162
+ self.v_caches = update_kv_cache[1]
163
+
164
+ def get_tokenizer(self):
165
+ return self.tokenizer
166
+
167
+ def get_system_prompt(self):
168
+ return self.system_prompt
169
+
170
+ def set_system_prompt(self, prompt):
171
+ self.system_prompt = prompt
172
+
173
+ def build_infer_model(self, ):
174
+ self.prefill_decoder_sessins = []
175
+
176
+ for i in tqdm(range(self.hf_config.num_hidden_layers), desc="Init InferenceSession"):
177
+ session = InferenceSession(
178
+ f"{self.axmodel_path}/qwen2_p128_l{i}_together.axmodel"
179
+ )
180
+ self.prefill_decoder_sessins.append(session)
181
+
182
+ self.post_process_session = InferenceSession(
183
+ f"{self.axmodel_path}/qwen2_post.axmodel"
184
+ )
185
+ print("The models have been loaded!")
186
+
187
+ def get_infer_session(self):
188
+ return [self.prefill_decoder_sessins, self.post_process_session]
189
+
190
+ @staticmethod
191
+ def _top_p(probs: np.ndarray, p: float) -> np.ndarray:
192
+ sorted_indices = np.argsort(probs)
193
+ filtered = probs.copy()
194
+ cumulative = 0
195
+ for idx in sorted_indices[::-1]:
196
+ if cumulative >= p:
197
+ filtered[idx] = 0
198
+ cumulative += filtered[idx]
199
+ return filtered / cumulative
200
+
201
+ @staticmethod
202
+ def _softmax(logits: np.ndarray) -> np.ndarray:
203
+ logits = logits - logits.max()
204
+ exp_logits = np.exp(logits)
205
+ return (exp_logits / np.sum(exp_logits)).astype(np.float64)
206
+
207
+ def post_process(self, logits, top_k=1, top_p=0.9, temperature=0.6):
208
+ logits = logits.astype(np.float32).flatten()
209
+ candidate_indices = np.argpartition(logits, -top_k)[-top_k:]
210
+ candidate_logits = logits[candidate_indices] / temperature
211
+ candidate_probs = self._softmax(candidate_logits)
212
+ candidate_probs = self._top_p(candidate_probs, top_p)
213
+ candidate_probs = candidate_probs.astype(np.float64) / candidate_probs.sum()
214
+ chosen_idx = np.random.multinomial(1, candidate_probs).argmax()
215
+ next_token = candidate_indices[chosen_idx]
216
+ return next_token, candidate_indices, candidate_probs
217
+
218
+ def gen_slice_indices(self, token_len, prefill=128, expand=128):
219
+ remaining = max(0, token_len - prefill)
220
+ extra_blocks = (remaining + expand - 1) // expand
221
+ return list(range(extra_blocks + 1))
222
+
223
+ def prefill(
224
+ self,
225
+ model_inputs,
226
+ slice_len=128,
227
+ precompute_len=0, # system prompt prefill 的时候, 只能设置为 0
228
+ ):
229
+ """
230
+ Prefill step for chunked inference.
231
+ """
232
+ token_ids = model_inputs["input_ids"]
233
+ token_embeds = model_inputs["input_embeds"]
234
+ token_len = model_inputs["input_ids_len"]
235
+
236
+ seq_len = len(token_ids)
237
+ slice_indices = [i for i in range(seq_len // slice_len + 1)]
238
+ print(f"slice_indices: {slice_indices}")
239
+ # total_prefill_len = (
240
+ # slice_len * slice_indices[-1]
241
+ # if slice_indices[-1] != 0
242
+ # else slice_len
243
+ # )
244
+ # slice_indices = self.gen_slice_indices(seq_len)
245
+ total_prefill_len = slice_len * (slice_indices[-1] + 1)
246
+ kv_mask_expand_len = 128
247
+
248
+ if total_prefill_len > 0:
249
+ for slice_index in slice_indices:
250
+ if slice_index == 0:
251
+ current_slice_len = slice_len
252
+ else:
253
+ current_slice_len = kv_mask_expand_len
254
+
255
+ indices = np.array(
256
+ list(
257
+ range(
258
+ precompute_len + slice_index * slice_len,
259
+ precompute_len + (slice_index + 1) * slice_len,
260
+ )
261
+ ),
262
+ np.uint32,
263
+ ).reshape((1, slice_len))
264
+ indices[:, min(token_len, slice_len):] = 0
265
+
266
+ mask = (
267
+ np.zeros((1, slice_len, current_slice_len * slice_index + slice_len))
268
+ - 65536
269
+ )
270
+ data = np.zeros((1, slice_len, self.hf_config.hidden_size)).astype(bfloat16)
271
+
272
+ for i, t in enumerate(
273
+ range(
274
+ slice_index * slice_len,
275
+ (slice_index + 1) * slice_len,
276
+ )
277
+ ):
278
+ if t < len(token_ids):
279
+ # mask[:, i, 0: slice_index * slice_len + i + 1] = 0
280
+ data[:, i : i + 1, :] = (
281
+ token_embeds[t]
282
+ .reshape((1, 1, self.hf_config.hidden_size))
283
+ .astype(bfloat16)
284
+ )
285
+ if t < len(token_ids) + precompute_len:
286
+ mask[:, i, 0: slice_index * slice_len + i + 1] = 0
287
+
288
+ if slice_index == slice_indices[-1]:
289
+ curlen_procd = token_len - slice_index * slice_len # curlen_procd 是当前处理数据的长度
290
+ else:
291
+ curlen_procd = slice_len
292
+
293
+ mask = mask.astype(bfloat16)
294
+ for i in range(self.hf_config.num_hidden_layers):
295
+ input_feed = {
296
+ "K_cache": (
297
+ self.k_caches[i][:, 0: current_slice_len * slice_index, :]
298
+ if slice_index
299
+ else np.zeros((1, 1, self.hf_config.hidden_size), dtype=bfloat16)
300
+ ),
301
+ "V_cache": (
302
+ self.v_caches[i][:, 0: current_slice_len * slice_index, :]
303
+ if slice_index
304
+ else np.zeros((1, 1, self.hf_config.hidden_size), dtype=bfloat16)
305
+ ),
306
+ "indices": indices,
307
+ "input": data,
308
+ "mask": mask,
309
+ }
310
+ outputs = self.prefill_decoder_sessins[i].run(None, input_feed, shape_group=slice_index + 1)
311
+ self.k_caches[i][
312
+ :,
313
+ slice_index
314
+ * slice_len + precompute_len : slice_index
315
+ * slice_len + curlen_procd + precompute_len,
316
+ :,
317
+ ] = outputs[0][:, :curlen_procd, :]
318
+
319
+ self.v_caches[i][
320
+ :,
321
+ slice_index
322
+ * slice_len + precompute_len: slice_index
323
+ * slice_len + curlen_procd + precompute_len,
324
+ :,
325
+ ] = outputs[1][:, :curlen_procd, :]
326
+
327
+ data = outputs[2]
328
+
329
+ print("slice prefill done", slice_index)
330
+ else:
331
+ print("No prefill needed.")
332
+ # return "Calculated the kv cache of the system prompt."
333
+ return (self.k_caches, self.v_caches)
334
+
335
+ def decode(
336
+ self,
337
+ token_ids,
338
+ prefill_len=128,
339
+ slice_len=128
340
+ ):
341
+ token_len = len(token_ids)
342
+ # set to decoder
343
+ print("answer: >> ", end='', flush=True)
344
+ kv_cache_len = 2559
345
+ mask = np.zeros((1, 1, kv_cache_len + 1), dtype=np.float32).astype(bfloat16)
346
+ mask[:, :, :kv_cache_len] -= 65536
347
+ if prefill_len > 0:
348
+ mask[:, :, :token_len + self.precompute_len] = 0
349
+
350
+ for start_indice in range(kv_cache_len):
351
+ if self.precompute_len > 0 and start_indice < self.precompute_len:
352
+ continue
353
+ next_token = token_ids[start_indice - self.precompute_len]
354
+ indices = np.array([start_indice], np.uint32).reshape((1, 1))
355
+ data = self.embeds[next_token, :].reshape((1, 1, self.hf_config.hidden_size)).astype(bfloat16)
356
+ for i in range(self.hf_config.num_hidden_layers):
357
+ input_feed = {
358
+ "K_cache": self.k_caches[i],
359
+ "V_cache": self.v_caches[i],
360
+ "indices": indices,
361
+ "input": data,
362
+ "mask": mask,
363
+ }
364
+ outputs = self.prefill_decoder_sessins[i].run(None, input_feed, shape_group=0)
365
+ self.k_caches[i][:, start_indice, :] = outputs[0][:, :, :]
366
+ self.v_caches[i][:, start_indice, :] = outputs[1][:, :, :]
367
+ data = outputs[2]
368
+ mask[..., start_indice] = 0
369
+ if start_indice < token_len + self.precompute_len - 1:
370
+ pass
371
+ else:
372
+ post_out = self.post_process_session.run(None, {"input": data})[0]
373
+ next_token, posssible_tokens, possible_soft = self.post_process(post_out)
374
+ token_ids.append(next_token)
375
+ print(self.tokenizer.decode(next_token, skip_special_tokens=True), end='', flush=True)
376
+
377
+ if next_token == self.tokenizer.eos_token_id and start_indice > token_len + self.precompute_len:
378
+ # print("\n>> HINT: The next_token encountered EOS token, generation completed.")
379
+ break
380
+ print("\n")
381
+ self.precompute_len = len(token_ids) + self.precompute_len - 1
382
+ return self.tokenizer.decode(token_ids[self.precompute_len - 1:], skip_special_tokens=True)
383
+