import langchain from langchain.agents import create_csv_agent from langchain.schema import HumanMessage from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings from langchain.vectorstores import Chroma from typing import List, Dict from langchain.agents import AgentType from langchain.chains.conversation.memory import ConversationBufferWindowMemory from utils.functions import Matcha_model from PIL import Image from pathlib import Path from langchain.tools import StructuredTool from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings class Bot: def __init__( self, openai_api_key: str, file_descriptions: List[Dict[str, any]], text_documents: List[langchain.schema.Document], verbose: bool = False ): self.verbose = verbose self.file_descriptions = file_descriptions self.llm = ChatOpenAI( openai_api_key=openai_api_key, temperature=0, model_name="gpt-3.5-turbo" ) embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") # embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) vector_store = Chroma.from_documents(text_documents, embedding_function) self.text_retriever = langchain.chains.RetrievalQAWithSourcesChain.from_chain_type( llm=self.llm, chain_type='stuff', retriever=vector_store.as_retriever() ) self.text_search_tool = langchain.agents.Tool( func=self._text_search, description="Use this tool when searching for text information", name="search text information" ) self.chart_model = Matcha_model() def __call__( self, question: str ): self.tools = [] self.tools.append(self.text_search_tool) file = self._define_appropriate_file(question) if file != "None of the files": number = int(file[file.find('№')+1:]) file_description = [x for x in self.file_descriptions if x['number'] == number][0] file_path = file_description['path'] if Path(file).suffix == '.csv': self.csv_agent = create_csv_agent( llm=self.llm, path=file_path, verbose=self.verbose ) self._init_tabular_search_tool(file_description) self.tools.append(self.tabular_search_tool) else: self._init_chart_search_tool(file_description) self.tools.append(self.chart_search_tool) self._init_chatbot() # print(file) response = self.agent(question) return response def _init_chatbot(self): conversational_memory = ConversationBufferWindowMemory( memory_key='chat_history', k=5, return_messages=True ) self.agent = langchain.agents.initialize_agent( agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, tools=self.tools, llm=self.llm, verbose=self.verbose, max_iterations=5, early_stopping_method='generate', memory=conversational_memory ) sys_msg = ( "You are an expert summarizer and deliverer of information. " "Yet, the reason you are so intelligent is that you make complex " "information incredibly simple to understand. It's actually rather incredible." "When users ask information you refer to the relevant tools." "if one of the tools helped you with only a part of the necessary information, you must " "try to find the missing information using another tool" "if you can't find the information using the provided tools, you MUST " "say 'I don't know'. Don't try to make up an answer." ) prompt = self.agent.agent.create_prompt( tools=self.tools, prefix = sys_msg ) self.agent.agent.llm_chain.prompt = prompt def _text_search( self, query: str ) -> str: query = self.text_retriever.prep_inputs(query) res = self.text_retriever(query)['answer'] return res def _tabular_search( self, query: str ) -> str: res = self.csv_agent.run(query) return res def _chart_search( self, image, query: str ) -> str: image = Image.open(image) res = self.chart_model.chart_qa(image, query) return res def _init_chart_search_tool( self, title: str ) -> None: title = title description = f""" Use this tool when searching for information on charts. With this tool you can answer the question about related chart. You should ask simple question about a chart, then the tool will give you number. This chart is called {title}. """ self.chart_search_tool = StructuredTool( func=self._chart_search, description=description, name="Ask over charts" ) def _init_tabular_search_tool( self, file_: Dict[str, any] ) -> None: description = f""" Use this tool when searching for tabular information. With this tool you could get access to table. This table title is "{title}" and the names of the columns in this table: {columns} """ self.tabular_search_tool = langchain.agents.Tool( func=self._tabular_search, description=description, name="search tabular information" ) def _define_appropriate_file( self, question: str ) -> str: ''' Определяет по описаниям таблиц в какой из них может содержаться ответ на вопрос. Возвращает номер таблицы по шаблону "Table №1" или "None of the tables" ''' message = 'I have list of descriptions: \n' k = 0 for description in self.file_descriptions: k += 1 str_description = f""" {k}) description for File №{description['number']}: """ for key, value in description.items(): string_val = str(key) + ' : ' + str(value) + '\n' str_description += string_val message += str_description print(message) question = f""" How do you think, which file can help answer the question: "{question}" . Your answer MUST be specific, for example if you think that File №2 can help answer the question, you MUST just write "File №2!". If you think that none of the files can help answer the question just write "None of the files!" Don't include to answer information about your thinking. """ message += question res = self.llm([HumanMessage(content=message)]) print(res.content) print(res.content[:-1]) return res.content[:-1]