|
import os |
|
from typing import Sequence, List |
|
|
|
import streamlit as st |
|
from langchain.agents import AgentExecutor |
|
from langchain.schema.language_model import BaseLanguageModel |
|
from langchain.tools import BaseTool |
|
|
|
from backend.chat_bot.message_converter import DefaultClickhouseMessageConverter |
|
from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT |
|
from backend.constants.streamlit_keys import AVAILABLE_RETRIEVAL_TOOLS |
|
from backend.constants.variables import GLOBAL_CONFIG, RETRIEVER_TOOLS |
|
from logger import logger |
|
|
|
try: |
|
from sqlalchemy.orm import declarative_base |
|
except ImportError: |
|
from sqlalchemy.ext.declarative import declarative_base |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.prompts.chat import MessagesPlaceholder |
|
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory |
|
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent |
|
from langchain.schema.messages import SystemMessage |
|
from langchain.memory import SQLChatMessageHistory |
|
|
|
|
|
def create_agent_executor( |
|
agent_name: str, |
|
session_id: str, |
|
llm: BaseLanguageModel, |
|
tools: Sequence[BaseTool], |
|
system_prompt: str, |
|
**kwargs |
|
) -> AgentExecutor: |
|
agent_name = agent_name.replace(" ", "_") |
|
conn_str = f'clickhouse://{os.environ["MYSCALE_USER"]}:{os.environ["MYSCALE_PASSWORD"]}@{os.environ["MYSCALE_HOST"]}:{os.environ["MYSCALE_PORT"]}' |
|
chat_memory = SQLChatMessageHistory( |
|
session_id, |
|
connection_string=f'{conn_str}/chat?protocol=http' if GLOBAL_CONFIG.myscale_enable_https == False else f'{conn_str}/chat?protocol=https', |
|
custom_message_converter=DefaultClickhouseMessageConverter(agent_name)) |
|
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory) |
|
|
|
prompt = OpenAIFunctionsAgent.create_prompt( |
|
system_message=SystemMessage(content=system_prompt), |
|
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], |
|
) |
|
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) |
|
return AgentExecutor( |
|
agent=agent, |
|
tools=tools, |
|
memory=memory, |
|
verbose=True, |
|
return_intermediate_steps=True, |
|
**kwargs |
|
) |
|
|
|
|
|
def build_agents( |
|
session_id: str, |
|
tool_names: List[str], |
|
model: str = "gpt-3.5-turbo-0125", |
|
temperature: float = 0.6, |
|
system_prompt: str = DEFAULT_SYSTEM_PROMPT |
|
): |
|
chat_llm = ChatOpenAI( |
|
model_name=model, |
|
temperature=temperature, |
|
base_url=GLOBAL_CONFIG.openai_api_base, |
|
api_key=GLOBAL_CONFIG.openai_api_key, |
|
streaming=True |
|
) |
|
tools = st.session_state.get(AVAILABLE_RETRIEVAL_TOOLS, st.session_state.get(RETRIEVER_TOOLS)) |
|
selected_tools = [tools[k] for k in tool_names] |
|
logger.info(f"create agent, use tools: {selected_tools}") |
|
agent = create_agent_executor( |
|
agent_name="chat_memory", |
|
session_id=session_id, |
|
llm=chat_llm, |
|
tools=selected_tools, |
|
system_prompt=system_prompt |
|
) |
|
return agent |
|
|