Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
from dotenv import load_dotenv | |
from langchain.chains import RetrievalQA, ConversationalRetrievalChain | |
from langchain.embeddings import OllamaEmbeddings | |
from langchain.vectorstores.chroma import Chroma | |
from langchain.llms.ollama import Ollama | |
from langchain.chat_models import ChatOllama | |
from langchain.memory import ConversationBufferMemory | |
import chromadb | |
import os | |
# import argparse | |
import time | |
from flask import Flask, jsonify, Blueprint, request | |
from constants import CHROMA_SETTINGS | |
from prompt_verified import create_prompt_template | |
#if not load_dotenv(): | |
if not load_dotenv(".env"): | |
print("Could not load .env file or it is empty. Please check if it exists and is readable.") | |
exit(1) | |
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") | |
persist_directory = os.environ.get('PERSIST_DIRECTORY') | |
model_type = os.environ.get('MODEL_TYPE') | |
model_path = os.environ.get('MODEL_PATH') | |
model_n_ctx = os.environ.get('MODEL_N_CTX') | |
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8)) | |
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) | |
chat = Blueprint('chat', __name__) | |
def base(): | |
return jsonify( | |
{ | |
"status": "success", | |
"message": "Welcome to the chatbot system", | |
"responseCode": 200 | |
} | |
), 200 | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
input_key="question", | |
output_key='answer', | |
return_messages=True, | |
# human_prefix = "John Doe", | |
# ai_prefix = "AFEX-trade-bot", | |
) | |
def load_qa_chain(memory, prompt): | |
embeddings = OllamaEmbeddings(model=embeddings_model_name) | |
chroma_client = chromadb.PersistentClient( | |
settings=CHROMA_SETTINGS, | |
path=persist_directory | |
) | |
db = Chroma( | |
persist_directory=persist_directory, | |
embedding_function=embeddings, | |
client_settings=CHROMA_SETTINGS, | |
client=chroma_client | |
) | |
retriever = db.as_retriever( | |
search_kwargs={ | |
"k": target_source_chunks | |
} | |
) | |
# Prepare the LLM | |
match model_type: | |
case "ollama": | |
llm = Ollama( | |
model=model_path, | |
temperature=0.2 | |
) | |
case _default: | |
# raise exception if model_type is not supported | |
raise Exception(f"Model type {model_type} is not supported. Please choose one of the following: LlamaCpp, GPT4All") | |
qa = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents= True | |
) | |
qa = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
memory=memory, | |
return_source_documents=True, | |
combine_docs_chain_kwargs={ | |
'prompt': prompt, | |
}, | |
verbose=True, | |
) | |
return qa | |
def main(): | |
global memory | |
# try: | |
# request. | |
# -------------- TO-DO ------------------ # | |
# Add a constraint to raise an error if # | |
# the userID is not passed in the request # | |
# -------------- TO-DO ------------------ # | |
userID = str(request.args.get('userID')) | |
customer_name = str(request.args.get('customerName')) | |
request_data = request.get_json() | |
# print(request_data['query']) | |
query = request_data['query'] | |
# Interactive questions and answers | |
while True: | |
if query.strip() == "": | |
continue | |
start_time = time.time() | |
prompt = create_prompt_template(customerName=customer_name) | |
qa = load_qa_chain(prompt=prompt, memory=memory) | |
response = qa( | |
{ | |
"question": query, | |
} | |
) | |
end_time = time.time() | |
time_taken = round(end_time - start_time, 2) | |
# print(time_taken) | |
answer = str(response['answer']) | |
docs = response['source_documents'] | |
print(response) | |
# Print the relevant sources used for the answer | |
for document in docs: | |
print("\n> " + document.metadata["source"] + ":") | |
# print(document.page_content) | |
# return jsonify(res['result']) | |
return jsonify( | |
{ | |
"Query": query, | |
"UserID":userID, | |
"Time_taken": time_taken, | |
"reply": answer, | |
# "chain_response": response, | |
"customer_name": customer_name, | |
"responseCode": 200 | |
} | |
), 200 | |
# except Exception as e: | |
# print(e) | |
# return jsonify( | |
# { | |
# "Status": "An error occured", | |
# # "error": e, | |
# "responseCode": 201 | |
# } | |
# ), 201 | |
# Flask App setup | |
app = Flask(__name__) | |
app.register_blueprint(chat) | |
if __name__ == "__main__": | |
app.run(debug=True, host='0.0.0.0', port=8088) | |
# main() | |