from fastapi import FastAPI, UploadFile, File, APIRouter from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from pydantic import BaseModel import uvicorn import tempfile import os print("CWD:", os.getcwd()) print("Static files:", os.listdir("static") if os.path.exists("static") else "No static directory") from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader from aimakerspace.openai_utils.prompts import ( UserRolePrompt, SystemRolePrompt, ) from aimakerspace.vectordatabase import VectorDatabase from aimakerspace.openai_utils.chatmodel import ChatOpenAI app = FastAPI() api_router = APIRouter(prefix="/api") # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize global variables text_splitter = CharacterTextSplitter() vector_db = None qa_pipeline = None system_template = """\ Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer.""" system_role_prompt = SystemRolePrompt(system_template) user_prompt_template = """\ Context: {context} Question: {question} """ user_role_prompt = UserRolePrompt(user_prompt_template) class RetrievalAugmentedQAPipeline: def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None: self.llm = llm self.vector_db_retriever = vector_db_retriever async def arun_pipeline(self, user_query: str): context_list = self.vector_db_retriever.search_by_text(user_query, k=4) context_prompt = "" for context in context_list: context_prompt += context[0] + "\n" # Format the prompts as strings formatted_user_prompt = user_prompt_template.format( question=user_query, context=context_prompt ) messages = [ {"role": "system", "content": system_template}, {"role": "user", "content": formatted_user_prompt} ] # Use astream for async streaming response_text = "" async for chunk in self.llm.astream(messages): response_text += chunk return {"response": response_text, "context": context_list} def process_file(file_path: str, file_name: str): # Create appropriate loader if file_name.lower().endswith('.pdf'): loader = PDFLoader(file_path) else: loader = TextFileLoader(file_path) # Load and process the documents documents = loader.load_documents() texts = text_splitter.split_texts(documents) return texts class Question(BaseModel): query: str @api_router.post("/upload") async def upload_file(file: UploadFile = File(...)): global vector_db, qa_pipeline # Create a temporary file suffix = f".{file.filename.split('.')[-1]}" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: # Copy uploaded file content content = await file.read() temp_file.write(content) temp_file.flush() try: # Process the file texts = process_file(temp_file.name, file.filename) # Create vector store vector_db = VectorDatabase() vector_db = await vector_db.abuild_from_list(texts) # Initialize QA pipeline chat_openai = ChatOpenAI() qa_pipeline = RetrievalAugmentedQAPipeline( vector_db_retriever=vector_db, llm=chat_openai ) return {"message": f"Successfully processed {file.filename}", "num_chunks": len(texts)} finally: # Clean up os.unlink(temp_file.name) @api_router.post("/ask") async def ask_question(question: Question): if not qa_pipeline: return {"error": "Please upload a document first "} result = await qa_pipeline.arun_pipeline(question.query) return { "answer": result["response"], "context": [context[0] for context in result["context"]] } # Include API router app.include_router(api_router) # Mount static files last app.mount("/static", StaticFiles(directory="static"), name="static") app.mount("/", StaticFiles(directory="static", html=True), name="root") print("CWD:", os.getcwd()) print("Static files:", os.listdir("static") if os.path.exists("static") else "No static directory") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)