# -*- coding:utf-8 -*- from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type import logging import commentjson as json import os import datetime import csv import requests import re import html import hashlib import gradio as gr from pypinyin import lazy_pinyin import tiktoken from markdown import markdown from pygments import highlight from pygments.lexers import get_lexer_by_name from pygments.formatters import HtmlFormatter import pandas as pd from modules.presets import * from . import shared from modules.config import retrieve_proxy, hide_history_when_not_logged_in if TYPE_CHECKING: from typing import TypedDict class DataframeData(TypedDict): headers: List[str] data: List[List[str | int | bool]] def predict(current_model, *args): iter = current_model.predict(*args) for i in iter: yield i def billing_info(current_model): return current_model.billing_info() def set_key(current_model, *args): return current_model.set_key(*args) def load_chat_history(current_model, *args): return current_model.load_chat_history(*args) def delete_chat_history(current_model, *args): return current_model.delete_chat_history(*args) def interrupt(current_model, *args): return current_model.interrupt(*args) def reset(current_model, *args): return current_model.reset(*args) def retry(current_model, *args): iter = current_model.retry(*args) for i in iter: yield i def delete_first_conversation(current_model, *args): return current_model.delete_first_conversation(*args) def delete_last_conversation(current_model, *args): return current_model.delete_last_conversation(*args) def set_system_prompt(current_model, *args): return current_model.set_system_prompt(*args) def rename_chat_history(current_model, *args): return current_model.rename_chat_history(*args) def auto_name_chat_history(current_model, *args): return current_model.auto_name_chat_history(*args) def export_markdown(current_model, *args): return current_model.export_markdown(*args) def upload_chat_history(current_model, *args): return current_model.load_chat_history(*args) def set_token_upper_limit(current_model, *args): return current_model.set_token_upper_limit(*args) def set_temperature(current_model, *args): current_model.set_temperature(*args) def set_top_p(current_model, *args): current_model.set_top_p(*args) def set_n_choices(current_model, *args): current_model.set_n_choices(*args) def set_stop_sequence(current_model, *args): current_model.set_stop_sequence(*args) def set_max_tokens(current_model, *args): current_model.set_max_tokens(*args) def set_presence_penalty(current_model, *args): current_model.set_presence_penalty(*args) def set_frequency_penalty(current_model, *args): current_model.set_frequency_penalty(*args) def set_logit_bias(current_model, *args): current_model.set_logit_bias(*args) def set_user_identifier(current_model, *args): current_model.set_user_identifier(*args) def set_single_turn(current_model, *args): current_model.set_single_turn(*args) def handle_file_upload(current_model, *args): return current_model.handle_file_upload(*args) def handle_summarize_index(current_model, *args): return current_model.summarize_index(*args) def like(current_model, *args): return current_model.like(*args) def dislike(current_model, *args): return current_model.dislike(*args) def count_token(input_str): encoding = tiktoken.get_encoding("cl100k_base") if type(input_str) == dict: input_str = f"role: {input_str['role']}, content: {input_str['content']}" length = len(encoding.encode(input_str)) return length def markdown_to_html_with_syntax_highlight(md_str): # deprecated def replacer(match): lang = match.group(1) or "text" code = match.group(2) try: lexer = get_lexer_by_name(lang, stripall=True) except ValueError: lexer = get_lexer_by_name("text", stripall=True) formatter = HtmlFormatter() highlighted_code = highlight(code, lexer, formatter) return f'
{highlighted_code}
' code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```" md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE) html_str = markdown(md_str) return html_str def normalize_markdown(md_text: str) -> str: # deprecated lines = md_text.split("\n") normalized_lines = [] inside_list = False for i, line in enumerate(lines): if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()): if not inside_list and i > 0 and lines[i - 1].strip() != "": normalized_lines.append("") inside_list = True normalized_lines.append(line) elif inside_list and line.strip() == "": if i < len(lines) - 1 and not re.match( r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip() ): normalized_lines.append(line) continue else: inside_list = False normalized_lines.append(line) return "\n".join(normalized_lines) def convert_mdtext(md_text): # deprecated code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL) inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL) code_blocks = code_block_pattern.findall(md_text) non_code_parts = code_block_pattern.split(md_text)[::2] result = [] raw = f'
{html.escape(md_text)}
' for non_code, code in zip(non_code_parts, code_blocks + [""]): if non_code.strip(): non_code = normalize_markdown(non_code) result.append(markdown(non_code, extensions=["tables"])) if code.strip(): # _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题 # code = code.replace("\n\n", "\n") # 暂时去除代码中的空行,因为在大段代码的情况下会出现问题 code = f"\n```{code}\n\n```" code = markdown_to_html_with_syntax_highlight(code) result.append(code) result = "".join(result) output = f'
{result}
' output += raw output += ALREADY_CONVERTED_MARK return output def clip_rawtext(chat_message, need_escape=True): # first, clip hr line hr_pattern = r'\n\n
(.*?)' hr_match = re.search(hr_pattern, chat_message, re.DOTALL) message_clipped = chat_message[:hr_match.start()] if hr_match else chat_message # second, avoid agent-prefix being escaped agent_prefix_pattern = r'

