Spaces:
Runtime error
Runtime error
""" | |
Python Backend API to chat with private data | |
08/14/2023 | |
D.M. Theekshana Samaradiwakara | |
""" | |
import os | |
import time | |
from dotenv import load_dotenv | |
from langchain.chains import RetrievalQA | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.llms import GPT4All | |
from langchain.llms import HuggingFaceHub | |
from langchain.chat_models import ChatOpenAI | |
# from langchain.retrievers._query.base import SelfQueryRetriever | |
# from langchain.chains.query_constructor.base import AttributeInfo | |
# from chromaDb import load_store | |
from faissDb import load_FAISS_store | |
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain, ConversationalRetrievalChain | |
from conversationBufferWindowMemory import ConversationBufferWindowMemory | |
from langchain.memory import ReadOnlySharedMemory | |
load_dotenv() | |
#gpt4 all model | |
gpt4all_model_path = os.environ.get('GPT4ALL_MODEL_PATH') | |
model_n_ctx = os.environ.get('MODEL_N_CTX') | |
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8)) | |
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) | |
openai_api_key = os.environ.get('OPENAI_API_KEY') | |
verbose = os.environ.get('VERBOSE') | |
# activate/deactivate the streaming StdOut callback for LLMs | |
callbacks = [StreamingStdOutCallbackHandler()] | |
memory = ConversationBufferWindowMemory( | |
memory_key="chat_history", | |
input_key="question", | |
return_messages=True, | |
k=3 | |
) | |
readonlymemory = ReadOnlySharedMemory(memory=memory) | |
print("\n\n> Initializing QAPipeline:") | |
global llm_name | |
llm_name = 'None' | |
global llm | |
llm = 'None' | |
global dataset_name | |
dataset_name = 'None' | |
global vectorstore | |
vectorstore = 'None' | |
qa_chain = None | |
agent = None | |
def run(query, model, dataset): | |
if (llm_name != model) or (dataset_name != dataset) or (qa_chain == None): | |
set_model(model) | |
set_vectorstore(dataset) | |
set_qa_chain() | |
# Get the answer from the chain | |
start = time.time() | |
res = qa_chain(query) | |
# answer, docs = res['result'],res['source_documents'] | |
end = time.time() | |
# Print the result | |
print("\n\n> Question:") | |
print(query) | |
print(f"\n> Answer (took {round(end - start, 2)} s.):") | |
print( res) | |
return res | |
def run_agent(query, model, dataset): | |
try: | |
if (llm_name != model) or (dataset_name != dataset) or (agent == None): | |
set_model(model) | |
set_vectorstore(dataset) | |
set_qa_chain_with_agent() | |
# Get the answer from the chain | |
start = time.time() | |
res = agent(query) | |
# answer, docs = res['result'],res['source_documents'] | |
end = time.time() | |
# Print the result | |
print("\n\n> Question:") | |
print(query) | |
print(f"\n> Answer (took {round(end - start, 2)} s.):") | |
print( res) | |
return res["output"] | |
except Exception as e: | |
# logger.error(f"Answer retrieval failed with {e}") | |
print(f"> QAPipeline run_agent Error : {e}")#, icon=":books:") | |
return | |
def set_model(model_type): | |
if model_type != llm_name: | |
global llm | |
match model_type: | |
case "gpt4all": | |
# llm = GPT4All(model=gpt4all_model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose) | |
llm = GPT4All(model=gpt4all_model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose) | |
# llm = HuggingFaceHub(repo_id="nomic-ai/gpt4all-j", model_kwargs={"temperature":0.001, "max_length":1024}) | |
case "google/flan-t5-xxl": | |
llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.001, "max_length":1024}) | |
case "tiiuae/falcon-7b-instruct": | |
llm = HuggingFaceHub(repo_id=model_type, model_kwargs={"temperature":0.001, "max_length":1024}) | |
case "openai": | |
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) | |
case _default: | |
# raise exception if model_type is not supported | |
raise Exception(f"Model type {model_type} is not supported. Please choose a valid one") | |
# global llm_name | |
llm_name = model_type | |
def set_vectorstore( dataset): | |
if dataset != dataset_name: | |
# vectorstore = load_store(dataset) | |
global vectorstore | |
vectorstore = load_FAISS_store() | |
print("\n\n> vectorstore loaded:") | |
dataset_name = dataset | |
def set_qa_chain(): | |
global qa_chain | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever = vectorstore.as_retriever(), | |
# retriever = vectorstore.as_retriever(search_kwargs={"k": target_source_chunks} | |
return_source_documents= True | |
) | |
def set_qa_chain_with_agent(): | |
try: | |
# Define a custom prompt | |
general_qa_template = ( | |
"""You can have a general conversation with the users like greetings. | |
Continue the conversation and only answer questions related to banking sector like financial and legal. | |
If you dont know the answer say you dont know, dont try to makeup answers. | |
Conversation: {chat_history} | |
Question: {question} | |
""" | |
) | |
general_qa_chain_prompt = PromptTemplate(input_variables=["question", "chat_history"], template=general_qa_template) | |
general_qa_chain = LLMChain( | |
llm=llm, | |
prompt=general_qa_chain_prompt, | |
verbose=True, | |
memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory | |
) | |
general_qa_chain_tool = Tool( | |
name="general qa", | |
func= general_qa_chain.run, | |
description='''useful for when you need to have a general conversation with the users like greetings | |
or to answer general purpose questions related to banking sector like financial and legal. | |
Input should be a fully formed question.''', | |
return_direct=True, | |
) | |
# Define a custom prompt | |
retrieval_qa_template = ( | |
""" | |
please answer the question based on the chat history and context with the latest information. | |
You have provided context information below related to central bank acts published in various years. | |
The content of a bank act can updated by a bank act from a latest year. | |
If you dont know the answer say you dont know, dont try to makeup answers. | |
Conversation: {chat_history} | |
Context: {context} | |
Question : {question} | |
""" | |
) | |
retrieval_qa_chain_prompt = PromptTemplate( | |
input_variables=["question", "context", "chat_history"], | |
template=retrieval_qa_template | |
) | |
bank_regulations_qa = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
chain_type="stuff", | |
retriever = vectorstore.as_retriever(), | |
# retriever = vectorstore.as_retriever(search_kwargs={"k": target_source_chunks} | |
return_source_documents= True, | |
get_chat_history=lambda h : h, | |
combine_docs_chain_kwargs={"prompt": retrieval_qa_chain_prompt}, | |
verbose=True, | |
memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory | |
) | |
bank_regulations_qa_tool = Tool( | |
name="bank regulations", | |
func= lambda question: bank_regulations_qa({"question": question}), | |
description='''useful for when you need to answer questions about | |
financial and legal information issued from central bank regarding banks and bank regulations. | |
Input should be a fully formed question.''', | |
return_direct=True, | |
) | |
tools = [ | |
bank_regulations_qa_tool, | |
general_qa_chain_tool | |
] | |
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:""" | |
suffix = """Begin!" | |
{chat_history} | |
Question: {question} | |
{agent_scratchpad}""" | |
agent_prompt = ZeroShotAgent.create_prompt( | |
tools, | |
prefix=prefix, | |
suffix=suffix, | |
input_variables=["question", "chat_history", "agent_scratchpad"], | |
) | |
llm_chain = LLMChain(llm=llm, prompt=agent_prompt) | |
zeroShotAgent = ZeroShotAgent( | |
llm_chain=llm_chain, | |
tools=tools, | |
verbose=True, | |
) | |
agent_chain = AgentExecutor.from_agent_and_tools( | |
agent=zeroShotAgent, | |
tools=tools, | |
verbose=True, | |
memory=memory, | |
handle_parsing_errors=True, | |
) | |
global agent | |
agent = agent_chain | |
print(f"\n> agent_chain created") | |
except Exception as e: | |
# logger.error(f"Answer retrieval failed with {e}") | |
print(f"> QAPipeline set_qa_chain_with_agent Error : {e}")#, icon=":books:") | |
return | |