Spaces:
Runtime error
Runtime error
""" | |
Python Backend API to chat with private data | |
08/14/2023 | |
D.M. Theekshana Samaradiwakara | |
""" | |
import os | |
import time | |
from dotenv import load_dotenv | |
from langchain.chains import RetrievalQA | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.llms import GPT4All | |
from langchain.llms import HuggingFaceHub | |
from langchain.chat_models import ChatOpenAI | |
# from langchain.retrievers.self_query.base import SelfQueryRetriever | |
# from langchain.chains.query_constructor.base import AttributeInfo | |
# from chromaDb import load_store | |
from faissDb import load_FAISS_store | |
from langchain.agents import initialize_agent, Tool | |
from langchain.agents import AgentType | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
load_dotenv() | |
#gpt4 all model | |
gpt4all_model_path = os.environ.get('GPT4ALL_MODEL_PATH') | |
model_n_ctx = os.environ.get('MODEL_N_CTX') | |
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8)) | |
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) | |
openai_api_key = os.environ.get('OPENAI_API_KEY') | |
verbose = os.environ.get('VERBOSE') | |
# activate/deactivate the streaming StdOut callback for LLMs | |
callbacks = [StreamingStdOutCallbackHandler()] | |
class QAPipeline: | |
def __init__(self): | |
self.llm_name = None | |
self.llm = None | |
self.dataset_name = None | |
self.vectorstore = None | |
self.qa_chain = None | |
self.agent = None | |
def run(self,query, model, dataset): | |
if (self.llm_name != model) or (self.dataset_name != dataset) or (self.qa_chain == None): | |
self.set_model(model) | |
self.set_vectorstore(dataset) | |
self.set_qa_chain() | |
# Get the answer from the chain | |
start = time.time() | |
res = self.qa_chain(query) | |
# answer, docs = res['result'],res['source_documents'] | |
end = time.time() | |
# Print the result | |
print("\n\n> Question:") | |
print(query) | |
print(f"\n> Answer (took {round(end - start, 2)} s.):") | |
print( res) | |
return res | |
def run_agent(self,query, model, dataset): | |
if (self.llm_name != model) or (self.dataset_name != dataset) or (self.agent == None): | |
self.set_model(model) | |
self.set_vectorstore(dataset) | |
self.set_qa_chain_with_agent() | |
# Get the answer from the chain | |
start = time.time() | |
res = self.agent(query) | |
# answer, docs = res['result'],res['source_documents'] | |
end = time.time() | |
# Print the result | |
print("\n\n> Question:") | |
print(query) | |
print(f"\n> Answer (took {round(end - start, 2)} s.):") | |
print( res) | |
return res["output"] | |
def set_model(self,model_type): | |
if model_type != self.llm_name: | |
match model_type: | |
case "gpt4all": | |
# self.llm = GPT4All(model=gpt4all_model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose) | |
self.llm = GPT4All(model=gpt4all_model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose) | |
# self.llm = HuggingFaceHub(repo_id="nomic-ai/gpt4all-j", model_kwargs={"temperature":0.001, "max_length":1024}) | |
case "google/flan-t5-xxl": | |
self.llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.001, "max_length":1024}) | |
case "tiiuae/falcon-7b-instruct": | |
self.llm = HuggingFaceHub(repo_id=model_type, model_kwargs={"temperature":0.001, "max_length":1024}) | |
case "openai": | |
self.llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) | |
case _default: | |
# raise exception if model_type is not supported | |
raise Exception(f"Model type {model_type} is not supported. Please choose a valid one") | |
self.llm_name = model_type | |
def set_vectorstore(self, dataset): | |
if dataset != self.dataset_name: | |
# self.vectorstore = load_store(dataset) | |
self.vectorstore = load_FAISS_store() | |
print("\n\n> vectorstore loaded:") | |
self.dataset_name = dataset | |
def set_qa_chain(self): | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever = self.vectorstore.as_retriever(), | |
# retriever = self.vectorstore.as_retriever(search_kwargs={"k": target_source_chunks} | |
return_source_documents= True | |
) | |
def set_qa_chain_with_agent(self): | |
# Define a custom prompt | |
general_qa_template = ( | |
"""You are the AI assistant of the Boardpac company which provide services for companies board members. | |
You can have a general conversation with the users like greetings. | |
But only answer questions related to banking sector like financial and legal. | |
If you dont know the answer say you dont know, dont try to makeup answers. | |
each answer should start with code word BoardPac Conversation AI: | |
Question: {question} | |
""" | |
) | |
general_qa_chain_prompt = PromptTemplate.from_template(general_qa_template) | |
general_qa_chain = LLMChain(llm=self.llm, prompt=general_qa_chain_prompt) | |
# Define a custom prompt | |
retrieval_qa_template = ( | |
"""You are the AI assistant of the Boardpac company which provide services for companies board members. | |
You have provided context information below related to central bank acts published in various years. The content of a bank act can updated by a bank act from a latest year. | |
{context} | |
Given this information, please answer the question with the latest information. | |
If you dont know the answer say you dont know, dont try to makeup answers. | |
each answer should start with code word BoardPac Retrieval AI: | |
Question: {question} | |
""" | |
) | |
retrieval_qa_chain_prompt = PromptTemplate.from_template(retrieval_qa_template) | |
bank_regulations_qa = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever = self.vectorstore.as_retriever(), | |
# retriever = self.vectorstore.as_retriever(search_kwargs={"k": target_source_chunks} | |
return_source_documents= True, | |
input_key="question", | |
chain_type_kwargs={"prompt": retrieval_qa_chain_prompt}, | |
) | |
tools = [ | |
Tool( | |
name="bank regulations", | |
func= lambda query: bank_regulations_qa({"question": query}), | |
description='''useful for when you need to answer questions about | |
financial and legal information issued from central bank regarding banks and bank regulations. | |
Input should be a fully formed question.''', | |
return_direct=True, | |
), | |
Tool( | |
name="general qa", | |
func= general_qa_chain.run, | |
description='''useful for when you need to have a general conversation with the users like greetings | |
or to answer general purpose questions related to banking sector like financial and legal. | |
Input should be a fully formed question.''', | |
return_direct=True, | |
), | |
] | |
self.agent = initialize_agent( | |
tools, | |
self.llm, | |
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
verbose=True, | |
max_iterations=3, | |
) | |