(.*?)<\/p>' agent_matches = re.findall(agent_prefix_pattern, message_clipped) final_message = "" if agent_matches: agent_parts = re.split(agent_prefix_pattern, message_clipped) for i, part in enumerate(agent_parts): if i % 2 == 0: final_message += escape_markdown(part) if need_escape else part else: final_message += f'

{part}

' else: final_message = escape_markdown(message_clipped) if need_escape else message_clipped return final_message def convert_bot_before_marked(chat_message): """ 注意不能给输出加缩进, 否则会被marked解析成代码块 """ if '
' in chat_message: return chat_message else: raw = f'
{clip_rawtext(chat_message)}
' # really_raw = f'{START_OF_OUTPUT_MARK}
{clip_rawtext(chat_message, need_escape=False)}\n
{END_OF_OUTPUT_MARK}' code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL) code_blocks = code_block_pattern.findall(chat_message) non_code_parts = code_block_pattern.split(chat_message)[::2] result = [] for non_code, code in zip(non_code_parts, code_blocks + [""]): if non_code.strip(): result.append(non_code) if code.strip(): code = f"\n```{code}\n```" result.append(code) result = "".join(result) md = f'
\n\n{result}\n
' return raw + md def convert_user_before_marked(chat_message): if '
' in chat_message: return chat_message else: return f'
{escape_markdown(chat_message)}
' def escape_markdown(text): """ Escape Markdown special characters to HTML-safe equivalents. """ escape_chars = { # ' ': ' ', '_': '_', '*': '*', '[': '[', ']': ']', '(': '(', ')': ')', '{': '{', '}': '}', '#': '#', '+': '+', '-': '-', '.': '.', '!': '!', '`': '`', '>': '>', '<': '<', '|': '|', '$': '$', ':': ':', '\n': '
', } text = text.replace(' ', '    ') return ''.join(escape_chars.get(c, c) for c in text) def convert_asis(userinput): # deprecated return ( f'

{html.escape(userinput)}

