Limour commited on
Commit
f4e6998
1 Parent(s): 1ee05f3

Upload 2 files

Browse files
Files changed (2) hide show
  1. KMP_list.py +55 -0
  2. llama_cpp_python_streamingllm.py +282 -0
KMP_list.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def compute_lps_array(sublist):
2
+ """
3
+ 计算模式串的最长前缀后缀匹配数组(LPS数组)
4
+ """
5
+ lps = [0] * len(sublist)
6
+ j = 0
7
+ i = 1
8
+ while i < len(sublist):
9
+ if sublist[i] == sublist[j]:
10
+ j += 1
11
+ lps[i] = j
12
+ i += 1
13
+ else:
14
+ if j != 0:
15
+ j = lps[j - 1]
16
+ else:
17
+ lps[i] = 0
18
+ i += 1
19
+ return lps
20
+
21
+
22
+ def kmp_search(main_list, sublist, _start=0, _end=None, lps=None):
23
+ """
24
+ 使用KMP算法在列表上查找子串
25
+ """
26
+ if not sublist:
27
+ return 0
28
+ if _end is None:
29
+ _end = len(main_list)
30
+ if lps is None:
31
+ lps = compute_lps_array(sublist)
32
+ i = _start # 指向主串的索引
33
+ j = 0 # 指向子串的索引
34
+ while i < _end:
35
+ if main_list[i] == sublist[j]:
36
+ i += 1
37
+ j += 1
38
+ if j == len(sublist):
39
+ return i - j
40
+ else:
41
+ if j != 0:
42
+ j = lps[j - 1]
43
+ else:
44
+ i += 1
45
+ return -1
46
+
47
+
48
+ if __name__ == '__main__':
49
+ a = [1, 1, 3, 2, 3, 6, 7, 8, 3, 2, 3]
50
+ b = [3, 2, 3]
51
+ c = compute_lps_array(b)
52
+ print(kmp_search(a, b, lps=c))
53
+ print(kmp_search(a, b, 3, lps=c))
54
+ print(kmp_search(a, b, 3, 10, lps=c))
55
+ print(kmp_search(a, b, 9, lps=c))
llama_cpp_python_streamingllm.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Generator
2
+
3
+ from llama_cpp import Llama, LogitsProcessorList, LlamaGrammar, llama_cpp, npt, np, StoppingCriteriaList
4
+ from ctypes import POINTER
5
+
6
+ from KMP_list import kmp_search, compute_lps_array
7
+
8
+
9
+ def is_UTF8_incomplete(all_text):
10
+ multibyte_fix = 0
11
+ if len(all_text) < 3:
12
+ all_text = b'000' + all_text
13
+ for k, char in enumerate(all_text[-3:]):
14
+ k = 3 - k
15
+ for num, pattern in [(2, 192), (3, 224), (4, 240)]:
16
+ # Bitwise AND check
17
+ if num > k and pattern & char == pattern:
18
+ multibyte_fix = num - k
19
+ return multibyte_fix
20
+
21
+
22
+ def get_complete_UTF8(all_text):
23
+ multibyte_fix = is_UTF8_incomplete(all_text)
24
+ if multibyte_fix > 0:
25
+ multibyte_fix = multibyte_fix - 3
26
+ return all_text[:multibyte_fix].decode("utf-8")
27
+ else:
28
+ return all_text.decode("utf-8")
29
+
30
+
31
+ class StreamingLLM(Llama):
32
+ def __init__(self, model_path: str, **kwargs):
33
+ super().__init__(model_path, **kwargs)
34
+ self.venv = [0]
35
+
36
+ def str_detokenize(self, tokens) -> str:
37
+ return get_complete_UTF8(self.detokenize(tokens))
38
+
39
+ def kv_cache_seq_trim(self):
40
+ self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
41
+
42
+ def venv_create(self):
43
+ self.venv.append(0)
44
+ return len(self.venv) - 1
45
+
46
+ def venv_disband(self):
47
+ if len(self.venv) <= 1:
48
+ return 0
49
+ tmp = self.venv.pop()
50
+ self.venv[-1] += tmp
51
+ return len(self.venv) - 1
52
+
53
+ def venv_remove(self, venv_idx=None):
54
+ if venv_idx is None:
55
+ venv_idx = len(self.venv) - 1
56
+ if venv_idx <= 0 or venv_idx >= len(self.venv):
57
+ return len(self.venv) - 1
58
+ if venv_idx == len(self.venv) - 1:
59
+ # 最后一层
60
+ self.n_tokens -= min(self.venv.pop(), self.n_tokens)
61
+ self.kv_cache_seq_trim()
62
+ else:
63
+ # 非最后一层
64
+ n_keep = self.n_tokens - sum(self.venv[i] for i in range(venv_idx, len(self.venv)))
65
+ n_discard = self.venv.pop(venv_idx)
66
+ self.kv_cache_seq_ltrim(n_keep, n_discard)
67
+ return len(self.venv) - 1
68
+
69
+ def venv_pop_token(self):
70
+ self.n_tokens -= 1
71
+ self.venv[-1] -= 1
72
+ self.kv_cache_seq_trim()
73
+
74
+ def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
75
+ if n_past < 0:
76
+ n_past = self.n_tokens
77
+ if im_start is not None: # [<|im_start|>, name, nl]
78
+ lps = compute_lps_array(im_start)
79
+ _idx = kmp_search(self.input_ids, im_start, n_keep + n_discard, n_past, lps)
80
+ if _idx >= n_keep: # 其实是大于等于 n_keep + n_discard
81
+ n_discard = _idx - n_keep # 截断到最近的 im_start 序列结构
82
+ else:
83
+ _idx = kmp_search(self.input_ids, im_start, n_keep, n_past, lps)
84
+ if _idx >= n_keep:
85
+ n_keep = _idx + len(im_start) # 至少保留一个 im_start 序列结构
86
+ self._ctx.kv_cache_seq_rm(-1, n_keep, n_keep + n_discard)
87
+ self._ctx.kv_cache_seq_shift(0, n_keep + n_discard, n_past, -n_discard)
88
+ self.input_ids[n_keep:n_past - n_discard] = self.input_ids[n_keep + n_discard:n_past]
89
+ self.n_tokens = n_past - n_discard
90
+
91
+ def eval_t(self, tokens, n_keep=4, n_discard=256, im_start=None):
92
+ if self._n_ctx < self.n_tokens + len(tokens):
93
+ tmp_n_discard = max(n_discard, self.n_tokens + len(tokens) - self._n_ctx)
94
+ self.kv_cache_seq_ltrim(n_keep, tmp_n_discard, im_start=im_start)
95
+ for i in range(0, len(tokens), self.n_batch):
96
+ batch = tokens[i: i + self.n_batch]
97
+ n_past = self.n_tokens
98
+ n_tokens = len(batch)
99
+ self._batch.set_batch(
100
+ batch=batch, n_past=n_past, logits_all=self.context_params.logits_all
101
+ )
102
+ self._ctx.decode(self._batch)
103
+ # Save tokens
104
+ self.input_ids[n_past: n_past + n_tokens] = batch
105
+ # Save logits
106
+ rows = n_tokens
107
+ cols = self._n_vocab
108
+ offset = (
109
+ 0 if self.context_params.logits_all else n_tokens - 1
110
+ ) # NOTE: Only save the last token logits if logits_all is False
111
+ self.scores[n_past + offset: n_past + n_tokens, :].reshape(-1)[
112
+ :
113
+ ] = self._ctx.get_logits()[offset * cols: rows * cols]
114
+ # Update n_tokens
115
+ self.n_tokens += n_tokens
116
+ self.venv[-1] += n_tokens
117
+ return self.n_tokens
118
+
119
+ def sample_t(
120
+ self,
121
+ top_k: int = 40,
122
+ top_p: float = 0.95,
123
+ min_p: float = 0.05,
124
+ typical_p: float = 1.0,
125
+ temp: float = 0.80,
126
+ repeat_penalty: float = 1.1,
127
+ repeat_last_n: int = 64,
128
+ frequency_penalty: float = 0.0,
129
+ presence_penalty: float = 0.0,
130
+ tfs_z: float = 1.0,
131
+ mirostat_mode: int = 0,
132
+ mirostat_eta: float = 0.1,
133
+ mirostat_tau: float = 5.0,
134
+ penalize_nl: bool = True,
135
+ logits_processor: Optional[LogitsProcessorList] = None,
136
+ grammar: Optional[LlamaGrammar] = None,
137
+ ):
138
+ last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
139
+ 0, repeat_last_n - self.n_tokens
140
+ ) + self._input_ids[-repeat_last_n:].tolist()
141
+ last_n_tokens_size = len(last_n_tokens_data)
142
+ n_vocab = self._n_vocab
143
+ n_ctx = self._n_ctx
144
+ top_k = n_vocab if top_k <= 0 else top_k
145
+ last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size
146
+ last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)(
147
+ *last_n_tokens_data
148
+ )
149
+ logits: npt.NDArray[np.single] = self.scores[self.n_tokens - 1: self.n_tokens, :].ravel()
150
+
151
+ if logits_processor is not None:
152
+ logits[:] = logits_processor(self._input_ids, logits)
153
+
154
+ self._candidates.copy_logits(logits)
155
+ self._ctx.sample_repetition_penalties(
156
+ candidates=self._candidates,
157
+ last_tokens_data=last_n_tokens_data_c,
158
+ penalty_last_n=last_n_tokens_size,
159
+ penalty_repeat=repeat_penalty,
160
+ penalty_freq=frequency_penalty,
161
+ penalty_present=presence_penalty,
162
+ )
163
+ if not penalize_nl:
164
+ nl_logit = logits[self._token_nl]
165
+ self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float(
166
+ nl_logit
167
+ )
168
+
169
+ if grammar is not None:
170
+ self._ctx.sample_grammar(
171
+ candidates=self._candidates,
172
+ grammar=grammar,
173
+ )
174
+
175
+ if temp < 0.0:
176
+ self._ctx.sample_softmax(candidates=self._candidates)
177
+ id_ = self._candidates.candidates.data[0].id
178
+ elif temp == 0.0:
179
+ id_ = self._ctx.sample_token_greedy(candidates=self._candidates)
180
+ elif mirostat_mode == 1:
181
+ self._ctx.sample_temp(candidates=self._candidates, temp=temp)
182
+ id_ = self._ctx.sample_token_mirostat(
183
+ candidates=self._candidates,
184
+ tau=mirostat_tau,
185
+ eta=mirostat_eta,
186
+ mu=2.0 * mirostat_tau,
187
+ m=100,
188
+ )
189
+ elif mirostat_mode == 2:
190
+ self._ctx.sample_temp(candidates=self._candidates, temp=temp)
191
+ id_ = self._ctx.sample_token_mirostat_v2(
192
+ candidates=self._candidates,
193
+ tau=mirostat_tau,
194
+ eta=mirostat_eta,
195
+ mu=2.0 * mirostat_tau,
196
+ )
197
+ else:
198
+ self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
199
+ self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
200
+ self._ctx.sample_typical(
201
+ candidates=self._candidates, p=typical_p, min_keep=1
202
+ )
203
+ self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
204
+ self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
205
+ self._ctx.sample_temp(candidates=self._candidates, temp=temp)
206
+ id_ = self._ctx.sample_token(candidates=self._candidates)
207
+ if grammar is not None:
208
+ self._ctx.grammar_accept_token(grammar=grammar, token=id_)
209
+ return id_
210
+
211
+ def generate_t(
212
+ self,
213
+ tokens: Sequence[int],
214
+ n_keep,
215
+ n_discard: int = 256,
216
+ im_start=None,
217
+ top_k: int = 40,
218
+ top_p: float = 0.95,
219
+ min_p: float = 0.05,
220
+ typical_p: float = 1.0,
221
+ temp: float = 0.80,
222
+ repeat_penalty: float = 1.1,
223
+ repeat_last_n: int = 64,
224
+ frequency_penalty: float = 0.0,
225
+ presence_penalty: float = 0.0,
226
+ tfs_z: float = 1.0,
227
+ mirostat_mode: int = 0,
228
+ mirostat_tau: float = 5.0,
229
+ mirostat_eta: float = 0.1,
230
+ logits_processor: Optional[LogitsProcessorList] = None,
231
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
232
+ grammar: Optional[LlamaGrammar] = None,
233
+ ) -> Generator[int, Optional[Sequence[int]], None]:
234
+ typical_p = float(typical_p)
235
+ frequency_penalty = float(frequency_penalty)
236
+ presence_penalty = float(presence_penalty)
237
+ tfs_z = float(tfs_z)
238
+ mirostat_tau = float(mirostat_tau)
239
+ while True:
240
+ self.eval_t(tokens, n_keep, n_discard, im_start=im_start)
241
+ token = self.sample_t(
242
+ top_k=top_k,
243
+ top_p=top_p,
244
+ min_p=min_p,
245
+ typical_p=typical_p,
246
+ temp=temp,
247
+ repeat_penalty=repeat_penalty,
248
+ repeat_last_n=repeat_last_n,
249
+ frequency_penalty=frequency_penalty,
250
+ presence_penalty=presence_penalty,
251
+ tfs_z=tfs_z,
252
+ mirostat_mode=mirostat_mode,
253
+ mirostat_tau=mirostat_tau,
254
+ mirostat_eta=mirostat_eta,
255
+ logits_processor=logits_processor,
256
+ grammar=grammar,
257
+ )
258
+ if stopping_criteria is not None and stopping_criteria(
259
+ self._input_ids, self._scores[-1, :]
260
+ ):
261
+ return
262
+ tokens_or_none = yield token
263
+ tokens = [token]
264
+ if tokens_or_none is not None:
265
+ tokens.extend(tokens_or_none)
266
+
267
+ def load_session(self, filepath: str):
268
+ n_tokens = POINTER(llama_cpp.c_size_t)(llama_cpp.c_size_t(0))
269
+ tokens = (llama_cpp.llama_token * self.n_ctx())()
270
+ retn = llama_cpp.llama_load_session_file(self._ctx.ctx,
271
+ filepath.encode('utf-8'),
272
+ tokens,
273
+ self.n_ctx(),
274
+ n_tokens)
275
+ self.n_tokens = n_tokens.contents.value
276
+ self.input_ids[:self.n_tokens] = tokens[:self.n_tokens]
277
+ return retn
278
+
279
+ def save_session(self, filepath: str):
280
+ tokens = self._input_ids.tolist()
281
+ tokens = (llama_cpp.llama_token * len(tokens))(*tokens)
282
+ return llama_cpp.llama_save_session_file(self._ctx.ctx, filepath.encode('utf-8'), tokens, self.n_tokens)