Spaces:
Runtime error
Runtime error
import langchain | |
from langchain.agents import create_csv_agent | |
from langchain.schema import HumanMessage | |
from langchain.chat_models import ChatOpenAI | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.vectorstores import Chroma | |
from typing import List, Dict | |
from langchain.agents import AgentType | |
from langchain.chains.conversation.memory import ConversationBufferWindowMemory | |
class Bot: | |
def __init__( | |
self, | |
openai_api_key: str, | |
table_descriptions: List[Dict[str, any]], | |
text_documents: List[langchain.schema.Document], | |
verbose: bool = False | |
): | |
self.verbose = verbose | |
self.table_descriptions = table_descriptions | |
self.llm = ChatOpenAI( | |
openai_api_key=openai_api_key, | |
temperature=0, | |
model_name="gpt-3.5-turbo" | |
) | |
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) | |
vector_store = Chroma.from_documents(text_documents, embeddings) | |
self.text_retriever = langchain.chains.RetrievalQAWithSourcesChain.from_chain_type( | |
llm=self.llm, | |
chain_type='stuff', | |
retriever=vector_store.as_retriever() | |
) | |
self.text_search_tool = langchain.agents.Tool( | |
func=self._text_search, | |
description="Use this tool when searching for text information", | |
name="search text information" | |
) | |
def __call__( | |
self, | |
question: str | |
): | |
self.tools = [] | |
self.tools.append(self.text_search_tool) | |
table = self._define_appropriate_table(question) | |
if table != "None of the tables": | |
number = int(table[table.find('№')+1:]) | |
table_description = [x for x in self.table_descriptions if x['number'] == number][0] | |
table_path = table_description['path'] | |
self.csv_agent = create_csv_agent( | |
llm=self.llm, | |
path=table_path, | |
verbose=self.verbose | |
) | |
self._init_tabular_search_tool(table_description) | |
self.tools.append(self.tabular_search_tool) | |
self._init_chatbot() | |
print(table) | |
response = self.agent(question) | |
return response | |
def _init_chatbot(self): | |
conversational_memory = ConversationBufferWindowMemory( | |
memory_key='chat_history', | |
k=5, | |
return_messages=True | |
) | |
self.agent = langchain.agents.initialize_agent( | |
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION, | |
tools=self.tools, | |
llm=self.llm, | |
verbose=self.verbose, | |
max_iterations=5, | |
early_stopping_method='generate', | |
memory=conversational_memory | |
) | |
sys_msg = ( | |
"You are an expert summarizer and deliverer of information. " | |
"Yet, the reason you are so intelligent is that you make complex " | |
"information incredibly simple to understand. It's actually rather incredible." | |
"When users ask information you refer to the relevant tools." | |
"if one of the tools helped you with only a part of the necessary information, you must " | |
"try to find the missing information using another tool" | |
"if you can't find the information using the provided tools, you MUST " | |
"say 'I don't know'. Don't try to make up an answer." | |
) | |
prompt = self.agent.agent.create_prompt( | |
system_message=sys_msg, | |
tools=self.tools | |
) | |
self.agent.agent.llm_chain.prompt = prompt | |
def _text_search( | |
self, | |
query: str | |
) -> str: | |
query = self.text_retriever.prep_inputs(query) | |
res = self.text_retriever(query)['answer'] | |
return res | |
def _tabular_search( | |
self, | |
query: str | |
) -> str: | |
res = self.csv_agent.run(query) | |
return res | |
def _init_tabular_search_tool( | |
self, | |
table_description: Dict[str, any] | |
) -> None: | |
columns = table_description["columns"] | |
columns = '"' + '", "'.join(columns) + '"' | |
tittle = table_description["tittle"] | |
description = f""" | |
Use this tool when searching for tabular information. | |
With this tool you could get access to table. | |
This table tittle is "{tittle}" and the names of the columns in this table: {columns} | |
""" | |
self.tabular_search_tool = langchain.agents.Tool( | |
func=self._tabular_search, | |
description=description, | |
name="search tabular information" | |
) | |
def _define_appropriate_table( | |
self, | |
question: str | |
) -> str: | |
''' Определяет по описаниям таблиц в какой из них может содержаться ответ на вопрос. | |
Возвращает номер таблицы по шаблону "Table №1" или "None of the tables" ''' | |
message = 'I have list of table descriptions: \n' | |
k = 0 | |
for description in self.table_descriptions: | |
k += 1 | |
number = description["number"] | |
columns = description["columns"] | |
columns = '"' + '", "'.join(columns) + '"' | |
tittle = description["tittle"] | |
str_description = f""" {k}) description for Table №{number}: | |
a) table consist of columns with names: {columns}; | |
b) table tittle: {tittle}.\n""" | |
message += str_description | |
question = f""" How do you think, which table can help answer the question: "{question}" . | |
Your answer MUST be specific, | |
for example if you think that Table №2 can help answer the question, you MUST just write "Table №2". | |
If you think that none of the tables can help answer the question just write "None of the tables" | |
Don't include to answer information about your thinking. | |
""" | |
message += question | |
res = self.llm([HumanMessage(content=message)]) | |
return res.content[:-1] | |