' + ALREADY_CONVERTED_MARK ) def detect_converted_mark(userinput): # deprecated try: if userinput.endswith(ALREADY_CONVERTED_MARK): return True else: return False except: return True def detect_language(code): # deprecated if code.startswith("\n"): first_line = "" else: first_line = code.strip().split("\n", 1)[0] language = first_line.lower() if first_line else "" code_without_language = code[len(first_line) :].lstrip() if first_line else code return language, code_without_language def construct_text(role, text): return {"role": role, "content": text} def construct_user(text): return construct_text("user", text) def construct_system(text): return construct_text("system", text) def construct_assistant(text): return construct_text("assistant", text) def save_file(filename, system, history, chatbot, user_name): os.makedirs(os.path.join(HISTORY_DIR, user_name), exist_ok=True) if filename is None: filename = new_auto_history_filename(user_name) if filename.endswith(".md"): filename = filename[:-3] if not filename.endswith(".json") and not filename.endswith(".md"): filename += ".json" if filename == ".json": raise Exception("文件名不能为空") json_s = {"system": system, "history": history, "chatbot": chatbot} repeat_file_index = 2 if not filename == os.path.basename(filename): history_file_path = filename else: history_file_path = os.path.join(HISTORY_DIR, user_name, filename) with open(history_file_path, "w", encoding='utf-8') as f: json.dump(json_s, f, ensure_ascii=False) filename = os.path.basename(filename) filename_md = filename[:-5] + ".md" md_s = f"system: \n- {system} \n" for data in history: md_s += f"\n{data['role']}: \n- {data['content']} \n" with open(os.path.join(HISTORY_DIR, user_name, filename_md), "w", encoding="utf8") as f: f.write(md_s) return os.path.join(HISTORY_DIR, user_name, filename) def sorted_by_pinyin(list): return sorted(list, key=lambda char: lazy_pinyin(char)[0][0]) def sorted_by_last_modified_time(list, dir): return sorted(list, key=lambda char: os.path.getctime(os.path.join(dir, char)), reverse=True) def get_file_names_by_type(dir, filetypes=[".json"]): logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes}") files = [] for type in filetypes: files += [f for f in os.listdir(dir) if f.endswith(type)] logging.debug(f"files are:{files}") return files def get_file_names_by_pinyin(dir, filetypes=[".json"]): files = get_file_names_by_type(dir, filetypes) if files != [""]: files = sorted_by_pinyin(files) logging.debug(f"files are:{files}") return files def get_file_names_dropdown_by_pinyin(dir, filetypes=[".json"]): files = get_file_names_by_pinyin(dir, filetypes) return gr.Dropdown.update(choices=files) def get_file_names_by_last_modified_time(dir, filetypes=[".json"]): files = get_file_names_by_type(dir, filetypes) if files != [""]: files = sorted_by_last_modified_time(files, dir) logging.debug(f"files are:{files}") return files def get_history_names(user_name=""): logging.debug(f"从用户 {user_name} 中获取历史记录文件名列表") if user_name == "" and hide_history_when_not_logged_in: return [] else: history_files = get_file_names_by_last_modified_time(os.path.join(HISTORY_DIR, user_name)) history_files = [f[:f.rfind(".")] for f in history_files] return history_files def get_first_history_name(user_name=""): history_names = get_history_names(user_name) return history_names[0] if history_names else None def get_history_list(user_name=""): history_names = get_history_names(user_name) return gr.Radio.update(choices=history_names) def init_history_list(user_name=""): history_names = get_history_names(user_name) return gr.Radio.update(choices=history_names, value=history_names[0] if history_names else "") def filter_history(user_name, keyword): history_names = get_history_names(user_name) try: history_names = [name for name in history_names if re.search(keyword, name)] return gr.update(choices=history_names) except: return gr.update(choices=history_names) def load_template(filename, mode=0): logging.debug(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)") lines = [] if filename.endswith(".json"): with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f: lines = json.load(f) lines = [[i["act"], i["prompt"]] for i in lines] else: with open( os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8" ) as csvfile: reader = csv.reader(csvfile) lines = list(reader) lines = lines[1:] if mode == 1: return sorted_by_pinyin([row[0] for row in lines]) elif mode == 2: return {row[0]: row[1] for row in lines} else: choices = sorted_by_pinyin([row[0] for row in lines]) return {row[0]: row[1] for row in lines}, gr.Dropdown.update( choices=choices ) def get_template_names(): logging.debug("获取模板文件名列表") return get_file_names_by_pinyin(TEMPLATES_DIR, filetypes=[".csv", "json"]) def get_template_dropdown(): logging.debug("获取模板下拉菜单") template_names = get_template_names() return gr.Dropdown.update(choices=template_names) def get_template_content(templates, selection, original_system_prompt): logging.debug(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}") try: return templates[selection] except: return original_system_prompt def reset_textbox(): logging.debug("重置文本框") return gr.update(value="") def reset_default(): default_host = shared.state.reset_api_host() retrieve_proxy("") return gr.update(value=default_host), gr.update(value=""), "API-Host 和代理已重置" def change_api_host(host): shared.state.set_api_host(host) msg = f"API-Host更改为了{host}" logging.info(msg) return msg def change_proxy(proxy): retrieve_proxy(proxy) os.environ["HTTPS_PROXY"] = proxy msg = f"代理更改为了{proxy}" logging.info(msg) return msg def hide_middle_chars(s): if s is None: return "" if len(s) <= 8: return s else: head = s[:4] tail = s[-4:] hidden = "*" * (len(s) - 8) return head + hidden + tail def submit_key(key): key = key.strip() msg = f"API密钥更改为了{hide_middle_chars(key)}" logging.info(msg) return key, msg def replace_today(prompt): today = datetime.datetime.today().strftime("%Y-%m-%d") return prompt.replace("{current_date}", today) def get_geoip(): try: with retrieve_proxy(): response = requests.get("https://ipapi.co/json/", timeout=5) data = response.json() except: data = {"error": True, "reason": "连接ipapi失败"} if "error" in data.keys(): logging.warning(f"无法获取IP地址信息。\n{data}") if data["reason"] == "RateLimited": return ( i18n("您的IP区域:未知。") ) else: return i18n("获取IP地理位置失败。原因:") + f"{data['reason']}" + i18n("。你仍然可以使用聊天功能。") else: country = data["country_name"] if country == "China": text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**" else: text = i18n("您的IP区域:") + f"{country}。" logging.info(text) return text def find_n(lst, max_num): n = len(lst) total = sum(lst) if total < max_num: return n for i in range(len(lst)): if total - lst[i] < max_num: return n - i - 1 total = total - lst[i] return 1 def start_outputing(): logging.debug("显示取消按钮,隐藏发送按钮") return gr.Button.update(visible=False), gr.Button.update(visible=True) def end_outputing(): return ( gr.Button.update(visible=True), gr.Button.update(visible=False), ) def cancel_outputing(): logging.info("中止输出……") shared.state.interrupt() def transfer_input(inputs): # 一次性返回,降低延迟 textbox = reset_textbox() outputing = start_outputing() return ( inputs, gr.update(value=""), gr.Button.update(visible=False), gr.Button.update(visible=True), ) def update_chuanhu(): from .repo import background_update print("[Updater] Trying to update...") update_status = background_update() if update_status == "success": logging.info("Successfully updated, restart needed") status = 'success' return gr.Markdown.update(value=i18n("更新成功,请重启本程序")+status) else: status = 'failure' return gr.Markdown.update(value=i18n("更新失败,请尝试[手动更新](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用教程#手动更新)")+status) def add_source_numbers(lst, source_name = "Source", use_source = True): if use_source: return [f'[{idx+1}]\t "{item[0]}"\n{source_name}: {item[1]}' for idx, item in enumerate(lst)] else: return [f'[{idx+1}]\t "{item}"' for idx, item in enumerate(lst)] def add_details(lst): nodes = [] for index, txt in enumerate(lst): brief = txt[:25].replace("\n", "") nodes.append( f"
{brief}...

