|
import os |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import gradio as gr |
|
|
|
from common.call_llm import chat, chat_stream_generator |
|
from plugin_task.model import Plugin, ReActStep |
|
from plugin_task.plugins import PLUGIN_JSON_SCHEMA, PLUGINS |
|
from plugin_task.prompt import ( |
|
FILLING_SLOT_PROMPT, |
|
FINAL_PROMPT, |
|
INTENT_RECOGNITION_PROMPT, |
|
) |
|
from plugin_task.util import ( |
|
build_prompt_plugin_variables, |
|
parse_reAct_step, |
|
plugin_parameter_validator, |
|
) |
|
|
|
PLUGIN_ENDPOINT = os.environ.get("PLUGIN_ENDPOINT") |
|
|
|
|
|
def api_plugin_chat( |
|
session: Dict, |
|
message: str, |
|
chat_history: List[List[str]], |
|
*radio_plugins, |
|
): |
|
"""调用插件""" |
|
|
|
if not check_in_plugin_session(session): |
|
plugins = prepare_plugins(radio_plugins) |
|
if not plugins: |
|
gr.Warning("没有启用插件") |
|
return |
|
|
|
intention, reAct_step = intent_recognition(message, plugins) |
|
if intention in ("ask_user_for_required_params", "plugin"): |
|
session["origin_message"] = message |
|
session["choice_plugin"] = reAct_step.thought["tool_to_use_for_user"] |
|
session["reAct_step"] = [reAct_step] |
|
|
|
else: |
|
intention, reAct_step = filling_slot_with_loop(session, message) |
|
|
|
print( |
|
f"[API_PLUGIN_CHAT]. message: {message},\n intention: {intention},\n session: {session}\n" |
|
+ "=" * 25 |
|
+ "END" |
|
+ "=" * 25 |
|
) |
|
|
|
if intention == "fail": |
|
chat_history[-1][1] = reAct_step |
|
session.clear() |
|
yield session, None, chat_history |
|
return |
|
|
|
if intention == "ask_user_for_required_params": |
|
chat_history[-1][1] = reAct_step.action_input.get("question", "") |
|
yield session, None, chat_history |
|
|
|
if intention == "plugin": |
|
yield from call_final_answer(session, reAct_step, chat_history) |
|
|
|
if intention == "chat": |
|
yield from call_chat(session, message, chat_history) |
|
|
|
if intention == "end": |
|
session.clear() |
|
chat_history[-1][1] = "[系统消息]:当前插件对话结束" |
|
yield session, None, chat_history |
|
return |
|
return |
|
|
|
|
|
def filling_slot_with_loop( |
|
session: Dict, message: str, retry: int = 3 |
|
) -> Tuple[str, Optional[Union[ReActStep, str]]]: |
|
"""处理填槽""" |
|
plugin = PLUGINS[session["choice_plugin"]] |
|
|
|
while True: |
|
lastest_reAct_step = session["reAct_step"][-1] |
|
if not lastest_reAct_step.observation: |
|
lastest_reAct_step.observation = {"user_answer": message} |
|
|
|
reAct_step_str = "\n".join(step.to_str() for step in session["reAct_step"]) |
|
|
|
ask_content = FILLING_SLOT_PROMPT.format( |
|
plugin_name=plugin.unique_name_for_model, |
|
description_for_human=plugin.description_for_human, |
|
parameter_schema=plugin.parameter_schema, |
|
question=session["origin_message"], |
|
reAct_step_str=reAct_step_str, |
|
) |
|
|
|
model_response = chat( |
|
[{"content": ask_content, "role": "user"}], |
|
stop="Observation", |
|
endpoint=PLUGIN_ENDPOINT, |
|
) |
|
print( |
|
f"[FILLING_SLOT_WITH_LOOP] message: {message} ask_content: {ask_content}\n model_response: {model_response}\n" |
|
+ "=" * 25 |
|
+ "END" |
|
+ "=" * 25 |
|
) |
|
reAct_step = parse_reAct_step(model_response) |
|
if not reAct_step: |
|
if (retry := retry - 1) < 0: |
|
return "fail", model_response |
|
|
|
continue |
|
|
|
tool_to_use_for_user = reAct_step.thought.get("tool_to_use_for_user") |
|
known_parameter = reAct_step.thought.get("known_params", {}) |
|
|
|
if ( |
|
reAct_step.action == "end_conversation" |
|
or tool_to_use_for_user == "end_conversation" |
|
): |
|
return "end", reAct_step |
|
|
|
if ( |
|
reAct_step.action == "ASK_USER_FOR_REQUIRED_PARAMS" |
|
and tool_to_use_for_user == plugin.unique_name_for_model |
|
): |
|
passed, _ = plugin_parameter_validator( |
|
known_parameter, |
|
tool_to_use_for_user, |
|
) |
|
if passed: |
|
reAct_step.action = tool_to_use_for_user |
|
action = "plugin" |
|
else: |
|
action = "ask_user_for_required_params" |
|
|
|
session["reAct_step"].append(reAct_step) |
|
return action, reAct_step |
|
|
|
if ( |
|
reAct_step.action == plugin.unique_name_for_model |
|
and tool_to_use_for_user == plugin.unique_name_for_model |
|
): |
|
passed, invalid_info = plugin_parameter_validator( |
|
known_parameter, |
|
tool_to_use_for_user, |
|
) |
|
|
|
if not passed: |
|
reAct_step.observation = {"tool_parameters_verification": invalid_info} |
|
session["reAct_step"].append(reAct_step) |
|
continue |
|
|
|
session["reAct_step"].append(reAct_step) |
|
return "plugin", reAct_step |
|
|
|
|
|
def call_chat(session: Dict, message: str, chat_history: List[List[str]]): |
|
from chat_task.chat import generate_chat |
|
|
|
for chunk in generate_chat(message, chat_history, PLUGIN_ENDPOINT): |
|
yield session, *chunk |
|
|
|
|
|
def check_in_plugin_session(session: Dict) -> bool: |
|
"""检查是否在插件会话中""" |
|
return bool(session) |
|
|
|
|
|
def prepare_plugins( |
|
radio_plugins: List[str], |
|
) -> List[Plugin]: |
|
return [ |
|
PLUGINS[PLUGIN_JSON_SCHEMA[plugin_idx]["unique_name_for_model"]] |
|
for plugin_idx, plugin_status in enumerate(radio_plugins) |
|
if plugin_status == "开启" |
|
] |
|
|
|
|
|
def intent_recognition( |
|
message: str, choice_plugins: List[Plugin] |
|
) -> Tuple[str, Union[ReActStep, str]]: |
|
"""意图识别""" |
|
|
|
plugins, plugin_names = build_prompt_plugin_variables(choice_plugins) |
|
ask_content = INTENT_RECOGNITION_PROMPT.format( |
|
plugins=plugins, plugin_names=plugin_names, question=message |
|
) |
|
|
|
print( |
|
f"[INTENT_RECOGNITION] message:{message} ask_content: {ask_content}" |
|
+ "=" * 25 |
|
+ "END" |
|
+ "=" * 25 |
|
) |
|
|
|
retry = 3 |
|
while retry != 0: |
|
model_response = chat( |
|
[{"content": ask_content, "role": "user"}], |
|
stop="Observation", |
|
endpoint=PLUGIN_ENDPOINT, |
|
) |
|
|
|
reAct_step = parse_reAct_step(model_response) |
|
if reAct_step: |
|
break |
|
retry -= 1 |
|
|
|
if not reAct_step: |
|
print(f"[INTENT_RECOGNITION] model fail: {model_response}") |
|
return "fail", model_response |
|
|
|
tool_to_use_for_user = reAct_step.thought.get("tool_to_use_for_user") |
|
known_params = reAct_step.thought.get("known_params", {}) |
|
|
|
if reAct_step.action == "TOOL_OTHER": |
|
return "chat", reAct_step |
|
|
|
elif ( |
|
reAct_step.action == "end_conversation" |
|
and tool_to_use_for_user == "end_conversation" |
|
): |
|
return "end", reAct_step |
|
|
|
elif tool_to_use_for_user in plugin_names.split(","): |
|
if reAct_step.action in ("ASK_USER_FOR_INTENT", "ASK_USER_FOR_REQUIRED_PARAMS"): |
|
passed, _ = plugin_parameter_validator( |
|
known_params, |
|
tool_to_use_for_user, |
|
) |
|
if passed: |
|
reAct_step.action = tool_to_use_for_user |
|
return "plugin", reAct_step |
|
|
|
return "ask_user_for_required_params", reAct_step |
|
|
|
if reAct_step.action in plugin_names.split(","): |
|
return "plugin", reAct_step |
|
|
|
return "chat", reAct_step |
|
|
|
|
|
def call_final_answer(session: Dict, reAct_step: ReActStep, history: List[List[str]]): |
|
"""调用最终回答""" |
|
plugin_result = PLUGINS[reAct_step.action].run(**reAct_step.action_input) |
|
|
|
lastest_reAct_step = session["reAct_step"][-1] |
|
lastest_reAct_step.observation = {"tool_response": plugin_result} |
|
|
|
reAct_step_str = "\n".join(step.to_str() for step in session["reAct_step"]) |
|
final_prompt = FINAL_PROMPT.format( |
|
question=session["origin_message"], |
|
reAct_step_str=reAct_step_str, |
|
) |
|
|
|
print( |
|
f"[CALL_FINAL_ANSWER] final_prompt: {final_prompt}\n" |
|
+ "=" * 25 |
|
+ "END" |
|
+ "=" * 25 |
|
) |
|
stream_response = chat_stream_generator( |
|
[{"content": final_prompt, "role": "user"}], |
|
endpoint=PLUGIN_ENDPOINT, |
|
) |
|
|
|
for character in stream_response: |
|
history[-1][1] += character |
|
yield session, None, history |
|
|
|
session.clear() |
|
|