Tuchuanhuhuhu
feat: 训练支持预估 token 消耗
eb87ba4
raw
history blame
22.4 kB
# -*- coding:utf-8 -*-
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
import logging
import commentjson as json
import os
import datetime
import csv
import requests
import re
import html
import hashlib
import gradio as gr
from pypinyin import lazy_pinyin
import tiktoken
from markdown import markdown
from pygments import highlight
from pygments.lexers import get_lexer_by_name
from pygments.formatters import HtmlFormatter
import pandas as pd
from modules.presets import *
from . import shared
from modules.config import retrieve_proxy, hide_history_when_not_logged_in
if TYPE_CHECKING:
from typing import TypedDict
class DataframeData(TypedDict):
headers: List[str]
data: List[List[str | int | bool]]
def predict(current_model, *args):
iter = current_model.predict(*args)
for i in iter:
yield i
def billing_info(current_model):
return current_model.billing_info()
def set_key(current_model, *args):
return current_model.set_key(*args)
def load_chat_history(current_model, *args):
return current_model.load_chat_history(*args)
def delete_chat_history(current_model, *args):
return current_model.delete_chat_history(*args)
def interrupt(current_model, *args):
return current_model.interrupt(*args)
def reset(current_model, *args):
return current_model.reset(*args)
def retry(current_model, *args):
iter = current_model.retry(*args)
for i in iter:
yield i
def delete_first_conversation(current_model, *args):
return current_model.delete_first_conversation(*args)
def delete_last_conversation(current_model, *args):
return current_model.delete_last_conversation(*args)
def set_system_prompt(current_model, *args):
return current_model.set_system_prompt(*args)
def save_chat_history(current_model, *args):
return current_model.save_chat_history(*args)
def export_markdown(current_model, *args):
return current_model.export_markdown(*args)
def load_chat_history(current_model, *args):
return current_model.load_chat_history(*args)
def upload_chat_history(current_model, *args):
return current_model.load_chat_history(*args)
def set_token_upper_limit(current_model, *args):
return current_model.set_token_upper_limit(*args)
def set_temperature(current_model, *args):
current_model.set_temperature(*args)
def set_top_p(current_model, *args):
current_model.set_top_p(*args)
def set_n_choices(current_model, *args):
current_model.set_n_choices(*args)
def set_stop_sequence(current_model, *args):
current_model.set_stop_sequence(*args)
def set_max_tokens(current_model, *args):
current_model.set_max_tokens(*args)
def set_presence_penalty(current_model, *args):
current_model.set_presence_penalty(*args)
def set_frequency_penalty(current_model, *args):
current_model.set_frequency_penalty(*args)
def set_logit_bias(current_model, *args):
current_model.set_logit_bias(*args)
def set_user_identifier(current_model, *args):
current_model.set_user_identifier(*args)
def set_single_turn(current_model, *args):
current_model.set_single_turn(*args)
def handle_file_upload(current_model, *args):
return current_model.handle_file_upload(*args)
def handle_summarize_index(current_model, *args):
return current_model.summarize_index(*args)
def like(current_model, *args):
return current_model.like(*args)
def dislike(current_model, *args):
return current_model.dislike(*args)
def count_token(input_str):
encoding = tiktoken.get_encoding("cl100k_base")
if type(input_str) == dict:
input_str = f"role: {input_str['role']}, content: {input_str['content']}"
length = len(encoding.encode(input_str))
return length
def markdown_to_html_with_syntax_highlight(md_str): # deprecated
def replacer(match):
lang = match.group(1) or "text"
code = match.group(2)
try:
lexer = get_lexer_by_name(lang, stripall=True)
except ValueError:
lexer = get_lexer_by_name("text", stripall=True)
formatter = HtmlFormatter()
highlighted_code = highlight(code, lexer, formatter)
return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'
code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
html_str = markdown(md_str)
return html_str
def normalize_markdown(md_text: str) -> str: # deprecated
lines = md_text.split("\n")
normalized_lines = []
inside_list = False
for i, line in enumerate(lines):
if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
if not inside_list and i > 0 and lines[i - 1].strip() != "":
normalized_lines.append("")
inside_list = True
normalized_lines.append(line)
elif inside_list and line.strip() == "":
if i < len(lines) - 1 and not re.match(
r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
):
normalized_lines.append(line)
continue
else:
inside_list = False
normalized_lines.append(line)
return "\n".join(normalized_lines)
def convert_mdtext(md_text): # deprecated
code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
code_blocks = code_block_pattern.findall(md_text)
non_code_parts = code_block_pattern.split(md_text)[::2]
result = []
raw = f'<div class="raw-message hideM">{html.escape(md_text)}</div>'
for non_code, code in zip(non_code_parts, code_blocks + [""]):
if non_code.strip():
non_code = normalize_markdown(non_code)
result.append(markdown(non_code, extensions=["tables"]))
if code.strip():
# _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
# code = code.replace("\n\n", "\n") # 暂时去除代码中的空行,因为在大段代码的情况下会出现问题
code = f"\n```{code}\n\n```"
code = markdown_to_html_with_syntax_highlight(code)
result.append(code)
result = "".join(result)
output = f'<div class="md-message">{result}</div>'
output += raw
output += ALREADY_CONVERTED_MARK
return output
def clip_rawtext(chat_message, need_escape=True):
# first, clip hr line
hr_pattern = r'\n\n<hr class="append-display no-in-raw" />(.*?)'
hr_match = re.search(hr_pattern, chat_message, re.DOTALL)
message_clipped = chat_message[:hr_match.start()] if hr_match else chat_message
# second, avoid agent-prefix being escaped
agent_prefix_pattern = r'<!-- S O PREFIX --><p class="agent-prefix">(.*?)<\/p><!-- E O PREFIX -->'
agent_matches = re.findall(agent_prefix_pattern, message_clipped)
final_message = ""
if agent_matches:
agent_parts = re.split(agent_prefix_pattern, message_clipped)
for i, part in enumerate(agent_parts):
if i % 2 == 0:
final_message += escape_markdown(part) if need_escape else part
else:
final_message += f'<!-- S O PREFIX --><p class="agent-prefix">{part}</p><!-- E O PREFIX -->'
else:
final_message = escape_markdown(message_clipped) if need_escape else message_clipped
return final_message
def convert_bot_before_marked(chat_message):
"""
注意不能给输出加缩进, 否则会被marked解析成代码块
"""
if '<div class="md-message">' in chat_message:
return chat_message
else:
raw = f'<div class="raw-message hideM"><pre>{clip_rawtext(chat_message)}</pre></div>'
# really_raw = f'{START_OF_OUTPUT_MARK}<div class="really-raw hideM">{clip_rawtext(chat_message, need_escape=False)}\n</div>{END_OF_OUTPUT_MARK}'
code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
code_blocks = code_block_pattern.findall(chat_message)
non_code_parts = code_block_pattern.split(chat_message)[::2]
result = []
for non_code, code in zip(non_code_parts, code_blocks + [""]):
if non_code.strip():
result.append(non_code)
if code.strip():
code = f"\n```{code}\n```"
result.append(code)
result = "".join(result)
md = f'<div class="md-message">\n\n{result}\n</div>'
return raw + md
def convert_user_before_marked(chat_message):
if '<div class="user-message">' in chat_message:
return chat_message
else:
return f'<div class="user-message">{escape_markdown(chat_message)}</div>'
def escape_markdown(text):
"""
Escape Markdown special characters to HTML-safe equivalents.
"""
escape_chars = {
# ' ': '&nbsp;',
'_': '&#95;',
'*': '&#42;',
'[': '&#91;',
']': '&#93;',
'(': '&#40;',
')': '&#41;',
'{': '&#123;',
'}': '&#125;',
'#': '&#35;',
'+': '&#43;',
'-': '&#45;',
'.': '&#46;',
'!': '&#33;',
'`': '&#96;',
'>': '&#62;',
'<': '&#60;',
'|': '&#124;',
'$': '&#36;',
':': '&#58;',
'\n': '<br>',
}
text = text.replace(' ', '&nbsp;&nbsp;&nbsp;&nbsp;')
return ''.join(escape_chars.get(c, c) for c in text)
def convert_asis(userinput): # deprecated
return (
f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>'
+ ALREADY_CONVERTED_MARK
)
def detect_converted_mark(userinput): # deprecated
try:
if userinput.endswith(ALREADY_CONVERTED_MARK):
return True
else:
return False
except:
return True
def detect_language(code): # deprecated
if code.startswith("\n"):
first_line = ""
else:
first_line = code.strip().split("\n", 1)[0]
language = first_line.lower() if first_line else ""
code_without_language = code[len(first_line) :].lstrip() if first_line else code
return language, code_without_language
def construct_text(role, text):
return {"role": role, "content": text}
def construct_user(text):
return construct_text("user", text)
def construct_system(text):
return construct_text("system", text)
def construct_assistant(text):
return construct_text("assistant", text)
def save_file(filename, system, history, chatbot, user_name):
logging.debug(f"{user_name} 保存对话历史中……")
os.makedirs(os.path.join(HISTORY_DIR, user_name), exist_ok=True)
if filename.endswith(".json"):
json_s = {"system": system, "history": history, "chatbot": chatbot}
if "/" in filename or "\\" in filename:
history_file_path = filename
else:
history_file_path = os.path.join(HISTORY_DIR, user_name, filename)
with open(history_file_path, "w", encoding='utf-8') as f:
json.dump(json_s, f, ensure_ascii=False)
elif filename.endswith(".md"):
md_s = f"system: \n- {system} \n"
for data in history:
md_s += f"\n{data['role']}: \n- {data['content']} \n"
with open(os.path.join(HISTORY_DIR, user_name, filename), "w", encoding="utf8") as f:
f.write(md_s)
logging.debug(f"{user_name} 保存对话历史完毕")
return os.path.join(HISTORY_DIR, user_name, filename)
def sorted_by_pinyin(list):
return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
def get_file_names(dir, plain=False, filetypes=[".json"]):
logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
files = []
try:
for type in filetypes:
files += [f for f in os.listdir(dir) if f.endswith(type)]
except FileNotFoundError:
files = []
files = sorted_by_pinyin(files)
if files == []:
files = [""]
logging.debug(f"files are:{files}")
if plain:
return files
else:
return gr.Dropdown.update(choices=files)
def get_history_names(plain=False, user_name=""):
logging.debug(f"从用户 {user_name} 中获取历史记录文件名列表")
if user_name == "" and hide_history_when_not_logged_in:
return ""
else:
return get_file_names(os.path.join(HISTORY_DIR, user_name), plain)
def load_template(filename, mode=0):
logging.debug(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
lines = []
if filename.endswith(".json"):
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
lines = json.load(f)
lines = [[i["act"], i["prompt"]] for i in lines]
else:
with open(
os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8"
) as csvfile:
reader = csv.reader(csvfile)
lines = list(reader)
lines = lines[1:]
if mode == 1:
return sorted_by_pinyin([row[0] for row in lines])
elif mode == 2:
return {row[0]: row[1] for row in lines}
else:
choices = sorted_by_pinyin([row[0] for row in lines])
return {row[0]: row[1] for row in lines}, gr.Dropdown.update(
choices=choices
)
def get_template_names(plain=False):
logging.debug("获取模板文件名列表")
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
def get_template_content(templates, selection, original_system_prompt):
logging.debug(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
try:
return templates[selection]
except:
return original_system_prompt
def reset_textbox():
logging.debug("重置文本框")
return gr.update(value="")
def reset_default():
default_host = shared.state.reset_api_host()
retrieve_proxy("")
return gr.update(value=default_host), gr.update(value=""), "API-Host 和代理已重置"
def change_api_host(host):
shared.state.set_api_host(host)
msg = f"API-Host更改为了{host}"
logging.info(msg)
return msg
def change_proxy(proxy):
retrieve_proxy(proxy)
os.environ["HTTPS_PROXY"] = proxy
msg = f"代理更改为了{proxy}"
logging.info(msg)
return msg
def hide_middle_chars(s):
if s is None:
return ""
if len(s) <= 8:
return s
else:
head = s[:4]
tail = s[-4:]
hidden = "*" * (len(s) - 8)
return head + hidden + tail
def submit_key(key):
key = key.strip()
msg = f"API密钥更改为了{hide_middle_chars(key)}"
logging.info(msg)
return key, msg
def replace_today(prompt):
today = datetime.datetime.today().strftime("%Y-%m-%d")
return prompt.replace("{current_date}", today)
def get_geoip():
try:
with retrieve_proxy():
response = requests.get("https://ipapi.co/json/", timeout=5)
data = response.json()
except:
data = {"error": True, "reason": "连接ipapi失败"}
if "error" in data.keys():
logging.warning(f"无法获取IP地址信息。\n{data}")
if data["reason"] == "RateLimited":
return (
i18n("您的IP区域:未知。")
)
else:
return i18n("获取IP地理位置失败。原因:") + f"{data['reason']}" + i18n("。你仍然可以使用聊天功能。")
else:
country = data["country_name"]
if country == "China":
text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
else:
text = i18n("您的IP区域:") + f"{country}。"
logging.info(text)
return text
def find_n(lst, max_num):
n = len(lst)
total = sum(lst)
if total < max_num:
return n
for i in range(len(lst)):
if total - lst[i] < max_num:
return n - i - 1
total = total - lst[i]
return 1
def start_outputing():
logging.debug("显示取消按钮,隐藏发送按钮")
return gr.Button.update(visible=False), gr.Button.update(visible=True)
def end_outputing():
return (
gr.Button.update(visible=True),
gr.Button.update(visible=False),
)
def cancel_outputing():
logging.info("中止输出……")
shared.state.interrupt()
def transfer_input(inputs):
# 一次性返回,降低延迟
textbox = reset_textbox()
outputing = start_outputing()
return (
inputs,
gr.update(value=""),
gr.Button.update(visible=False),
gr.Button.update(visible=True),
)
def update_chuanhu():
from .repo import background_update
print("[Updater] Trying to update...")
update_status = background_update()
if update_status == "success":
logging.info("Successfully updated, restart needed")
status = '<span id="update-status" class="hideK">success</span>'
return gr.Markdown.update(value=i18n("更新成功,请重启本程序")+status)
else:
status = '<span id="update-status" class="hideK">failure</span>'
return gr.Markdown.update(value=i18n("更新失败,请尝试[手动更新](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用教程#手动更新)")+status)
def add_source_numbers(lst, source_name = "Source", use_source = True):
if use_source:
return [f'[{idx+1}]\t "{item[0]}"\n{source_name}: {item[1]}' for idx, item in enumerate(lst)]
else:
return [f'[{idx+1}]\t "{item}"' for idx, item in enumerate(lst)]
def add_details(lst):
nodes = []
for index, txt in enumerate(lst):
brief = txt[:25].replace("\n", "")
nodes.append(
f"<details><summary>{brief}...</summary><p>{txt}</p></details>"
)
return nodes
def sheet_to_string(sheet, sheet_name = None):
result = []
for index, row in sheet.iterrows():
row_string = ""
for column in sheet.columns:
row_string += f"{column}: {row[column]}, "
row_string = row_string.rstrip(", ")
row_string += "."
result.append(row_string)
return result
def excel_to_string(file_path):
# 读取Excel文件中的所有工作表
excel_file = pd.read_excel(file_path, engine='openpyxl', sheet_name=None)
# 初始化结果字符串
result = []
# 遍历每一个工作表
for sheet_name, sheet_data in excel_file.items():
# 处理当前工作表并添加到结果字符串
result += sheet_to_string(sheet_data, sheet_name=sheet_name)
return result
def get_last_day_of_month(any_day):
# The day 28 exists in every month. 4 days later, it's always next month
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
# subtracting the number of the current day brings us back one month
return next_month - datetime.timedelta(days=next_month.day)
def get_model_source(model_name, alternative_source):
if model_name == "gpt2-medium":
return "https://huggingface.co/gpt2-medium"
def refresh_ui_elements_on_load(current_model, selected_model_name, user_name):
current_model.set_user_identifier(user_name)
return toggle_like_btn_visibility(selected_model_name), *current_model.auto_load()
def toggle_like_btn_visibility(selected_model_name):
if selected_model_name == "xmchat":
return gr.update(visible=True)
else:
return gr.update(visible=False)
def new_auto_history_filename(dirname):
latest_file = get_latest_filepath(dirname)
if latest_file:
with open(os.path.join(dirname, latest_file), 'r', encoding="utf-8") as f:
if len(f.read()) == 0:
return latest_file
now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
return f'{now}.json'
def get_latest_filepath(dirname):
pattern = re.compile(r'\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}')
latest_time = None
latest_file = None
for filename in os.listdir(dirname):
if os.path.isfile(os.path.join(dirname, filename)):
match = pattern.search(filename)
if match and match.group(0) == filename[:19]:
time_str = filename[:19]
filetime = datetime.datetime.strptime(time_str, '%Y-%m-%d_%H-%M-%S')
if not latest_time or filetime > latest_time:
latest_time = filetime
latest_file = filename
return latest_file
def get_history_filepath(username):
dirname = os.path.join(HISTORY_DIR, username)
os.makedirs(dirname, exist_ok=True)
latest_file = get_latest_filepath(dirname)
if not latest_file:
latest_file = new_auto_history_filename(dirname)
latest_file = os.path.join(dirname, latest_file)
return latest_file
def beautify_err_msg(err_msg):
if "insufficient_quota" in err_msg:
return i18n("剩余配额不足,[进一步了解](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/%E5%B8%B8%E8%A7%81%E9%97%AE%E9%A2%98#you-exceeded-your-current-quota-please-check-your-plan-and-billing-details)")
if "The model: gpt-4 does not exist" in err_msg:
return i18n("你没有权限访问 GPT4,[进一步了解](https://github.com/GaiZhenbiao/ChuanhuChatGPT/issues/843)")
if "Resource not found" in err_msg:
return i18n("请查看 config_example.json,配置 Azure OpenAI")
return err_msg
def auth_from_conf(username, password):
try:
with open("config.json", encoding="utf-8") as f:
conf = json.load(f)
usernames, passwords = [i[0] for i in conf["users"]], [i[1] for i in conf["users"]]
if username in usernames:
if passwords[usernames.index(username)] == password:
return True
return False
except:
return False
def get_file_hash(file_src=None, file_paths=None):
if file_src:
file_paths = [x.name for x in file_src]
file_paths.sort(key=lambda x: os.path.basename(x))
md5_hash = hashlib.md5()
for file_path in file_paths:
with open(file_path, "rb") as f:
while chunk := f.read(8192):
md5_hash.update(chunk)
return md5_hash.hexdigest()