Spaces:
Build error
Build error
import gradio as gr | |
import numpy as np | |
import json | |
from tts_api import TTSapi, DEFAULT_TTS_MODEL_NAME | |
from config import * | |
from utils import * | |
from knowledge_base import LocalRAG, CosPlayer | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
def handle_retry(history, thinking_history, config, section_state, retry_data: gr.RetryData): | |
# 获取用户之前的消息 | |
previous_message = history[retry_data.index]['content'] | |
# 清除后续的回复和思考过程 | |
new_history = history[:retry_data.index] | |
section_state['chat_history'] = section_state['chat_history'][:retry_data.index + 1] | |
try: | |
items = thinking_history.split('\n==================\n') | |
if len(items) > 2: | |
new_thinking_history = '\n==================\n'.join(items[:-2]) | |
else: | |
new_thinking_history = '' | |
items = section_state['thinking_history'].split('\n==================\n') | |
if len(items) > 2: | |
section_state['thinking_history'] = '\n==================\n'.join(items[:-2]) | |
else: | |
section_state['thinking_history'] = '' | |
except Exception as e: | |
print('-----------------------------------') | |
print(e) | |
print('-----------------------------------') | |
print('思考过程发生异常,重置为空') | |
section_state['thinking_history'] = '' | |
new_thinking_history = '' | |
# 重新生成回复 | |
return predict(previous_message, new_history, new_thinking_history, config, section_state) | |
def predict(message, chat_history, thinking_history, config, section_state): | |
global local_rag, TTS_LOADED, LLM_LOADED, synthesiser, core_llm, core_tokenizer | |
print(f"当前模式:{config['mode_selected']}") | |
print(f'角色扮演描述:{config["character_description"]}') | |
print(f"写入角色设定方式:{config['character_setting_mode']}") | |
print(f"选中LLM:{config['llm_model']}") | |
print(f"是否使用RAG本地知识库:{config['kb_on']}") | |
print(f"选中知识库:{config['current_knowledge_base']}") | |
print(f"是否联网搜索:{config['net_on']}") | |
print(f"选中TTS模型:{config['tts_model']}") | |
print(f"是否合成语音:{config['tts_on']}") | |
print(f"参考音频路径:{config['ref_audio']}") | |
print(f"参考音频文本:{config['ref_audio_transcribe']}") | |
context = '' | |
net_search_res = [] | |
docs = [] | |
if config['kb_on'] and len(config['current_knowledge_base']) > 0: | |
# 检索相似文档 | |
doc_and_scores = local_rag.vector_db.similarity_search(message, k=local_rag.rag_top_k) | |
# doc_and_scores = list(filter(lambda x: x[1] <= 0.4, doc_and_scores)) | |
if len(doc_and_scores) > 0: | |
docs, scores = list(zip(*doc_and_scores)) | |
docs, scores = list(docs), list(scores) | |
context_local = "【本地知识库】" + "\n".join([concate_metadata(d.metadata) + d.page_content for d in docs]) | |
context = context + context_local | |
if config['net_on']: | |
# 检索相似文档 | |
ret = web_search(message, max_results=MAX_RESULTS) | |
net_search_res = parse_net_search(ret) | |
context_net = "\n【网络搜索结果】" + ''.join(net_search_res) | |
context = context + context_net | |
if config['character_description']: | |
if config['character_setting_mode'] == 'by system': | |
if len(section_state['chat_history']) == 0 or section_state['chat_history'][0]['role'] != 'system': | |
section_state['chat_history'].insert(0, {"role": "system", "content": config["character_description"]}) | |
elif config['character_setting_mode'] == 'by prompt': | |
if len(section_state['chat_history']) > 0 and section_state['chat_history'][0]['role'] == 'system': | |
section_state['chat_history'].pop(0) | |
context = f'【系统核心设定】:{config["character_description"]}\n' if config["character_description"] else '' + context | |
else: | |
raise ValueError(f"未知的角色设定模式:{config['character_setting_mode']}") | |
if len(context) > 0: | |
prompt = f"""请充分理解以下上下文信息,并结合当前及历史对话产生回复':\n | |
上下文:{context} | |
用户当前输入:{message} | |
回复: | |
""" | |
input_message = section_state["chat_history"] + [{"role": "user", "content": prompt}] | |
else: | |
input_message = section_state["chat_history"] + [{"role": "user", "content": message}] | |
# 关闭Qwen3系列默认的思考模式 | |
if config['llm_model'].startswith('Qwen3'): | |
input_message[-1]['content'] += '/no_think' | |
# input_message[-1]['content'] += '/no_think' | |
# 添加用户消息到历史 | |
section_state["chat_history"].append({"role": "user", "content": message}) | |
try: | |
# 调用模型 | |
if not LLM_LOADED: | |
core_llm = AutoModelForCausalLM.from_pretrained( | |
config['llm_model'], | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
core_tokenizer = AutoTokenizer.from_pretrained(config['llm_model']) | |
LLM_LOADED = True | |
token_cnt = count_tokens_local(input_message, core_tokenizer) | |
if token_cnt >= MAX_MODEL_CTX: | |
gr.Warning("当前对话已经超出模型上下文长度,请开启新会话...") | |
text = core_tokenizer.apply_chat_template( | |
input_message, | |
tokenize=False, | |
add_generation_prompt=True, | |
enable_thinking=False | |
) | |
model_inputs = core_tokenizer([text], return_tensors="pt").to(core_llm.device) | |
# conduct text completion | |
generated_ids = core_llm.generate( | |
**model_inputs, | |
max_new_tokens=32768 | |
) | |
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() | |
# parsing thinking content | |
# try: | |
# # rindex finding 151668 (</think>) | |
# index = len(output_ids) - output_ids[::-1].index(151668) | |
# except ValueError: | |
# index = 0 | |
index = 0 | |
# thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n") | |
thinking = None | |
response_content = core_tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n") | |
print('回复:', response_content) | |
# 更新对话历史 | |
chat_history.append({'role': 'user', 'content': message}) | |
if len(context) > 0: | |
# 构建带折叠结构的消息 | |
formatted_response = f""" | |
<details class="rag-details"> | |
<summary style='cursor: pointer; color: #666;'> | |
🔍 检索完成✅(共{len(docs)+len(net_search_res)}条) | |
</summary> | |
<div style='margin:10px 0;padding:10px;background:#f5f5f5;border-radius:8px;'> | |
{ | |
"<br>".join( | |
["<br>".join(wash_up_content(content if isinstance(content, str) else (content.page_content, scores[idx]))) | |
for idx, content in enumerate(docs + net_search_res)] | |
) | |
} | |
</div> | |
</details> | |
<div style="margin-top: 10px;">{response_content}</div> <!-- 增加顶部间距容器 --> | |
""" | |
chat_history.append({'role': 'assistant', 'content': formatted_response}) | |
else: | |
chat_history.append({'role': 'assistant', 'content': response_content}) | |
thinking_history += f"User: {message}\nThinking: {thinking}" + '\n==================\n' | |
# 添加助手响应到历史 | |
section_state["chat_history"].append({"role": "assistant", "content": response_content}) | |
section_state["thinking_history"] += f"User: {message}\nThinking: {thinking}" + '\n==================\n' | |
if (not config['tts_on']) or len(response_content) == 0: | |
audio_output = np.array([0], dtype=np.int16) | |
if len(response_content) == 0: | |
print("LLM 回复为空,无法合成语音") | |
else: | |
if not TTS_LOADED: | |
print('TTS模型首次加载...') | |
gr.Info("初次加载TTS模型,请稍候..", duration=63) | |
synthesiser = TTSapi(model_name=config['tts_model']) | |
TTS_LOADED = True | |
print('加载完毕...') | |
# 检查当前模型是否是所选 | |
if config['tts_model'] != synthesiser.model_name: | |
print(f'当前TTS模型{synthesiser.model_name}非所选,重新加载') | |
synthesiser.reload(model_name=config['tts_model']) | |
# 如果提供了参考音频,则需把参考音频的文本加在response_content前面作为前缀 | |
if config['ref_audio']: | |
prompt_text = config['ref_audio_transcribe'] | |
if prompt_text is None: | |
# prompt_text = ... | |
raise NotImplementedError('暂时必须提供文本') # TODO:考虑后续加入ASR模型 | |
response_content = prompt_text + response_content | |
audio_output = synthesiser.forward(response_content, speech_prompt=config['ref_audio']) | |
except Exception as e: | |
print('!!!!!!!!') | |
print(e) | |
print('!!!!!!!!') | |
error_msg = f"Error: {str(e)}" | |
chat_history.append((message, error_msg)) | |
thinking_history += f"Error occurred: {str(e)}" + '\n' | |
return "", chat_history, thinking_history, (synthesiser.sr if synthesiser else 16000, audio_output) | |
def init_model(init_llm=True, init_rag=False, init_tts=False): | |
if init_llm: | |
print(f'正在加载LLM:{DEFAULT_MODEL_NAME}...') | |
core_llm = AutoModelForCausalLM.from_pretrained( | |
DEFAULT_MODEL_NAME, | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
print('device:', core_llm.device) | |
core_tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL_NAME) | |
LLM_LOADED = True | |
else: | |
core_llm, core_tokenizer = None, None | |
LLM_LOADED = False | |
if init_rag: | |
gr.Info("正在加载知识库,请稍候...") | |
local_rag = LocalRAG(rag_top_k=RAG_TOP_K) | |
else: | |
local_rag =None | |
if init_tts: | |
print(f'正在加载TTS模型:{DEFAULT_TTS_MODEL_NAME}...') | |
synthesiser = TTSapi() | |
TTS_LOADED = True | |
else: | |
synthesiser = None | |
TTS_LOADED = False | |
return local_rag, synthesiser, core_llm, core_tokenizer, TTS_LOADED, LLM_LOADED | |
if __name__ == "__main__": | |
import time | |
st = time.time() | |
print('********************模型加载中************************') | |
local_rag, synthesiser, core_llm, core_tokenizer, TTS_LOADED, LLM_LOADED = init_model() | |
print('********************模型加载完成************************') | |
print('耗时:',time.time() - st) | |
state = {} | |
resp, state = log_in(0, state) | |
cosplayer = CosPlayer(description_file=DEFAULT_COSPLAY_SETTING) | |
print("===== 初始化开始 =====") | |
with gr.Blocks(css=CSS, title="LLM Chat Demo", theme=gr.themes.Soft(font=["sans-serif", "Arial"])) as demo: | |
gr.Markdown(""" | |
# LLM Chat Demo | |
## 用法介绍 | |
### 用户登录 | |
* 输入用户名,点击Log In按钮。首次登录会自动创建用户目录,聊天记录会保存在下面,如不登录,默认为公共目录'0' | |
### 模型选择 | |
目前支持Qwen、Deepseek-R1蒸馏系列等部分模型,可下拉菜单选择 | |
### 高级设置 | |
* 模式选择:可以选择角色扮演模式/普通模式 | |
* 角色设定选择:支持加载不同角色设定文件 | |
* 角色配置方式: | |
* by system: 角色设定将作为system prompt存在于输入首部 | |
* by prompt: 角色设定每次被添加到当前上下文中 | |
* 知识库配置: 支持自由选择、组合知识库 | |
""") | |
section_state = gr.State(value=state) # 创建会话状态对象 | |
with gr.Row(): | |
uid_input = gr.Textbox(label="Type Your UID:") | |
response = gr.Textbox(label='', value=resp) | |
login_button = gr.Button("Log In") | |
llm_select = gr.Dropdown(label= "模型选择", choices=AVALIABLE_MODELS, value=DEFAULT_MODEL_NAME, visible=True) | |
gr.Markdown("## 高级设置") | |
with gr.Accordion("点击展开折叠", open=False, visible=True): | |
mode_select = gr.Radio(label='模式选择', choices=SUPPORT_MODES, value=DEFAULT_MODE) | |
coser_select = gr.Dropdown(label= "角色设定选择", choices=cosplayer.get_all_characters(), value=DEFAULT_COSPLAY_SETTING, visible=True) | |
coser_setting = gr.Radio(label='角色配置方式', choices=CHARACTER_SETTING_MODES, value=DEFAULT_C_SETTING_MODE, visible=True) | |
kb_select = gr.Dropdown(label= "知识库配置", choices=AVALIABLE_KNOWLEDGE_BASE, value=None, visible=True, multiselect=True) | |
with gr.Row(): | |
# 页面左侧 | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot(label="对话记录", height=500, show_copy_button=True, type='messages') | |
with gr.Row(): | |
msg = gr.Textbox(label="输入消息", placeholder="请输入您的问题...", scale=7) | |
with gr.Column(scale=1, min_width=15): | |
with gr.Row(): | |
rag_switch = gr.Checkbox(label="本地RAG", value=False, info="") | |
net_switch = gr.Checkbox(label="联网搜索", value=False, info="") | |
submit_btn = gr.Button("发送", variant="primary", min_width=15)#, , elem_classes=['custom-btn'] | |
with gr.Row(): | |
gr.Examples( | |
examples=[[example] for example in EXAMPLES], | |
inputs=msg, | |
outputs=chatbot, | |
fn=predict, | |
visible=True, | |
cache_examples=False | |
) | |
with gr.Row(): | |
save_btn = gr.Button("保存对话") | |
clear_btn = gr.Button("清空对话") | |
chat_history_select = gr.Dropdown(label='加载历史对话', choices=state['available_history'], visible=True, interactive=True) | |
# 页面右侧 | |
with gr.Column(scale=2): | |
thinking_display = gr.TextArea(label="思考过程",interactive=False, | |
placeholder="模型思考过程将在此显示..." | |
) | |
tts_switch = gr.Checkbox(label="TTS开关", value=False, info="Check me to hear voice") | |
with gr.Tabs() as audio_tabs: | |
# 选项卡1:音频播放 | |
with gr.Tab("音频输出", id="audio_output"): | |
audio_player = gr.Audio( | |
label="听听我声音~", | |
type="numpy", | |
interactive=False | |
) | |
# 选项卡2:TTS配置 | |
with gr.Tab("TTS配置", id="tts_config"): | |
# TTS模型选择 | |
tts_model = gr.Dropdown( | |
label="选择TTS模型", | |
choices=["ChatTTS", "HKUSTAudio/LLasa-1B", "HKUSTAudio/LLasa-3B", "HKUSTAudio/LLasa-8B"], | |
value=DEFAULT_TTS_MODEL_NAME, | |
interactive=True | |
) | |
# 参考音频上传 | |
ref_audio = gr.Audio( | |
label="上传参考音频", | |
type="filepath", | |
interactive=True | |
) | |
ref_audio_transcribe = gr.Textbox(label="参考音频对应文本", visible=True) | |
# ================= 状态管理 ================= | |
current_config = gr.State({ | |
"llm_model": DEFAULT_MODEL_NAME, | |
"tts_model": DEFAULT_TTS_MODEL_NAME, | |
"tts_on": False, | |
"kb_on": False, | |
"net_on": False, | |
"ref_audio": None, | |
"ref_audio_transcribe": None, | |
"mode_selected": DEFAULT_MODE, | |
"character_description": cosplayer.get_core_setting(), | |
"character_setting_mode": DEFAULT_C_SETTING_MODE, | |
"current_knowledge_base": AVALIABLE_KNOWLEDGE_BASE[0] | |
}) | |
# 事件处理 | |
login_button.click(log_in, inputs=[uid_input, section_state], outputs=[response, section_state]) | |
gr.on(triggers=[llm_select.change, tts_model.change, ref_audio.change, | |
ref_audio_transcribe.change, tts_switch.select, rag_switch.select, net_switch.select, | |
mode_select.change], | |
fn=lambda model1, model2, audio, text, tts_on, kb_on, net_on, mode, character_setting, kb_select: {"llm_model": model1, "tts_model": model2, "ref_audio": audio, | |
"ref_audio_transcribe": text, "tts_on": tts_on, "kb_on": kb_on, 'net_on': net_on, | |
"mode_selected": mode, "character_description": None if mode == '普通模式' else cosplayer.get_core_setting(), | |
"character_setting_mode": character_setting, "current_knowledge_base": kb_select}, | |
inputs=[llm_select, tts_model, ref_audio, ref_audio_transcribe, tts_switch, rag_switch, net_switch, mode_select, coser_setting, kb_select], | |
outputs=current_config | |
) | |
msg.submit( | |
predict, | |
[msg, chatbot, thinking_display, current_config, section_state], | |
[msg, chatbot, thinking_display, audio_player], | |
queue=False | |
) | |
chatbot.retry(fn=handle_retry, | |
inputs=[chatbot, thinking_display, current_config, section_state], | |
outputs=[msg, chatbot, thinking_display, audio_player]) | |
submit_btn.click( | |
predict, | |
[msg, chatbot, thinking_display, current_config, section_state], | |
[msg, chatbot, thinking_display, audio_player], | |
queue=False | |
) | |
def save_chat(state): | |
from datetime import datetime | |
now = datetime.now().strftime('%Y%m%d_%H%M%S') | |
with open(state['user_dir'] / f'chat_history_{now}.json', 'w', encoding='utf-8') as file: | |
json.dump(state["chat_history"], file, ensure_ascii=False, indent=4) | |
with open(state['user_dir'] / f'thinking_history_{now}.txt', 'w') as file: | |
if isinstance(state["thinking_history"], list): | |
for item in state["thinking_history"]: | |
file.write(item + '\n') | |
else: | |
file.write(state["thinking_history"]) | |
gr.Info("聊天记录已保存!") | |
state['available_history'].append(f'chat_history_{now}') | |
return state | |
def clear_chat(state): | |
state["chat_history"] = [] | |
state["thinking_history"] = [] | |
prologue = cosplayer.get_prologue() | |
if prologue: | |
state['chat_history'].append({'role': 'assistant', 'content': prologue}) | |
chatbot = [{'role': 'assistant', 'content': prologue}] | |
else: | |
chatbot = [] | |
return chatbot, [], state | |
def load_chat(state, chat_file): | |
# NOTE: 加载历史聊天记录。一般在对话开始之前加载,如果本次对话已经开始,本操作会覆盖当前会话内容 | |
if chat_file: | |
think_file = chat_file.replace("chat_", "thinking_") | |
chat_file_path = state['user_dir'] / (chat_file + '.json') | |
think_file_path = state['user_dir'] / (think_file + '.txt') | |
if not chat_file_path.exists(): | |
gr.Warning(f'聊天记录文件:{chat_file}.json不存在, 加载失败') | |
return [], '', state | |
with open(chat_file_path, 'r', encoding='utf-8') as f: | |
content = json.load(f) | |
state['chat_history'] = content | |
think = '' | |
if think_file_path.exists(): | |
with open(think_file_path, 'r') as f: | |
think = f.read() | |
state['thinking_history'] = think | |
# 转换成chatbot可以识别的格式 | |
# bot_content = parse_chat_history(content) | |
# 指定chatbot类型为message后,无需解析 | |
bot_content = content | |
return bot_content, think, state | |
return [], '', state | |
def update_history(state): | |
return gr.update(choices=state['available_history']) | |
def update_visible(mode): | |
if mode != '普通模式': | |
gr.Warning("当前为角色扮演模式,请确认已配置好该角色的知识库...") | |
return gr.update(visible=True), gr.update(visible=True) | |
return gr.update(visible=False), gr.update(visible=False) | |
def update_cosplay(cos_select, config, chatbot, think_display, state): | |
cosplayer.update(cos_select) | |
config['character_description'] = cosplayer.get_core_setting() | |
# 角色设定发生改变后,自动保存当前聊天记录,之后清空历史记录 | |
if len(state['chat_history']) > 1: | |
state = save_chat(state) | |
gr.Warning("我的角色已更换,对话已重置。请检查知识库是否需要更新...") | |
chatbot, think_display, state = clear_chat(state) | |
return gr.update(value=cos_select), config, chatbot, think_display, state | |
def update_character_setting_mode(coser_setting, config): | |
config['character_setting_mode'] = coser_setting | |
return gr.update(value=coser_setting), config | |
def update_knowledge_base(knowledge_base, config): | |
global local_rag | |
config['current_knowledge_base'] = knowledge_base | |
if len(knowledge_base) == 0: | |
gr.Warning("当前未选中任何知识库,本地RAG将失效。请确认...") | |
else: | |
if local_rag is None: | |
gr.Info("初次加载知识库,请稍候...") | |
local_rag = LocalRAG(rag_top_k=RAG_TOP_K, doc_dir=knowledge_base) | |
gr.Info("知识库加载完成!") | |
else: | |
gr.Info("重新加载知识库,请稍候...") | |
local_rag.reload_knowledge_base(knowledge_base) | |
gr.Info("知识库加载完成!") | |
return gr.update(value=knowledge_base), config | |
def init_kb(rag_on, kb_select, config): | |
global local_rag | |
if rag_on: | |
# 初始化本地知识库 | |
if config['mode_selected'] == "角色扮演": | |
gr.Warning("当前为角色扮演模式,请确认已配置好该角色的知识库...") | |
if local_rag is None: | |
gr.Info("初次加载知识库,请稍候...") | |
local_rag = LocalRAG(rag_top_k=RAG_TOP_K, doc_dir=kb_select) | |
gr.Info("知识库加载完成!") | |
return gr.update(value=rag_on) | |
# 选择非普通模式时(角色扮演),会展示可控选择的角色设定列表 | |
mode_select.change(update_visible, | |
inputs=mode_select, | |
outputs=[coser_select, coser_setting]) | |
coser_select.change(update_cosplay, | |
inputs=[coser_select, current_config, chatbot, thinking_display, section_state], | |
outputs=[coser_select, current_config, chatbot, thinking_display, section_state]) | |
# TODO: 根据角色变化动态展示示例 | |
# coser_select.change(update_examples, | |
# inputs=[coser_select], | |
# outputs=[examples_show]) | |
coser_setting.change(update_character_setting_mode, | |
inputs=[coser_setting, current_config], | |
outputs=[coser_setting, current_config]) | |
kb_select.change(update_knowledge_base, | |
inputs=[kb_select, current_config], | |
outputs=[kb_select, current_config]) | |
# 勾选本地知识库时,若为角色扮演模式,提醒用户设置知识库目录 | |
rag_switch.select(init_kb, inputs=[rag_switch, kb_select, current_config], outputs=rag_switch) | |
clear_btn.click( | |
clear_chat, | |
inputs=section_state, | |
outputs=[chatbot, thinking_display, section_state], | |
queue=False | |
) | |
save_btn.click( | |
save_chat, | |
inputs=section_state, | |
outputs=section_state, | |
queue=False | |
).then( | |
fn=update_history, | |
inputs=section_state, | |
outputs=chat_history_select | |
) | |
chat_history_select.change(load_chat, | |
inputs=[section_state, chat_history_select], | |
outputs=[chatbot, thinking_display, section_state]) | |
section_state.change(update_history, | |
inputs=section_state, | |
outputs=chat_history_select) | |
print("===== 初始化完成 =====") | |
demo.launch(share=False) | |