|
import logging |
|
import os |
|
import shutil |
|
import subprocess |
|
|
|
import torch |
|
from auto_gptq import AutoGPTQForCausalLM |
|
from flask import Flask, jsonify, request |
|
from langchain.chains import RetrievalQA |
|
from langchain.embeddings import HuggingFaceInstructEmbeddings |
|
|
|
|
|
from langchain.llms import HuggingFacePipeline |
|
from run_localGPT import load_model |
|
|
|
|
|
from langchain.vectorstores import Chroma |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
GenerationConfig, |
|
LlamaForCausalLM, |
|
LlamaTokenizer, |
|
pipeline, |
|
) |
|
from werkzeug.utils import secure_filename |
|
|
|
from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME |
|
|
|
DEVICE_TYPE = "cuda" if torch.cuda.is_available() else "cpu" |
|
SHOW_SOURCES = True |
|
logging.info(f"Running on: {DEVICE_TYPE}") |
|
logging.info(f"Display Source Documents set to: {SHOW_SOURCES}") |
|
|
|
EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE}) |
|
|
|
|
|
|
|
if os.path.exists(PERSIST_DIRECTORY): |
|
try: |
|
shutil.rmtree(PERSIST_DIRECTORY) |
|
except OSError as e: |
|
print(f"Error: {e.filename} - {e.strerror}.") |
|
else: |
|
print("The directory does not exist") |
|
|
|
run_langest_commands = ["python", "ingest.py"] |
|
if DEVICE_TYPE == "cpu": |
|
run_langest_commands.append("--device_type") |
|
run_langest_commands.append(DEVICE_TYPE) |
|
|
|
result = subprocess.run(run_langest_commands, capture_output=True) |
|
if result.returncode != 0: |
|
raise FileNotFoundError( |
|
"No files were found inside SOURCE_DOCUMENTS, please put a starter file inside before starting the API!" |
|
) |
|
|
|
|
|
DB = Chroma( |
|
persist_directory=PERSIST_DIRECTORY, |
|
embedding_function=EMBEDDINGS, |
|
client_settings=CHROMA_SETTINGS, |
|
) |
|
|
|
RETRIEVER = DB.as_retriever() |
|
|
|
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME) |
|
|
|
QA = RetrievalQA.from_chain_type( |
|
llm=LLM, chain_type="stuff", retriever=RETRIEVER, return_source_documents=SHOW_SOURCES |
|
) |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
@app.route("/api/delete_source", methods=["GET"]) |
|
def delete_source_route(): |
|
folder_name = "SOURCE_DOCUMENTS" |
|
|
|
if os.path.exists(folder_name): |
|
shutil.rmtree(folder_name) |
|
|
|
os.makedirs(folder_name) |
|
|
|
return jsonify({"message": f"Folder '{folder_name}' successfully deleted and recreated."}) |
|
|
|
|
|
@app.route("/api/save_document", methods=["GET", "POST"]) |
|
def save_document_route(): |
|
if "document" not in request.files: |
|
return "No document part", 400 |
|
file = request.files["document"] |
|
if file.filename == "": |
|
return "No selected file", 400 |
|
if file: |
|
filename = secure_filename(file.filename) |
|
folder_path = "SOURCE_DOCUMENTS" |
|
if not os.path.exists(folder_path): |
|
os.makedirs(folder_path) |
|
file_path = os.path.join(folder_path, filename) |
|
file.save(file_path) |
|
return "File saved successfully", 200 |
|
|
|
|
|
@app.route("/api/run_ingest", methods=["GET"]) |
|
def run_ingest_route(): |
|
global DB |
|
global RETRIEVER |
|
global QA |
|
try: |
|
if os.path.exists(PERSIST_DIRECTORY): |
|
try: |
|
shutil.rmtree(PERSIST_DIRECTORY) |
|
except OSError as e: |
|
print(f"Error: {e.filename} - {e.strerror}.") |
|
else: |
|
print("The directory does not exist") |
|
|
|
run_langest_commands = ["python", "ingest.py"] |
|
if DEVICE_TYPE == "cpu": |
|
run_langest_commands.append("--device_type") |
|
run_langest_commands.append(DEVICE_TYPE) |
|
|
|
result = subprocess.run(run_langest_commands, capture_output=True) |
|
if result.returncode != 0: |
|
return "Script execution failed: {}".format(result.stderr.decode("utf-8")), 500 |
|
|
|
DB = Chroma( |
|
persist_directory=PERSIST_DIRECTORY, |
|
embedding_function=EMBEDDINGS, |
|
client_settings=CHROMA_SETTINGS, |
|
) |
|
RETRIEVER = DB.as_retriever() |
|
|
|
QA = RetrievalQA.from_chain_type( |
|
llm=LLM, chain_type="stuff", retriever=RETRIEVER, return_source_documents=SHOW_SOURCES |
|
) |
|
return "Script executed successfully: {}".format(result.stdout.decode("utf-8")), 200 |
|
except Exception as e: |
|
return f"Error occurred: {str(e)}", 500 |
|
|
|
|
|
@app.route("/api/prompt_route", methods=["GET", "POST"]) |
|
def prompt_route(): |
|
global QA |
|
user_prompt = request.form.get("user_prompt") |
|
if user_prompt: |
|
|
|
|
|
res = QA(user_prompt) |
|
answer, docs = res["result"], res["source_documents"] |
|
|
|
prompt_response_dict = { |
|
"Prompt": user_prompt, |
|
"Answer": answer, |
|
} |
|
|
|
prompt_response_dict["Sources"] = [] |
|
for document in docs: |
|
prompt_response_dict["Sources"].append( |
|
(os.path.basename(str(document.metadata["source"])), str(document.page_content)) |
|
) |
|
|
|
return jsonify(prompt_response_dict), 200 |
|
else: |
|
return "No user prompt received", 400 |
|
|
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", level=logging.INFO |
|
) |
|
app.run(debug=False, port=5110) |
|
|