|
|
|
|
|
|
|
import os, random, logging, pickle, shutil |
|
from dotenv import load_dotenv, find_dotenv |
|
from typing import Optional |
|
from pydantic import BaseModel, Field |
|
|
|
from fastapi import FastAPI, HTTPException, File, UploadFile, status |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
try: |
|
load_dotenv(find_dotenv('env')) |
|
|
|
except Exception as e: |
|
pass |
|
|
|
from app.engine.processing import ( |
|
process_pdf, |
|
process_txt, |
|
index_data, |
|
empty_collection, |
|
vector_search, |
|
vector_search_raw |
|
) |
|
from app.rag.rag import rag_it |
|
|
|
from app.engine.logger import logger |
|
|
|
from app.settings import datadir, datadir2 |
|
|
|
if not os.path.exists(datadir): |
|
os.makedirs(datadir, exist_ok=True) |
|
|
|
if not os.path.exists(datadir2): |
|
os.makedirs(datadir2, exist_ok=True) |
|
|
|
os.makedirs(datadir, exist_ok=True) |
|
|
|
EXTENSIONS = ["pdf", "txt"] |
|
|
|
app = FastAPI() |
|
|
|
environment = os.getenv("ENVIRONMENT", "dev") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
def read_root(): |
|
logger("Title displayed on home page") |
|
return """ |
|
<html> |
|
<body> |
|
<h1>Welcome to MultiRAG, a RAG system designed by JP Bianchi!</h1> |
|
</body> |
|
</html> |
|
""" |
|
|
|
|
|
@app.get("/ping/") |
|
def ping(): |
|
""" Testing """ |
|
logger("Someone is pinging the server") |
|
return {"answer": str(int(random.random() * 100))} |
|
|
|
|
|
@app.delete("/erase_data/") |
|
def erase_data(): |
|
""" Erase all files in the data directory at the first level only, |
|
(in case we would like to use it for something else) |
|
but not the vector store or the parquet file. |
|
We can do it since the embeddings are in the parquet file already. |
|
""" |
|
if len(os.listdir(datadir)) == 0: |
|
logger("No data to erase") |
|
return {"message": "No data to erase"} |
|
|
|
|
|
for f in os.listdir(datadir): |
|
if f == '.DS_Store' or f.split('.')[-1].lower() in EXTENSIONS: |
|
print(f"Removing {f}") |
|
os.remove(os.path.join(datadir, f)) |
|
|
|
|
|
logger("All data has been erased") |
|
return {"message": "All data has been erased"} |
|
|
|
|
|
@app.delete("/empty_collection/") |
|
def delete_vectors(): |
|
""" Empty the collection in the vector store """ |
|
try: |
|
status = empty_collection() |
|
return {"message": f"Collection{'' if status else ' NOT'} erased!"} |
|
except Exception as e: |
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) |
|
|
|
|
|
@app.get("/list_files/") |
|
def list_files(): |
|
""" List all files in the data directory """ |
|
print("Listing files") |
|
files = os.listdir(datadir) |
|
logger(f"Files in data directory: {files}") |
|
return {"files": files} |
|
|
|
|
|
@app.post("/upload/") |
|
|
|
async def upload_file(file: UploadFile = File(...)): |
|
""" Uploads a file in data directory, for later indexing """ |
|
try: |
|
filepath = os.path.join(datadir, file.filename) |
|
logger(f"Fiename detected: {file.filename}") |
|
if os.path.exists(filepath): |
|
logger(f"File {file.filename} already exists: no processing done") |
|
return {"message": f"File {file.filename} already exists: no processing done"} |
|
|
|
else: |
|
logger(f"Receiving file: {file.filename}") |
|
contents = await file.read() |
|
logger(f"File reception complete!") |
|
|
|
except Exception as e: |
|
logger(f"Error during file upload: {str(e)}") |
|
return {"message": f"Error during file upload: {str(e)}"} |
|
|
|
if file.filename.endswith('.pdf'): |
|
|
|
|
|
with open(filepath, 'wb') as f: |
|
f.write(contents) |
|
|
|
|
|
filepath2 = os.path.join(datadir2, file.filename) |
|
with open(filepath2, 'wb') as f: |
|
f.write(contents) |
|
|
|
try: |
|
logger(f"Starting to process {file.filename}") |
|
new_content = process_pdf(filepath) |
|
success = {"message": f"Successfully uploaded {file.filename}"} |
|
success.update(new_content) |
|
return success |
|
|
|
except Exception as e: |
|
return {"message": f"Failed to extract text from PDF: {str(e)}"} |
|
|
|
elif file.filename.endswith('.txt'): |
|
|
|
with open(filepath, 'wb') as f: |
|
f.write(contents) |
|
|
|
filepath2 = os.path.join(datadir2, file.filename) |
|
with open(filepath2, 'wb') as f: |
|
f.write(contents) |
|
|
|
try: |
|
logger(f"Reading {file.filename}") |
|
new_content = process_txt(filepath) |
|
success = {"message": f"Successfully uploaded {file.filename}"} |
|
success.update(new_content) |
|
return success |
|
|
|
except Exception as e: |
|
return {"message": f"Failed to extract text from TXT: {str(e)}"} |
|
|
|
else: |
|
return {"message": "Only PDF & txt files are accepted"} |
|
|
|
|
|
@app.post("/create_index/") |
|
async def create_index(): |
|
""" Create an index for the uploaded files """ |
|
|
|
logger("Creating index for uploaded files") |
|
try: |
|
msg = index_data() |
|
return {"message": msg} |
|
except Exception as e: |
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) |
|
|
|
|
|
class Question(BaseModel): |
|
question: str |
|
|
|
@app.post("/ask/") |
|
async def hybrid_search(question: Question): |
|
logger(f"Processing question: {question.question}") |
|
try: |
|
search_results = vector_search(question.question) |
|
logger(f"Answer: {search_results}") |
|
return {"answer": search_results} |
|
except Exception as e: |
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) |
|
|
|
|
|
@app.post("/ragit/") |
|
async def ragit(question: Question): |
|
logger(f"Processing question: {question.question}") |
|
try: |
|
search_results = vector_search_raw(question.question) |
|
logger(f"Search results generated: {search_results}") |
|
|
|
answer = rag_it(question.question, search_results) |
|
|
|
logger(f"Answer: {answer}") |
|
return {"answer": answer} |
|
except Exception as e: |
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) |
|
|
|
|
|
if __name__ == '__main__': |
|
import uvicorn |
|
from os import getenv |
|
port = int(getenv("PORT", 80)) |
|
print(f"Starting server on port {port}") |
|
reload = True if environment == "dev" else False |
|
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=reload) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|