Spaces:
Runtime error
Runtime error
| import langchain | |
| from langchain.agents import create_csv_agent | |
| from langchain.schema import HumanMessage | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.vectorstores import Chroma | |
| from typing import List, Dict | |
| from langchain.agents import AgentType | |
| from langchain.chains.conversation.memory import ConversationBufferWindowMemory | |
| from utils.functions import Matcha_model | |
| from PIL import Image | |
| from pathlib import Path | |
| from langchain.tools import StructuredTool | |
| from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
| class Bot: | |
| def __init__( | |
| self, | |
| openai_api_key: str, | |
| file_descriptions: List[Dict[str, any]], | |
| text_documents: List[langchain.schema.Document], | |
| verbose: bool = False | |
| ): | |
| self.verbose = verbose | |
| self.file_descriptions = file_descriptions | |
| self.llm = ChatOpenAI( | |
| openai_api_key=openai_api_key, | |
| temperature=0, | |
| model_name="gpt-3.5-turbo" | |
| ) | |
| embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
| # embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) | |
| vector_store = Chroma.from_documents(text_documents, embedding_function) | |
| self.text_retriever = langchain.chains.RetrievalQAWithSourcesChain.from_chain_type( | |
| llm=self.llm, | |
| chain_type='stuff', | |
| retriever=vector_store.as_retriever() | |
| ) | |
| self.text_search_tool = langchain.agents.Tool( | |
| func=self._text_search, | |
| description="Use this tool when searching for text information", | |
| name="search text information" | |
| ) | |
| self.chart_model = Matcha_model() | |
| def __call__( | |
| self, | |
| question: str | |
| ): | |
| self.tools = [] | |
| self.tools.append(self.text_search_tool) | |
| file = self._define_appropriate_file(question) | |
| if file != "None of the files": | |
| number = int(file[file.find('№')+1:]) | |
| file_description = [x for x in self.file_descriptions if x['number'] == number][0] | |
| file_path = file_description['path'] | |
| if Path(file).suffix == '.csv': | |
| self.csv_agent = create_csv_agent( | |
| llm=self.llm, | |
| path=file_path, | |
| verbose=self.verbose | |
| ) | |
| self._init_tabular_search_tool(file_description) | |
| self.tools.append(self.tabular_search_tool) | |
| else: | |
| self._init_chart_search_tool(file_description) | |
| self.tools.append(self.chart_search_tool) | |
| self._init_chatbot() | |
| # print(file) | |
| response = self.agent(question) | |
| return response | |
| def _init_chatbot(self): | |
| conversational_memory = ConversationBufferWindowMemory( | |
| memory_key='chat_history', | |
| k=5, | |
| return_messages=True | |
| ) | |
| self.agent = langchain.agents.initialize_agent( | |
| agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, | |
| tools=self.tools, | |
| llm=self.llm, | |
| verbose=self.verbose, | |
| max_iterations=5, | |
| early_stopping_method='generate', | |
| memory=conversational_memory | |
| ) | |
| sys_msg = ( | |
| "You are an expert summarizer and deliverer of information. " | |
| "Yet, the reason you are so intelligent is that you make complex " | |
| "information incredibly simple to understand. It's actually rather incredible." | |
| "When users ask information you refer to the relevant tools." | |
| "if one of the tools helped you with only a part of the necessary information, you must " | |
| "try to find the missing information using another tool" | |
| "if you can't find the information using the provided tools, you MUST " | |
| "say 'I don't know'. Don't try to make up an answer." | |
| ) | |
| prompt = self.agent.agent.create_prompt( | |
| tools=self.tools, | |
| prefix = sys_msg | |
| ) | |
| self.agent.agent.llm_chain.prompt = prompt | |
| def _text_search( | |
| self, | |
| query: str | |
| ) -> str: | |
| query = self.text_retriever.prep_inputs(query) | |
| res = self.text_retriever(query)['answer'] | |
| return res | |
| def _tabular_search( | |
| self, | |
| query: str | |
| ) -> str: | |
| res = self.csv_agent.run(query) | |
| return res | |
| def _chart_search( | |
| self, | |
| image, | |
| query: str | |
| ) -> str: | |
| image = Image.open(image) | |
| res = self.chart_model.chart_qa(image, query) | |
| return res | |
| def _init_chart_search_tool( | |
| self, | |
| title: str | |
| ) -> None: | |
| title = title | |
| description = f""" | |
| Use this tool when searching for information on charts. | |
| With this tool you can answer the question about related chart. | |
| You should ask simple question about a chart, then the tool will give you number. | |
| This chart is called {title}. | |
| """ | |
| self.chart_search_tool = StructuredTool( | |
| func=self._chart_search, | |
| description=description, | |
| name="Ask over charts" | |
| ) | |
| def _init_tabular_search_tool( | |
| self, | |
| file_: Dict[str, any] | |
| ) -> None: | |
| description = f""" | |
| Use this tool when searching for tabular information. | |
| With this tool you could get access to table. | |
| This table title is "{title}" and the names of the columns in this table: {columns} | |
| """ | |
| self.tabular_search_tool = langchain.agents.Tool( | |
| func=self._tabular_search, | |
| description=description, | |
| name="search tabular information" | |
| ) | |
| def _define_appropriate_file( | |
| self, | |
| question: str | |
| ) -> str: | |
| ''' Определяет по описаниям таблиц в какой из них может содержаться ответ на вопрос. | |
| Возвращает номер таблицы по шаблону "Table №1" или "None of the tables" ''' | |
| message = 'I have list of descriptions: \n' | |
| k = 0 | |
| for description in self.file_descriptions: | |
| k += 1 | |
| str_description = f""" {k}) description for File №{description['number']}: """ | |
| for key, value in description.items(): | |
| string_val = str(key) + ' : ' + str(value) + '\n' | |
| str_description += string_val | |
| message += str_description | |
| print(message) | |
| question = f""" How do you think, which file can help answer the question: "{question}" . | |
| Your answer MUST be specific, | |
| for example if you think that File №2 can help answer the question, you MUST just write "File №2!". | |
| If you think that none of the files can help answer the question just write "None of the files!" | |
| Don't include to answer information about your thinking. | |
| """ | |
| message += question | |
| res = self.llm([HumanMessage(content=message)]) | |
| print(res.content) | |
| print(res.content[:-1]) | |
| return res.content[:-1] | |