Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| from dotenv import load_dotenv | |
| from langchain.chains import RetrievalQA, ConversationalRetrievalChain | |
| from langchain.embeddings import OllamaEmbeddings | |
| from langchain.vectorstores.chroma import Chroma | |
| from langchain.llms.ollama import Ollama | |
| from langchain.chat_models import ChatOllama | |
| from langchain.memory import ConversationBufferMemory | |
| import chromadb | |
| import os | |
| # import argparse | |
| import time | |
| from flask import Flask, jsonify, Blueprint, request | |
| from constants import CHROMA_SETTINGS | |
| from prompt_verified import create_prompt_template | |
| #if not load_dotenv(): | |
| if not load_dotenv(".env"): | |
| print("Could not load .env file or it is empty. Please check if it exists and is readable.") | |
| exit(1) | |
| embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") | |
| persist_directory = os.environ.get('PERSIST_DIRECTORY') | |
| model_type = os.environ.get('MODEL_TYPE') | |
| model_path = os.environ.get('MODEL_PATH') | |
| model_n_ctx = os.environ.get('MODEL_N_CTX') | |
| model_n_batch = int(os.environ.get('MODEL_N_BATCH',8)) | |
| target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) | |
| chat = Blueprint('chat', __name__) | |
| def base(): | |
| return jsonify( | |
| { | |
| "status": "success", | |
| "message": "Welcome to the chatbot system", | |
| "responseCode": 200 | |
| } | |
| ), 200 | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| input_key="question", | |
| output_key='answer', | |
| return_messages=True, | |
| # human_prefix = "John Doe", | |
| # ai_prefix = "AFEX-trade-bot", | |
| ) | |
| def load_qa_chain(memory, prompt): | |
| embeddings = OllamaEmbeddings(model=embeddings_model_name) | |
| chroma_client = chromadb.PersistentClient( | |
| settings=CHROMA_SETTINGS, | |
| path=persist_directory | |
| ) | |
| db = Chroma( | |
| persist_directory=persist_directory, | |
| embedding_function=embeddings, | |
| client_settings=CHROMA_SETTINGS, | |
| client=chroma_client | |
| ) | |
| retriever = db.as_retriever( | |
| search_kwargs={ | |
| "k": target_source_chunks | |
| } | |
| ) | |
| # Prepare the LLM | |
| match model_type: | |
| case "ollama": | |
| llm = Ollama( | |
| model=model_path, | |
| temperature=0.2 | |
| ) | |
| case _default: | |
| # raise exception if model_type is not supported | |
| raise Exception(f"Model type {model_type} is not supported. Please choose one of the following: LlamaCpp, GPT4All") | |
| qa = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents= True | |
| ) | |
| qa = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| memory=memory, | |
| return_source_documents=True, | |
| combine_docs_chain_kwargs={ | |
| 'prompt': prompt, | |
| }, | |
| verbose=True, | |
| ) | |
| return qa | |
| def main(): | |
| global memory | |
| # try: | |
| # request. | |
| # -------------- TO-DO ------------------ # | |
| # Add a constraint to raise an error if # | |
| # the userID is not passed in the request # | |
| # -------------- TO-DO ------------------ # | |
| userID = str(request.args.get('userID')) | |
| customer_name = str(request.args.get('customerName')) | |
| request_data = request.get_json() | |
| # print(request_data['query']) | |
| query = request_data['query'] | |
| # Interactive questions and answers | |
| while True: | |
| if query.strip() == "": | |
| continue | |
| start_time = time.time() | |
| prompt = create_prompt_template(customerName=customer_name) | |
| qa = load_qa_chain(prompt=prompt, memory=memory) | |
| response = qa( | |
| { | |
| "question": query, | |
| } | |
| ) | |
| end_time = time.time() | |
| time_taken = round(end_time - start_time, 2) | |
| # print(time_taken) | |
| answer = str(response['answer']) | |
| docs = response['source_documents'] | |
| print(response) | |
| # Print the relevant sources used for the answer | |
| for document in docs: | |
| print("\n> " + document.metadata["source"] + ":") | |
| # print(document.page_content) | |
| # return jsonify(res['result']) | |
| return jsonify( | |
| { | |
| "Query": query, | |
| "UserID":userID, | |
| "Time_taken": time_taken, | |
| "reply": answer, | |
| # "chain_response": response, | |
| "customer_name": customer_name, | |
| "responseCode": 200 | |
| } | |
| ), 200 | |
| # except Exception as e: | |
| # print(e) | |
| # return jsonify( | |
| # { | |
| # "Status": "An error occured", | |
| # # "error": e, | |
| # "responseCode": 201 | |
| # } | |
| # ), 201 | |
| # Flask App setup | |
| app = Flask(__name__) | |
| app.register_blueprint(chat) | |
| if __name__ == "__main__": | |
| app.run(debug=True, host='0.0.0.0', port=8088) | |
| # main() | |