Spaces:
Runtime error
Runtime error
theekshana
commited on
Commit
•
99f2e6a
1
Parent(s):
cb81a79
changed to ConversationalRetrievalChain
Browse files- .env +1 -1
- __pycache__/config.cpython-311.pyc +0 -0
- __pycache__/qaPipeline.cpython-311.pyc +0 -0
- __pycache__/qaPipeline_chain_only.cpython-311.pyc +0 -0
- app.py +9 -8
- app_agent.py +246 -0
- config.py +7 -4
- qaPipeline.py +9 -6
- qaPipeline_chain_only.py +241 -0
.env
CHANGED
@@ -11,7 +11,7 @@ TARGET_SOURCE_CHUNKS=4
|
|
11 |
|
12 |
#API token keys
|
13 |
HUGGINGFACEHUB_API_TOKEN=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
|
14 |
-
OPENAI_API_KEY=sk-
|
15 |
ANYSCALE_ENDPOINT_TOKEN=esecret_n1svfld85uklyx5ebaasyiw2m9
|
16 |
|
17 |
#api app
|
|
|
11 |
|
12 |
#API token keys
|
13 |
HUGGINGFACEHUB_API_TOKEN=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
|
14 |
+
OPENAI_API_KEY=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
|
15 |
ANYSCALE_ENDPOINT_TOKEN=esecret_n1svfld85uklyx5ebaasyiw2m9
|
16 |
|
17 |
#api app
|
__pycache__/config.cpython-311.pyc
CHANGED
Binary files a/__pycache__/config.cpython-311.pyc and b/__pycache__/config.cpython-311.pyc differ
|
|
__pycache__/qaPipeline.cpython-311.pyc
CHANGED
Binary files a/__pycache__/qaPipeline.cpython-311.pyc and b/__pycache__/qaPipeline.cpython-311.pyc differ
|
|
__pycache__/qaPipeline_chain_only.cpython-311.pyc
ADDED
Binary file (11.2 kB). View file
|
|
app.py
CHANGED
@@ -17,9 +17,7 @@ logger = get_logger(__name__)
|
|
17 |
from ui.htmlTemplates import css, bot_template, user_template, source_template
|
18 |
from config import MODELS, DATASETS
|
19 |
|
20 |
-
from
|
21 |
-
import qaPipeline_functions
|
22 |
-
from faissDb import create_faiss
|
23 |
|
24 |
# loads environment variables
|
25 |
from dotenv import load_dotenv
|
@@ -117,6 +115,7 @@ def side_bar():
|
|
117 |
|
118 |
def chat_body():
|
119 |
st.header("Chat with your own data:")
|
|
|
120 |
with st.form('chat_body'):
|
121 |
|
122 |
user_question = st.text_input(
|
@@ -195,11 +194,13 @@ def show_query_response(query, response, show_source_files):
|
|
195 |
# "{{MSG}}", "source files" ), unsafe_allow_html=True)
|
196 |
|
197 |
if len(docs)>0 :
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
st.
|
|
|
|
|
203 |
|
204 |
# st.write(response)
|
205 |
|
|
|
17 |
from ui.htmlTemplates import css, bot_template, user_template, source_template
|
18 |
from config import MODELS, DATASETS
|
19 |
|
20 |
+
from qaPipeline_chain_only import QAPipeline
|
|
|
|
|
21 |
|
22 |
# loads environment variables
|
23 |
from dotenv import load_dotenv
|
|
|
115 |
|
116 |
def chat_body():
|
117 |
st.header("Chat with your own data:")
|
118 |
+
# st.text("Implemented using ConversationalRetrievalChain")
|
119 |
with st.form('chat_body'):
|
120 |
|
121 |
user_question = st.text_input(
|
|
|
194 |
# "{{MSG}}", "source files" ), unsafe_allow_html=True)
|
195 |
|
196 |
if len(docs)>0 :
|
197 |
+
code_word = 'Boardpac AI(QA):'
|
198 |
+
if code_word in answer:
|
199 |
+
st.markdown("#### source files : ")
|
200 |
+
for source in docs:
|
201 |
+
# st.info(source.metadata)
|
202 |
+
with st.expander(source.metadata["source"]):
|
203 |
+
st.markdown(source.page_content)
|
204 |
|
205 |
# st.write(response)
|
206 |
|
app_agent.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Python Backend API to chat with private data
|
3 |
+
|
4 |
+
08/16/2023
|
5 |
+
D.M. Theekshana Samaradiwakara
|
6 |
+
|
7 |
+
python -m streamlit run app.py
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
import time
|
12 |
+
import streamlit as st
|
13 |
+
from streamlit.logger import get_logger
|
14 |
+
|
15 |
+
logger = get_logger(__name__)
|
16 |
+
|
17 |
+
from ui.htmlTemplates import css, bot_template, user_template, source_template
|
18 |
+
from config import MODELS, DATASETS
|
19 |
+
|
20 |
+
from qaPipeline import QAPipeline
|
21 |
+
import qaPipeline_functions
|
22 |
+
from faissDb import create_faiss
|
23 |
+
|
24 |
+
# loads environment variables
|
25 |
+
from dotenv import load_dotenv
|
26 |
+
load_dotenv()
|
27 |
+
|
28 |
+
isHuggingFaceHubEnabled = os.environ.get('ENABLE_HUGGINGFSCE_HUB_MODELS')
|
29 |
+
isOpenAiApiEnabled = os.environ.get('ENABLE_OPENAI_API_MODELS')
|
30 |
+
|
31 |
+
st.set_page_config(page_title="Chat with data",
|
32 |
+
page_icon=":books:")
|
33 |
+
st.write(css, unsafe_allow_html=True)
|
34 |
+
|
35 |
+
qaPipeline = QAPipeline()
|
36 |
+
# qaPipeline = qaPipeline_functions
|
37 |
+
|
38 |
+
def initialize_session_state():
|
39 |
+
# Initialise all session state variables with defaults
|
40 |
+
SESSION_DEFAULTS = {
|
41 |
+
"model": MODELS["DEFAULT"],
|
42 |
+
"dataset": DATASETS["DEFAULT"],
|
43 |
+
"chat_history": None,
|
44 |
+
"is_parameters_changed":False,
|
45 |
+
"show_source_files": False,
|
46 |
+
"user_question":'',
|
47 |
+
}
|
48 |
+
|
49 |
+
for k, v in SESSION_DEFAULTS.items():
|
50 |
+
if k not in st.session_state:
|
51 |
+
st.session_state[k] = v
|
52 |
+
|
53 |
+
def side_bar():
|
54 |
+
with st.sidebar:
|
55 |
+
st.subheader("Chat parameters")
|
56 |
+
|
57 |
+
with st.form('param_form'):
|
58 |
+
st.info('Info: use openai chat model for best results')
|
59 |
+
chat_model = st.selectbox(
|
60 |
+
"Chat model",
|
61 |
+
MODELS,
|
62 |
+
key="chat_model",
|
63 |
+
help="Select the LLM model for the chat",
|
64 |
+
# on_change=update_parameters_change,
|
65 |
+
)
|
66 |
+
|
67 |
+
# data_source = st.selectbox(
|
68 |
+
# "dataset",
|
69 |
+
# DATASETS,
|
70 |
+
# key="data_source",
|
71 |
+
# help="Select the private data_source for the chat",
|
72 |
+
# on_change=update_parameters_change,
|
73 |
+
# )
|
74 |
+
|
75 |
+
st.session_state.dataset = "DEFAULT"
|
76 |
+
|
77 |
+
show_source = st.checkbox(
|
78 |
+
label="show source files",
|
79 |
+
key="show_source",
|
80 |
+
help="Select this to show relavant source files for the query",
|
81 |
+
# on_change=update_parameters_change,
|
82 |
+
)
|
83 |
+
|
84 |
+
submitted = st.form_submit_button(
|
85 |
+
"Save Parameters",
|
86 |
+
# on_click=update_parameters_change
|
87 |
+
)
|
88 |
+
|
89 |
+
if submitted:
|
90 |
+
parameters_change_button(chat_model, show_source)
|
91 |
+
|
92 |
+
|
93 |
+
# if st.session_state.is_parameters_changed:
|
94 |
+
# st.button("Update",
|
95 |
+
# on_click=parameters_change_button,
|
96 |
+
# args=[chat_model, show_source]
|
97 |
+
# )
|
98 |
+
|
99 |
+
st.markdown("\n")
|
100 |
+
|
101 |
+
# if st.button("Create FAISS db"):
|
102 |
+
# try:
|
103 |
+
# with st.spinner('creating faiss vector store'):
|
104 |
+
# create_faiss()
|
105 |
+
# st.success('faiss saved')
|
106 |
+
# except Exception as e:
|
107 |
+
# st.error(f"Error : {e}")#, icon=":books:")
|
108 |
+
# return
|
109 |
+
|
110 |
+
st.markdown(
|
111 |
+
"### How to use\n"
|
112 |
+
"1. Select the chat model\n" # noqa: E501
|
113 |
+
"2. Select \"show source files\" to show the source files related to the answer.📄\n"
|
114 |
+
"3. Ask a question about the documents💬\n"
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
def chat_body():
|
119 |
+
st.header("Chat with your own data:")
|
120 |
+
with st.form('chat_body'):
|
121 |
+
|
122 |
+
user_question = st.text_input(
|
123 |
+
"Ask a question about your documents:",
|
124 |
+
placeholder="enter question",
|
125 |
+
key='user_question',
|
126 |
+
# on_change=submit_user_question,
|
127 |
+
)
|
128 |
+
|
129 |
+
submitted = st.form_submit_button(
|
130 |
+
"Submit",
|
131 |
+
# on_click=update_parameters_change
|
132 |
+
)
|
133 |
+
|
134 |
+
if submitted:
|
135 |
+
submit_user_question()
|
136 |
+
|
137 |
+
# if user_question:
|
138 |
+
# submit_user_question()
|
139 |
+
# # user_question = False
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
def submit_user_question():
|
144 |
+
with st.spinner("Processing"):
|
145 |
+
user_question = st.session_state.user_question
|
146 |
+
# st.success(user_question)
|
147 |
+
handle_userinput(user_question)
|
148 |
+
# st.session_state.user_question=''
|
149 |
+
|
150 |
+
|
151 |
+
def main():
|
152 |
+
|
153 |
+
initialize_session_state()
|
154 |
+
|
155 |
+
side_bar()
|
156 |
+
|
157 |
+
chat_body()
|
158 |
+
|
159 |
+
|
160 |
+
def update_parameters_change():
|
161 |
+
st.session_state.is_parameters_changed = True
|
162 |
+
|
163 |
+
|
164 |
+
def parameters_change_button(chat_model, show_source):
|
165 |
+
st.session_state.model = chat_model
|
166 |
+
st.session_state.dataset = "DEFAULT"
|
167 |
+
st.session_state.show_source_files = show_source
|
168 |
+
st.session_state.is_parameters_changed = False
|
169 |
+
|
170 |
+
alert = st.success("chat parameters updated")
|
171 |
+
time.sleep(1) # Wait for 3 seconds
|
172 |
+
alert.empty() # Clear the alert
|
173 |
+
|
174 |
+
@st.cache_data
|
175 |
+
def get_answer_from_backend(query, model, dataset):
|
176 |
+
# response = qaPipeline.run(query=query, model=model, dataset=dataset)
|
177 |
+
response = qaPipeline.run_agent(query=query, model=model, dataset=dataset)
|
178 |
+
return response
|
179 |
+
|
180 |
+
|
181 |
+
def show_query_response(query, response, show_source_files):
|
182 |
+
docs = []
|
183 |
+
if isinstance(response, dict):
|
184 |
+
answer, docs = response['answer'], response['source_documents']
|
185 |
+
else:
|
186 |
+
answer = response
|
187 |
+
|
188 |
+
st.write(user_template.replace(
|
189 |
+
"{{MSG}}", query), unsafe_allow_html=True)
|
190 |
+
st.write(bot_template.replace(
|
191 |
+
"{{MSG}}", answer ), unsafe_allow_html=True)
|
192 |
+
|
193 |
+
if show_source_files:
|
194 |
+
# st.write(source_template.replace(
|
195 |
+
# "{{MSG}}", "source files" ), unsafe_allow_html=True)
|
196 |
+
|
197 |
+
if len(docs)>0 :
|
198 |
+
st.markdown("#### source files : ")
|
199 |
+
for source in docs:
|
200 |
+
# st.info(source.metadata)
|
201 |
+
with st.expander(source.metadata["source"]):
|
202 |
+
st.markdown(source.page_content)
|
203 |
+
|
204 |
+
# st.write(response)
|
205 |
+
|
206 |
+
|
207 |
+
def is_query_valid(query: str) -> bool:
|
208 |
+
if (not query) or (query.strip() == ''):
|
209 |
+
st.error("Please enter a question!")
|
210 |
+
return False
|
211 |
+
return True
|
212 |
+
|
213 |
+
|
214 |
+
def handle_userinput(query):
|
215 |
+
# Get the answer from the chain
|
216 |
+
try:
|
217 |
+
if not is_query_valid(query):
|
218 |
+
st.stop()
|
219 |
+
|
220 |
+
model = MODELS[st.session_state.model]
|
221 |
+
dataset = DATASETS[st.session_state.dataset]
|
222 |
+
show_source_files = st.session_state.show_source_files
|
223 |
+
|
224 |
+
# Try to access openai and deeplake
|
225 |
+
print(f">\n model: {model} \n dataset : {dataset} \n show_source_files : {show_source_files}")
|
226 |
+
|
227 |
+
response = get_answer_from_backend(query, model, dataset)
|
228 |
+
|
229 |
+
show_query_response(query, response, show_source_files)
|
230 |
+
|
231 |
+
|
232 |
+
except Exception as e:
|
233 |
+
# logger.error(f"Answer retrieval failed with {e}")
|
234 |
+
st.error(f"Error ocuured! see log info for more details.")#, icon=":books:")
|
235 |
+
print(f"Streamlit handle_userinput Error : {e}")#, icon=":books:")
|
236 |
+
return
|
237 |
+
|
238 |
+
|
239 |
+
if __name__ == "__main__":
|
240 |
+
main()
|
241 |
+
|
242 |
+
# initialize_session_state()
|
243 |
+
|
244 |
+
# side_bar()
|
245 |
+
|
246 |
+
# chat_body()
|
config.py
CHANGED
@@ -2,14 +2,17 @@ MODELS={
|
|
2 |
"DEFAULT":"tiiuae/falcon-7b-instruct",
|
3 |
# "gpt4all":"gpt4all",
|
4 |
# "flan-t5-xxl":"google/flan-t5-xxl",
|
5 |
-
"falcon-7b-instruct":"tiiuae/falcon-7b-instruct",
|
|
|
6 |
"anyscale/Llama-2-13b":"anyscale/Llama-2-13b-chat-hf",
|
7 |
"anyscale/Llama-2-70b":"anyscale/Llama-2-70b-chat-hf",
|
8 |
-
|
|
|
|
|
|
|
9 |
# "Deci/DeciLM-6b-instruct":"Deci/DeciLM-6b-instruct",
|
10 |
# "Deci/DeciLM-6b":"Deci/DeciLM-6b",
|
11 |
-
|
12 |
-
|
13 |
}
|
14 |
|
15 |
DATASETS={
|
|
|
2 |
"DEFAULT":"tiiuae/falcon-7b-instruct",
|
3 |
# "gpt4all":"gpt4all",
|
4 |
# "flan-t5-xxl":"google/flan-t5-xxl",
|
5 |
+
"hf/falcon-7b-instruct":"tiiuae/falcon-7b-instruct",
|
6 |
+
|
7 |
"anyscale/Llama-2-13b":"anyscale/Llama-2-13b-chat-hf",
|
8 |
"anyscale/Llama-2-70b":"anyscale/Llama-2-70b-chat-hf",
|
9 |
+
|
10 |
+
"local/Llama-2-13b":"local/LLAMA2",
|
11 |
+
|
12 |
+
"openai/gpt-3.5":"openai",
|
13 |
# "Deci/DeciLM-6b-instruct":"Deci/DeciLM-6b-instruct",
|
14 |
# "Deci/DeciLM-6b":"Deci/DeciLM-6b",
|
15 |
+
|
|
|
16 |
}
|
17 |
|
18 |
DATASETS={
|
qaPipeline.py
CHANGED
@@ -309,12 +309,15 @@ class QAPipeline:
|
|
309 |
general_qa_chain_tool
|
310 |
]
|
311 |
|
312 |
-
prefix = """
|
313 |
-
suffix = """Begin!
|
314 |
-
|
315 |
-
{chat_history}
|
316 |
-
|
317 |
-
|
|
|
|
|
|
|
318 |
|
319 |
agent_prompt = ZeroShotAgent.create_prompt(
|
320 |
tools,
|
|
|
309 |
general_qa_chain_tool
|
310 |
]
|
311 |
|
312 |
+
prefix = """<<SYS>> You are the AI of company boardpac which provide services to company board members related to banking and financial sector. Have a conversation with the user, answering the following questions as best you can. You have access to the following tools:"""
|
313 |
+
suffix = """Begin! "
|
314 |
+
{agent_scratchpad}
|
315 |
+
<chat history>: {chat_history}
|
316 |
+
<</SYS>>
|
317 |
+
|
318 |
+
[INST]
|
319 |
+
<Question>: {question}
|
320 |
+
[/INST]"""
|
321 |
|
322 |
agent_prompt = ZeroShotAgent.create_prompt(
|
323 |
tools,
|
qaPipeline_chain_only.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Python Backend API to chat with private data
|
3 |
+
|
4 |
+
08/14/2023
|
5 |
+
D.M. Theekshana Samaradiwakara
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
|
13 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
14 |
+
|
15 |
+
from langchain.llms import GPT4All
|
16 |
+
from langchain.llms import HuggingFaceHub
|
17 |
+
from langchain.chat_models import ChatOpenAI
|
18 |
+
from langchain.chat_models import ChatAnyscale
|
19 |
+
|
20 |
+
# from langchain.retrievers.self_query.base import SelfQueryRetriever
|
21 |
+
# from langchain.chains.query_constructor.base import AttributeInfo
|
22 |
+
|
23 |
+
# from chromaDb import load_store
|
24 |
+
from faissDb import load_FAISS_store
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
from langchain.prompts import PromptTemplate
|
29 |
+
from langchain.chains import LLMChain, ConversationalRetrievalChain
|
30 |
+
from conversationBufferWindowMemory import ConversationBufferWindowMemory
|
31 |
+
|
32 |
+
load_dotenv()
|
33 |
+
|
34 |
+
#gpt4 all model
|
35 |
+
gpt4all_model_path = os.environ.get('GPT4ALL_MODEL_PATH')
|
36 |
+
model_n_ctx = os.environ.get('MODEL_N_CTX')
|
37 |
+
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8))
|
38 |
+
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
|
39 |
+
|
40 |
+
openai_api_key = os.environ.get('OPENAI_API_KEY')
|
41 |
+
anyscale_api_key = os.environ.get('ANYSCALE_ENDPOINT_TOKEN')
|
42 |
+
|
43 |
+
verbose = os.environ.get('VERBOSE')
|
44 |
+
|
45 |
+
# activate/deactivate the streaming StdOut callback for LLMs
|
46 |
+
callbacks = [StreamingStdOutCallbackHandler()]
|
47 |
+
|
48 |
+
def get_local_LLAMA2():
|
49 |
+
import torch
|
50 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
51 |
+
|
52 |
+
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-13b-chat-hf",
|
53 |
+
# use_auth_token=True,
|
54 |
+
)
|
55 |
+
|
56 |
+
model = AutoModelForCausalLM.from_pretrained("NousResearch/Llama-2-13b-chat-hf",
|
57 |
+
device_map='auto',
|
58 |
+
torch_dtype=torch.float16,
|
59 |
+
use_auth_token=True,
|
60 |
+
# load_in_8bit=True,
|
61 |
+
# load_in_4bit=True
|
62 |
+
)
|
63 |
+
from transformers import pipeline
|
64 |
+
|
65 |
+
pipe = pipeline("text-generation",
|
66 |
+
model=model,
|
67 |
+
tokenizer= tokenizer,
|
68 |
+
torch_dtype=torch.bfloat16,
|
69 |
+
device_map="auto",
|
70 |
+
max_new_tokens = 512,
|
71 |
+
do_sample=True,
|
72 |
+
top_k=30,
|
73 |
+
num_return_sequences=1,
|
74 |
+
eos_token_id=tokenizer.eos_token_id
|
75 |
+
)
|
76 |
+
|
77 |
+
from langchain import HuggingFacePipeline
|
78 |
+
LLAMA2 = HuggingFacePipeline(pipeline = pipe, model_kwargs = {'temperature':0})
|
79 |
+
print(f"\n\n> torch.cuda.is_available(): {torch.cuda.is_available()}")
|
80 |
+
print("\n\n> local LLAMA2 loaded")
|
81 |
+
return LLAMA2
|
82 |
+
|
83 |
+
class QAPipeline:
|
84 |
+
|
85 |
+
def __init__(self):
|
86 |
+
|
87 |
+
print("\n\n> Initializing QAPipeline:")
|
88 |
+
self.llm_name = None
|
89 |
+
self.llm = None
|
90 |
+
|
91 |
+
self.dataset_name = None
|
92 |
+
self.vectorstore = None
|
93 |
+
|
94 |
+
self.qa_chain = None
|
95 |
+
|
96 |
+
def run_agent(self,query, model, dataset):
|
97 |
+
|
98 |
+
try:
|
99 |
+
if (self.llm_name != model) or (self.dataset_name != dataset) or (self.qa_chain == None):
|
100 |
+
self.set_model(model)
|
101 |
+
self.set_vectorstore(dataset)
|
102 |
+
self.set_qa_chain()
|
103 |
+
|
104 |
+
# Get the answer from the chain
|
105 |
+
start = time.time()
|
106 |
+
res = self.qa_chain(query)
|
107 |
+
# answer, docs = res['result'],res['source_documents']
|
108 |
+
end = time.time()
|
109 |
+
|
110 |
+
# Print the result
|
111 |
+
print("\n\n> Question:")
|
112 |
+
print(query)
|
113 |
+
print(f"\n> Answer (took {round(end - start, 2)} s.):")
|
114 |
+
print( res)
|
115 |
+
|
116 |
+
return res
|
117 |
+
|
118 |
+
except Exception as e:
|
119 |
+
# logger.error(f"Answer retrieval failed with {e}")
|
120 |
+
print(f"> QAPipeline run_agent Error : {e}")#, icon=":books:")
|
121 |
+
return
|
122 |
+
|
123 |
+
|
124 |
+
def set_model(self,model_type):
|
125 |
+
if model_type != self.llm_name:
|
126 |
+
match model_type:
|
127 |
+
case "gpt4all":
|
128 |
+
# self.llm = GPT4All(model=gpt4all_model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose)
|
129 |
+
self.llm = GPT4All(model=gpt4all_model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose)
|
130 |
+
# self.llm = HuggingFaceHub(repo_id="nomic-ai/gpt4all-j", model_kwargs={"temperature":0.001, "max_length":1024})
|
131 |
+
case "google/flan-t5-xxl":
|
132 |
+
self.llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.001, "max_length":1024})
|
133 |
+
case "tiiuae/falcon-7b-instruct":
|
134 |
+
self.llm = HuggingFaceHub(repo_id=model_type, model_kwargs={"temperature":0.001, "max_length":1024})
|
135 |
+
case "openai":
|
136 |
+
self.llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
|
137 |
+
case "Deci/DeciLM-6b-instruct":
|
138 |
+
self.llm = ChatOpenAI(model_name="Deci/DeciLM-6b-instruct", temperature=0)
|
139 |
+
case "Deci/DeciLM-6b":
|
140 |
+
self.llm = ChatOpenAI(model_name="Deci/DeciLM-6b", temperature=0)
|
141 |
+
case "local/LLAMA2":
|
142 |
+
self.llm = get_local_LLAMA2()
|
143 |
+
case "anyscale/Llama-2-13b-chat-hf":
|
144 |
+
self.llm = ChatAnyscale(anyscale_api_key=anyscale_api_key,temperature=0, model_name='meta-llama/Llama-2-13b-chat-hf', streaming=False)
|
145 |
+
case "anyscale/Llama-2-70b-chat-hf":
|
146 |
+
self.llm = ChatAnyscale(anyscale_api_key=anyscale_api_key,temperature=0, model_name='meta-llama/Llama-2-70b-chat-hf', streaming=False)
|
147 |
+
case _default:
|
148 |
+
# raise exception if model_type is not supported
|
149 |
+
raise Exception(f"Model type {model_type} is not supported. Please choose a valid one")
|
150 |
+
|
151 |
+
self.llm_name = model_type
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
def set_vectorstore(self, dataset):
|
156 |
+
if dataset != self.dataset_name:
|
157 |
+
# self.vectorstore = load_store(dataset)
|
158 |
+
self.vectorstore = load_FAISS_store()
|
159 |
+
print("\n\n> vectorstore loaded:")
|
160 |
+
self.dataset_name = dataset
|
161 |
+
|
162 |
+
|
163 |
+
def set_qa_chain(self):
|
164 |
+
|
165 |
+
try:
|
166 |
+
memory = ConversationBufferWindowMemory(
|
167 |
+
memory_key="chat_history",
|
168 |
+
input_key="question",
|
169 |
+
output_key = "answer",
|
170 |
+
return_messages=True,
|
171 |
+
k=3
|
172 |
+
)
|
173 |
+
|
174 |
+
# Define a custom prompt
|
175 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
176 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
177 |
+
|
178 |
+
retrieval_qa_template = (
|
179 |
+
"""<<SYS>>
|
180 |
+
You are the AI assistant of company boardpac which provide services to company board members related to banking and financial sector.
|
181 |
+
You have 2 tasks to do.
|
182 |
+
|
183 |
+
Task 1: combine the given chat history and user question to come up with a follow-up question
|
184 |
+
<chat history>: {chat_history}
|
185 |
+
|
186 |
+
Task 2:
|
187 |
+
Identify the type of the follow-up question using following 3 types and answer accordingly.
|
188 |
+
Answer should be short and simple as possible.
|
189 |
+
Dont add any extra details that is not mentioned in the context.
|
190 |
+
|
191 |
+
<Type 1>
|
192 |
+
If the user asks questions like welcome messages, greetings and goodbyes.
|
193 |
+
Just reply accordingly with a short and simple answer as possible.
|
194 |
+
Dont use context information provided below to answer the question.
|
195 |
+
Start the answer with code word Boardpac AI(chat):
|
196 |
+
</Type 1>
|
197 |
+
|
198 |
+
<Type 2>
|
199 |
+
If the follow-up question doesn't belong to type 1 or type 3, that means if the question is not about greetings or Banking and Financial Services say that the question is out of your domain.
|
200 |
+
Start the answer with code word Boardpac AI(OD):
|
201 |
+
</Type 2>
|
202 |
+
|
203 |
+
<Type 3>
|
204 |
+
If the follow-up question is related to Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations.
|
205 |
+
please answer the question based on the context information provided in bottom containing few related documents of central bank acts published in various years.
|
206 |
+
The published year is mentioned as the metadata 'year' of each source document.
|
207 |
+
The content of a bank act of a past year can updated by a bank act from a latest year.
|
208 |
+
Always try to answer with latest information and mention the year which information extracted.
|
209 |
+
If you dont know the answer say you dont know, dont try to makeup answers.
|
210 |
+
Start the answer with code word Boardpac AI(QA):
|
211 |
+
</Type 3>
|
212 |
+
|
213 |
+
<Context information>: {context}
|
214 |
+
<</SYS>>
|
215 |
+
|
216 |
+
[INST]Question : {question}[/INST]"""
|
217 |
+
)
|
218 |
+
|
219 |
+
retrieval_qa_chain_prompt = PromptTemplate(
|
220 |
+
input_variables=["question", "context", "chat_history"],
|
221 |
+
template=retrieval_qa_template
|
222 |
+
)
|
223 |
+
|
224 |
+
self.qa_chain = ConversationalRetrievalChain.from_llm(
|
225 |
+
llm=self.llm,
|
226 |
+
chain_type="stuff",
|
227 |
+
retriever = self.vectorstore.as_retriever(),
|
228 |
+
# retriever = self.vectorstore.as_retriever(search_kwargs={"k": target_source_chunks}
|
229 |
+
return_source_documents= True,
|
230 |
+
get_chat_history=lambda h : h,
|
231 |
+
combine_docs_chain_kwargs={"prompt": retrieval_qa_chain_prompt},
|
232 |
+
verbose=True,
|
233 |
+
memory=memory,
|
234 |
+
)
|
235 |
+
|
236 |
+
print(f"\n> agent_chain created")
|
237 |
+
|
238 |
+
except Exception as e:
|
239 |
+
# logger.error(f"Answer retrieval failed with {e}")
|
240 |
+
print(f"> QAPipeline set_qa_chain_with_agent Error : {e}")#, icon=":books:")
|
241 |
+
return
|