|
import logging |
|
|
|
import gradio as gr |
|
|
|
from src import ChatWorld |
|
|
|
chatWorld = ChatWorld() |
|
|
|
role_name_list_global = None |
|
role_name_dict_global = None |
|
|
|
Meta = { |
|
"uuid":"111" |
|
} |
|
|
|
|
|
def getContent(input_file): |
|
|
|
with open(input_file.name, "r", encoding="utf-8") as f: |
|
logging.info(f"read file {input_file.name}") |
|
input_text = f.read() |
|
logging.info(f"file content: {input_text}") |
|
|
|
chatWorld.setStory(stories=input_text, metas=None) |
|
|
|
|
|
role_name_list, role_name_dict = chatWorld.getRoleNameFromFile(input_text) |
|
|
|
global role_name_list_global |
|
role_name_list_global = role_name_list |
|
global role_name_dict_global |
|
role_name_dict_global = role_name_dict |
|
|
|
return ( |
|
gr.Radio(choices=role_name_list, interactive=True), |
|
gr.Radio(choices=role_name_list, interactive=True), |
|
) |
|
|
|
|
|
def submit_message( |
|
message, |
|
history, |
|
model_role_name, |
|
role_name, |
|
model_role_nickname, |
|
role_nickname, |
|
withCharacter, |
|
): |
|
if withCharacter: |
|
response = chatWorld.chatWithCharacter( |
|
text=message, |
|
role_name=role_name, |
|
role_nickname=role_nickname, |
|
model_role_name=model_role_name, |
|
model_role_nickname=model_role_nickname, |
|
use_local_model=True, |
|
) |
|
else: |
|
response = chatWorld.chatWithoutCharacter( |
|
text=message, |
|
use_local_model=True, |
|
) |
|
return response |
|
|
|
|
|
def submit_message_api( |
|
message, |
|
history, |
|
model_role_name, |
|
role_name, |
|
model_role_nickname, |
|
role_nickname, |
|
withCharacter, |
|
): |
|
if withCharacter: |
|
response = chatWorld.chatWithCharacter( |
|
text=message, |
|
role_name=role_name, |
|
role_nickname=role_nickname, |
|
model_role_name=model_role_name, |
|
model_role_nickname=model_role_nickname, |
|
use_local_model=False, |
|
) |
|
else: |
|
response = chatWorld.chatWithoutCharacter( |
|
text=message, |
|
use_local_model=False, |
|
) |
|
return response |
|
|
|
|
|
def get_role_list(): |
|
global role_name_list_global |
|
if role_name_list_global: |
|
return role_name_list_global |
|
else: |
|
return [] |
|
|
|
|
|
def change_role_list(name): |
|
global role_name_dict_global |
|
|
|
return role_name_dict_global[name] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
upload_c = gr.File(label="上传文档文件") |
|
|
|
with gr.Row(): |
|
model_role_name = gr.Radio(get_role_list(), label="模型角色名") |
|
model_role_nickname = gr.Textbox(label="模型角色昵称") |
|
|
|
with gr.Row(): |
|
role_name = gr.Radio(get_role_list(), label="角色名") |
|
role_nickname = gr.Textbox(label="角色昵称") |
|
|
|
model_role_name.change( |
|
fn=change_role_list, inputs=[model_role_name], outputs=[model_role_nickname] |
|
) |
|
role_name.change(fn=change_role_list, inputs=[role_name], outputs=[role_nickname]) |
|
|
|
upload_c.upload( |
|
fn=getContent, inputs=upload_c, outputs=[model_role_name, role_name] |
|
) |
|
|
|
withCharacter = gr.Radio([True, False], value=True, label="是否进行角色扮演") |
|
|
|
with gr.Row(): |
|
chatBox_local = gr.ChatInterface( |
|
submit_message, |
|
chatbot=gr.Chatbot(height=400, label="本地模型", render=False), |
|
additional_inputs=[ |
|
model_role_name, |
|
role_name, |
|
model_role_nickname, |
|
role_nickname, |
|
withCharacter, |
|
], |
|
) |
|
|
|
chatBox_api = gr.ChatInterface( |
|
submit_message_api, |
|
chatbot=gr.Chatbot(height=400, label="API模型", render=False), |
|
additional_inputs=[ |
|
model_role_name, |
|
role_name, |
|
model_role_nickname, |
|
role_nickname, |
|
withCharacter, |
|
], |
|
) |
|
|
|
|
|
demo.launch(share=True, server_name="0.0.0.0") |
|
|