Spaces:
Runtime error
Runtime error
""" | |
Python Backend API to chat with private data | |
08/14/2023 | |
D.M. Theekshana Samaradiwakara | |
""" | |
import os | |
import time | |
from dotenv import load_dotenv | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.llms import GPT4All | |
from langchain.llms import HuggingFaceHub | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chat_models import ChatAnyscale | |
# from langchain.retrievers.self_query.base import SelfQueryRetriever | |
# from langchain.chains.query_constructor.base import AttributeInfo | |
# from chromaDb import load_store | |
from faissDb import load_FAISS_store | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain, ConversationalRetrievalChain | |
from conversationBufferWindowMemory import ConversationBufferWindowMemory | |
load_dotenv() | |
#gpt4 all model | |
gpt4all_model_path = os.environ.get('GPT4ALL_MODEL_PATH') | |
model_n_ctx = os.environ.get('MODEL_N_CTX') | |
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8)) | |
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) | |
openai_api_key = os.environ.get('OPENAI_API_KEY') | |
anyscale_api_key = os.environ.get('ANYSCALE_ENDPOINT_TOKEN') | |
verbose = os.environ.get('VERBOSE') | |
# activate/deactivate the streaming StdOut callback for LLMs | |
callbacks = [StreamingStdOutCallbackHandler()] | |
import re | |
def is_valid_open_ai_api_key(secretKey): | |
if re.search("^sk-[a-zA-Z0-9]{32,}$", secretKey ): | |
return True | |
else: return False | |
def get_local_LLAMA2(): | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-13b-chat-hf", | |
# use_auth_token=True, | |
) | |
model = AutoModelForCausalLM.from_pretrained("NousResearch/Llama-2-13b-chat-hf", | |
device_map='auto', | |
torch_dtype=torch.float16, | |
use_auth_token=True, | |
# load_in_8bit=True, | |
# load_in_4bit=True | |
) | |
from transformers import pipeline | |
pipe = pipeline("text-generation", | |
model=model, | |
tokenizer= tokenizer, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
max_new_tokens = 512, | |
do_sample=True, | |
top_k=30, | |
num_return_sequences=1, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
from langchain import HuggingFacePipeline | |
LLAMA2 = HuggingFacePipeline(pipeline = pipe, model_kwargs = {'temperature':0}) | |
print(f"\n\n> torch.cuda.is_available(): {torch.cuda.is_available()}") | |
print("\n\n> local LLAMA2 loaded") | |
return LLAMA2 | |
memory = ConversationBufferWindowMemory( | |
memory_key="chat_history", | |
input_key="question", | |
output_key = "answer", | |
return_messages=True, | |
k=3 | |
) | |
class QAPipeline: | |
def __init__(self): | |
print("\n\n> Initializing QAPipeline:") | |
self.llm_name = None | |
self.llm = None | |
self.dataset_name = None | |
self.vectorstore = None | |
self.qa_chain = None | |
def run_agent(self,query, model, dataset, openai_api_key=None): | |
try: | |
if (self.llm_name != model) or (self.dataset_name != dataset) or (self.qa_chain == None): | |
self.set_model(model, openai_api_key) | |
self.set_vectorstore(dataset) | |
self.set_qa_chain() | |
# Get the answer from the chain | |
start = time.time() | |
res = self.qa_chain(query) | |
# answer, docs = res['result'],res['source_documents'] | |
end = time.time() | |
# Print the result | |
print("\n\n> Question:") | |
print(query) | |
print(f"\n> Answer (took {round(end - start, 2)} s.):") | |
print( res) | |
return res | |
except Exception as e: | |
# logger.error(f"Answer retrieval failed with {e}") | |
print(f"> QAPipeline run_agent Error : {e}")#, icon=":books:") | |
return | |
def set_model(self, model_type, openai_api_key): | |
if model_type != self.llm_name: | |
match model_type: | |
case "gpt4all": | |
# self.llm = GPT4All(model=gpt4all_model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose) | |
self.llm = GPT4All(model=gpt4all_model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose) | |
# self.llm = HuggingFaceHub(repo_id="nomic-ai/gpt4all-j", model_kwargs={"temperature":0.001, "max_length":1024}) | |
case "google/flan-t5-xxl": | |
self.llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.001, "max_length":1024}) | |
case "tiiuae/falcon-7b-instruct": | |
self.llm = HuggingFaceHub(repo_id=model_type, model_kwargs={"temperature":0.001, "max_length":1024}) | |
case "openai": | |
print(f"> openai_api_key: {openai_api_key}") | |
if is_valid_open_ai_api_key(openai_api_key): | |
self.llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=openai_api_key ) | |
else: return KeyError("openai_api_key is not valid") | |
case "Deci/DeciLM-6b": | |
self.llm = ChatOpenAI(model_name="Deci/DeciLM-6b", temperature=0) | |
case "local/LLAMA2": | |
self.llm = get_local_LLAMA2() | |
case "anyscale/Llama-2-13b-chat-hf": | |
self.llm = ChatAnyscale(anyscale_api_key=anyscale_api_key,temperature=0, model_name='meta-llama/Llama-2-13b-chat-hf', streaming=False) | |
case "anyscale/Llama-2-70b-chat-hf": | |
self.llm = ChatAnyscale(anyscale_api_key=anyscale_api_key,temperature=0, model_name='meta-llama/Llama-2-70b-chat-hf', streaming=False) | |
case _default: | |
# raise exception if model_type is not supported | |
raise Exception(f"Model type {model_type} is not supported. Please choose a valid one") | |
self.llm_name = model_type | |
def set_vectorstore(self, dataset): | |
if dataset != self.dataset_name: | |
# self.vectorstore = load_store(dataset) | |
self.vectorstore = load_FAISS_store() | |
print("\n\n> vectorstore loaded:") | |
self.dataset_name = dataset | |
def set_qa_chain(self): | |
print(f"\n> creating agent_chain") | |
try: | |
# Define a custom prompt | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
retrieval_qa_template = ( | |
"""<<SYS>> | |
You are the AI assistant of company boardpac which provide services to company board members related to banking and financial sector. | |
please answer the question based on the chat history provided below. | |
<chat history>: {chat_history} | |
Identify the type of the question using following 3 types and answer accordingly. | |
Answer should be short and simple as possible. | |
Dont add any extra details that is not mentioned in the context. | |
<Type 1> | |
If the user asks questions like welcome messages, greetings and goodbyes. | |
Just reply accordingly with a short and simple answer as possible. | |
Dont use context information provided below to answer the question. | |
Start the answer with code word Boardpac AI(chat): | |
<Type 2> | |
If the question doesn't belong to type 1 or type 3, that means if the question is not about greetings or Banking and Financial Services say that the question is out of your domain. | |
Start the answer with code word Boardpac AI(OD): | |
<Type 3> | |
If the question is related to Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations. | |
please answer the question based only on the information provided in following central bank documents published in various years. | |
The published year is mentioned as the metadata 'year' of each source document. | |
Please notice that content of a one document of a past year can updated by a new document from a recent year. | |
Always try to answer with latest information and mention the year which information extracted. | |
If you dont know the answer say you dont know, dont try to makeup answers. | |
Start the answer with code word Boardpac AI(QA): | |
<</SYS>> | |
[INST] | |
<DOCUMENTS> | |
{context} | |
</DOCUMENTS> | |
Question : {question}[/INST]""" | |
) | |
retrieval_qa_chain_prompt = PromptTemplate( | |
input_variables=["question", "context", "chat_history"], | |
template=retrieval_qa_template | |
) | |
self.qa_chain = ConversationalRetrievalChain.from_llm( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever = self.vectorstore.as_retriever(), | |
# retriever = self.vectorstore.as_retriever(search_kwargs={"k": target_source_chunks} | |
return_source_documents= True, | |
get_chat_history=lambda h : h, | |
combine_docs_chain_kwargs={"prompt": retrieval_qa_chain_prompt}, | |
verbose=True, | |
memory=memory, | |
) | |
print(f"\n> agent_chain created") | |
except Exception as e: | |
# logger.error(f"Answer retrieval failed with {e}") | |
print(f"> QAPipeline set_qa_chain_with_agent Error : {e}")#, icon=":books:") | |
return | |