Boardpac/theekshanas commited on
Commit
027bfbf
1 Parent(s): 9e2dc86

agent with memory

Browse files
.env CHANGED
@@ -11,7 +11,7 @@ TARGET_SOURCE_CHUNKS=4
11
 
12
  #API token keys
13
  HUGGINGFACEHUB_API_TOKEN=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
14
- OPENAI_API_KEY=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
15
 
16
  #api app
17
  APP_HOST=127.0.0.1
 
11
 
12
  #API token keys
13
  HUGGINGFACEHUB_API_TOKEN=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
14
+ OPENAI_API_KEY=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
15
 
16
  #api app
17
  APP_HOST=127.0.0.1
__pycache__/conversationBufferWindowMemory.cpython-311.pyc ADDED
Binary file (6.23 kB). View file
 
__pycache__/qaPipeline.cpython-311.pyc CHANGED
Binary files a/__pycache__/qaPipeline.cpython-311.pyc and b/__pycache__/qaPipeline.cpython-311.pyc differ
 
__pycache__/qaPipeline_functions.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
app.py CHANGED
@@ -16,6 +16,7 @@ from ui.htmlTemplates import css, bot_template, user_template, source_template
16
  from config import MODELS, DATASETS
17
 
18
  from qaPipeline import QAPipeline
 
19
  from faissDb import create_faiss
20
 
21
  # loads environment variables
@@ -25,7 +26,12 @@ load_dotenv()
25
  isHuggingFaceHubEnabled = os.environ.get('ENABLE_HUGGINGFSCE_HUB_MODELS')
26
  isOpenAiApiEnabled = os.environ.get('ENABLE_OPENAI_API_MODELS')
27
 
 
 
 
 
28
  qaPipeline = QAPipeline()
 
29
 
30
  def initialize_session_state():
31
  # Initialise all session state variables with defaults
@@ -109,13 +115,22 @@ def side_bar():
109
 
110
  def chat_body():
111
  st.header("Chat with your own data:")
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- user_question = st.text_input(
114
- "Ask a question about your documents:",
115
- placeholder="enter question",
116
- key='user_question',
117
- on_change=submit_user_question,
118
- )
119
 
120
  # if user_question:
121
  # submit_user_question()
@@ -128,18 +143,15 @@ def submit_user_question():
128
  user_question = st.session_state.user_question
129
  # st.success(user_question)
130
  handle_userinput(user_question)
131
- st.session_state.user_question=''
132
 
133
 
134
  def main():
135
 
136
- st.set_page_config(page_title="Chat with data",
137
- page_icon=":books:")
138
- st.write(css, unsafe_allow_html=True)
139
-
140
  initialize_session_state()
141
 
142
  side_bar()
 
143
  chat_body()
144
 
145
 
@@ -157,7 +169,7 @@ def parameters_change_button(chat_model, show_source):
157
  time.sleep(1) # Wait for 3 seconds
158
  alert.empty() # Clear the alert
159
 
160
-
161
  def get_answer_from_backend(query, model, dataset):
162
  # response = qaPipeline.run(query=query, model=model, dataset=dataset)
163
  response = qaPipeline.run_agent(query=query, model=model, dataset=dataset)
@@ -167,7 +179,7 @@ def get_answer_from_backend(query, model, dataset):
167
  def show_query_response(query, response, show_source_files):
168
  docs = []
169
  if isinstance(response, dict):
170
- answer, docs = response['result'], response['source_documents']
171
  else:
172
  answer = response
173
 
@@ -217,9 +229,15 @@ def handle_userinput(query):
217
 
218
  except Exception as e:
219
  # logger.error(f"Answer retrieval failed with {e}")
220
- st.error(f"Error : {e}")#, icon=":books:")
221
  return
222
 
223
 
224
  if __name__ == "__main__":
225
  main()
 
 
 
 
 
 
 
16
  from config import MODELS, DATASETS
17
 
18
  from qaPipeline import QAPipeline
19
+ import qaPipeline_functions
20
  from faissDb import create_faiss
21
 
22
  # loads environment variables
 
26
  isHuggingFaceHubEnabled = os.environ.get('ENABLE_HUGGINGFSCE_HUB_MODELS')
27
  isOpenAiApiEnabled = os.environ.get('ENABLE_OPENAI_API_MODELS')
28
 
29
+ st.set_page_config(page_title="Chat with data",
30
+ page_icon=":books:")
31
+ st.write(css, unsafe_allow_html=True)
32
+
33
  qaPipeline = QAPipeline()
34
+ # qaPipeline = qaPipeline_functions
35
 
36
  def initialize_session_state():
37
  # Initialise all session state variables with defaults
 
115
 
116
  def chat_body():
117
  st.header("Chat with your own data:")
