Limour commited on
Commit
afa463a
1 Parent(s): 7d38177

Upload llama_cpp_python_streamingllm.py

Browse files
Files changed (1) hide show
  1. llama_cpp_python_streamingllm.py +543 -8
llama_cpp_python_streamingllm.py CHANGED
@@ -1,14 +1,516 @@
1
- from typing import Optional, Sequence, Generator
2
-
3
- from llama_cpp import Llama, LogitsProcessorList, LlamaGrammar, llama_cpp, npt, np, StoppingCriteriaList
4
- from ctypes import POINTER
 
 
 
 
5
 
6
  from KMP_list import kmp_search, compute_lps_array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class StreamingLLM(Llama):
10
- def __init__(self, model_path: str, **kwargs):
11
- super().__init__(model_path, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  self._venv_init()
13
 
14
  def str_detokenize(self, tokens) -> str:
@@ -63,8 +565,9 @@ class StreamingLLM(Llama):
63
  if name not in self.venv_idx_map:
64
  return False
65
  venv_idx = self.venv_idx_map.index(name) + 1
 
66
  while self.venv_idx_map:
67
- if keep_last and self.venv_idx_map.count(name) <= keep_last:
68
  break # 保留最后n个
69
  self.venv_idx_map.pop(venv_idx - 1) # 删除
70
  if venv_idx == len(self.venv) - 1:
@@ -81,6 +584,7 @@ class StreamingLLM(Llama):
81
  venv_idx = self.venv_idx_map.index(name, venv_idx - 1) + 1
82
  except ValueError: # 没有了
83
  break
 
84
  return True
85
 
86
  def venv_pop_token(self, n=1):
@@ -92,6 +596,36 @@ class StreamingLLM(Llama):
92
  def venv_info(self):
93
  return str((self.n_tokens, self.venv, self.venv_idx_map))
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
96
  if n_keep < 0:
97
  return
@@ -106,6 +640,7 @@ class StreamingLLM(Llama):
106
  _idx = kmp_search(self.input_ids, im_start, n_keep, n_past, lps)
107
  if _idx >= n_keep:
108
  n_keep = _idx + len(im_start) # 至少保留一个 im_start 序列结构
 
109
  self._ctx.kv_cache_seq_rm(-1, n_keep, n_keep + n_discard)
110
  self._ctx.kv_cache_seq_shift(0, n_keep + n_discard, n_past, -n_discard)
111
  self.input_ids[n_keep:n_past - n_discard] = self.input_ids[n_keep + n_discard:n_past]
@@ -287,7 +822,7 @@ class StreamingLLM(Llama):
287
  tokens = [token]
288
 
289
  def load_session(self, filepath: str):
290
- n_tokens = POINTER(llama_cpp.c_size_t)(llama_cpp.c_size_t(0))
291
  tokens = (llama_cpp.llama_token * self.n_ctx())()
292
  retn = llama_cpp.llama_load_session_file(self._ctx.ctx,
293
  filepath.encode('utf-8'),
 
1
+ from llama_cpp import *
2
+ from ctypes import POINTER, c_size_t
3
+ from llama_cpp._internals import (
4
+ _LlamaModel, # type: ignore
5
+ _LlamaContext, # type: ignore
6
+ _LlamaBatch, # type: ignore
7
+ _LlamaTokenDataArray, # type: ignore
8
+ )
9
 
10
  from KMP_list import kmp_search, compute_lps_array
11
+ from Turbo_Colormap import map_value_to_color, NOCOLOR, LEGEND, BACK_WHITE
12
+
13
+
14
+ class LLMGenerate:
15
+ def __init__(
16
+ self,
17
+ model,
18
+ n_keep,
19
+ n_discard: int = 256,
20
+ im_start=None,
21
+ top_k: int = 40,
22
+ top_p: float = 0.95,
23
+ min_p: float = 0.05,
24
+ typical_p: float = 1.0,
25
+ temp: float = 0.80,
26
+ repeat_penalty: float = 1.1,
27
+ repeat_last_n: int = 64,
28
+ frequency_penalty: float = 0.0,
29
+ presence_penalty: float = 0.0,
30
+ tfs_z: float = 1.0,
31
+ mirostat_mode: int = 0,
32
+ mirostat_tau: float = 5.0,
33
+ mirostat_eta: float = 0.1
34
+ ):
35
+ def _eval_t(tokens):
36
+ return model.eval_t(
37
+ tokens=tokens,
38
+ n_keep=n_keep,
39
+ n_discard=n_discard,
40
+ im_start=im_start
41
+ )
42
+
43
+ def _sample_t(logits_processor):
44
+ return model.sample_t(
45
+ top_k=top_k,
46
+ top_p=top_p,
47
+ min_p=min_p,
48
+ typical_p=typical_p,
49
+ temp=temp,
50
+ repeat_penalty=repeat_penalty,
51
+ repeat_last_n=repeat_last_n,
52
+ frequency_penalty=frequency_penalty,
53
+ presence_penalty=presence_penalty,
54
+ tfs_z=tfs_z,
55
+ mirostat_mode=mirostat_mode,
56
+ mirostat_tau=mirostat_tau,
57
+ mirostat_eta=mirostat_eta,
58
+ logits_processor=logits_processor
59
+ )
60
+
61
+ self._eval_t = _eval_t
62
+ self._sample_t = _sample_t
63
+ self.str_detokenize = model.str_detokenize
64
+ self.venv_pop_token = model.venv_pop_token
65
+ # ========== 保存输出 ==========
66
+ self.t_bot = []
67
+ self.completion_tokens = []
68
+ self.history = ''
69
+ self.token = None
70
+
71
+ def eval_t(self, tokens):
72
+ # ========== 避免不完整的utf-8编码 ==========
73
+ self.completion_tokens.extend(tokens)
74
+ all_text = self.str_detokenize(self.completion_tokens)
75
+ if all_text:
76
+ self.t_bot.extend(self.completion_tokens)
77
+ self.history += all_text
78
+ self.completion_tokens = []
79
+ return self._eval_t(tokens)
80
+
81
+ def sample_t(self, logits_processor):
82
+ self.token = self._sample_t(logits_processor)
83
+ return self.token
84
+
85
+ def detokenize_sample_t(self):
86
+ self.completion_tokens.append(self.token)
87
+ all_text = self.str_detokenize(self.completion_tokens)
88
+ if not all_text:
89
+ return False
90
+ self.t_bot.extend(self.completion_tokens)
91
+ self.history += all_text
92
+ self.completion_tokens = []
93
+ return True
94
+
95
+ def eval_sample_t(self):
96
+ return self._eval_t([self.token])
97
+
98
+ def endswith_t(self, token_list):
99
+ return self.token in token_list
100
+
101
+ def endswith_s(self, start_func, str_list, com_func=str.rstrip):
102
+ if self.completion_tokens: # 不完整
103
+ return False
104
+
105
+ history = self.history
106
+ t_bot = self.t_bot
107
+
108
+ if start_func(history):
109
+ history = com_func(history)
110
+ for x in str_list:
111
+ if history.endswith(x):
112
+ n = len(t_bot)
113
+ for i in range(1, n): # 找出需要弃置的tokens长度
114
+ tmp = self.str_detokenize(t_bot[n - i:])
115
+ tmp = com_func(tmp)
116
+ if tmp.endswith(x):
117
+ if i > 1: # 最后一个token并未进入kv_cache
118
+ self.venv_pop_token(i - 1)
119
+ if history.endswith(tmp):
120
+ self.history = history[:-len(tmp)] # 移除末尾的tmp
121
+ return True
122
+ return False
123
+
124
 
125
+ kv_cache_type = {
126
+ 'f32': 0,
127
+ 'f16': 1,
128
+ 'q8_0': 8,
129
+ 'q4_0': 2,
130
+ 'q4_1': 3,
131
+ 'iq4_nl': 20,
132
+ 'q5_0': 6,
133
+ 'q5_1': 7
134
+ }
135
 
136
  class StreamingLLM(Llama):
137
+
138
+ __backend_initialized = False
139
+
140
+ def __init__(
141
+ self,
142
+ model_path: str,
143
+ *,
144
+ # Model Params
145
+ n_gpu_layers: int = 0,
146
+ split_mode: int = llama_cpp.LLAMA_SPLIT_MODE_LAYER,
147
+ main_gpu: int = 0,
148
+ tensor_split: Optional[List[float]] = None,
149
+ vocab_only: bool = False,
150
+ use_mmap: bool = True,
151
+ use_mlock: bool = False,
152
+ kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None,
153
+ # Context Params
154
+ seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
155
+ n_ctx: int = 512,
156
+ n_batch: int = 512,
157
+ n_threads: Optional[int] = None,
158
+ n_threads_batch: Optional[int] = None,
159
+ rope_scaling_type: Optional[int] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
160
+ pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED,
161
+ rope_freq_base: float = 0.0,
162
+ rope_freq_scale: float = 0.0,
163
+ yarn_ext_factor: float = -1.0,
164
+ yarn_attn_factor: float = 1.0,
165
+ yarn_beta_fast: float = 32.0,
166
+ yarn_beta_slow: float = 1.0,
167
+ yarn_orig_ctx: int = 0,
168
+ logits_all: bool = False,
169
+ embedding: bool = False,
170
+ offload_kqv: bool = True,
171
+ # Sampling Params
172
+ last_n_tokens_size: int = 64,
173
+ # LoRA Params
174
+ lora_base: Optional[str] = None,
175
+ lora_scale: float = 1.0,
176
+ lora_path: Optional[str] = None,
177
+ # Backend Params
178
+ numa: Union[bool, int] = False,
179
+ # Chat Format Params
180
+ chat_format: Optional[str] = None,
181
+ chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
182
+ # Speculative Decoding
183
+ draft_model: Optional[LlamaDraftModel] = None,
184
+ # Tokenizer Override
185
+ tokenizer: Optional[BaseLlamaTokenizer] = None,
186
+ # Misc
187
+ verbose: bool = True,
188
+ # Extra Params
189
+ type_k: str = 'f16',
190
+ type_v: str = 'f16',
191
+ **kwargs, # type: ignore
192
+ ):
193
+ """Load a llama.cpp model from `model_path`.
194
+
195
+ Examples:
196
+ Basic usage
197
+
198
+ >>> import llama_cpp
199
+ >>> model = llama_cpp.Llama(
200
+ ... model_path="path/to/model",
201
+ ... )
202
+ >>> print(model("The quick brown fox jumps ", stop=["."])["choices"][0]["text"])
203
+ the lazy dog
204
+
205
+ Loading a chat model
206
+
207
+ >>> import llama_cpp
208
+ >>> model = llama_cpp.Llama(
209
+ ... model_path="path/to/model",
210
+ ... chat_format="llama-2",
211
+ ... )
212
+ >>> print(model.create_chat_completion(
213
+ ... messages=[{
214
+ ... "role": "user",
215
+ ... "content": "what is the meaning of life?"
216
+ ... }]
217
+ ... ))
218
+
219
+ Args:
220
+ model_path: Path to the model.
221
+ n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
222
+ split_mode: How to split the model across GPUs. See llama_cpp.LLAMA_SPLIT_* for options.
223
+ main_gpu: main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model. LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results. LLAMA_SPLIT_LAYER: ignored
224
+ tensor_split: How split tensors should be distributed across GPUs. If None, the model is not split.
225
+ vocab_only: Only load the vocabulary no weights.
226
+ use_mmap: Use mmap if possible.
227
+ use_mlock: Force the system to keep the model in RAM.
228
+ kv_overrides: Key-value overrides for the model.
229
+ seed: RNG seed, -1 for random
230
+ n_ctx: Text context, 0 = from model
231
+ n_batch: Prompt processing maximum batch size
232
+ n_threads: Number of threads to use for generation
233
+ n_threads_batch: Number of threads to use for batch processing
234
+ rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
235
+ pooling_type: Pooling type, from `enum llama_pooling_type`.
236
+ rope_freq_base: RoPE base frequency, 0 = from model
237
+ rope_freq_scale: RoPE frequency scaling factor, 0 = from model
238
+ yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
239
+ yarn_attn_factor: YaRN magnitude scaling factor
240
+ yarn_beta_fast: YaRN low correction dim
241
+ yarn_beta_slow: YaRN high correction dim
242
+ yarn_orig_ctx: YaRN original context size
243
+ logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
244
+ embedding: Embedding mode only.
245
+ offload_kqv: Offload K, Q, V to GPU.
246
+ last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
247
+ lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
248
+ lora_path: Path to a LoRA file to apply to the model.
249
+ numa: numa policy
250
+ chat_format: String specifying the chat format to use when calling create_chat_completion.
251
+ chat_handler: Optional chat handler to use when calling create_chat_completion.
252
+ draft_model: Optional draft model to use for speculative decoding.
253
+ tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
254
+ verbose: Print verbose output to stderr.
255
+
256
+ Raises:
257
+ ValueError: If the model path does not exist.
258
+
259
+ Returns:
260
+ A Llama instance.
261
+ """
262
+ self.verbose = verbose
263
+
264
+ set_verbose(verbose)
265
+
266
+ if not StreamingLLM.__backend_initialized:
267
+ with suppress_stdout_stderr(disable=verbose):
268
+ llama_cpp.llama_backend_init()
269
+ StreamingLLM.__backend_initialized = True
270
+
271
+ if isinstance(numa, bool):
272
+ self.numa = (
273
+ llama_cpp.GGML_NUMA_STRATEGY_DISTRIBUTE
274
+ if numa
275
+ else llama_cpp.GGML_NUMA_STRATEGY_DISABLED
276
+ )
277
+ else:
278
+ self.numa = numa
279
+
280
+ if self.numa != llama_cpp.GGML_NUMA_STRATEGY_DISABLED:
281
+ with suppress_stdout_stderr(disable=verbose):
282
+ llama_cpp.llama_numa_init(self.numa)
283
+
284
+ self.model_path = model_path
285
+
286
+ # Model Params
287
+ self.model_params = llama_cpp.llama_model_default_params()
288
+ self.model_params.n_gpu_layers = (
289
+ 0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers
290
+ ) # 0x7FFFFFFF is INT32 max, will be auto set to all layers
291
+ self.model_params.split_mode = split_mode
292
+ self.model_params.main_gpu = main_gpu
293
+ self.tensor_split = tensor_split
294
+ self._c_tensor_split = None
295
+ if self.tensor_split is not None:
296
+ if len(self.tensor_split) > llama_cpp.LLAMA_MAX_DEVICES:
297
+ raise ValueError(
298
+ f"Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES={llama_cpp.LLAMA_MAX_DEVICES}"
299
+ )
300
+ # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
301
+ FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
302
+ self._c_tensor_split = FloatArray(
303
+ *tensor_split # type: ignore
304
+ ) # keep a reference to the array so it is not gc'd
305
+ self.model_params.tensor_split = self._c_tensor_split
306
+ self.model_params.vocab_only = vocab_only
307
+ self.model_params.use_mmap = use_mmap if lora_path is None else False
308
+ self.model_params.use_mlock = use_mlock
309
+
310
+ # kv_overrides is the original python dict
311
+ self.kv_overrides = kv_overrides
312
+ if kv_overrides is not None:
313
+ # _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
314
+ kvo_array_len = len(kv_overrides) + 1 # for sentinel element
315
+ self._kv_overrides_array = (
316
+ llama_cpp.llama_model_kv_override * kvo_array_len
317
+ )()
318
+
319
+ for i, (k, v) in enumerate(kv_overrides.items()):
320
+ self._kv_overrides_array[i].key = k.encode("utf-8")
321
+ if isinstance(v, bool):
322
+ self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
323
+ self._kv_overrides_array[i].value.bool_value = v
324
+ elif isinstance(v, int):
325
+ self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
326
+ self._kv_overrides_array[i].value.int_value = v
327
+ elif isinstance(v, float):
328
+ self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
329
+ self._kv_overrides_array[i].value.float_value = v
330
+ else:
331
+ raise ValueError(f"Unknown value type for {k}: {v}")
332
+
333
+ self._kv_overrides_array[-1].key = (
334
+ b"\0" # ensure sentinel element is zeroed
335
+ )
336
+ self.model_params.kv_overrides = self._kv_overrides_array
337
+
338
+ self.n_batch = min(n_ctx, n_batch) # ???
339
+ self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
340
+ self.n_threads_batch = n_threads_batch or max(
341
+ multiprocessing.cpu_count() // 2, 1
342
+ )
343
+
344
+ # Context Params
345
+ self.context_params = llama_cpp.llama_context_default_params()
346
+ self.context_params.seed = seed
347
+ self.context_params.n_ctx = n_ctx
348
+ self.context_params.n_batch = self.n_batch
349
+ self.context_params.n_threads = self.n_threads
350
+ self.context_params.n_threads_batch = self.n_threads_batch
351
+ self.context_params.rope_scaling_type = (
352
+ rope_scaling_type
353
+ if rope_scaling_type is not None
354
+ else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
355
+ )
356
+ self.context_params.pooling_type = pooling_type
357
+ self.context_params.rope_freq_base = (
358
+ rope_freq_base if rope_freq_base != 0.0 else 0
359
+ )
360
+ self.context_params.rope_freq_scale = (
361
+ rope_freq_scale if rope_freq_scale != 0.0 else 0
362
+ )
363
+ self.context_params.yarn_ext_factor = (
364
+ yarn_ext_factor if yarn_ext_factor != 0.0 else 0
365
+ )
366
+ self.context_params.yarn_attn_factor = (
367
+ yarn_attn_factor if yarn_attn_factor != 0.0 else 0
368
+ )
369
+ self.context_params.yarn_beta_fast = (
370
+ yarn_beta_fast if yarn_beta_fast != 0.0 else 0
371
+ )
372
+ self.context_params.yarn_beta_slow = (
373
+ yarn_beta_slow if yarn_beta_slow != 0.0 else 0
374
+ )
375
+ self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
376
+ self.context_params.logits_all = (
377
+ logits_all if draft_model is None else True
378
+ ) # Must be set to True for speculative decoding
379
+ self.context_params.embeddings = embedding # TODO: Rename to embeddings
380
+
381
+ # KV cache quantization
382
+ print(self.context_params.type_k, self.context_params.type_v)
383
+ self.context_params.type_k = kv_cache_type[type_k]
384
+ self.context_params.type_v = kv_cache_type[type_v]
385
+
386
+ self.context_params.offload_kqv = offload_kqv
387
+
388
+ # Sampling Params
389
+ self.last_n_tokens_size = last_n_tokens_size
390
+
391
+ self.cache: Optional[BaseLlamaCache] = None
392
+
393
+ self.lora_base = lora_base
394
+ self.lora_scale = lora_scale
395
+ self.lora_path = lora_path
396
+
397
+ if not os.path.exists(model_path):
398
+ raise ValueError(f"Model path does not exist: {model_path}")
399
+
400
+ self._model = _LlamaModel(
401
+ path_model=self.model_path, params=self.model_params, verbose=self.verbose
402
+ )
403
+
404
+ # Override tokenizer
405
+ self.tokenizer_ = tokenizer or LlamaTokenizer(self)
406
+
407
+ # Set the default value for the context and correct the batch
408
+ if n_ctx == 0:
409
+ n_ctx = self._model.n_ctx_train()
410
+ self.n_batch = min(n_ctx, n_batch)
411
+ self.context_params.n_ctx = self._model.n_ctx_train()
412
+ self.context_params.n_batch = self.n_batch
413
+
414
+ self._ctx = _LlamaContext(
415
+ model=self._model,
416
+ params=self.context_params,
417
+ verbose=self.verbose,
418
+ )
419
+
420
+ self._batch = _LlamaBatch(
421
+ n_tokens=self.n_batch,
422
+ embd=0,
423
+ n_seq_max=self.context_params.n_ctx,
424
+ verbose=self.verbose,
425
+ )
426
+
427
+ if self.lora_path:
428
+ if self._model.apply_lora_from_file(
429
+ self.lora_path,
430
+ self.lora_scale,
431
+ self.lora_base,
432
+ self.n_threads,
433
+ ):
434
+ raise RuntimeError(
435
+ f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}"
436
+ )
437
+
438
+ if self.verbose:
439
+ print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
440
+
441
+ self.chat_format = chat_format
442
+ self.chat_handler = chat_handler
443
+
444
+ self.draft_model = draft_model
445
+
446
+ self._n_vocab = self.n_vocab()
447
+ self._n_ctx = self.n_ctx()
448
+
449
+ self._token_nl = self.token_nl()
450
+ self._token_eos = self.token_eos()
451
+
452
+ self._candidates = _LlamaTokenDataArray(n_vocab=self._n_vocab)
453
+
454
+ self.n_tokens = 0
455
+ self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
456
+ self.scores: npt.NDArray[np.single] = np.ndarray(
457
+ (n_ctx, self._n_vocab), dtype=np.single
458
+ )
459
+
460
+ self._mirostat_mu = ctypes.c_float(
461
+ 2.0 * 5.0
462
+ ) # TODO: Move this to sampling context
463
+
464
+ try:
465
+ self.metadata = self._model.metadata()
466
+ except Exception as e:
467
+ self.metadata = {}
468
+ if self.verbose:
469
+ print(f"Failed to load metadata: {e}", file=sys.stderr)
470
+
471
+ if self.verbose:
472
+ print(f"Model metadata: {self.metadata}", file=sys.stderr)
473
+
474
+ if (
475
+ self.chat_format is None
476
+ and self.chat_handler is None
477
+ and "tokenizer.chat_template" in self.metadata
478
+ ):
479
+ chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
480
+ self.metadata
481
+ )
482
+
483
+ if chat_format is not None:
484
+ self.chat_format = chat_format
485
+ if self.verbose:
486
+ print(f"Guessed chat format: {chat_format}", file=sys.stderr)
487
+ else:
488
+ template = self.metadata["tokenizer.chat_template"]
489
+ try:
490
+ eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
491
+ except:
492
+ eos_token_id = self.token_eos()
493
+ try:
494
+ bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
495
+ except:
496
+ bos_token_id = self.token_bos()
497
+
498
+ eos_token = self._model.token_get_text(eos_token_id)
499
+ bos_token = self._model.token_get_text(bos_token_id)
500
+
501
+ if self.verbose:
502
+ print(f"Using gguf chat template: {template}", file=sys.stderr)
503
+ print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
504
+ print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
505
+
506
+ self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
507
+ template=template, eos_token=eos_token, bos_token=bos_token
508
+ ).to_chat_handler()
509
+
510
+ if self.chat_format is None and self.chat_handler is None:
511
+ self.chat_format = "llama-2"
512
+ if self.verbose:
513
+ print(f"Using fallback chat format: {chat_format}", file=sys.stderr)
514
  self._venv_init()
