Spaces:
Runtime error
Runtime error
Boardpac/theekshanas
commited on
Commit
•
027bfbf
1
Parent(s):
9e2dc86
agent with memory
Browse files- .env +1 -1
- __pycache__/conversationBufferWindowMemory.cpython-311.pyc +0 -0
- __pycache__/qaPipeline.cpython-311.pyc +0 -0
- __pycache__/qaPipeline_functions.cpython-311.pyc +0 -0
- app.py +32 -14
- app2.py +184 -0
- conversationBufferWindowMemory.py +118 -0
- qaPipeline.py +160 -77
- qaPipeline_functions.py +278 -0
.env
CHANGED
@@ -11,7 +11,7 @@ TARGET_SOURCE_CHUNKS=4
|
|
11 |
|
12 |
#API token keys
|
13 |
HUGGINGFACEHUB_API_TOKEN=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
|
14 |
-
OPENAI_API_KEY=sk-
|
15 |
|
16 |
#api app
|
17 |
APP_HOST=127.0.0.1
|
|
|
11 |
|
12 |
#API token keys
|
13 |
HUGGINGFACEHUB_API_TOKEN=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
|
14 |
+
OPENAI_API_KEY=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
|
15 |
|
16 |
#api app
|
17 |
APP_HOST=127.0.0.1
|
__pycache__/conversationBufferWindowMemory.cpython-311.pyc
ADDED
Binary file (6.23 kB). View file
|
|
__pycache__/qaPipeline.cpython-311.pyc
CHANGED
Binary files a/__pycache__/qaPipeline.cpython-311.pyc and b/__pycache__/qaPipeline.cpython-311.pyc differ
|
|
__pycache__/qaPipeline_functions.cpython-311.pyc
ADDED
Binary file (10.3 kB). View file
|
|
app.py
CHANGED
@@ -16,6 +16,7 @@ from ui.htmlTemplates import css, bot_template, user_template, source_template
|
|
16 |
from config import MODELS, DATASETS
|
17 |
|
18 |
from qaPipeline import QAPipeline
|
|
|
19 |
from faissDb import create_faiss
|
20 |
|
21 |
# loads environment variables
|
@@ -25,7 +26,12 @@ load_dotenv()
|
|
25 |
isHuggingFaceHubEnabled = os.environ.get('ENABLE_HUGGINGFSCE_HUB_MODELS')
|
26 |
isOpenAiApiEnabled = os.environ.get('ENABLE_OPENAI_API_MODELS')
|
27 |
|
|
|
|
|
|
|
|
|
28 |
qaPipeline = QAPipeline()
|
|
|
29 |
|
30 |
def initialize_session_state():
|
31 |
# Initialise all session state variables with defaults
|
@@ -109,13 +115,22 @@ def side_bar():
|
|
109 |
|
110 |
def chat_body():
|
111 |
st.header("Chat with your own data:")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
placeholder="enter question",
|
116 |
-
key='user_question',
|
117 |
-
on_change=submit_user_question,
|
118 |
-
)
|
119 |
|
120 |
# if user_question:
|
121 |
# submit_user_question()
|
@@ -128,18 +143,15 @@ def submit_user_question():
|
|
128 |
user_question = st.session_state.user_question
|
129 |
# st.success(user_question)
|
130 |
handle_userinput(user_question)
|
131 |
-
st.session_state.user_question=''
|
132 |
|
133 |
|
134 |
def main():
|
135 |
|
136 |
-
st.set_page_config(page_title="Chat with data",
|
137 |
-
page_icon=":books:")
|
138 |
-
st.write(css, unsafe_allow_html=True)
|
139 |
-
|
140 |
initialize_session_state()
|
141 |
|
142 |
side_bar()
|
|
|
143 |
chat_body()
|
144 |
|
145 |
|
@@ -157,7 +169,7 @@ def parameters_change_button(chat_model, show_source):
|
|
157 |
time.sleep(1) # Wait for 3 seconds
|
158 |
alert.empty() # Clear the alert
|
159 |
|
160 |
-
|
161 |
def get_answer_from_backend(query, model, dataset):
|
162 |
# response = qaPipeline.run(query=query, model=model, dataset=dataset)
|
163 |
response = qaPipeline.run_agent(query=query, model=model, dataset=dataset)
|
@@ -167,7 +179,7 @@ def get_answer_from_backend(query, model, dataset):
|
|
167 |
def show_query_response(query, response, show_source_files):
|
168 |
docs = []
|
169 |
if isinstance(response, dict):
|
170 |
-
answer, docs = response['
|
171 |
else:
|
172 |
answer = response
|
173 |
|
@@ -217,9 +229,15 @@ def handle_userinput(query):
|
|
217 |
|
218 |
except Exception as e:
|
219 |
# logger.error(f"Answer retrieval failed with {e}")
|
220 |
-
st.error(f"Error : {e}")#, icon=":books:")
|
221 |
return
|
222 |
|
223 |
|
224 |
if __name__ == "__main__":
|
225 |
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from config import MODELS, DATASETS
|
17 |
|
18 |
from qaPipeline import QAPipeline
|
19 |
+
import qaPipeline_functions
|
20 |
from faissDb import create_faiss
|
21 |
|
22 |
# loads environment variables
|
|
|
26 |
isHuggingFaceHubEnabled = os.environ.get('ENABLE_HUGGINGFSCE_HUB_MODELS')
|
27 |
isOpenAiApiEnabled = os.environ.get('ENABLE_OPENAI_API_MODELS')
|
28 |
|
29 |
+
st.set_page_config(page_title="Chat with data",
|
30 |
+
page_icon=":books:")
|
31 |
+
st.write(css, unsafe_allow_html=True)
|
32 |
+
|
33 |
qaPipeline = QAPipeline()
|
34 |
+
# qaPipeline = qaPipeline_functions
|
35 |
|
36 |
def initialize_session_state():
|
37 |
# Initialise all session state variables with defaults
|
|
|
115 |
|
116 |
def chat_body():
|
117 |
st.header("Chat with your own data:")
|
118 |
+
with st.form('chat_body'):
|
119 |
+
|
120 |
+
user_question = st.text_input(
|
121 |
+
"Ask a question about your documents:",
|
122 |
+
placeholder="enter question",
|
123 |
+
key='user_question',
|
124 |
+
# on_change=submit_user_question,
|
125 |
+
)
|
126 |
+
|
127 |
+
submitted = st.form_submit_button(
|
128 |
+
"Submit",
|
129 |
+
# on_click=update_parameters_change
|
130 |
+
)
|
131 |
|
132 |
+
if submitted:
|
133 |
+
submit_user_question()
|
|
|
|
|
|
|
|
|
134 |
|
135 |
# if user_question:
|
136 |
# submit_user_question()
|
|
|
143 |
user_question = st.session_state.user_question
|
144 |
# st.success(user_question)
|
145 |
handle_userinput(user_question)
|
146 |
+
# st.session_state.user_question=''
|
147 |
|
148 |
|
149 |
def main():
|
150 |
|
|
|
|
|
|
|
|
|
151 |
initialize_session_state()
|
152 |
|
153 |
side_bar()
|
154 |
+
|
155 |
chat_body()
|
156 |
|
157 |
|
|
|
169 |
time.sleep(1) # Wait for 3 seconds
|
170 |
alert.empty() # Clear the alert
|
171 |
|
172 |
+
# @st.cache_data
|
173 |
def get_answer_from_backend(query, model, dataset):
|
174 |
# response = qaPipeline.run(query=query, model=model, dataset=dataset)
|
175 |
response = qaPipeline.run_agent(query=query, model=model, dataset=dataset)
|
|
|
179 |
def show_query_response(query, response, show_source_files):
|
180 |
docs = []
|
181 |
if isinstance(response, dict):
|
182 |
+
answer, docs = response['answer'], response['source_documents']
|
183 |
else:
|
184 |
answer = response
|
185 |
|
|
|
229 |
|
230 |
except Exception as e:
|
231 |
# logger.error(f"Answer retrieval failed with {e}")
|
232 |
+
st.error(f"Streamlit handle_userinput Error : {e}")#, icon=":books:")
|
233 |
return
|
234 |
|
235 |
|
236 |
if __name__ == "__main__":
|
237 |
main()
|
238 |
+
|
239 |
+
# initialize_session_state()
|
240 |
+
|
241 |
+
# side_bar()
|
242 |
+
|
243 |
+
# chat_body()
|
app2.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Python Backend API to chat with private data
|
3 |
+
|
4 |
+
08/16/2023
|
5 |
+
D.M. Theekshana Samaradiwakara
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
import streamlit as st
|
11 |
+
from streamlit.logger import get_logger
|
12 |
+
|
13 |
+
logger = get_logger(__name__)
|
14 |
+
|
15 |
+
from ui.htmlTemplates import css, bot_template, user_template, source_template
|
16 |
+
from config import MODELS, DATASETS
|
17 |
+
|
18 |
+
from qaPipeline import QAPipeline
|
19 |
+
from faissDb import create_faiss
|
20 |
+
|
21 |
+
# loads environment variables
|
22 |
+
from dotenv import load_dotenv
|
23 |
+
load_dotenv()
|
24 |
+
|
25 |
+
isHuggingFaceHubEnabled = os.environ.get('ENABLE_HUGGINGFSCE_HUB_MODELS')
|
26 |
+
isOpenAiApiEnabled = os.environ.get('ENABLE_OPENAI_API_MODELS')
|
27 |
+
|
28 |
+
st.set_page_config(page_title="Chat with data",
|
29 |
+
page_icon=":books:")
|
30 |
+
st.write(css, unsafe_allow_html=True)
|
31 |
+
|
32 |
+
|
33 |
+
SESSION_DEFAULTS = {
|
34 |
+
"model": MODELS["DEFAULT"],
|
35 |
+
"dataset": DATASETS["DEFAULT"],
|
36 |
+
"chat_history": None,
|
37 |
+
"is_parameters_changed":False,
|
38 |
+
"show_source_files": False,
|
39 |
+
"user_question":'',
|
40 |
+
}
|
41 |
+
|
42 |
+
for k, v in SESSION_DEFAULTS.items():
|
43 |
+
if k not in st.session_state:
|
44 |
+
st.session_state[k] = v
|
45 |
+
|
46 |
+
|
47 |
+
with st.sidebar:
|
48 |
+
st.subheader("Chat parameters")
|
49 |
+
|
50 |
+
with st.form('param_form'):
|
51 |
+
|
52 |
+
chat_model = st.selectbox(
|
53 |
+
"Chat model",
|
54 |
+
MODELS,
|
55 |
+
key="chat_model",
|
56 |
+
help="Select the LLM model for the chat",
|
57 |
+
# on_change=update_parameters_change,
|
58 |
+
)
|
59 |
+
|
60 |
+
st.session_state.dataset = "DEFAULT"
|
61 |
+
|
62 |
+
show_source = st.checkbox(
|
63 |
+
label="show source files",
|
64 |
+
key="show_source",
|
65 |
+
help="Select this to show relavant source files for the query",
|
66 |
+
)
|
67 |
+
|
68 |
+
submitted = st.form_submit_button(
|
69 |
+
"Submit",
|
70 |
+
# on_click=parameters_change_button,
|
71 |
+
# args=[chat_model, show_source]
|
72 |
+
)
|
73 |
+
|
74 |
+
# submitted = st.button(
|
75 |
+
# "Submit",
|
76 |
+
# # on_click=parameters_change_button,
|
77 |
+
# # args=[chat_model, show_source]
|
78 |
+
# )
|
79 |
+
|
80 |
+
if submitted:
|
81 |
+
st.session_state.model = chat_model
|
82 |
+
st.session_state.dataset = "DEFAULT"
|
83 |
+
st.session_state.show_source_files = show_source
|
84 |
+
st.session_state.is_parameters_changed = False
|
85 |
+
|
86 |
+
alert = st.success("chat parameters updated")
|
87 |
+
time.sleep(1) # Wait for 3 seconds
|
88 |
+
alert.empty() # Clear the alert
|
89 |
+
|
90 |
+
st.markdown("\n")
|
91 |
+
|
92 |
+
# if st.button("Create FAISS db"):
|
93 |
+
# try:
|
94 |
+
# with st.spinner('creating faiss vector store'):
|
95 |
+
# create_faiss()
|
96 |
+
# st.success('faiss saved')
|
97 |
+
# except Exception as e:
|
98 |
+
# st.error(f"Error : {e}")#, icon=":books:")
|
99 |
+
# return
|
100 |
+
|
101 |
+
st.markdown(
|
102 |
+
"### How to use\n"
|
103 |
+
"1. Select the chat model\n" # noqa: E501
|
104 |
+
"2. Select \"show source files\" to show the source files related to the answer.📄\n"
|
105 |
+
"3. Ask a question about the documents💬\n"
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
st.header("Chat with your own data:")
|
111 |
+
@st.experimental_singleton # 👈 Add the caching decorator
|
112 |
+
def load_QaPipeline():
|
113 |
+
print('> QAPipeline loaded for front end')
|
114 |
+
return QAPipeline()
|
115 |
+
|
116 |
+
qaPipeline = load_QaPipeline()
|
117 |
+
# qaPipeline = QAPipeline()
|
118 |
+
with st.form('chat_body'):
|
119 |
+
|
120 |
+
|
121 |
+
user_question = st.text_input(
|
122 |
+
"Ask a question about your documents:",
|
123 |
+
placeholder="enter question",
|
124 |
+
key='user_question',
|
125 |
+
# on_change=submit_user_question,
|
126 |
+
)
|
127 |
+
|
128 |
+
submitted = st.form_submit_button(
|
129 |
+
"Submit",
|
130 |
+
# on_click=submit_user_question
|
131 |
+
)
|
132 |
+
|
133 |
+
if submitted:
|
134 |
+
with st.spinner("Processing"):
|
135 |
+
user_question = st.session_state.user_question
|
136 |
+
# st.success(user_question)
|
137 |
+
query = user_question
|
138 |
+
# st.session_state.user_question=''
|
139 |
+
|
140 |
+
# Get the answer from the chain
|
141 |
+
try:
|
142 |
+
if (not query) or (query.strip() == ''):
|
143 |
+
st.error("Please enter a question!")
|
144 |
+
st.stop()
|
145 |
+
|
146 |
+
model = MODELS[st.session_state.model]
|
147 |
+
dataset = DATASETS[st.session_state.dataset]
|
148 |
+
show_source_files = st.session_state.show_source_files
|
149 |
+
|
150 |
+
# Try to access openai and deeplake
|
151 |
+
print(f">\n model: {model} \n dataset : {dataset} \n show_source_files : {show_source_files}")
|
152 |
+
|
153 |
+
# response = qaPipeline.run(query=query, model=model, dataset=dataset)
|
154 |
+
response = qaPipeline.run_agent(query=query, model=model, dataset=dataset)
|
155 |
+
|
156 |
+
|
157 |
+
docs = []
|
158 |
+
if isinstance(response, dict):
|
159 |
+
answer, docs = response['answer'], response['source_documents']
|
160 |
+
else:
|
161 |
+
answer = response
|
162 |
+
|
163 |
+
st.write(user_template.replace(
|
164 |
+
"{{MSG}}", query), unsafe_allow_html=True)
|
165 |
+
st.write(bot_template.replace(
|
166 |
+
"{{MSG}}", answer ), unsafe_allow_html=True)
|
167 |
+
|
168 |
+
if show_source_files:
|
169 |
+
# st.write(source_template.replace(
|
170 |
+
# "{{MSG}}", "source files" ), unsafe_allow_html=True)
|
171 |
+
|
172 |
+
if len(docs)>0 :
|
173 |
+
st.markdown("#### source files : ")
|
174 |
+
for source in docs:
|
175 |
+
# st.info(source.metadata)
|
176 |
+
with st.expander(source.metadata["source"]):
|
177 |
+
st.markdown(source.page_content)
|
178 |
+
|
179 |
+
# st.write(response)
|
180 |
+
|
181 |
+
except Exception as e:
|
182 |
+
# logger.error(f"Answer retrieval failed with {e}")
|
183 |
+
st.error(f"Error : {e}")#, icon=":books:")
|
184 |
+
|
conversationBufferWindowMemory.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
from typing import Any, Dict, Optional, Tuple
|
3 |
+
# import json
|
4 |
+
|
5 |
+
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
|
6 |
+
from langchain.memory.utils import get_prompt_input_key
|
7 |
+
from langchain.pydantic_v1 import Field
|
8 |
+
from langchain.schema import BaseChatMessageHistory, BaseMemory
|
9 |
+
|
10 |
+
from typing import List, Union
|
11 |
+
|
12 |
+
# from langchain.memory.chat_memory import BaseChatMemory
|
13 |
+
from langchain.schema.messages import BaseMessage, get_buffer_string
|
14 |
+
|
15 |
+
|
16 |
+
class BaseChatMemory(BaseMemory, ABC):
|
17 |
+
"""Abstract base class for chat memory."""
|
18 |
+
|
19 |
+
chat_memory: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
|
20 |
+
output_key: Optional[str] = None
|
21 |
+
input_key: Optional[str] = None
|
22 |
+
return_messages: bool = False
|
23 |
+
|
24 |
+
def _get_input_output(
|
25 |
+
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
26 |
+
) -> Tuple[str, str]:
|
27 |
+
|
28 |
+
|
29 |
+
if self.input_key is None:
|
30 |
+
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
31 |
+
else:
|
32 |
+
prompt_input_key = self.input_key
|
33 |
+
|
34 |
+
if self.output_key is None:
|
35 |
+
"""
|
36 |
+
output for agent with LLM chain tool = {answer}
|
37 |
+
output for agent with ConversationalRetrievalChain tool = {'question', 'chat_history', 'answer','source_documents'}
|
38 |
+
"""
|
39 |
+
|
40 |
+
LLM_key = 'output'
|
41 |
+
Retrieval_key = 'answer'
|
42 |
+
if isinstance(outputs[LLM_key], dict):
|
43 |
+
Retrieval_dict = outputs[LLM_key]
|
44 |
+
if Retrieval_key in Retrieval_dict.keys():
|
45 |
+
#output keys are 'answer' , 'source_documents'
|
46 |
+
output = Retrieval_dict[Retrieval_key]
|
47 |
+
else:
|
48 |
+
raise ValueError(f"output key: {LLM_key} not a valid dictionary")
|
49 |
+
|
50 |
+
else:
|
51 |
+
#otherwise output key will be 'output'
|
52 |
+
output_key = list(outputs.keys())[0]
|
53 |
+
output = outputs[output_key]
|
54 |
+
|
55 |
+
# if len(outputs) != 1:
|
56 |
+
# raise ValueError(f"One output key expected, got {outputs.keys()}")
|
57 |
+
|
58 |
+
|
59 |
+
else:
|
60 |
+
output_key = self.output_key
|
61 |
+
output = outputs[output_key]
|
62 |
+
|
63 |
+
return inputs[prompt_input_key], output
|
64 |
+
|
65 |
+
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
66 |
+
"""Save context from this conversation to buffer."""
|
67 |
+
input_str, output_str = self._get_input_output(inputs, outputs)
|
68 |
+
self.chat_memory.add_user_message(input_str)
|
69 |
+
self.chat_memory.add_ai_message(output_str)
|
70 |
+
|
71 |
+
def clear(self) -> None:
|
72 |
+
"""Clear memory contents."""
|
73 |
+
self.chat_memory.clear()
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
class ConversationBufferWindowMemory(BaseChatMemory):
|
80 |
+
"""Buffer for storing conversation memory inside a limited size window."""
|
81 |
+
|
82 |
+
human_prefix: str = "Human"
|
83 |
+
ai_prefix: str = "AI"
|
84 |
+
memory_key: str = "history" #: :meta private:
|
85 |
+
k: int = 5
|
86 |
+
"""Number of messages to store in buffer."""
|
87 |
+
|
88 |
+
@property
|
89 |
+
def buffer(self) -> Union[str, List[BaseMessage]]:
|
90 |
+
"""String buffer of memory."""
|
91 |
+
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
92 |
+
|
93 |
+
@property
|
94 |
+
def buffer_as_str(self) -> str:
|
95 |
+
"""Exposes the buffer as a string in case return_messages is True."""
|
96 |
+
messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
97 |
+
return get_buffer_string(
|
98 |
+
messages,
|
99 |
+
human_prefix=self.human_prefix,
|
100 |
+
ai_prefix=self.ai_prefix,
|
101 |
+
)
|
102 |
+
|
103 |
+
@property
|
104 |
+
def buffer_as_messages(self) -> List[BaseMessage]:
|
105 |
+
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
106 |
+
return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
107 |
+
|
108 |
+
@property
|
109 |
+
def memory_variables(self) -> List[str]:
|
110 |
+
"""Will always return list of memory variables.
|
111 |
+
|
112 |
+
:meta private:
|
113 |
+
"""
|
114 |
+
return [self.memory_key]
|
115 |
+
|
116 |
+
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
117 |
+
"""Return history buffer."""
|
118 |
+
return {self.memory_key: self.buffer}
|
qaPipeline.py
CHANGED
@@ -23,10 +23,12 @@ from langchain.chat_models import ChatOpenAI
|
|
23 |
# from chromaDb import load_store
|
24 |
from faissDb import load_FAISS_store
|
25 |
|
26 |
-
from langchain.agents import
|
27 |
-
|
28 |
from langchain.prompts import PromptTemplate
|
29 |
-
from langchain.chains import LLMChain
|
|
|
|
|
30 |
|
31 |
load_dotenv()
|
32 |
|
@@ -43,10 +45,35 @@ verbose = os.environ.get('VERBOSE')
|
|
43 |
# activate/deactivate the streaming StdOut callback for LLMs
|
44 |
callbacks = [StreamingStdOutCallbackHandler()]
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
class QAPipeline:
|
47 |
|
48 |
def __init__(self):
|
49 |
-
|
|
|
50 |
self.llm_name = None
|
51 |
self.llm = None
|
52 |
|
@@ -56,6 +83,7 @@ class QAPipeline:
|
|
56 |
self.qa_chain = None
|
57 |
self.agent = None
|
58 |
|
|
|
59 |
def run(self,query, model, dataset):
|
60 |
|
61 |
if (self.llm_name != model) or (self.dataset_name != dataset) or (self.qa_chain == None):
|
@@ -79,24 +107,31 @@ class QAPipeline:
|
|
79 |
|
80 |
def run_agent(self,query, model, dataset):
|
81 |
|
82 |
-
|
83 |
-
self.set_model(model)
|
84 |
-
self.set_vectorstore(dataset)
|
85 |
-
self.set_qa_chain_with_agent()
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
end = time.time()
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
|
102 |
|
@@ -139,67 +174,115 @@ class QAPipeline:
|
|
139 |
|
140 |
def set_qa_chain_with_agent(self):
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
#
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
name="bank regulations",
|
182 |
-
func= lambda
|
183 |
description='''useful for when you need to answer questions about
|
184 |
financial and legal information issued from central bank regarding banks and bank regulations.
|
185 |
Input should be a fully formed question.''',
|
186 |
return_direct=True,
|
187 |
-
)
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# from chromaDb import load_store
|
24 |
from faissDb import load_FAISS_store
|
25 |
|
26 |
+
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
|
27 |
+
|
28 |
from langchain.prompts import PromptTemplate
|
29 |
+
from langchain.chains import LLMChain, ConversationalRetrievalChain
|
30 |
+
from conversationBufferWindowMemory import ConversationBufferWindowMemory
|
31 |
+
from langchain.memory import ReadOnlySharedMemory
|
32 |
|
33 |
load_dotenv()
|
34 |
|
|
|
45 |
# activate/deactivate the streaming StdOut callback for LLMs
|
46 |
callbacks = [StreamingStdOutCallbackHandler()]
|
47 |
|
48 |
+
memory = ConversationBufferWindowMemory(
|
49 |
+
memory_key="chat_history",
|
50 |
+
input_key="question",
|
51 |
+
return_messages=True,
|
52 |
+
k=3
|
53 |
+
)
|
54 |
+
|
55 |
+
readonlymemory = ReadOnlySharedMemory(memory=memory)
|
56 |
+
|
57 |
+
class Singleton:
|
58 |
+
__instance = None
|
59 |
+
@staticmethod
|
60 |
+
def getInstance():
|
61 |
+
""" Static access method. """
|
62 |
+
if Singleton.__instance == None:
|
63 |
+
Singleton()
|
64 |
+
return Singleton.__instance
|
65 |
+
def __init__(self):
|
66 |
+
""" Virtually private constructor. """
|
67 |
+
if Singleton.__instance != None:
|
68 |
+
raise Exception("This class is a singleton!")
|
69 |
+
else:
|
70 |
+
Singleton.__instance = QAPipeline()
|
71 |
+
|
72 |
class QAPipeline:
|
73 |
|
74 |
def __init__(self):
|
75 |
+
|
76 |
+
print("\n\n> Initializing QAPipeline:")
|
77 |
self.llm_name = None
|
78 |
self.llm = None
|
79 |
|
|
|
83 |
self.qa_chain = None
|
84 |
self.agent = None
|
85 |
|
86 |
+
|
87 |
def run(self,query, model, dataset):
|
88 |
|
89 |
if (self.llm_name != model) or (self.dataset_name != dataset) or (self.qa_chain == None):
|
|
|
107 |
|
108 |
def run_agent(self,query, model, dataset):
|
109 |
|
110 |
+
try:
|
|
|
|
|
|
|
111 |
|
112 |
+
if (self.llm_name != model) or (self.dataset_name != dataset) or (self.agent == None):
|
113 |
+
self.set_model(model)
|
114 |
+
self.set_vectorstore(dataset)
|
115 |
+
self.set_qa_chain_with_agent()
|
|
|
116 |
|
117 |
+
# Get the answer from the chain
|
118 |
+
start = time.time()
|
119 |
+
res = self.agent(query)
|
120 |
+
# answer, docs = res['result'],res['source_documents']
|
121 |
+
end = time.time()
|
122 |
|
123 |
+
# Print the result
|
124 |
+
print("\n\n> Question:")
|
125 |
+
print(query)
|
126 |
+
print(f"\n> Answer (took {round(end - start, 2)} s.):")
|
127 |
+
print( res)
|
128 |
+
|
129 |
+
return res["output"]
|
130 |
+
|
131 |
+
except Exception as e:
|
132 |
+
# logger.error(f"Answer retrieval failed with {e}")
|
133 |
+
print(f"> QAPipeline run_agent Error : {e}")#, icon=":books:")
|
134 |
+
return
|
135 |
|
136 |
|
137 |
|
|
|
174 |
|
175 |
def set_qa_chain_with_agent(self):
|
176 |
|
177 |
+
try:
|
178 |
+
|
179 |
+
# Define a custom prompt
|
180 |
+
general_qa_template = (
|
181 |
+
"""You can have a general conversation with the users like greetings.
|
182 |
+
Continue the conversation and only answer questions related to banking sector like financial and legal.
|
183 |
+
If you dont know the answer say you dont know, dont try to makeup answers.
|
184 |
+
Conversation: {chat_history}
|
185 |
+
Question: {question}
|
186 |
+
"""
|
187 |
+
)
|
188 |
+
|
189 |
+
general_qa_chain_prompt = PromptTemplate(input_variables=["question", "chat_history"], template=general_qa_template)
|
190 |
+
|
191 |
+
general_qa_chain = LLMChain(
|
192 |
+
llm=self.llm,
|
193 |
+
prompt=general_qa_chain_prompt,
|
194 |
+
verbose=True,
|
195 |
+
memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
|
196 |
+
)
|
197 |
+
|
198 |
+
general_qa_chain_tool = Tool(
|
199 |
+
name="general qa",
|
200 |
+
func= general_qa_chain.run,
|
201 |
+
description='''useful for when you need to have a general conversation with the users like greetings
|
202 |
+
or to answer general purpose questions related to banking sector like financial and legal.
|
203 |
+
Input should be a fully formed question.''',
|
204 |
+
return_direct=True,
|
205 |
+
|
206 |
+
)
|
207 |
+
|
208 |
+
# Define a custom prompt
|
209 |
+
retrieval_qa_template = (
|
210 |
+
"""
|
211 |
+
please answer the question based on the chat history and context with the latest information.
|
212 |
+
You have provided context information below related to central bank acts published in various years.
|
213 |
+
The content of a bank act can updated by a bank act from a latest year.
|
214 |
+
If you dont know the answer say you dont know, dont try to makeup answers.
|
215 |
+
Conversation: {chat_history}
|
216 |
+
Context: {context}
|
217 |
+
Question : {question}
|
218 |
+
"""
|
219 |
+
)
|
220 |
+
retrieval_qa_chain_prompt = PromptTemplate(
|
221 |
+
input_variables=["question", "context", "chat_history"],
|
222 |
+
template=retrieval_qa_template
|
223 |
+
)
|
224 |
+
|
225 |
+
bank_regulations_qa = ConversationalRetrievalChain.from_llm(
|
226 |
+
llm=self.llm,
|
227 |
+
chain_type="stuff",
|
228 |
+
retriever = self.vectorstore.as_retriever(),
|
229 |
+
# retriever = self.vectorstore.as_retriever(search_kwargs={"k": target_source_chunks}
|
230 |
+
return_source_documents= True,
|
231 |
+
get_chat_history=lambda h : h,
|
232 |
+
combine_docs_chain_kwargs={"prompt": retrieval_qa_chain_prompt},
|
233 |
+
verbose=True,
|
234 |
+
memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
|
235 |
+
)
|
236 |
+
|
237 |
+
bank_regulations_qa_tool = Tool(
|
238 |
name="bank regulations",
|
239 |
+
func= lambda question: bank_regulations_qa({"question": question}),
|
240 |
description='''useful for when you need to answer questions about
|
241 |
financial and legal information issued from central bank regarding banks and bank regulations.
|
242 |
Input should be a fully formed question.''',
|
243 |
return_direct=True,
|
244 |
+
)
|
245 |
|
246 |
+
tools = [
|
247 |
+
bank_regulations_qa_tool,
|
248 |
+
general_qa_chain_tool
|
249 |
+
]
|
250 |
+
|
251 |
+
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
|
252 |
+
suffix = """Begin!"
|
253 |
+
|
254 |
+
{chat_history}
|
255 |
+
Question: {question}
|
256 |
+
{agent_scratchpad}"""
|
257 |
+
|
258 |
+
agent_prompt = ZeroShotAgent.create_prompt(
|
259 |
+
tools,
|
260 |
+
prefix=prefix,
|
261 |
+
suffix=suffix,
|
262 |
+
input_variables=["question", "chat_history", "agent_scratchpad"],
|
263 |
+
)
|
264 |
+
|
265 |
+
llm_chain = LLMChain(llm=self.llm, prompt=agent_prompt)
|
266 |
+
|
267 |
+
agent = ZeroShotAgent(
|
268 |
+
llm_chain=llm_chain,
|
269 |
+
tools=tools,
|
270 |
+
verbose=True,
|
271 |
+
)
|
272 |
+
|
273 |
+
agent_chain = AgentExecutor.from_agent_and_tools(
|
274 |
+
agent=agent,
|
275 |
+
tools=tools,
|
276 |
+
verbose=True,
|
277 |
+
memory=memory,
|
278 |
+
handle_parsing_errors=True,
|
279 |
+
)
|
280 |
+
|
281 |
+
self.agent = agent_chain
|
282 |
+
|
283 |
+
print(f"\n> agent_chain created")
|
284 |
+
|
285 |
+
except Exception as e:
|
286 |
+
# logger.error(f"Answer retrieval failed with {e}")
|
287 |
+
print(f"> QAPipeline set_qa_chain_with_agent Error : {e}")#, icon=":books:")
|
288 |
+
return
|
qaPipeline_functions.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Python Backend API to chat with private data
|
3 |
+
|
4 |
+
08/14/2023
|
5 |
+
D.M. Theekshana Samaradiwakara
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
|
13 |
+
from langchain.chains import RetrievalQA
|
14 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
15 |
+
|
16 |
+
from langchain.llms import GPT4All
|
17 |
+
from langchain.llms import HuggingFaceHub
|
18 |
+
from langchain.chat_models import ChatOpenAI
|
19 |
+
|
20 |
+
# from langchain.retrievers._query.base import SelfQueryRetriever
|
21 |
+
# from langchain.chains.query_constructor.base import AttributeInfo
|
22 |
+
|
23 |
+
# from chromaDb import load_store
|
24 |
+
from faissDb import load_FAISS_store
|
25 |
+
|
26 |
+
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
|
27 |
+
|
28 |
+
from langchain.prompts import PromptTemplate
|
29 |
+
from langchain.chains import LLMChain, ConversationalRetrievalChain
|
30 |
+
from conversationBufferWindowMemory import ConversationBufferWindowMemory
|
31 |
+
from langchain.memory import ReadOnlySharedMemory
|
32 |
+
|
33 |
+
load_dotenv()
|
34 |
+
|
35 |
+
#gpt4 all model
|
36 |
+
gpt4all_model_path = os.environ.get('GPT4ALL_MODEL_PATH')
|
37 |
+
model_n_ctx = os.environ.get('MODEL_N_CTX')
|
38 |
+
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8))
|
39 |
+
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
|
40 |
+
|
41 |
+
openai_api_key = os.environ.get('OPENAI_API_KEY')
|
42 |
+
|
43 |
+
verbose = os.environ.get('VERBOSE')
|
44 |
+
|
45 |
+
# activate/deactivate the streaming StdOut callback for LLMs
|
46 |
+
callbacks = [StreamingStdOutCallbackHandler()]
|
47 |
+
|
48 |
+
memory = ConversationBufferWindowMemory(
|
49 |
+
memory_key="chat_history",
|
50 |
+
input_key="question",
|
51 |
+
return_messages=True,
|
52 |
+
k=3
|
53 |
+
)
|
54 |
+
|
55 |
+
readonlymemory = ReadOnlySharedMemory(memory=memory)
|
56 |
+
|
57 |
+
|
58 |
+
print("\n\n> Initializing QAPipeline:")
|
59 |
+
|
60 |
+
global llm_name
|
61 |
+
llm_name = 'None'
|
62 |
+
global llm
|
63 |
+
llm = 'None'
|
64 |
+
|
65 |
+
global dataset_name
|
66 |
+
dataset_name = 'None'
|
67 |
+
global vectorstore
|
68 |
+
vectorstore = 'None'
|
69 |
+
|
70 |
+
qa_chain = None
|
71 |
+
agent = None
|
72 |
+
|
73 |
+
|
74 |
+
def run(query, model, dataset):
|
75 |
+
|
76 |
+
if (llm_name != model) or (dataset_name != dataset) or (qa_chain == None):
|
77 |
+
set_model(model)
|
78 |
+
set_vectorstore(dataset)
|
79 |
+
set_qa_chain()
|
80 |
+
|
81 |
+
# Get the answer from the chain
|
82 |
+
start = time.time()
|
83 |
+
res = qa_chain(query)
|
84 |
+
# answer, docs = res['result'],res['source_documents']
|
85 |
+
end = time.time()
|
86 |
+
|
87 |
+
# Print the result
|
88 |
+
print("\n\n> Question:")
|
89 |
+
print(query)
|
90 |
+
print(f"\n> Answer (took {round(end - start, 2)} s.):")
|
91 |
+
print( res)
|
92 |
+
|
93 |
+
return res
|
94 |
+
|
95 |
+
def run_agent(query, model, dataset):
|
96 |
+
|
97 |
+
try:
|
98 |
+
|
99 |
+
if (llm_name != model) or (dataset_name != dataset) or (agent == None):
|
100 |
+
set_model(model)
|
101 |
+
set_vectorstore(dataset)
|
102 |
+
set_qa_chain_with_agent()
|
103 |
+
|
104 |
+
# Get the answer from the chain
|
105 |
+
start = time.time()
|
106 |
+
res = agent(query)
|
107 |
+
# answer, docs = res['result'],res['source_documents']
|
108 |
+
end = time.time()
|
109 |
+
|
110 |
+
# Print the result
|
111 |
+
print("\n\n> Question:")
|
112 |
+
print(query)
|
113 |
+
print(f"\n> Answer (took {round(end - start, 2)} s.):")
|
114 |
+
print( res)
|
115 |
+
|
116 |
+
return res["output"]
|
117 |
+
|
118 |
+
except Exception as e:
|
119 |
+
# logger.error(f"Answer retrieval failed with {e}")
|
120 |
+
print(f"> QAPipeline run_agent Error : {e}")#, icon=":books:")
|
121 |
+
return
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
def set_model(model_type):
|
126 |
+
if model_type != llm_name:
|
127 |
+
global llm
|
128 |
+
match model_type:
|
129 |
+
case "gpt4all":
|
130 |
+
# llm = GPT4All(model=gpt4all_model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose)
|
131 |
+
llm = GPT4All(model=gpt4all_model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose)
|
132 |
+
# llm = HuggingFaceHub(repo_id="nomic-ai/gpt4all-j", model_kwargs={"temperature":0.001, "max_length":1024})
|
133 |
+
case "google/flan-t5-xxl":
|
134 |
+
llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.001, "max_length":1024})
|
135 |
+
case "tiiuae/falcon-7b-instruct":
|
136 |
+
llm = HuggingFaceHub(repo_id=model_type, model_kwargs={"temperature":0.001, "max_length":1024})
|
137 |
+
case "openai":
|
138 |
+
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
|
139 |
+
case _default:
|
140 |
+
# raise exception if model_type is not supported
|
141 |
+
raise Exception(f"Model type {model_type} is not supported. Please choose a valid one")
|
142 |
+
# global llm_name
|
143 |
+
llm_name = model_type
|
144 |
+
|
145 |
+
def set_vectorstore( dataset):
|
146 |
+
if dataset != dataset_name:
|
147 |
+
# vectorstore = load_store(dataset)
|
148 |
+
global vectorstore
|
149 |
+
vectorstore = load_FAISS_store()
|
150 |
+
print("\n\n> vectorstore loaded:")
|
151 |
+
dataset_name = dataset
|
152 |
+
|
153 |
+
def set_qa_chain():
|
154 |
+
global qa_chain
|
155 |
+
qa_chain = RetrievalQA.from_chain_type(
|
156 |
+
llm=llm,
|
157 |
+
chain_type="stuff",
|
158 |
+
retriever = vectorstore.as_retriever(),
|
159 |
+
# retriever = vectorstore.as_retriever(search_kwargs={"k": target_source_chunks}
|
160 |
+
return_source_documents= True
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
def set_qa_chain_with_agent():
|
165 |
+
|
166 |
+
try:
|
167 |
+
|
168 |
+
# Define a custom prompt
|
169 |
+
general_qa_template = (
|
170 |
+
"""You can have a general conversation with the users like greetings.
|
171 |
+
Continue the conversation and only answer questions related to banking sector like financial and legal.
|
172 |
+
If you dont know the answer say you dont know, dont try to makeup answers.
|
173 |
+
Conversation: {chat_history}
|
174 |
+
Question: {question}
|
175 |
+
"""
|
176 |
+
)
|
177 |
+
|
178 |
+
general_qa_chain_prompt = PromptTemplate(input_variables=["question", "chat_history"], template=general_qa_template)
|
179 |
+
|
180 |
+
general_qa_chain = LLMChain(
|
181 |
+
llm=llm,
|
182 |
+
prompt=general_qa_chain_prompt,
|
183 |
+
verbose=True,
|
184 |
+
memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
|
185 |
+
)
|
186 |
+
|
187 |
+
general_qa_chain_tool = Tool(
|
188 |
+
name="general qa",
|
189 |
+
func= general_qa_chain.run,
|
190 |
+
description='''useful for when you need to have a general conversation with the users like greetings
|
191 |
+
or to answer general purpose questions related to banking sector like financial and legal.
|
192 |
+
Input should be a fully formed question.''',
|
193 |
+
return_direct=True,
|
194 |
+
|
195 |
+
)
|
196 |
+
|
197 |
+
# Define a custom prompt
|
198 |
+
retrieval_qa_template = (
|
199 |
+
"""
|
200 |
+
please answer the question based on the chat history and context with the latest information.
|
201 |
+
You have provided context information below related to central bank acts published in various years.
|
202 |
+
The content of a bank act can updated by a bank act from a latest year.
|
203 |
+
If you dont know the answer say you dont know, dont try to makeup answers.
|
204 |
+
Conversation: {chat_history}
|
205 |
+
Context: {context}
|
206 |
+
Question : {question}
|
207 |
+
"""
|
208 |
+
)
|
209 |
+
retrieval_qa_chain_prompt = PromptTemplate(
|
210 |
+
input_variables=["question", "context", "chat_history"],
|
211 |
+
template=retrieval_qa_template
|
212 |
+
)
|
213 |
+
|
214 |
+
bank_regulations_qa = ConversationalRetrievalChain.from_llm(
|
215 |
+
llm=llm,
|
216 |
+
chain_type="stuff",
|
217 |
+
retriever = vectorstore.as_retriever(),
|
218 |
+
# retriever = vectorstore.as_retriever(search_kwargs={"k": target_source_chunks}
|
219 |
+
return_source_documents= True,
|
220 |
+
get_chat_history=lambda h : h,
|
221 |
+
combine_docs_chain_kwargs={"prompt": retrieval_qa_chain_prompt},
|
222 |
+
verbose=True,
|
223 |
+
memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
|
224 |
+
)
|
225 |
+
|
226 |
+
bank_regulations_qa_tool = Tool(
|
227 |
+
name="bank regulations",
|
228 |
+
func= lambda question: bank_regulations_qa({"question": question}),
|
229 |
+
description='''useful for when you need to answer questions about
|
230 |
+
financial and legal information issued from central bank regarding banks and bank regulations.
|
231 |
+
Input should be a fully formed question.''',
|
232 |
+
return_direct=True,
|
233 |
+
)
|
234 |
+
|
235 |
+
tools = [
|
236 |
+
bank_regulations_qa_tool,
|
237 |
+
general_qa_chain_tool
|
238 |
+
]
|
239 |
+
|
240 |
+
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
|
241 |
+
suffix = """Begin!"
|
242 |
+
|
243 |
+
{chat_history}
|
244 |
+
Question: {question}
|
245 |
+
{agent_scratchpad}"""
|
246 |
+
|
247 |
+
agent_prompt = ZeroShotAgent.create_prompt(
|
248 |
+
tools,
|
249 |
+
prefix=prefix,
|
250 |
+
suffix=suffix,
|
251 |
+
input_variables=["question", "chat_history", "agent_scratchpad"],
|
252 |
+
)
|
253 |
+
|
254 |
+
llm_chain = LLMChain(llm=llm, prompt=agent_prompt)
|
255 |
+
|
256 |
+
zeroShotAgent = ZeroShotAgent(
|
257 |
+
llm_chain=llm_chain,
|
258 |
+
tools=tools,
|
259 |
+
verbose=True,
|
260 |
+
)
|
261 |
+
|
262 |
+
agent_chain = AgentExecutor.from_agent_and_tools(
|
263 |
+
agent=zeroShotAgent,
|
264 |
+
tools=tools,
|
265 |
+
verbose=True,
|
266 |
+
memory=memory,
|
267 |
+
handle_parsing_errors=True,
|
268 |
+
)
|
269 |
+
|
270 |
+
global agent
|
271 |
+
agent = agent_chain
|
272 |
+
|
273 |
+
print(f"\n> agent_chain created")
|
274 |
+
|
275 |
+
except Exception as e:
|
276 |
+
# logger.error(f"Answer retrieval failed with {e}")
|
277 |
+
print(f"> QAPipeline set_qa_chain_with_agent Error : {e}")#, icon=":books:")
|
278 |
+
return
|