Spaces:
Running
Running
from fastapi import FastAPI, UploadFile, File, APIRouter | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
import uvicorn | |
import tempfile | |
import os | |
print("CWD:", os.getcwd()) | |
print("Static files:", os.listdir("static") if os.path.exists("static") else "No static directory") | |
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader | |
from aimakerspace.openai_utils.prompts import ( | |
UserRolePrompt, | |
SystemRolePrompt, | |
) | |
from aimakerspace.vectordatabase import VectorDatabase | |
from aimakerspace.openai_utils.chatmodel import ChatOpenAI | |
app = FastAPI() | |
api_router = APIRouter(prefix="/api") | |
# Enable CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize global variables | |
text_splitter = CharacterTextSplitter() | |
vector_db = None | |
qa_pipeline = None | |
system_template = """\ | |
Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer.""" | |
system_role_prompt = SystemRolePrompt(system_template) | |
user_prompt_template = """\ | |
Context: | |
{context} | |
Question: | |
{question} | |
""" | |
user_role_prompt = UserRolePrompt(user_prompt_template) | |
class RetrievalAugmentedQAPipeline: | |
def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None: | |
self.llm = llm | |
self.vector_db_retriever = vector_db_retriever | |
async def arun_pipeline(self, user_query: str): | |
context_list = self.vector_db_retriever.search_by_text(user_query, k=4) | |
context_prompt = "" | |
for context in context_list: | |
context_prompt += context[0] + "\n" | |
# Format the prompts as strings | |
formatted_user_prompt = user_prompt_template.format( | |
question=user_query, | |
context=context_prompt | |
) | |
messages = [ | |
{"role": "system", "content": system_template}, | |
{"role": "user", "content": formatted_user_prompt} | |
] | |
# Use astream for async streaming | |
response_text = "" | |
async for chunk in self.llm.astream(messages): | |
response_text += chunk | |
return {"response": response_text, "context": context_list} | |
def process_file(file_path: str, file_name: str): | |
# Create appropriate loader | |
if file_name.lower().endswith('.pdf'): | |
loader = PDFLoader(file_path) | |
else: | |
loader = TextFileLoader(file_path) | |
# Load and process the documents | |
documents = loader.load_documents() | |
texts = text_splitter.split_texts(documents) | |
return texts | |
class Question(BaseModel): | |
query: str | |
async def upload_file(file: UploadFile = File(...)): | |
global vector_db, qa_pipeline | |
# Create a temporary file | |
suffix = f".{file.filename.split('.')[-1]}" | |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: | |
# Copy uploaded file content | |
content = await file.read() | |
temp_file.write(content) | |
temp_file.flush() | |
try: | |
# Process the file | |
texts = process_file(temp_file.name, file.filename) | |
# Create vector store | |
vector_db = VectorDatabase() | |
vector_db = await vector_db.abuild_from_list(texts) | |
# Initialize QA pipeline | |
chat_openai = ChatOpenAI() | |
qa_pipeline = RetrievalAugmentedQAPipeline( | |
vector_db_retriever=vector_db, | |
llm=chat_openai | |
) | |
return {"message": f"Successfully processed {file.filename}", "num_chunks": len(texts)} | |
finally: | |
# Clean up | |
os.unlink(temp_file.name) | |
async def ask_question(question: Question): | |
if not qa_pipeline: | |
return {"error": "Please upload a document first "} | |
result = await qa_pipeline.arun_pipeline(question.query) | |
return { | |
"answer": result["response"], | |
"context": [context[0] for context in result["context"]] | |
} | |
# Include API router | |
app.include_router(api_router) | |
# Mount static files last | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
app.mount("/", StaticFiles(directory="static", html=True), name="root") | |
print("CWD:", os.getcwd()) | |
print("Static files:", os.listdir("static") if os.path.exists("static") else "No static directory") | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) |