515
 
516
  def str_detokenize(self, tokens) -> str:
 
565
  if name not in self.venv_idx_map:
566
  return False
567
  venv_idx = self.venv_idx_map.index(name) + 1
568
+ count_name = self.venv_idx_map.count(name) if keep_last else 0
569
  while self.venv_idx_map:
570
+ if keep_last and count_name <= keep_last:
571
  break # 保留最后n个
572
  self.venv_idx_map.pop(venv_idx - 1) # 删除
573
  if venv_idx == len(self.venv) - 1:
 
584
  venv_idx = self.venv_idx_map.index(name, venv_idx - 1) + 1
585
  except ValueError: # 没有了
586
  break
587
+ count_name -= 1 # 计数减一
588
  return True
589
 
590
  def venv_pop_token(self, n=1):
 
596
  def venv_info(self):
597
  return str((self.n_tokens, self.venv, self.venv_idx_map))
598
 
599
+ def venv_viz(self):
600
+ completion_tokens = []
601
+ history = LEGEND + '\n'
602
+ text_color = NOCOLOR
603
+ for i in range(self.venv[-1]):
604
+ idx = self.n_tokens - self.venv[-1] + i
605
+ token = self._input_ids[idx]
606
+ if not completion_tokens: # 不完整则是第一个token
607
+ # ========== 获取对应token的概率 ==========
608
+ score = self.scores[idx-1: idx, :].ravel() # 第i个token的分数是前i-1个token预测的,所以减一
609
+ score = np.exp(score) # 空白则全1,但无所谓了
610
+ sum_score = np.sum(score)
611
+ probabilities = score[token] / sum_score
612
+ if probabilities < 0.001:
613
+ text_color = NOCOLOR
614
+ else:
615
+ if text_color is NOCOLOR:
616
+ text_color = BACK_WHITE + map_value_to_color(probabilities)
617
+ else:
618
+ text_color = map_value_to_color(probabilities)
619
+ history += text_color
620
+ # ========== 避免不完整的utf-8编码 ==========
621
+ completion_tokens.append(token)
622
+ all_text = self.str_detokenize(completion_tokens)
623
+ if not all_text:
624
+ continue
625
+ completion_tokens = [] # 完整则清空缓存
626
+ history += repr(all_text)[1:-1]
627
+ return history + NOCOLOR
628
+
629
  def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
630
  if n_keep < 0:
631
  return
 
640
  _idx = kmp_search(self.input_ids, im_start, n_keep, n_past, lps)
641
  if _idx >= n_keep:
642
  n_keep = _idx + len(im_start) # 至少保留一个 im_start 序列结构
643
+ print(im_start, n_keep, n_discard, _idx)
644
  self._ctx.kv_cache_seq_rm(-1, n_keep, n_keep + n_discard)
645
  self._ctx.kv_cache_seq_shift(0, n_keep + n_discard, n_past, -n_discard)
646
  self.input_ids[n_keep:n_past - n_discard] = self.input_ids[n_keep + n_discard:n_past]
 
822
  tokens = [token]
823
 
824
  def load_session(self, filepath: str):
825
+ n_tokens = POINTER(c_size_t)(c_size_t(0))
826
  tokens = (llama_cpp.llama_token * self.n_ctx())()
827
  retn = llama_cpp.llama_load_session_file(self._ctx.ctx,
828
  filepath.encode('utf-8'),