Spaces:
Runtime error
Runtime error
""" | |
Python Backend API to chat with private data | |
08/14/2023 | |
D.M. Theekshana Samaradiwakara | |
""" | |
import os | |
import time | |
from dotenv import load_dotenv | |
from langchain.chains import RetrievalQA | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.llms import GPT4All | |
from langchain.llms import HuggingFaceHub | |
from langchain.chat_models import ChatOpenAI | |
# from langchain.retrievers.self_query.base import SelfQueryRetriever | |
# from langchain.chains.query_constructor.base import AttributeInfo | |
# from chromaDb import load_store | |
from faissDb import load_FAISS_store | |
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain, ConversationalRetrievalChain | |
from conversationBufferWindowMemory import ConversationBufferWindowMemory | |
from langchain.memory import ReadOnlySharedMemory | |
load_dotenv() | |
#gpt4 all model | |
gpt4all_model_path = os.environ.get('GPT4ALL_MODEL_PATH') | |
model_n_ctx = os.environ.get('MODEL_N_CTX') | |
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8)) | |
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) | |
openai_api_key = os.environ.get('OPENAI_API_KEY') | |
verbose = os.environ.get('VERBOSE') | |
# activate/deactivate the streaming StdOut callback for LLMs | |
callbacks = [StreamingStdOutCallbackHandler()] | |
memory = ConversationBufferWindowMemory( | |
memory_key="chat_history", | |
input_key="question", | |
return_messages=True, | |
k=3 | |
) | |
readonlymemory = ReadOnlySharedMemory(memory=memory) | |
class Singleton: | |
__instance = None | |
def getInstance(): | |
""" Static access method. """ | |
if Singleton.__instance == None: | |
Singleton() | |
return Singleton.__instance | |
def __init__(self): | |
""" Virtually private constructor. """ | |
if Singleton.__instance != None: | |
raise Exception("This class is a singleton!") | |
else: | |
Singleton.__instance = QAPipeline() | |
def get_local_LLAMA2(): | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-13b-chat-hf", | |
# use_auth_token=True, | |
) | |
model = AutoModelForCausalLM.from_pretrained("NousResearch/Llama-2-13b-chat-hf", | |
device_map='auto', | |
torch_dtype=torch.float16, | |
use_auth_token=True, | |
# load_in_8bit=True, | |
# load_in_4bit=True | |
) | |
from transformers import pipeline | |
pipe = pipeline("text-generation", | |
model=model, | |
tokenizer= tokenizer, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
max_new_tokens = 512, | |
do_sample=True, | |
top_k=30, | |
num_return_sequences=1, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
from langchain import HuggingFacePipeline | |
LLAMA2 = HuggingFacePipeline(pipeline = pipe, model_kwargs = {'temperature':0}) | |
print(f"\n\n> torch.cuda.is_available(): {torch.cuda.is_available()}") | |
print("\n\n> local LLAMA2 loaded") | |
return LLAMA2 | |
class QAPipeline: | |
def __init__(self): | |
print("\n\n> Initializing QAPipeline:") | |
self.llm_name = None | |
self.llm = None | |
self.dataset_name = None | |
self.vectorstore = None | |
self.qa_chain = None | |
self.agent = None | |
def run(self,query, model, dataset): | |
if (self.llm_name != model) or (self.dataset_name != dataset) or (self.qa_chain == None): | |
self.set_model(model) | |
self.set_vectorstore(dataset) | |
self.set_qa_chain() | |
# Get the answer from the chain | |
start = time.time() | |
res = self.qa_chain(query) | |
# answer, docs = res['result'],res['source_documents'] | |
end = time.time() | |
# Print the result | |
print("\n\n> Question:") | |
print(query) | |
print(f"\n> Answer (took {round(end - start, 2)} s.):") | |
print( res) | |
return res | |
def run_agent(self,query, model, dataset): | |
try: | |
if (self.llm_name != model) or (self.dataset_name != dataset) or (self.agent == None): | |
self.set_model(model) | |
self.set_vectorstore(dataset) | |
self.set_qa_chain_with_agent() | |
# Get the answer from the chain | |
start = time.time() | |
res = self.agent(query) | |
# answer, docs = res['result'],res['source_documents'] | |
end = time.time() | |
# Print the result | |
print("\n\n> Question:") | |
print(query) | |
print(f"\n> Answer (took {round(end - start, 2)} s.):") | |
print( res) | |
return res["output"] | |
except Exception as e: | |
# logger.error(f"Answer retrieval failed with {e}") | |
print(f"> QAPipeline run_agent Error : {e}")#, icon=":books:") | |
return | |
def set_model(self,model_type): | |
if model_type != self.llm_name: | |
match model_type: | |
case "gpt4all": | |
# self.llm = GPT4All(model=gpt4all_model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose) | |
self.llm = GPT4All(model=gpt4all_model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose) | |
# self.llm = HuggingFaceHub(repo_id="nomic-ai/gpt4all-j", model_kwargs={"temperature":0.001, "max_length":1024}) | |
case "google/flan-t5-xxl": | |
self.llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.001, "max_length":1024}) | |
case "tiiuae/falcon-7b-instruct": | |
self.llm = HuggingFaceHub(repo_id=model_type, model_kwargs={"temperature":0.001, "max_length":1024}) | |
case "openai": | |
self.llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) | |
case "Deci/DeciLM-6b-instruct": | |
self.llm = ChatOpenAI(model_name="Deci/DeciLM-6b-instruct", temperature=0) | |
case "Deci/DeciLM-6b": | |
self.llm = ChatOpenAI(model_name="Deci/DeciLM-6b", temperature=0) | |
case "local/LLAMA2": | |
self.llm = get_local_LLAMA2() | |
case _default: | |
# raise exception if model_type is not supported | |
raise Exception(f"Model type {model_type} is not supported. Please choose a valid one") | |
self.llm_name = model_type | |
def set_vectorstore(self, dataset): | |
if dataset != self.dataset_name: | |
# self.vectorstore = load_store(dataset) | |
self.vectorstore = load_FAISS_store() | |
print("\n\n> vectorstore loaded:") | |
self.dataset_name = dataset | |
def set_qa_chain(self): | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever = self.vectorstore.as_retriever(), | |
# retriever = self.vectorstore.as_retriever(search_kwargs={"k": target_source_chunks} | |
return_source_documents= True | |
) | |
def set_qa_chain_with_agent(self): | |
try: | |
# Define a custom prompt | |
general_qa_template = ( | |
"""You can have a general conversation with the users like greetings. | |
Continue the conversation and only answer questions related to banking sector like financial and legal. | |
If you dont know the answer say you dont know, dont try to makeup answers. | |
Conversation: {chat_history} | |
Question: {question} | |
""" | |
) | |
general_qa_chain_prompt = PromptTemplate(input_variables=["question", "chat_history"], template=general_qa_template) | |
general_qa_chain = LLMChain( | |
llm=self.llm, | |
prompt=general_qa_chain_prompt, | |
verbose=True, | |
memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory | |
) | |
general_qa_chain_tool = Tool( | |
name="general qa", | |
func= general_qa_chain.run, | |
description='''useful for when you need to have a general conversation with the users like greetings | |
or to answer general purpose questions related to banking sector like financial and legal. | |
Input should be a fully formed question.''', | |
return_direct=True, | |
) | |
# Define a custom prompt | |
retrieval_qa_template = ( | |
""" | |
please answer the question based on the chat history and context with the latest information. | |
You have provided context information below related to central bank acts published in various years. | |
The content of a bank act can updated by a bank act from a latest year. | |
If you dont know the answer say you dont know, dont try to makeup answers. | |
Conversation: {chat_history} | |
Context: {context} | |
Question : {question} | |
""" | |
) | |
retrieval_qa_chain_prompt = PromptTemplate( | |
input_variables=["question", "context", "chat_history"], | |
template=retrieval_qa_template | |
) | |
bank_regulations_qa = ConversationalRetrievalChain.from_llm( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever = self.vectorstore.as_retriever(), | |
# retriever = self.vectorstore.as_retriever(search_kwargs={"k": target_source_chunks} | |
return_source_documents= True, | |
get_chat_history=lambda h : h, | |
combine_docs_chain_kwargs={"prompt": retrieval_qa_chain_prompt}, | |
verbose=True, | |
memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory | |
) | |
bank_regulations_qa_tool = Tool( | |
name="bank regulations", | |
func= lambda question: bank_regulations_qa({"question": question}), | |
description='''useful for when you need to answer questions about | |
financial and legal information issued from central bank regarding banks and bank regulations. | |
Input should be a fully formed question.''', | |
return_direct=True, | |
) | |
tools = [ | |
bank_regulations_qa_tool, | |
general_qa_chain_tool | |
] | |
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:""" | |
suffix = """Begin!" | |
{chat_history} | |
Question: {question} | |
{agent_scratchpad}""" | |
agent_prompt = ZeroShotAgent.create_prompt( | |
tools, | |
prefix=prefix, | |
suffix=suffix, | |
input_variables=["question", "chat_history", "agent_scratchpad"], | |
) | |
llm_chain = LLMChain(llm=self.llm, prompt=agent_prompt) | |
agent = ZeroShotAgent( | |
llm_chain=llm_chain, | |
tools=tools, | |
verbose=True, | |
) | |
agent_chain = AgentExecutor.from_agent_and_tools( | |
agent=agent, | |
tools=tools, | |
verbose=True, | |
memory=memory, | |
handle_parsing_errors=True, | |
) | |
self.agent = agent_chain | |
print(f"\n> agent_chain created") | |
except Exception as e: | |
# logger.error(f"Answer retrieval failed with {e}") | |
print(f"> QAPipeline set_qa_chain_with_agent Error : {e}")#, icon=":books:") | |
return | |