xiaoximew's picture
Init
031beb8
raw
history blame
No virus
8.44 kB
import os
from typing import Dict, List, Optional, Tuple, Union
import gradio as gr
from common.call_llm import chat, chat_stream_generator
from plugin_task.model import Plugin, ReActStep
from plugin_task.plugins import PLUGIN_JSON_SCHEMA, PLUGINS
from plugin_task.prompt import (
FILLING_SLOT_PROMPT,
FINAL_PROMPT,
INTENT_RECOGNITION_PROMPT,
)
from plugin_task.util import (
build_prompt_plugin_variables,
parse_reAct_step,
plugin_parameter_validator,
)
PLUGIN_ENDPOINT = os.environ.get("PLUGIN_ENDPOINT")
def api_plugin_chat(
session: Dict,
message: str,
chat_history: List[List[str]],
*radio_plugins,
):
"""调用插件"""
if not check_in_plugin_session(session):
plugins = prepare_plugins(radio_plugins)
if not plugins:
gr.Warning("没有启用插件")
return
intention, reAct_step = intent_recognition(message, plugins)
if intention in ("ask_user_for_required_params", "plugin"):
session["origin_message"] = message
session["choice_plugin"] = reAct_step.thought["tool_to_use_for_user"]
session["reAct_step"] = [reAct_step]
else:
intention, reAct_step = filling_slot_with_loop(session, message)
print(
f"[API_PLUGIN_CHAT]. message: {message},\n intention: {intention},\n session: {session}\n"
+ "=" * 25
+ "END"
+ "=" * 25
)
if intention == "fail":
chat_history[-1][1] = reAct_step
session.clear()
yield session, None, chat_history
return
if intention == "ask_user_for_required_params":
chat_history[-1][1] = reAct_step.action_input.get("question", "")
yield session, None, chat_history
if intention == "plugin":
yield from call_final_answer(session, reAct_step, chat_history)
if intention == "chat":
yield from call_chat(session, message, chat_history)
if intention == "end":
session.clear()
chat_history[-1][1] = "[系统消息]:当前插件对话结束"
yield session, None, chat_history
return
return
def filling_slot_with_loop(
session: Dict, message: str, retry: int = 3
) -> Tuple[str, Optional[Union[ReActStep, str]]]:
"""处理填槽"""
plugin = PLUGINS[session["choice_plugin"]]
while True:
lastest_reAct_step = session["reAct_step"][-1]
if not lastest_reAct_step.observation:
lastest_reAct_step.observation = {"user_answer": message}
reAct_step_str = "\n".join(step.to_str() for step in session["reAct_step"])
ask_content = FILLING_SLOT_PROMPT.format(
plugin_name=plugin.unique_name_for_model,
description_for_human=plugin.description_for_human,
parameter_schema=plugin.parameter_schema,
question=session["origin_message"],
reAct_step_str=reAct_step_str,
)
model_response = chat(
[{"content": ask_content, "role": "user"}],
stop="Observation",
endpoint=PLUGIN_ENDPOINT,
)
print(
f"[FILLING_SLOT_WITH_LOOP] message: {message} ask_content: {ask_content}\n model_response: {model_response}\n"
+ "=" * 25
+ "END"
+ "=" * 25
)
reAct_step = parse_reAct_step(model_response)
if not reAct_step:
if (retry := retry - 1) < 0:
return "fail", model_response
continue
tool_to_use_for_user = reAct_step.thought.get("tool_to_use_for_user")
known_parameter = reAct_step.thought.get("known_params", {})
if (
reAct_step.action == "end_conversation"
or tool_to_use_for_user == "end_conversation"
):
return "end", reAct_step
if (
reAct_step.action == "ASK_USER_FOR_REQUIRED_PARAMS"
and tool_to_use_for_user == plugin.unique_name_for_model
):
passed, _ = plugin_parameter_validator(
known_parameter,
tool_to_use_for_user,
)
if passed:
reAct_step.action = tool_to_use_for_user
action = "plugin"
else:
action = "ask_user_for_required_params"
session["reAct_step"].append(reAct_step)
return action, reAct_step
if (
reAct_step.action == plugin.unique_name_for_model
and tool_to_use_for_user == plugin.unique_name_for_model
):
passed, invalid_info = plugin_parameter_validator(
known_parameter,
tool_to_use_for_user,
)
if not passed:
reAct_step.observation = {"tool_parameters_verification": invalid_info}
session["reAct_step"].append(reAct_step)
continue
session["reAct_step"].append(reAct_step)
return "plugin", reAct_step
def call_chat(session: Dict, message: str, chat_history: List[List[str]]):
from chat_task.chat import generate_chat
for chunk in generate_chat(message, chat_history, PLUGIN_ENDPOINT):
yield session, *chunk
def check_in_plugin_session(session: Dict) -> bool:
"""检查是否在插件会话中"""
return bool(session)
def prepare_plugins(
radio_plugins: List[str],
) -> List[Plugin]:
return [
PLUGINS[PLUGIN_JSON_SCHEMA[plugin_idx]["unique_name_for_model"]]
for plugin_idx, plugin_status in enumerate(radio_plugins)
if plugin_status == "开启"
]
def intent_recognition(
message: str, choice_plugins: List[Plugin]
) -> Tuple[str, Union[ReActStep, str]]:
"""意图识别"""
plugins, plugin_names = build_prompt_plugin_variables(choice_plugins)
ask_content = INTENT_RECOGNITION_PROMPT.format(
plugins=plugins, plugin_names=plugin_names, question=message
)
print(
f"[INTENT_RECOGNITION] message:{message} ask_content: {ask_content}"
+ "=" * 25
+ "END"
+ "=" * 25
)
retry = 3
while retry != 0:
model_response = chat(
[{"content": ask_content, "role": "user"}],
stop="Observation",
endpoint=PLUGIN_ENDPOINT,
)
reAct_step = parse_reAct_step(model_response)
if reAct_step:
break
retry -= 1
if not reAct_step:
print(f"[INTENT_RECOGNITION] model fail: {model_response}")
return "fail", model_response
tool_to_use_for_user = reAct_step.thought.get("tool_to_use_for_user")
known_params = reAct_step.thought.get("known_params", {})
if reAct_step.action == "TOOL_OTHER":
return "chat", reAct_step
elif (
reAct_step.action == "end_conversation"
and tool_to_use_for_user == "end_conversation"
):
return "end", reAct_step
elif tool_to_use_for_user in plugin_names.split(","):
if reAct_step.action in ("ASK_USER_FOR_INTENT", "ASK_USER_FOR_REQUIRED_PARAMS"):
passed, _ = plugin_parameter_validator(
known_params,
tool_to_use_for_user,
)
if passed:
reAct_step.action = tool_to_use_for_user
return "plugin", reAct_step
return "ask_user_for_required_params", reAct_step
if reAct_step.action in plugin_names.split(","):
return "plugin", reAct_step
return "chat", reAct_step
def call_final_answer(session: Dict, reAct_step: ReActStep, history: List[List[str]]):
"""调用最终回答"""
plugin_result = PLUGINS[reAct_step.action].run(**reAct_step.action_input)
lastest_reAct_step = session["reAct_step"][-1]
lastest_reAct_step.observation = {"tool_response": plugin_result}
reAct_step_str = "\n".join(step.to_str() for step in session["reAct_step"])
final_prompt = FINAL_PROMPT.format(
question=session["origin_message"],
reAct_step_str=reAct_step_str,
)
print(
f"[CALL_FINAL_ANSWER] final_prompt: {final_prompt}\n"
+ "=" * 25
+ "END"
+ "=" * 25
)
stream_response = chat_stream_generator(
[{"content": final_prompt, "role": "user"}],
endpoint=PLUGIN_ENDPOINT,
)
for character in stream_response:
history[-1][1] += character
yield session, None, history
session.clear()