import json import asyncio from fastapi import Body from sse_starlette.sse import EventSourceResponse from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL from langchain.chains import LLMChain from langchain.memory import ConversationBufferWindowMemory from langchain.agents import LLMSingleActionAgent, AgentExecutor from typing import AsyncIterable, Optional, List from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template from server.knowledge_base.kb_service.base import get_kb_details from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent from server.agent.tools_select import tools, tool_names from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status from server.chat.utils import History from server.agent import model_container from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), history: List[History] = Body([], description="历史对话", examples=[[ {"role": "user", "content": "请使用知识库工具查询今天北京天气"}, {"role": "assistant", "content": "使用天气查询工具查询到今天北京多云,10-14摄氏度,东北风2级,易感冒"}]] ), stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): history = [History.from_data(h) for h in history] async def agent_chat_iterator( query: str, history: Optional[List[History]], model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: nonlocal max_tokens callback = CustomAsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None model = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=[callback], ) kb_list = {x["kb_name"]: x for x in get_kb_details()} model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()} if Agent_MODEL: model_agent = get_ChatOpenAI( model_name=Agent_MODEL, temperature=temperature, max_tokens=max_tokens, callbacks=[callback], ) model_container.MODEL = model_agent else: model_container.MODEL = model prompt_template = get_prompt_template("agent_chat", prompt_name) prompt_template_agent = CustomPromptTemplate( template=prompt_template, tools=tools, input_variables=["input", "intermediate_steps", "history"] ) output_parser = CustomOutputParser() llm_chain = LLMChain(llm=model, prompt=prompt_template_agent) memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2) for message in history: if message.role == 'user': memory.chat_memory.add_user_message(message.content) else: memory.chat_memory.add_ai_message(message.content) if "chatglm3" in model_container.MODEL.model_name or "zhipu-api" in model_container.MODEL.model_name: agent_executor = initialize_glm3_agent( llm=model, tools=tools, callback_manager=None, prompt=prompt_template, input_variables=["input", "intermediate_steps", "history"], memory=memory, verbose=True, ) else: agent = LLMSingleActionAgent( llm_chain=llm_chain, output_parser=output_parser, stop=["\nObservation:", "Observation"], allowed_tools=tool_names, ) agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory, ) while True: try: task = asyncio.create_task(wrap_done( agent_executor.acall(query, callbacks=[callback], include_run_info=True), callback.done)) break except: pass if stream: async for chunk in callback.aiter(): tools_use = [] # Use server-sent-events to stream the response data = json.loads(chunk) if data["status"] == Status.start or data["status"] == Status.complete: continue elif data["status"] == Status.error: tools_use.append("\n```\n") tools_use.append("工具名称: " + data["tool_name"]) tools_use.append("工具状态: " + "调用失败") tools_use.append("错误信息: " + data["error"]) tools_use.append("重新开始尝试") tools_use.append("\n```\n") yield json.dumps({"tools": tools_use}, ensure_ascii=False) elif data["status"] == Status.tool_finish: tools_use.append("\n```\n") tools_use.append("工具名称: " + data["tool_name"]) tools_use.append("工具状态: " + "调用成功") tools_use.append("工具输入: " + data["input_str"]) tools_use.append("工具输出: " + data["output_str"]) tools_use.append("\n```\n") yield json.dumps({"tools": tools_use}, ensure_ascii=False) elif data["status"] == Status.agent_finish: yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False) else: yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False) else: answer = "" final_answer = "" async for chunk in callback.aiter(): data = json.loads(chunk) if data["status"] == Status.start or data["status"] == Status.complete: continue if data["status"] == Status.error: answer += "\n```\n" answer += "工具名称: " + data["tool_name"] + "\n" answer += "工具状态: " + "调用失败" + "\n" answer += "错误信息: " + data["error"] + "\n" answer += "\n```\n" if data["status"] == Status.tool_finish: answer += "\n```\n" answer += "工具名称: " + data["tool_name"] + "\n" answer += "工具状态: " + "调用成功" + "\n" answer += "工具输入: " + data["input_str"] + "\n" answer += "工具输出: " + data["output_str"] + "\n" answer += "\n```\n" if data["status"] == Status.agent_finish: final_answer = data["final_answer"] else: answer += data["llm_token"] yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False) await task return EventSourceResponse(agent_chat_iterator(query=query, history=history, model_name=model_name, prompt_name=prompt_name), )