118
+ with st.form('chat_body'):
119
+
120
+ user_question = st.text_input(
121
+ "Ask a question about your documents:",
122
+ placeholder="enter question",
123
+ key='user_question',
124
+ # on_change=submit_user_question,
125
+ )
126
+
127
+ submitted = st.form_submit_button(
128
+ "Submit",
129
+ # on_click=update_parameters_change
130
+ )
131
 
132
+ if submitted:
133
+ submit_user_question()
 
 
 
 
134
 
135
  # if user_question:
136
  # submit_user_question()
 
143
  user_question = st.session_state.user_question
144
  # st.success(user_question)
145
  handle_userinput(user_question)
146
+ # st.session_state.user_question=''
147
 
148
 
149
  def main():
150
 
 
 
 
 
151
  initialize_session_state()
152
 
153
  side_bar()
154
+
155
  chat_body()
156
 
157
 
 
169
  time.sleep(1) # Wait for 3 seconds
170
  alert.empty() # Clear the alert
171
 
172
+ # @st.cache_data
173
  def get_answer_from_backend(query, model, dataset):
174
  # response = qaPipeline.run(query=query, model=model, dataset=dataset)
175
  response = qaPipeline.run_agent(query=query, model=model, dataset=dataset)
 
179
  def show_query_response(query, response, show_source_files):
180
  docs = []
181
  if isinstance(response, dict):
182
+ answer, docs = response['answer'], response['source_documents']
183
  else:
184
  answer = response
185
 
 
229
 
230
  except Exception as e:
231
  # logger.error(f"Answer retrieval failed with {e}")
232
+ st.error(f"Streamlit handle_userinput Error : {e}")#, icon=":books:")
233
  return
234
 
235
 
236
  if __name__ == "__main__":
237
  main()