{txt}

" ) return nodes def sheet_to_string(sheet, sheet_name = None): result = [] for index, row in sheet.iterrows(): row_string = "" for column in sheet.columns: row_string += f"{column}: {row[column]}, " row_string = row_string.rstrip(", ") row_string += "." result.append(row_string) return result def excel_to_string(file_path): # 读取Excel文件中的所有工作表 excel_file = pd.read_excel(file_path, engine='openpyxl', sheet_name=None) # 初始化结果字符串 result = [] # 遍历每一个工作表 for sheet_name, sheet_data in excel_file.items(): # 处理当前工作表并添加到结果字符串 result += sheet_to_string(sheet_data, sheet_name=sheet_name) return result def get_last_day_of_month(any_day): # The day 28 exists in every month. 4 days later, it's always next month next_month = any_day.replace(day=28) + datetime.timedelta(days=4) # subtracting the number of the current day brings us back one month return next_month - datetime.timedelta(days=next_month.day) def get_model_source(model_name, alternative_source): if model_name == "gpt2-medium": return "https://huggingface.co/gpt2-medium" def refresh_ui_elements_on_load(current_model, selected_model_name, user_name): current_model.set_user_identifier(user_name) return toggle_like_btn_visibility(selected_model_name), *current_model.auto_load() def toggle_like_btn_visibility(selected_model_name): if selected_model_name == "xmchat": return gr.update(visible=True) else: return gr.update(visible=False) def new_auto_history_filename(username): latest_file = get_first_history_name(username) if latest_file: with open(os.path.join(HISTORY_DIR, username, latest_file + ".json"), 'r', encoding="utf-8") as f: if len(f.read()) == 0: return latest_file now = i18n("新对话 ") + datetime.datetime.now().strftime('%m-%d %H-%M') return f'{now}.json' def get_history_filepath(username): dirname = os.path.join(HISTORY_DIR, username) os.makedirs(dirname, exist_ok=True) latest_file = get_first_history_name(username) if not latest_file: latest_file = new_auto_history_filename(username) latest_file = os.path.join(dirname, latest_file) return latest_file def beautify_err_msg(err_msg): if "insufficient_quota" in err_msg: return i18n("剩余配额不足,[进一步了解](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/%E5%B8%B8%E8%A7%81%E9%97%AE%E9%A2%98#you-exceeded-your-current-quota-please-check-your-plan-and-billing-details)") if "The model `gpt-4` does not exist" in err_msg: return i18n("你没有权限访问 GPT4,[进一步了解](https://github.com/GaiZhenbiao/ChuanhuChatGPT/issues/843)") if "Resource not found" in err_msg: return i18n("请查看 config_example.json,配置 Azure OpenAI") return err_msg def auth_from_conf(username, password): try: with open("config.json", encoding="utf-8") as f: conf = json.load(f) usernames, passwords = [i[0] for i in conf["users"]], [i[1] for i in conf["users"]] if username in usernames: if passwords[usernames.index(username)] == password: return True return False except: return False def get_file_hash(file_src=None, file_paths=None): if file_src: file_paths = [x.name for x in file_src] file_paths.sort(key=lambda x: os.path.basename(x)) md5_hash = hashlib.md5() for file_path in file_paths: with open(file_path, "rb") as f: while chunk := f.read(8192): md5_hash.update(chunk) return md5_hash.hexdigest() def myprint(**args): print(args) def replace_special_symbols(string, replace_string=" "): # 定义正则表达式,匹配所有特殊符号 pattern = r'[!@#$%^&*()<>?/\|}{~:]' new_string = re.sub(pattern, replace_string, string) return new_string