boardpac_chat_app_test / qaPipeline.py
Boardpac/theekshanas
upload files again
39de480
raw history blame
No virus
3.87 kB
"""
Python Backend API to chat with private data
08/14/2023
D.M. Theekshana Samaradiwakara
"""
import os
import time
from dotenv import load_dotenv
from langchain.chains import RetrievalQA
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.llms import GPT4All
from langchain.llms import HuggingFaceHub
from langchain.chat_models import ChatOpenAI
# from langchain.retrievers.self_query.base import SelfQueryRetriever
# from langchain.chains.query_constructor.base import AttributeInfo
# from chromaDb import load_store
from faissDb import load_FAISS_store
load_dotenv()
#gpt4 all model
gpt4all_model_path = os.environ.get('GPT4ALL_MODEL_PATH')
model_n_ctx = os.environ.get('MODEL_N_CTX')
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8))
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
openai_api_key = os.environ.get('OPENAI_API_KEY')
verbose = os.environ.get('VERBOSE')
# activate/deactivate the streaming StdOut callback for LLMs
callbacks = [StreamingStdOutCallbackHandler()]
class QAPipeline:
def __init__(self):
self.llm_name = None
self.llm = None
self.dataset_name = None
self.vectorstore = None
self.qa_chain = None
def run(self,query, model, dataset):
if (self.llm_name != model) or (self.dataset_name != dataset) or (self.qa_chain == None):
self.set_model(model)
self.set_vectorstore(dataset)
self.set_qa_chain()
# Get the answer from the chain
start = time.time()
res = self.qa_chain(query)
# answer, docs = res['result'],res['source_documents']
end = time.time()
# Print the result
print("\n\n> Question:")
print(query)
print(f"\n> Answer (took {round(end - start, 2)} s.):")
print( res)
return res
def set_model(self,model_type):
if model_type != self.llm_name:
match model_type:
case "gpt4all":
# self.llm = GPT4All(model=gpt4all_model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose)
self.llm = GPT4All(model=gpt4all_model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=verbose)
# self.llm = HuggingFaceHub(repo_id="nomic-ai/gpt4all-j", model_kwargs={"temperature":0.001, "max_length":1024})
case "google/flan-t5-xxl":
self.llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.001, "max_length":1024})
case "tiiuae/falcon-7b-instruct":
self.llm = HuggingFaceHub(repo_id=model_type, model_kwargs={"temperature":0.001, "max_length":1024})
case "openai":
self.llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
case _default:
# raise exception if model_type is not supported
raise Exception(f"Model type {model_type} is not supported. Please choose a valid one")
self.llm_name = model_type
def set_vectorstore(self, dataset):
if dataset != self.dataset_name:
# self.vectorstore = load_store(dataset)
self.vectorstore = load_FAISS_store()
print("\n\n> vectorstore loaded:")
self.dataset_name = dataset
def set_qa_chain(self):
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever = self.vectorstore.as_retriever(),
# retriever = self.vectorstore.as_retriever(search_kwargs={"k": target_source_chunks}
return_source_documents= True
)