ainur1's picture
first
d31e8ca
raw
history blame contribute delete
No virus
6.2 kB
import langchain
from langchain.agents import create_csv_agent
from langchain.schema import HumanMessage
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from typing import List, Dict
from langchain.agents import AgentType
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
class Bot:
def __init__(
self,
openai_api_key: str,
table_descriptions: List[Dict[str, any]],
text_documents: List[langchain.schema.Document],
verbose: bool = False
):
self.verbose = verbose
self.table_descriptions = table_descriptions
self.llm = ChatOpenAI(
openai_api_key=openai_api_key,
temperature=0,
model_name="gpt-3.5-turbo"
)
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
vector_store = Chroma.from_documents(text_documents, embeddings)
self.text_retriever = langchain.chains.RetrievalQAWithSourcesChain.from_chain_type(
llm=self.llm,
chain_type='stuff',
retriever=vector_store.as_retriever()
)
self.text_search_tool = langchain.agents.Tool(
func=self._text_search,
description="Use this tool when searching for text information",
name="search text information"
)
def __call__(
self,
question: str
):
self.tools = []
self.tools.append(self.text_search_tool)
table = self._define_appropriate_table(question)
if table != "None of the tables":
number = int(table[table.find('№')+1:])
table_description = [x for x in self.table_descriptions if x['number'] == number][0]
table_path = table_description['path']
self.csv_agent = create_csv_agent(
llm=self.llm,
path=table_path,
verbose=self.verbose
)
self._init_tabular_search_tool(table_description)
self.tools.append(self.tabular_search_tool)
self._init_chatbot()
print(table)
response = self.agent(question)
return response
def _init_chatbot(self):
conversational_memory = ConversationBufferWindowMemory(
memory_key='chat_history',
k=5,
return_messages=True
)
self.agent = langchain.agents.initialize_agent(
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
tools=self.tools,
llm=self.llm,
verbose=self.verbose,
max_iterations=5,
early_stopping_method='generate',
memory=conversational_memory
)
sys_msg = (
"You are an expert summarizer and deliverer of information. "
"Yet, the reason you are so intelligent is that you make complex "
"information incredibly simple to understand. It's actually rather incredible."
"When users ask information you refer to the relevant tools."
"if one of the tools helped you with only a part of the necessary information, you must "
"try to find the missing information using another tool"
"if you can't find the information using the provided tools, you MUST "
"say 'I don't know'. Don't try to make up an answer."
)
prompt = self.agent.agent.create_prompt(
system_message=sys_msg,
tools=self.tools
)
self.agent.agent.llm_chain.prompt = prompt
def _text_search(
self,
query: str
) -> str:
query = self.text_retriever.prep_inputs(query)
res = self.text_retriever(query)['answer']
return res
def _tabular_search(
self,
query: str
) -> str:
res = self.csv_agent.run(query)
return res
def _init_tabular_search_tool(
self,
table_description: Dict[str, any]
) -> None:
columns = table_description["columns"]
columns = '"' + '", "'.join(columns) + '"'
tittle = table_description["tittle"]
description = f"""
Use this tool when searching for tabular information.
With this tool you could get access to table.
This table tittle is "{tittle}" and the names of the columns in this table: {columns}
"""
self.tabular_search_tool = langchain.agents.Tool(
func=self._tabular_search,
description=description,
name="search tabular information"
)
def _define_appropriate_table(
self,
question: str
) -> str:
''' Определяет по описаниям таблиц в какой из них может содержаться ответ на вопрос.
Возвращает номер таблицы по шаблону "Table №1" или "None of the tables" '''
message = 'I have list of table descriptions: \n'
k = 0
for description in self.table_descriptions:
k += 1
number = description["number"]
columns = description["columns"]
columns = '"' + '", "'.join(columns) + '"'
tittle = description["tittle"]
str_description = f""" {k}) description for Table №{number}:
a) table consist of columns with names: {columns};
b) table tittle: {tittle}.\n"""
message += str_description
question = f""" How do you think, which table can help answer the question: "{question}" .
Your answer MUST be specific,
for example if you think that Table №2 can help answer the question, you MUST just write "Table №2".
If you think that none of the tables can help answer the question just write "None of the tables"
Don't include to answer information about your thinking.
"""
message += question
res = self.llm([HumanMessage(content=message)])
return res.content[:-1]