AllenYkl commited on
Commit
e6a65ef
1 Parent(s): c5ac60d

Delete bin_public/utils/utils.py

Browse files
Files changed (1) hide show
  1. bin_public/utils/utils.py +0 -472
bin_public/utils/utils.py DELETED
@@ -1,472 +0,0 @@
1
- # -*- coding:utf-8 -*-
2
- from __future__ import annotations
3
- from typing import TYPE_CHECKING, List, Tuple
4
- import logging
5
- import json
6
- import gradio as gr
7
- # import openai
8
- import os
9
- import requests
10
- # import markdown
11
- import csv
12
- import mdtex2html
13
- from pypinyin import lazy_pinyin
14
- from bin_public.config.presets import *
15
- import tiktoken
16
- from tqdm import tqdm
17
- from duckduckgo_search import ddg
18
- from bin_public.utils.utils_db import *
19
- import datetime
20
-
21
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
22
-
23
- if TYPE_CHECKING:
24
- from typing import TypedDict
25
-
26
- class DataframeData(TypedDict):
27
- headers: List[str]
28
- data: List[List[str | int | bool]]
29
-
30
- initial_prompt = "You are a helpful assistant."
31
- API_URL = "https://api.openai.com/v1/chat/completions"
32
- HISTORY_DIR = "history"
33
- TEMPLATES_DIR = "templates"
34
-
35
-
36
- def postprocess(
37
- self, y: List[Tuple[str | None, str | None]]
38
- ) -> List[Tuple[str | None, str | None]]:
39
- """
40
- Parameters:
41
- y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format.
42
- Returns:
43
- List of tuples representing the message and response. Each message and response will be a string of HTML.
44
- """
45
- if y is None:
46
- return []
47
- for i, (message, response) in enumerate(y):
48
- y[i] = (
49
- # None if message is None else markdown.markdown(message),
50
- # None if response is None else markdown.markdown(response),
51
- None if message is None else mdtex2html.convert((message)),
52
- None if response is None else mdtex2html.convert(response),
53
- )
54
- return y
55
-
56
-
57
- def count_token(input_str):
58
- encoding = tiktoken.get_encoding("cl100k_base")
59
- length = len(encoding.encode(input_str))
60
- return length
61
-
62
-
63
- def parse_text(text):
64
- lines = text.split("\n")
65
- lines = [line for line in lines if line != ""]
66
- count = 0
67
- for i, line in enumerate(lines):
68
- if "```" in line:
69
- count += 1
70
- items = line.split('`')
71
- if count % 2 == 1:
72
- lines[i] = f'<pre><code class="language-{items[-1]}">'
73
- else:
74
- lines[i] = f'<br></code></pre>'
75
- else:
76
- if i > 0:
77
- if count % 2 == 1:
78
- line = line.replace("`", "\`")
79
- line = line.replace("<", "&lt;")
80
- line = line.replace(">", "&gt;")
81
- line = line.replace(" ", "&nbsp;")
82
- line = line.replace("*", "&ast;")
83
- line = line.replace("_", "&lowbar;")
84
- line = line.replace("-", "&#45;")
85
- line = line.replace(".", "&#46;")
86
- line = line.replace("!", "&#33;")
87
- line = line.replace("(", "&#40;")
88
- line = line.replace(")", "&#41;")
89
- line = line.replace("$", "&#36;")
90
- lines[i] = "<br>"+line
91
- text = "".join(lines)
92
- return text
93
-
94
-
95
- def construct_text(role, text):
96
- return {"role": role, "content": text}
97
-
98
-
99
- def construct_user(text):
100
- return construct_text("user", text)
101
-
102
-
103
- def construct_system(text):
104
- return construct_text("system", text)
105
-
106
-
107
- def construct_assistant(text):
108
- return construct_text("assistant", text)
109
-
110
-
111
- def construct_token_message(token, stream=False):
112
- return f"Token 计数: {token}"
113
-
114
-
115
- def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model):
116
- headers = {
117
- "Content-Type": "application/json",
118
- "Authorization": f"Bearer {openai_api_key}"
119
- }
120
-
121
- history = [construct_system(system_prompt), *history]
122
-
123
- payload = {
124
- "model": selected_model,
125
- "messages": history, # [{"role": "user", "content": f"{inputs}"}],
126
- "temperature": temperature, # 1.0,
127
- "top_p": top_p, # 1.0,
128
- "n": 1,
129
- "stream": stream,
130
- "presence_penalty": 0,
131
- "frequency_penalty": 0,
132
- }
133
- if stream:
134
- timeout = timeout_streaming
135
- else:
136
- timeout = timeout_all
137
- response = requests.post(API_URL, headers=headers, json=payload, stream=True, timeout=timeout)
138
- return response
139
-
140
-
141
- def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model):
142
- def get_return_value():
143
- return chatbot, history, status_text, all_token_counts
144
-
145
- logging.info("实时回答模式")
146
- partial_words = ""
147
- counter = 0
148
- status_text = "开始实时传输回答……"
149
- history.append(construct_user(inputs))
150
- history.append(construct_assistant(""))
151
- chatbot.append((parse_text(inputs), ""))
152
- user_token_count = 0
153
- if len(all_token_counts) == 0:
154
- system_prompt_token_count = count_token(system_prompt)
155
- user_token_count = count_token(inputs) + system_prompt_token_count
156
- else:
157
- user_token_count = count_token(inputs)
158
- all_token_counts.append(user_token_count)
159
- logging.info(f"输入token计数: {user_token_count}")
160
- yield get_return_value()
161
- try:
162
- response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True, selected_model)
163
- except requests.exceptions.ConnectTimeout:
164
- status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
165
- yield get_return_value()
166
- return
167
- except requests.exceptions.ReadTimeout:
168
- status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
169
- yield get_return_value()
170
- return
171
-
172
- yield get_return_value()
173
- error_json_str = ""
174
-
175
- for chunk in tqdm(response.iter_lines()):
176
- if counter == 0:
177
- counter += 1
178
- continue
179
- counter += 1
180
- # check whether each line is non-empty
181
- if chunk:
182
- chunk = chunk.decode()
183
- chunklength = len(chunk)
184
- try:
185
- chunk = json.loads(chunk[6:])
186
- except json.JSONDecodeError:
187
- logging.info(chunk)
188
- error_json_str += chunk
189
- status_text = f"JSON解析错误。请重置对话。收到的内容: {error_json_str}"
190
- yield get_return_value()
191
- continue
192
- # decode each line as response data is in bytes
193
- if chunklength > 6 and "delta" in chunk['choices'][0]:
194
- finish_reason = chunk['choices'][0]['finish_reason']
195
- status_text = construct_token_message(sum(all_token_counts), stream=True)
196
- if finish_reason == "stop":
197
- yield get_return_value()
198
- break
199
- try:
200
- partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
201
- except KeyError:
202
- status_text = standard_error_msg + "API回复中找不到内容。很可能是Token计数达到上限。请重置对话。当前Token计数: " + str(sum(all_token_counts))
203
- yield get_return_value()
204
- break
205
- history[-1] = construct_assistant(partial_words)
206
- chatbot[-1] = (parse_text(inputs), parse_text(partial_words))
207
- all_token_counts[-1] += 1
208
- yield get_return_value()
209
-
210
-
211
- def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model):
212
- logging.info("一次性回答模式")
213
- history.append(construct_user(inputs))
214
- history.append(construct_assistant(""))
215
- chatbot.append((parse_text(inputs), ""))
216
- all_token_counts.append(count_token(inputs))
217
- try:
218
- response = get_response(openai_api_key, system_prompt, history, temperature, top_p, False, selected_model)
219
- except requests.exceptions.ConnectTimeout:
220
- status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
221
- return chatbot, history, status_text, all_token_counts
222
- except requests.exceptions.ProxyError:
223
- status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
224
- return chatbot, history, status_text, all_token_counts
225
- except requests.exceptions.SSLError:
226
- status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
227
- return chatbot, history, status_text, all_token_counts
228
- response = json.loads(response.text)
229
- content = response["choices"][0]["message"]["content"]
230
- history[-1] = construct_assistant(content)
231
- chatbot[-1] = (parse_text(inputs), parse_text(content))
232
- total_token_count = response["usage"]["total_tokens"]
233
- all_token_counts[-1] = total_token_count - sum(all_token_counts)
234
- status_text = construct_token_message(total_token_count)
235
- return chatbot, history, status_text, all_token_counts
236
-
237
-
238
- def predict(openai_api_key, invite_code, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, stream=False, selected_model = MODELS[0], use_websearch_checkbox = False, should_check_token_count = True): # repetition_penalty, top_k
239
- # logging.info("输入为:" +colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
240
- if use_websearch_checkbox:
241
- results = ddg(inputs, max_results=3)
242
- web_results = []
243
- for idx, result in enumerate(results):
244
- logging.info(f"搜索结果{idx + 1}:{result}")
245
- web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
246
- web_results = "\n\n".join(web_results)
247
- today = datetime.datetime.today().strftime("%Y-%m-%d")
248
- inputs = websearch_prompt.replace("{current_date}", today).replace("{query}", inputs).replace("{web_results}", web_results)
249
- if len(openai_api_key) != 51:
250
- status_text = standard_error_msg + no_apikey_msg
251
- logging.info(status_text)
252
- chatbot.append((parse_text(inputs), ""))
253
- if len(history) == 0:
254
- history.append(construct_user(inputs))
255
- history.append("")
256
- all_token_counts.append(0)
257
- else:
258
- history[-2] = construct_user(inputs)
259
- yield chatbot, history, status_text, all_token_counts
260
- return
261
- if stream:
262
- yield chatbot, history, "开始生成回答……", all_token_counts
263
- if stream:
264
- logging.info("使用流式传输")
265
- iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model)
266
- for chatbot, history, status_text, all_token_counts in iter:
267
- yield chatbot, history, status_text, all_token_counts
268
- else:
269
- logging.info("不使用流式传输")
270
- chatbot, history, status_text, all_token_counts = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model)
271
- yield chatbot, history, status_text, all_token_counts
272
- logging.info(f"传输完毕。当前token计数为{all_token_counts}")
273
- if len(history) > 1 and history[-1]['content'] != inputs:
274
- # logging.info("回答为:" +colorama.Fore.BLUE + f"{history[-1]['content']}" + colorama.Style.RESET_ALL)
275
- try:
276
- token = all_token_counts[-1]
277
- except:
278
- token = 0
279
- holo_query_insert_chat_message(invite_code, inputs, history[-1]['content'], token, history)
280
- if stream:
281
- max_token = max_token_streaming
282
- else:
283
- max_token = max_token_all
284
- if sum(all_token_counts) > max_token and should_check_token_count:
285
- status_text = f"精简token中{all_token_counts}/{max_token}"
286
- logging.info(status_text)
287
- yield chatbot, history, status_text, all_token_counts
288
- iter = reduce_token_size(openai_api_key, invite_code, system_prompt, history, chatbot, all_token_counts, top_p, temperature, stream=False, selected_model=selected_model, hidden=True)
289
- for chatbot, history, status_text, all_token_counts in iter:
290
- status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
291
- yield chatbot, history, status_text, all_token_counts
292
-
293
-
294
- def retry(openai_api_key, invite_code, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, selected_model = MODELS[0]):
295
- logging.info("重试中……")
296
- if len(history) == 0:
297
- yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
298
- return
299
- history.pop()
300
- inputs = history.pop()["content"]
301
- token_count.pop()
302
- iter = predict(openai_api_key, invite_code, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=stream, selected_model=selected_model)
303
- logging.info("重试完毕")
304
- for x in iter:
305
- yield x
306
-
307
-
308
- def reduce_token_size(openai_api_key, invite_code, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, selected_model = MODELS[0], hidden=False):
309
- logging.info("开始减少token数量……")
310
- iter = predict(openai_api_key, invite_code, system_prompt, history, summarize_prompt, chatbot, token_count, top_p, temperature, stream=stream, selected_model = selected_model, should_check_token_count=False)
311
- logging.info(f"chatbot: {chatbot}")
312
- for chatbot, history, status_text, previous_token_count in iter:
313
- history = history[-2:]
314
- token_count = previous_token_count[-1:]
315
- if hidden:
316
- chatbot.pop()
317
- yield chatbot, history, construct_token_message(sum(token_count), stream=stream), token_count
318
- logging.info("减少token数量完毕")
319
-
320
-
321
- def delete_last_conversation(chatbot, history, previous_token_count):
322
- if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
323
- logging.info("由于包含报错信息,只删除chatbot记录")
324
- chatbot.pop()
325
- return chatbot, history
326
- if len(history) > 0:
327
- logging.info("删除了一组对话历史")
328
- history.pop()
329
- history.pop()
330
- if len(chatbot) > 0:
331
- logging.info("删除了一组chatbot对话")
332
- chatbot.pop()
333
- if len(previous_token_count) > 0:
334
- logging.info("删除了一组对话的token计数记录")
335
- previous_token_count.pop()
336
- return chatbot, history, previous_token_count, construct_token_message(sum(previous_token_count))
337
-
338
-
339
- def save_file(filename, system, history, chatbot):
340
- logging.info("保存对话历史中……")
341
- os.makedirs(HISTORY_DIR, exist_ok=True)
342
- if filename.endswith(".json"):
343
- json_s = {"system": system, "history": history, "chatbot": chatbot}
344
- print(json_s)
345
- with open(os.path.join(HISTORY_DIR, filename), "w") as f:
346
- json.dump(json_s, f)
347
- elif filename.endswith(".md"):
348
- md_s = f"system: \n- {system} \n"
349
- for data in history:
350
- md_s += f"\n{data['role']}: \n- {data['content']} \n"
351
- with open(os.path.join(HISTORY_DIR, filename), "w", encoding="utf8") as f:
352
- f.write(md_s)
353
- logging.info("保存对话历史完毕")
354
- return os.path.join(HISTORY_DIR, filename)
355
-
356
-
357
- def save_chat_history(filename, system, history, chatbot):
358
- if filename == "":
359
- return
360
- if not filename.endswith(".json"):
361
- filename += ".json"
362
- return save_file(filename, system, history, chatbot)
363
-
364
-
365
- def export_markdown(filename, system, history, chatbot):
366
- if filename == "":
367
- return
368
- if not filename.endswith(".md"):
369
- filename += ".md"
370
- return save_file(filename, system, history, chatbot)
371
-
372
-
373
- def load_chat_history(filename, system, history, chatbot):
374
- logging.info("加载对话历史中……")
375
- if type(filename) != str:
376
- filename = filename.name
377
- try:
378
- with open(os.path.join(HISTORY_DIR, filename), "r") as f:
379
- json_s = json.load(f)
380
- try:
381
- if type(json_s["history"][0]) == str:
382
- logging.info("历史记录格式为旧版,正在转换……")
383
- new_history = []
384
- for index, item in enumerate(json_s["history"]):
385
- if index % 2 == 0:
386
- new_history.append(construct_user(item))
387
- else:
388
- new_history.append(construct_assistant(item))
389
- json_s["history"] = new_history
390
- logging.info(new_history)
391
- except:
392
- # 没有对话历史
393
- pass
394
- logging.info("加载对话历史完毕")
395
- return filename, json_s["system"], json_s["history"], json_s["chatbot"]
396
- except FileNotFoundError:
397
- logging.info("没有找到对话历史文件,不执行任何操作")
398
- return filename, system, history, chatbot
399
-
400
-
401
- def sorted_by_pinyin(list):
402
- return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
403
-
404
-
405
- def get_file_names(dir, plain=False, filetypes=[".json"]):
406
- logging.info(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
407
- files = []
408
- try:
409
- for type in filetypes:
410
- files += [f for f in os.listdir(dir) if f.endswith(type)]
411
- except FileNotFoundError:
412
- files = []
413
- files = sorted_by_pinyin(files)
414
- if files == []:
415
- files = [""]
416
- if plain:
417
- return files
418
- else:
419
- return gr.Dropdown.update(choices=files)
420
-
421
-
422
- def get_history_names(plain=False):
423
- logging.info("获取历史记录文件名列表")
424
- return get_file_names(HISTORY_DIR, plain)
425
-
426
-
427
- def load_template(filename, mode=0):
428
- logging.info(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
429
- lines = []
430
- logging.info("Loading template...")
431
- if filename.endswith(".json"):
432
- with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
433
- lines = json.load(f)
434
- lines = [[i["act"], i["prompt"]] for i in lines]
435
- else:
436
- with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as csvfile:
437
- reader = csv.reader(csvfile)
438
- lines = list(reader)
439
- lines = lines[1:]
440
- if mode == 1:
441
- return sorted_by_pinyin([row[0] for row in lines])
442
- elif mode == 2:
443
- return {row[0]:row[1] for row in lines}
444
- else:
445
- choices = sorted_by_pinyin([row[0] for row in lines])
446
- return {row[0]:row[1] for row in lines}, gr.Dropdown.update(choices=choices, value=choices[0])
447
-
448
-
449
- def get_template_names(plain=False):
450
- logging.info("获取模板文件名列表")
451
- return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
452
-
453
-
454
- def get_template_content(templates, selection, original_system_prompt):
455
- logging.info(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
456
- try:
457
- return templates[selection]
458
- except:
459
- return original_system_prompt
460
-
461
-
462
- def reset_state():
463
- logging.info("重置状态")
464
- return [], [], [], construct_token_message(0)
465
-
466
-
467
- def reset_textbox():
468
- return gr.update(value='')
469
-
470
-
471
- def reset_file(file):
472
- return file.clear(reset_textbox, [], [])