Spaces:
Runtime error
Runtime error
Boardpac/theekshanas
commited on
Commit
•
3715d20
1
Parent(s):
c9744f1
agent implementation
Browse files- __pycache__/qaPipeline.cpython-311.pyc +0 -0
- app.py +7 -2
- faiss_index/dummy +0 -0
- qaPipeline.py +96 -1
__pycache__/qaPipeline.cpython-311.pyc
CHANGED
Binary files a/__pycache__/qaPipeline.cpython-311.pyc and b/__pycache__/qaPipeline.cpython-311.pyc differ
|
|
app.py
CHANGED
@@ -132,11 +132,16 @@ def parameters_change_button(chat_model, show_source):
|
|
132 |
|
133 |
|
134 |
def get_answer_from_backend(query, model, dataset):
|
135 |
-
response = qaPipeline.run(query=query, model=model, dataset=dataset)
|
|
|
136 |
return response
|
137 |
|
138 |
def show_query_response(query, response, show_source_files):
|
139 |
-
|
|
|
|
|
|
|
|
|
140 |
|
141 |
st.write(user_template.replace(
|
142 |
"{{MSG}}", query), unsafe_allow_html=True)
|
|
|
132 |
|
133 |
|
134 |
def get_answer_from_backend(query, model, dataset):
|
135 |
+
# response = qaPipeline.run(query=query, model=model, dataset=dataset)
|
136 |
+
response = qaPipeline.run_agent(query=query, model=model, dataset=dataset)
|
137 |
return response
|
138 |
|
139 |
def show_query_response(query, response, show_source_files):
|
140 |
+
docs = []
|
141 |
+
if isinstance(response, dict):
|
142 |
+
answer, docs = response['result'], response['source_documents']
|
143 |
+
else:
|
144 |
+
answer = response
|
145 |
|
146 |
st.write(user_template.replace(
|
147 |
"{{MSG}}", query), unsafe_allow_html=True)
|
faiss_index/dummy
DELETED
File without changes
|
qaPipeline.py
CHANGED
@@ -23,6 +23,11 @@ from langchain.chat_models import ChatOpenAI
|
|
23 |
# from chromaDb import load_store
|
24 |
from faissDb import load_FAISS_store
|
25 |
|
|
|
|
|
|
|
|
|
|
|
26 |
load_dotenv()
|
27 |
|
28 |
#gpt4 all model
|
@@ -49,6 +54,7 @@ class QAPipeline:
|
|
49 |
self.vectorstore = None
|
50 |
|
51 |
self.qa_chain = None
|
|
|
52 |
|
53 |
def run(self,query, model, dataset):
|
54 |
|
@@ -70,6 +76,28 @@ class QAPipeline:
|
|
70 |
print( res)
|
71 |
|
72 |
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
|
75 |
def set_model(self,model_type):
|
@@ -107,4 +135,71 @@ class QAPipeline:
|
|
107 |
# retriever = self.vectorstore.as_retriever(search_kwargs={"k": target_source_chunks}
|
108 |
return_source_documents= True
|
109 |
)
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# from chromaDb import load_store
|
24 |
from faissDb import load_FAISS_store
|
25 |
|
26 |
+
from langchain.agents import initialize_agent, Tool
|
27 |
+
from langchain.agents import AgentType
|
28 |
+
from langchain.prompts import PromptTemplate
|
29 |
+
from langchain.chains import LLMChain
|
30 |
+
|
31 |
load_dotenv()
|
32 |
|
33 |
#gpt4 all model
|
|
|
54 |
self.vectorstore = None
|
55 |
|
56 |
self.qa_chain = None
|
57 |
+
self.agent = None
|
58 |
|
59 |
def run(self,query, model, dataset):
|
60 |
|
|
|
76 |
print( res)
|
77 |
|
78 |
return res
|
79 |
+
|
80 |
+
def run_agent(self,query, model, dataset):
|
81 |
+
|
82 |
+
if (self.llm_name != model) or (self.dataset_name != dataset) or (self.agent == None):
|
83 |
+
self.set_model(model)
|
84 |
+
self.set_vectorstore(dataset)
|
85 |
+
self.set_qa_chain_with_agent()
|
86 |
+
|
87 |
+
# Get the answer from the chain
|
88 |
+
start = time.time()
|
89 |
+
res = self.agent(query)
|
90 |
+
# answer, docs = res['result'],res['source_documents']
|
91 |
+
end = time.time()
|
92 |
+
|
93 |
+
# Print the result
|
94 |
+
print("\n\n> Question:")
|
95 |
+
print(query)
|
96 |
+
print(f"\n> Answer (took {round(end - start, 2)} s.):")
|
97 |
+
print( res)
|
98 |
+
|
99 |
+
return res["output"]
|
100 |
+
|
101 |
|
102 |
|
103 |
def set_model(self,model_type):
|
|
|
135 |
# retriever = self.vectorstore.as_retriever(search_kwargs={"k": target_source_chunks}
|
136 |
return_source_documents= True
|
137 |
)
|
138 |
+
|
139 |
+
|
140 |
+
def set_qa_chain_with_agent(self):
|
141 |
+
|
142 |
+
# Define a custom prompt
|
143 |
+
general_qa_template = (
|
144 |
+
"""You are the AI assistant of the Boardpac company which provide services for companies board members.
|
145 |
+
You can have a general conversation with the users like greetings.
|
146 |
+
But only answer questions related to banking sector like financial and legal.
|
147 |
+
If you dont know the answer say you dont know, dont try to makeup answers.
|
148 |
+
each answer should start with code word BoardPac Conversation AI:
|
149 |
+
Question: {question}
|
150 |
+
"""
|
151 |
+
)
|
152 |
+
|
153 |
+
general_qa_chain_prompt = PromptTemplate.from_template(general_qa_template)
|
154 |
+
general_qa_chain = LLMChain(llm=self.llm, prompt=general_qa_chain_prompt)
|
155 |
+
|
156 |
+
# Define a custom prompt
|
157 |
+
retrieval_qa_template = (
|
158 |
+
"""You are the AI assistant of the Boardpac company which provide services for companies board members.
|
159 |
+
You have provided context information below related to central bank acts published in various years. The content of a bank act can updated by a bank act from a latest year.
|
160 |
+
{context}
|
161 |
+
Given this information, please answer the question with the latest information.
|
162 |
+
If you dont know the answer say you dont know, dont try to makeup answers.
|
163 |
+
each answer should start with code word BoardPac Retrieval AI:
|
164 |
+
Question: {question}
|
165 |
+
"""
|
166 |
+
)
|
167 |
+
retrieval_qa_chain_prompt = PromptTemplate.from_template(retrieval_qa_template)
|
168 |
+
|
169 |
+
bank_regulations_qa = RetrievalQA.from_chain_type(
|
170 |
+
llm=self.llm,
|
171 |
+
chain_type="stuff",
|
172 |
+
retriever = self.vectorstore.as_retriever(),
|
173 |
+
# retriever = self.vectorstore.as_retriever(search_kwargs={"k": target_source_chunks}
|
174 |
+
return_source_documents= True,
|
175 |
+
input_key="question",
|
176 |
+
chain_type_kwargs={"prompt": retrieval_qa_chain_prompt},
|
177 |
+
)
|
178 |
+
|
179 |
+
tools = [
|
180 |
+
Tool(
|
181 |
+
name="bank regulations",
|
182 |
+
func= lambda query: bank_regulations_qa({"question": query}),
|
183 |
+
description='''useful for when you need to answer questions about
|
184 |
+
financial and legal information issued from central bank regarding banks and bank regulations.
|
185 |
+
Input should be a fully formed question.''',
|
186 |
+
return_direct=True,
|
187 |
+
),
|
188 |
+
|
189 |
+
Tool(
|
190 |
+
name="general qa",
|
191 |
+
func= general_qa_chain.run,
|
192 |
+
description='''useful for when you need to have a general conversation with the users like greetings
|
193 |
+
or to answer general purpose questions related to banking sector like financial and legal.
|
194 |
+
Input should be a fully formed question.''',
|
195 |
+
return_direct=True,
|
196 |
+
),
|
197 |
+
]
|
198 |
+
|
199 |
+
self.agent = initialize_agent(
|
200 |
+
tools,
|
201 |
+
self.llm,
|
202 |
+
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
203 |
+
verbose=True,
|
204 |
+
max_iterations=3,
|
205 |
+
)
|