238
+
239
+ # initialize_session_state()
240
+
241
+ # side_bar()
242
+
243
+ # chat_body()
app2.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Python Backend API to chat with private data
3
+
4
+ 08/16/2023
5
+ D.M. Theekshana Samaradiwakara
6
+ """
7
+
8
+ import os
9
+ import time
10
+ import streamlit as st
11
+ from streamlit.logger import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+ from ui.htmlTemplates import css, bot_template, user_template, source_template
16
+ from config import MODELS, DATASETS
17
+
18
+ from qaPipeline import QAPipeline
19
+ from faissDb import create_faiss
20
+
21
+ # loads environment variables
22
+ from dotenv import load_dotenv
23
+ load_dotenv()
24
+
25
+ isHuggingFaceHubEnabled = os.environ.get('ENABLE_HUGGINGFSCE_HUB_MODELS')
26
+ isOpenAiApiEnabled = os.environ.get('ENABLE_OPENAI_API_MODELS')
27
+
28
+ st.set_page_config(page_title="Chat with data",
29
+ page_icon=":books:")
30
+ st.write(css, unsafe_allow_html=True)
31
+
32
+
33
+ SESSION_DEFAULTS = {
34
+ "model": MODELS["DEFAULT"],
35
+ "dataset": DATASETS["DEFAULT"],
36
+ "chat_history": None,
37
+ "is_parameters_changed":False,
38
+ "show_source_files": False,
39
+ "user_question":'',
40
+ }
41
+
42
+ for k, v in SESSION_DEFAULTS.items():
43
+ if k not in st.session_state:
44
+ st.session_state[k] = v
45
+
46
+
47
+ with st.sidebar:
48
+ st.subheader("Chat parameters")
49
+
50
+ with st.form('param_form'):
51
+
52
+ chat_model = st.selectbox(
53
+ "Chat model",
54
+ MODELS,
55
+ key="chat_model",
56
+ help="Select the LLM model for the chat",
57
+ # on_change=update_parameters_change,
58
+ )
59
+
60
+ st.session_state.dataset = "DEFAULT"
61
+
62
+ show_source = st.checkbox(
63
+ label="show source files",
64
+ key="show_source",
65
+ help="Select this to show relavant source files for the query",
66
+ )
67
+
68
+ submitted = st.form_submit_button(
69
+ "Submit",
70
+ # on_click=parameters_change_button,
71
+ # args=[chat_model, show_source]
72
+ )
73
+
74
+ # submitted = st.button(
75
+ # "Submit",
76
+ # # on_click=parameters_change_button,
77
+ # # args=[chat_model, show_source]
78
+ # )
79
+
80
+ if submitted:
81
+ st.session_state.model = chat_model
82
+ st.session_state.dataset = "DEFAULT"
83
+ st.session_state.show_source_files = show_source
84
+ st.session_state.is_parameters_changed = False
85
+
86
+ alert = st.success("chat parameters updated")
87
+ time.sleep(1) # Wait for 3 seconds
88
+ alert.empty() # Clear the alert
89
+
90
+ st.markdown("\n")
91
+
92
+ # if st.button("Create FAISS db"):
93
+ # try:
94
+ # with st.spinner('creating faiss vector store'):
95
+ # create_faiss()
96
+ # st.success('faiss saved')
97
+ # except Exception as e:
98
+ # st.error(f"Error : {e}")#, icon=":books:")
99
+ # return
100
+
101
+ st.markdown(
102
+ "### How to use\n"
103
+ "1. Select the chat model\n" # noqa: E501
104
+ "2. Select \"show source files\" to show the source files related to the answer.📄\n"
105
+ "3. Ask a question about the documents💬\n"
106
+ )
107
+
108
+
109
+
110
+ st.header("Chat with your own data:")
111
+ @st.experimental_singleton # 👈 Add the caching decorator
112
+ def load_QaPipeline():
113
+ print('> QAPipeline loaded for front end')
114
+ return QAPipeline()
115
+
116
+ qaPipeline = load_QaPipeline()
117
+ # qaPipeline = QAPipeline()
118
+ with st.form('chat_body'):
119
+
120
+
121
+ user_question = st.text_input(
122
+ "Ask a question about your documents:",
123
+ placeholder="enter question",
124
+ key='user_question',
125
+ # on_change=submit_user_question,
126
+ )
127
+
128
+ submitted = st.form_submit_button(
129
+ "Submit",
130
+ # on_click=submit_user_question
131
+ )
132
+
133
+ if submitted:
134
+ with st.spinner("Processing"):
135
+ user_question = st.session_state.user_question
136
+ # st.success(user_question)
137
+ query = user_question
138
+ # st.session_state.user_question=''
139
+
140
+ # Get the answer from the chain
141
+ try:
142
+ if (not query) or (query.strip() == ''):
143
+ st.error("Please enter a question!")
144
+ st.stop()
145
+
146
+ model = MODELS[st.session_state.model]
147
+ dataset = DATASETS[st.session_state.dataset]
148
+ show_source_files = st.session_state.show_source_files
149
+
150
+ # Try to access openai and deeplake
151
+ print(f">\n model: {model} \n dataset : {dataset} \n show_source_files : {show_source_files}")
152
+
153
+ # response = qaPipeline.run(query=query, model=model, dataset=dataset)
154
+ response = qaPipeline.run_agent(query=query, model=model, dataset=dataset)
155
+
156
+
157
+ docs = []
158
+ if isinstance(response, dict):
159
+ answer, docs = response['answer'], response['source_documents']
160
+ else:
161
+ answer = response
162
+
163
+ st.write(user_template.replace(
164
+ "{{MSG}}", query), unsafe_allow_html=True)
165
+ st.write(bot_template.replace(
166
+ "{{MSG}}", answer ), unsafe_allow_html=True)
167
+
168
+ if show_source_files:
169
+ # st.write(source_template.replace(
170
+ # "{{MSG}}", "source files" ), unsafe_allow_html=True)
171
+
172
+ if len(docs)>0 :
173
+ st.markdown("#### source files : ")
174
+ for source in docs:
175
+ # st.info(source.metadata)
176
+ with st.expander(source.metadata["source"]):
177
+ st.markdown(source.page_content)
178
+
179
+ # st.write(response)
180
+
181
+ except Exception as e:
182
+ # logger.error(f"Answer retrieval failed with {e}")
183
+ st.error(f"Error : {e}")#, icon=":books:")
184
+
conversationBufferWindowMemory.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import Any, Dict, Optional, Tuple
3
+ # import json
4
+
5
+ from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
6
+ from langchain.memory.utils import get_prompt_input_key
7
+ from langchain.pydantic_v1 import Field
8
+ from langchain.schema import BaseChatMessageHistory, BaseMemory
9
+
10
+ from typing import List, Union
11
+
12
+ # from langchain.memory.chat_memory import BaseChatMemory
13
+ from langchain.schema.messages import BaseMessage, get_buffer_string
14
+
15
+
16
+ class BaseChatMemory(BaseMemory, ABC):
17
+ """Abstract base class for chat memory."""
18
+
19
+ chat_memory: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
20
+ output_key: Optional[str] = None
21
+ input_key: Optional[str] = None
22
+ return_messages: bool = False
23
+
24
+ def _get_input_output(
25
+ self, inputs: Dict[str, Any], outputs: Dict[str, str]
26
+ ) -> Tuple[str, str]:
27
+
28
+
29
+ if self.input_key is None:
30
+ prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
31
+ else:
32
+ prompt_input_key = self.input_key
33
+
34
+ if self.output_key is None:
35
+ """
36
+ output for agent with LLM chain tool = {answer}
37
+ output for agent with ConversationalRetrievalChain tool = {'question', 'chat_history', 'answer','source_documents'}
38
+ """
39
+
40
+ LLM_key = 'output'
41
+ Retrieval_key = 'answer'
42
+ if isinstance(outputs[LLM_key], dict):
43
+ Retrieval_dict = outputs[LLM_key]
44
+ if Retrieval_key in Retrieval_dict.keys():
45
+ #output keys are 'answer' , 'source_documents'
46
+ output = Retrieval_dict[Retrieval_key]
47
+ else:
48
+ raise ValueError(f"output key: {LLM_key} not a valid dictionary")
49
+
50
+ else:
51
+ #otherwise output key will be 'output'
52
+ output_key = list(outputs.keys())[0]
53
+ output = outputs[output_key]
54
+
55
+ # if len(outputs) != 1:
56
+ # raise ValueError(f"One output key expected, got {outputs.keys()}")
57
+
58
+
59
+ else:
60
+ output_key = self.output_key
61
+ output = outputs[output_key]
62
+
63
+ return inputs[prompt_input_key], output
64
+
65
+ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
66
+ """Save context from this conversation to buffer."""
67
+ input_str, output_str = self._get_input_output(inputs, outputs)
68
+ self.chat_memory.add_user_message(input_str)
69
+ self.chat_memory.add_ai_message(output_str)
70
+
71
+ def clear(self) -> None:
72
+ """Clear memory contents."""
73
+ self.chat_memory.clear()
74
+
75
+
76
+
77
+
78
+
79
+ class ConversationBufferWindowMemory(BaseChatMemory):
80
+ """Buffer for storing conversation memory inside a limited size window."""
81
+
82
+ human_prefix: str = "Human"
83
+ ai_prefix: str = "AI"
84
+ memory_key: str = "history" #: :meta private:
85
+ k: int = 5
86
+ """Number of messages to store in buffer."""
87
+
88
+ @property
89
+ def buffer(self) -> Union[str, List[BaseMessage]]:
90
+ """String buffer of memory."""
91
+ return self.buffer_as_messages if self.return_messages else self.buffer_as_str
92
+
93
+ @property
94
+ def buffer_as_str(self) -> str:
95
+ """Exposes the buffer as a string in case return_messages is True."""
96
+ messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
97
+ return get_buffer_string(
98
+ messages,
99
+ human_prefix=self.human_prefix,
100
+ ai_prefix=self.ai_prefix,
101
+ )
102
+
103
+ @property
104
+ def buffer_as_messages(self) -> List[BaseMessage]:
105
+ """Exposes the buffer as a list of messages in case return_messages is False."""
106
+ return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
107
+
108
+ @property
109
+ def memory_variables(self) -> List[str]:
110
+ """Will always return list of memory variables.
111
+
112
+ :meta private:
113
+ """
114
+ return [self.memory_key]
115
+
116
+ def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
117
+ """Return history buffer."""
118
+ return {self.memory_key: self.buffer}
qaPipeline.py CHANGED
@@ -23,10 +23,12 @@ from langchain.chat_models import ChatOpenAI
23
  # from chromaDb import load_store
24
  from faissDb import load_FAISS_store
25
 
26
- from langchain.agents import initialize_agent, Tool
27
- from langchain.agents import AgentType
28
  from langchain.prompts import PromptTemplate
29
- from langchain.chains import LLMChain
 
 
30
 
31
  load_dotenv()
32
 
@@ -43,10 +45,35 @@ verbose = os.environ.get('VERBOSE')
43
  # activate/deactivate the streaming StdOut callback for LLMs
44
  callbacks = [StreamingStdOutCallbackHandler()]
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  class QAPipeline:
47
 
48
  def __init__(self):
49
-
 
50
  self.llm_name = None
51
  self.llm = None
52
 
@@ -56,6 +83,7 @@ class QAPipeline:
56
  self.qa_chain = None
57
  self.agent = None
58
 
 
59
  def run(self,query, model, dataset):
60
 
61
  if (self.llm_name != model) or (self.dataset_name != dataset) or (self.qa_chain == None):
@@ -79,24 +107,31 @@ class QAPipeline:
79
 
80
  def run_agent(self,query, model, dataset):
81
 
82
- if (self.llm_name != model) or (self.dataset_name != dataset) or (self.agent == None):
83
- self.set_model(model)
84
- self.set_vectorstore(dataset)
85
- self.set_qa_chain_with_agent()
86
 
87
- # Get the answer from the chain
88
- start = time.time()
89
- res = self.agent(query)
90
- # answer, docs = res['result'],res['source_documents']
91
- end = time.time()
92
 
93
- # Print the result
94
- print("\n\n> Question:")
95
- print(query)
96
- print(f"\n> Answer (took {round(end - start, 2)} s.):")
97
- print( res)
98
 
99
- return res["output"]
 
 
 
 
 
 
 
 
 
 
 
100
 
101
 
102
 
@@ -139,67 +174,115 @@ class QAPipeline:
139
 
140
  def set_qa_chain_with_agent(self):
141
 
142
- # Define a custom prompt
143
- general_qa_template = (
144
- """You are the AI assistant of the Boardpac company which provide services for companies board members.
145
- You can have a general conversation with the users like greetings.
146
- But only answer questions related to banking sector like financial and legal.
147
- If you dont know the answer say you dont know, dont try to makeup answers.
148
- each answer should start with code word BoardPac AI (Conversation):
149
- Question: {question}
150
- """
151
- )
152
-
153
- general_qa_chain_prompt = PromptTemplate.from_template(general_qa_template)
154
- general_qa_chain = LLMChain(llm=self.llm, prompt=general_qa_chain_prompt)
155
-
156
- # Define a custom prompt
157
- retrieval_qa_template = (
158
- """You are the AI assistant of the Boardpac company which provide services for companies board members.
159
- You have provided context information below related to central bank acts published in various years. The content of a bank act can updated by a bank act from a latest year.
160
- {context}
161
- Given this information, please answer the question with the latest information.
162
- If you dont know the answer say you dont know, dont try to makeup answers.
163
- each answer should start with code word BoardPac AI (Retrieval):
164
- Question: {question}
165
- """
166
- )
167
- retrieval_qa_chain_prompt = PromptTemplate.from_template(retrieval_qa_template)
168
-
169
- bank_regulations_qa = RetrievalQA.from_chain_type(
170
- llm=self.llm,
171
- chain_type="stuff",
172
- retriever = self.vectorstore.as_retriever(),
173
- # retriever = self.vectorstore.as_retriever(search_kwargs={"k": target_source_chunks}
174
- return_source_documents= True,
175
- input_key="question",
176
- chain_type_kwargs={"prompt": retrieval_qa_chain_prompt},
177
- )
178
-
179
- tools = [
180
- Tool(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  name="bank regulations",
182
- func= lambda query: bank_regulations_qa({"question": query}),
183
  description='''useful for when you need to answer questions about
184
  financial and legal information issued from central bank regarding banks and bank regulations.
185
  Input should be a fully formed question.''',
186
  return_direct=True,
187
- ),
188
 
189
- Tool(
190
- name="general qa",
191
- func= general_qa_chain.run,
192
- description='''useful for when you need to have a general conversation with the users like greetings
193
- or to answer general purpose questions related to banking sector like financial and legal.
194
- Input should be a fully formed question.''',
195
- return_direct=True,
196
- ),
197
- ]
198
-
199
- self.agent = initialize_agent(
200
- tools,
201
- self.llm,
202
- agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
203
- verbose=True,
204
- max_iterations=3,
205
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # from chromaDb import load_store
24
  from faissDb import load_FAISS_store
25
 
26
+ from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
27
+
28
  from langchain.prompts import PromptTemplate
29
+ from langchain.chains import LLMChain, ConversationalRetrievalChain
30
+ from conversationBufferWindowMemory import ConversationBufferWindowMemory
31
+ from langchain.memory import ReadOnlySharedMemory
32
 
33
  load_dotenv()
34
 
 
45
  # activate/deactivate the streaming StdOut callback for LLMs
46
  callbacks = [StreamingStdOutCallbackHandler()]
47
 
48
+ memory = ConversationBufferWindowMemory(
49
+ memory_key="chat_history",
50
+ input_key="question",
51
+ return_messages=True,
52
+ k=3
53
+ )
54
+
55
+ readonlymemory = ReadOnlySharedMemory(memory=memory)
56
+
57
+ class Singleton:
58
+ __instance = None
59
+ @staticmethod
60
+ def getInstance():
61
+ """ Static access method. """
62
+ if Singleton.__instance == None:
63
+ Singleton()
64
+ return Singleton.__instance
65
+ def __init__(self):
66
+ """ Virtually private constructor. """
67
+ if Singleton.__instance != None:
68
+ raise Exception("This class is a singleton!")
69
+ else:
70
+ Singleton.__instance = QAPipeline()
71
+
72
  class QAPipeline:
73
 
74
  def __init__(self):
75
+
76
+ print("\n\n> Initializing QAPipeline:")
77
  self.llm_name = None
78
  self.llm = None
79
 
 
83
  self.qa_chain = None
84
  self.agent = None
85
 
86
+
87
  def run(self,query, model, dataset):
88
 
89
  if (self.llm_name != model) or (self.dataset_name != dataset) or (self.qa_chain == None):
 
107
 
108
  def run_agent(self,query, model, dataset):
109
 
110
+ try:
 
 
 
111
 
112
+ if (self.llm_name != model) or (self.dataset_name != dataset) or (self.agent == None):
113
+ self.set_model(model)
114
+ self.set_vectorstore(dataset)
115
+ self.set_qa_chain_with_agent()
 
116
 
117
+ # Get the answer from the chain
118
+ start = time.time()
119
+ res = self.agent(query)
120
+ # answer, docs = res['result'],res['source_documents']
121
+ end = time.time()
122
 
123
+ # Print the result
124
+ print("\n\n> Question:")
125
+ print(query)
126
+ print(f"\n> Answer (took {round(end - start, 2)} s.):")
127
+ print( res)
128
+
129
+ return res["output"]
130
+
131
+ except Exception as e:
132
+ # logger.error(f"Answer retrieval failed with {e}")
133
+ print(f"> QAPipeline run_agent Error : {e}")#, icon=":books:")
134
+ return
135
 
136
 
137
 
 
174
 
175
  def set_qa_chain_with_agent(self):
176
 
177
+ try:
178
+
179
+ # Define a custom prompt
180
+ general_qa_template = (
181
+ """You can have a general conversation with the users like greetings.
182
+ Continue the conversation and only answer questions related to banking sector like financial and legal.
183
+ If you dont know the answer say you dont know, dont try to makeup answers.
184
+ Conversation: {chat_history}
185
+ Question: {question}
186
+ """
187
+ )
188
+
189
+ general_qa_chain_prompt = PromptTemplate(input_variables=["question", "chat_history"], template=general_qa_template)
190
+
191
+ general_qa_chain = LLMChain(
192
+ llm=self.llm,
193
+ prompt=general_qa_chain_prompt,
194
+ verbose=True,
195
+ memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
196
+ )
197
+
198
+ general_qa_chain_tool = Tool(
199
+ name="general qa",
200
+ func= general_qa_chain.run,
201
+ description='''useful for when you need to have a general conversation with the users like greetings
202
+ or to answer general purpose questions related to banking sector like financial and legal.
203
+ Input should be a fully formed question.''',
204
+ return_direct=True,
205
+
206
+ )
207
+
208
+ # Define a custom prompt
209
+ retrieval_qa_template = (
210
+ """
211
+ please answer the question based on the chat history and context with the latest information.
212
+ You have provided context information below related to central bank acts published in various years.
213
+ The content of a bank act can updated by a bank act from a latest year.
214
+ If you dont know the answer say you dont know, dont try to makeup answers.
215
+ Conversation: {chat_history}
216
+ Context: {context}
217
+ Question : {question}
218
+ """
219
+ )
220
+ retrieval_qa_chain_prompt = PromptTemplate(
221
+ input_variables=["question", "context", "chat_history"],
222
+ template=retrieval_qa_template
223
+ )
224
+
225
+ bank_regulations_qa = ConversationalRetrievalChain.from_llm(
226
+ llm=self.llm,
227
+ chain_type="stuff",
228
+ retriever = self.vectorstore.as_retriever(),
229
+ # retriever = self.vectorstore.as_retriever(search_kwargs={"k": target_source_chunks}
230
+ return_source_documents= True,
231
+ get_chat_history=lambda h : h,
232
+ combine_docs_chain_kwargs={"prompt": retrieval_qa_chain_prompt},
233
+ verbose=True,
234
+ memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
235
+ )
236
+
237
+ bank_regulations_qa_tool = Tool(
238
  name="bank regulations",
239
+ func= lambda question: bank_regulations_qa({"question": question}),
240
  description='''useful for when you need to answer questions about
241
  financial and legal information issued from central bank regarding banks and bank regulations.
242
  Input should be a fully formed question.''',
243
  return_direct=True,
244
+ )
245
 
246
+ tools = [
247
+ bank_regulations_qa_tool,
248
+ general_qa_chain_tool
249
+ ]
250
+
251
+ prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
252
+ suffix = """Begin!"
253
+
254
+ {chat_history}
255
+ Question: {question}
256
+ {agent_scratchpad}"""
257
+
258
+ agent_prompt = ZeroShotAgent.create_prompt(
259
+ tools,
260
+ prefix=prefix,
261
+ suffix=suffix,
262
+ input_variables=["question", "chat_history", "agent_scratchpad"],
263
+ )
264
+
265
+ llm_chain = LLMChain(llm=self.llm, prompt=agent_prompt)
266
+
267
+ agent = ZeroShotAgent(
268
+ llm_chain=llm_chain,
269
+ tools=tools,
270
+ verbose=True,
271
+ )
272
+
273
+ agent_chain = AgentExecutor.from_agent_and_tools(
274
+ agent=agent,
275
+ tools=tools,
276
+ verbose=True,
277
+ memory=memory,
278
+ handle_parsing_errors=True,
279
+ )
280
+
281
+ self.agent = agent_chain
282
+
283
+ print(f"\n> agent_chain created")
284
+
285
+ except Exception as e:
286
+ # logger.error(f"Answer retrieval failed with {e}")
287
+ print(f"> QAPipeline set_qa_chain_with_agent Error : {e}")#, icon=":books:")
288
+ return
qaPipeline_functions.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Python Backend API to chat with private data
3
+
4
+ 08/14/2023
5
+ D.M. Theekshana Samaradiwakara
6
+ """
7
+
8
+ import os
9
+ import time
10
+
11
+ from dotenv import load_dotenv
12
+
13
+ from langchain.chains import RetrievalQA
14
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
15
+
16
+ from langchain.llms import GPT4All
17
+ from langchain.llms import HuggingFaceHub
18
+ from langchain.chat_models import ChatOpenAI
19
+
20
+ # from langchain.retrievers._query.base import SelfQueryRetriever
21
+ # from langchain.chains.query_constructor.base import AttributeInfo
22
+
23
+ # from chromaDb import load_store
24
+ from faissDb import load_FAISS_store
25
+
26
+ from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
27
+
28
+ from langchain.prompts import PromptTemplate
29
+ from langchain.chains import LLMChain, ConversationalRetrievalChain
30
+ from conversationBufferWindowMemory import ConversationBufferWindowMemory
31
+ from langchain.memory import ReadOnlySharedMemory
32
+
33
+ load_dotenv()
34
+
35
+ #gpt4 all model
36
+ gpt4all_model_path = os.environ.get('GPT4ALL_MODEL_PATH')
37
+ model_n_ctx = os.environ.get('MODEL_N_CTX')
38
+ model_n_batch = int(os.environ.get('MODEL_N_BATCH',8))
39
+ target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
40
+
41
+ openai_api_key = os.environ.get('OPENAI_API_KEY')
42
+
43
+ verbose = os.environ.get('VERBOSE')
44
+
45
+ # activate/deactivate the streaming StdOut callback for LLMs
46
+ callbacks = [StreamingStdOutCallbackHandler()]
47
+
48
+ memory = ConversationBufferWindowMemory(
49
+ memory_key="chat_history",
50
+ input_key="question",
51
+ return_messages=True,
52
+ k=3
53
+ )
54
+
55
+ readonlymemory = ReadOnlySharedMemory(memory=memory)
56
+
57
+
58
+ print("\n\n> Initializing QAPipeline:")
59
+
60
+ global llm_name
61
+ llm_name = 'None'
62
+ global llm
63
+ llm = 'None'
64
+
65
+ global dataset_name
66
+ dataset_name = 'None'
67
+ global vectorstore
68
+ vectorstore = 'None'
69
+
70
+ qa_chain = None
71
+ agent = None
72
+
73
+
74
+ def run(query, model, dataset):
75
+
76
+ if (llm_name != model) or (dataset_name != dataset) or (qa_chain == None):
77
+ set_model(model)
78
+ set_vectorstore(dataset)
79
+ set_qa_chain()
80
+
81
+ # Get the answer from the chain
82
+ start = time.time()
83
+ res = qa_chain(query)
84
+ # answer, docs = res['result'],res['source_documents']
85
+ end = time.time()
86
+
87
+ # Print the result
88
+ print("\n\n> Question:")
89
+ print(query)
90
+ print(f"\n> Answer (took {round(end - start, 2)} s.):")
91
+ print( res)
92
+
93
+ return res
94
+
95
+ def run_agent(query, model, dataset):
96
+
97
+ try:
98
+
99
+ if (llm_name != model) or (dataset_name != dataset) or (agent == None):
100
+ set_model(model)
101
+ set_vectorstore(dataset)
102
+ set_qa_chain_with_agent()
103
+
104
+ # Get the answer from the chain
105
+ start = time.time()
106
+ res = agent(query)
107
+ # answer, docs = res['result'],res['source_documents']
108
+ end = time.time()
109
+
110
+ # Print the result
111
+ print("\n\n> Question:")
112
+ print(query)
113
+ print(f"\n> Answer (took {round(end - start, 2)} s.):")
114
+ print( res)
115
+
116
+ return res["output"]
117
+
118
+ except Exception as e:
119
+ # logger.error(f"Answer retrieval failed with {e}")
120
+ print(f"> QAPipeline run_agent Error : {e}")#, icon=":books:")
121
+ return
122
+
123
+
124
+
125
+ def set_model(model_type):
126
+ if model_type != llm_name:
127
+ global llm
128
+ match model_type:
129
+ case "gpt4all":
130
+ # llm = GPT4All(model=gpt4all_model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose)
131
+ llm = GPT4All(model=gpt4all_model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose)
132
+ # llm = HuggingFaceHub(repo_id="nomic-ai/gpt4all-j", model_kwargs={"temperature":0.001, "max_length":1024})
133
+ case "google/flan-t5-xxl":
134
+ llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.001, "max_length":1024})
135
+ case "tiiuae/falcon-7b-instruct":
136
+ llm = HuggingFaceHub(repo_id=model_type, model_kwargs={"temperature":0.001, "max_length":1024})
137
+ case "openai":
138
+ llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
139
+ case _default:
140
+ # raise exception if model_type is not supported
141
+ raise Exception(f"Model type {model_type} is not supported. Please choose a valid one")
142
+ # global llm_name
143
+ llm_name = model_type
144
+
145
+ def set_vectorstore( dataset):
146
+ if dataset != dataset_name:
147
+ # vectorstore = load_store(dataset)
148
+ global vectorstore
149
+ vectorstore = load_FAISS_store()
150
+ print("\n\n> vectorstore loaded:")
151
+ dataset_name = dataset
152
+
153
+ def set_qa_chain():
154
+ global qa_chain
155
+ qa_chain = RetrievalQA.from_chain_type(
156
+ llm=llm,
157
+ chain_type="stuff",
158
+ retriever = vectorstore.as_retriever(),
159
+ # retriever = vectorstore.as_retriever(search_kwargs={"k": target_source_chunks}
160
+ return_source_documents= True
161
+ )
162
+
163
+
164
+ def set_qa_chain_with_agent():
165
+
166
+ try:
167
+
168
+ # Define a custom prompt
169
+ general_qa_template = (
170
+ """You can have a general conversation with the users like greetings.
171
+ Continue the conversation and only answer questions related to banking sector like financial and legal.
172
+ If you dont know the answer say you dont know, dont try to makeup answers.
173
+ Conversation: {chat_history}
174
+ Question: {question}
175
+ """
176
+ )
177
+
178
+ general_qa_chain_prompt = PromptTemplate(input_variables=["question", "chat_history"], template=general_qa_template)
179
+
180
+ general_qa_chain = LLMChain(
181
+ llm=llm,
182
+ prompt=general_qa_chain_prompt,
183
+ verbose=True,
184
+ memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
185
+ )
186
+
187
+ general_qa_chain_tool = Tool(
188
+ name="general qa",
189
+ func= general_qa_chain.run,
190
+ description='''useful for when you need to have a general conversation with the users like greetings
191
+ or to answer general purpose questions related to banking sector like financial and legal.
192
+ Input should be a fully formed question.''',
193
+ return_direct=True,
194
+
195
+ )
196
+
197
+ # Define a custom prompt
198
+ retrieval_qa_template = (
199
+ """
200
+ please answer the question based on the chat history and context with the latest information.
201
+ You have provided context information below related to central bank acts published in various years.
202
+ The content of a bank act can updated by a bank act from a latest year.
203
+ If you dont know the answer say you dont know, dont try to makeup answers.
204
+ Conversation: {chat_history}
205
+ Context: {context}
206
+ Question : {question}
207
+ """
208
+ )
209
+ retrieval_qa_chain_prompt = PromptTemplate(
210
+ input_variables=["question", "context", "chat_history"],
211
+ template=retrieval_qa_template
212
+ )
213
+
214
+ bank_regulations_qa = ConversationalRetrievalChain.from_llm(
215
+ llm=llm,
216
+ chain_type="stuff",
217
+ retriever = vectorstore.as_retriever(),
218
+ # retriever = vectorstore.as_retriever(search_kwargs={"k": target_source_chunks}
219
+ return_source_documents= True,
220
+ get_chat_history=lambda h : h,
221
+ combine_docs_chain_kwargs={"prompt": retrieval_qa_chain_prompt},
222
+ verbose=True,
223
+ memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
224
+ )
225
+
226
+ bank_regulations_qa_tool = Tool(
227
+ name="bank regulations",
228
+ func= lambda question: bank_regulations_qa({"question": question}),
229
+ description='''useful for when you need to answer questions about
230
+ financial and legal information issued from central bank regarding banks and bank regulations.
231
+ Input should be a fully formed question.''',
232
+ return_direct=True,
233
+ )
234
+
235
+ tools = [
236
+ bank_regulations_qa_tool,
237
+ general_qa_chain_tool
238
+ ]
239
+
240
+ prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
241
+ suffix = """Begin!"
242
+
243
+ {chat_history}
244
+ Question: {question}
245
+ {agent_scratchpad}"""
246
+
247
+ agent_prompt = ZeroShotAgent.create_prompt(
248
+ tools,
249
+ prefix=prefix,
250
+ suffix=suffix,
251
+ input_variables=["question", "chat_history", "agent_scratchpad"],
252
+ )
253
+
254
+ llm_chain = LLMChain(llm=llm, prompt=agent_prompt)
255
+
256
+ zeroShotAgent = ZeroShotAgent(
257
+ llm_chain=llm_chain,
258
+ tools=tools,
259
+ verbose=True,
260
+ )
261
+
262
+ agent_chain = AgentExecutor.from_agent_and_tools(
263
+ agent=zeroShotAgent,
264
+ tools=tools,
265
+ verbose=True,
266
+ memory=memory,
267
+ handle_parsing_errors=True,
268
+ )
269
+
270
+ global agent
271
+ agent = agent_chain
272
+
273
+ print(f"\n> agent_chain created")
274
+
275
+ except Exception as e:
276
+ # logger.error(f"Answer retrieval failed with {e}")
277
+ print(f"> QAPipeline set_qa_chain_with_agent Error : {e}")#, icon=":books:")
278
+ return