hfs3-advanced / app.py
njhaveri's picture
fix
ee6d6f4
from fastapi import FastAPI, UploadFile, File, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import uvicorn
import tempfile
import os
print("CWD:", os.getcwd())
print("Static files:", os.listdir("static") if os.path.exists("static") else "No static directory")
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
from aimakerspace.openai_utils.prompts import (
UserRolePrompt,
SystemRolePrompt,
)
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.chatmodel import ChatOpenAI
app = FastAPI()
api_router = APIRouter(prefix="/api")
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize global variables
text_splitter = CharacterTextSplitter()
vector_db = None
qa_pipeline = None
system_template = """\
Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
system_role_prompt = SystemRolePrompt(system_template)
user_prompt_template = """\
Context:
{context}
Question:
{question}
"""
user_role_prompt = UserRolePrompt(user_prompt_template)
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
self.llm = llm
self.vector_db_retriever = vector_db_retriever
async def arun_pipeline(self, user_query: str):
context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
context_prompt = ""
for context in context_list:
context_prompt += context[0] + "\n"
# Format the prompts as strings
formatted_user_prompt = user_prompt_template.format(
question=user_query,
context=context_prompt
)
messages = [
{"role": "system", "content": system_template},
{"role": "user", "content": formatted_user_prompt}
]
# Use astream for async streaming
response_text = ""
async for chunk in self.llm.astream(messages):
response_text += chunk
return {"response": response_text, "context": context_list}
def process_file(file_path: str, file_name: str):
# Create appropriate loader
if file_name.lower().endswith('.pdf'):
loader = PDFLoader(file_path)
else:
loader = TextFileLoader(file_path)
# Load and process the documents
documents = loader.load_documents()
texts = text_splitter.split_texts(documents)
return texts
class Question(BaseModel):
query: str
@api_router.post("/upload")
async def upload_file(file: UploadFile = File(...)):
global vector_db, qa_pipeline
# Create a temporary file
suffix = f".{file.filename.split('.')[-1]}"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
# Copy uploaded file content
content = await file.read()
temp_file.write(content)
temp_file.flush()
try:
# Process the file
texts = process_file(temp_file.name, file.filename)
# Create vector store
vector_db = VectorDatabase()
vector_db = await vector_db.abuild_from_list(texts)
# Initialize QA pipeline
chat_openai = ChatOpenAI()
qa_pipeline = RetrievalAugmentedQAPipeline(
vector_db_retriever=vector_db,
llm=chat_openai
)
return {"message": f"Successfully processed {file.filename}", "num_chunks": len(texts)}
finally:
# Clean up
os.unlink(temp_file.name)
@api_router.post("/ask")
async def ask_question(question: Question):
if not qa_pipeline:
return {"error": "Please upload a document first "}
result = await qa_pipeline.arun_pipeline(question.query)
return {
"answer": result["response"],
"context": [context[0] for context in result["context"]]
}
# Include API router
app.include_router(api_router)
# Mount static files last
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/", StaticFiles(directory="static", html=True), name="root")
print("CWD:", os.getcwd())
print("Static files:", os.listdir("static") if os.path.exists("static") else "No static directory")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)