File size: 6,197 Bytes
d31e8ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import langchain
from langchain.agents import create_csv_agent
from langchain.schema import HumanMessage
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from typing import List, Dict
from langchain.agents import AgentType
from langchain.chains.conversation.memory import ConversationBufferWindowMemory


class Bot:

    def __init__(
            self,
            openai_api_key: str,
            table_descriptions: List[Dict[str, any]],
            text_documents: List[langchain.schema.Document],
            verbose: bool = False
    ):
        self.verbose = verbose
        self.table_descriptions = table_descriptions

        self.llm = ChatOpenAI(
            openai_api_key=openai_api_key,
            temperature=0,
            model_name="gpt-3.5-turbo"
        )

        embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
        vector_store = Chroma.from_documents(text_documents, embeddings)
        self.text_retriever = langchain.chains.RetrievalQAWithSourcesChain.from_chain_type(
            llm=self.llm,
            chain_type='stuff',
            retriever=vector_store.as_retriever()
        )
        self.text_search_tool = langchain.agents.Tool(
            func=self._text_search,
            description="Use this tool when searching for text information",
            name="search text information"
        )

    def __call__(
            self,
            question: str
    ):
        self.tools = []
        self.tools.append(self.text_search_tool)
        table = self._define_appropriate_table(question)
        if table != "None of the tables":
            number = int(table[table.find('№')+1:])
            table_description = [x for x in self.table_descriptions if x['number'] == number][0]
            table_path = table_description['path']

            self.csv_agent = create_csv_agent(
                llm=self.llm,
                path=table_path,
                verbose=self.verbose
            )

            self._init_tabular_search_tool(table_description)
            self.tools.append(self.tabular_search_tool)

        self._init_chatbot()
        print(table)
        response = self.agent(question)
        return response

    def _init_chatbot(self):

        conversational_memory = ConversationBufferWindowMemory(
            memory_key='chat_history',
            k=5,
            return_messages=True
        )

        self.agent = langchain.agents.initialize_agent(
            agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
            tools=self.tools,
            llm=self.llm,
            verbose=self.verbose,
            max_iterations=5,
            early_stopping_method='generate',
            memory=conversational_memory
        )
        sys_msg = (
            "You are an expert summarizer and deliverer of information. "
            "Yet, the reason you are so intelligent is that you make complex "
            "information incredibly simple to understand. It's actually rather incredible."
            "When users ask information you refer to the relevant tools."
            "if one of the tools helped you with only a part of the necessary information, you must "
            "try to find the missing information using another tool"
            "if you can't find the information using the provided tools, you MUST "
            "say 'I don't know'. Don't try to make up an answer."
        )
        prompt = self.agent.agent.create_prompt(
            system_message=sys_msg,
            tools=self.tools
        )
        self.agent.agent.llm_chain.prompt = prompt

    def _text_search(
            self,
            query: str
    ) -> str:
        query = self.text_retriever.prep_inputs(query)
        res = self.text_retriever(query)['answer']
        return res

    def _tabular_search(
            self,
            query: str
    ) -> str:
        res = self.csv_agent.run(query)
        return res

    def _init_tabular_search_tool(
            self,
            table_description: Dict[str, any]
    ) -> None:

        columns = table_description["columns"]
        columns = '"' + '", "'.join(columns) + '"'
        tittle = table_description["tittle"]

        description = f"""
            Use this tool when searching for tabular information.
            With this tool you could get access to table.
            This table tittle is "{tittle}" and the names of the columns in this table: {columns} 
        """

        self.tabular_search_tool = langchain.agents.Tool(
            func=self._tabular_search,
            description=description,
            name="search tabular information"
        )

    def _define_appropriate_table(
            self,
            question: str
    ) -> str:
        ''' Определяет по описаниям таблиц в какой из них может содержаться ответ на вопрос.
        Возвращает номер таблицы по шаблону "Table №1" или "None of the tables" '''

        message = 'I have list of table descriptions: \n'
        k = 0
        for description in self.table_descriptions:
            k += 1
            number = description["number"]
            columns = description["columns"]
            columns = '"' + '", "'.join(columns) + '"'
            tittle = description["tittle"]
            str_description = f"""  {k}) description for Table №{number}:
            a) table consist of columns with names: {columns}; 
            b) table tittle: {tittle}.\n"""
            message += str_description

        question = f""" How do you think, which table can help answer the question: "{question}" .
        Your answer MUST be specific, 
        for example if you think that Table №2 can help answer the question, you MUST just write  "Table №2". 
        If you think that none of the tables can help answer the question just write "None of the tables"
        Don't include to answer information about your thinking.
        """
        message += question

        res = self.llm([HumanMessage(content=message)])
        return res.content[:-1]