Spaces:
Sleeping
Sleeping
File size: 4,611 Bytes
868e135 fdc14d4 868e135 512bf9c fdc14d4 868e135 fdc14d4 ee6d6f4 fdc14d4 ee6d6f4 868e135 ee6d6f4 868e135 fdc14d4 868e135 fdc14d4 868e135 fdc14d4 868e135 fdc14d4 868e135 fdc14d4 868e135 fdc14d4 868e135 fdc14d4 868e135 fdc14d4 868e135 fdc14d4 868e135 908b273 fdc14d4 868e135 fdc14d4 868e135 fdc14d4 868e135 f1f1db3 fdc14d4 e09b558 868e135 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from fastapi import FastAPI, UploadFile, File, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import uvicorn
import tempfile
import os
print("CWD:", os.getcwd())
print("Static files:", os.listdir("static") if os.path.exists("static") else "No static directory")
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
from aimakerspace.openai_utils.prompts import (
UserRolePrompt,
SystemRolePrompt,
)
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.chatmodel import ChatOpenAI
app = FastAPI()
api_router = APIRouter(prefix="/api")
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize global variables
text_splitter = CharacterTextSplitter()
vector_db = None
qa_pipeline = None
system_template = """\
Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
system_role_prompt = SystemRolePrompt(system_template)
user_prompt_template = """\
Context:
{context}
Question:
{question}
"""
user_role_prompt = UserRolePrompt(user_prompt_template)
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
self.llm = llm
self.vector_db_retriever = vector_db_retriever
async def arun_pipeline(self, user_query: str):
context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
context_prompt = ""
for context in context_list:
context_prompt += context[0] + "\n"
# Format the prompts as strings
formatted_user_prompt = user_prompt_template.format(
question=user_query,
context=context_prompt
)
messages = [
{"role": "system", "content": system_template},
{"role": "user", "content": formatted_user_prompt}
]
# Use astream for async streaming
response_text = ""
async for chunk in self.llm.astream(messages):
response_text += chunk
return {"response": response_text, "context": context_list}
def process_file(file_path: str, file_name: str):
# Create appropriate loader
if file_name.lower().endswith('.pdf'):
loader = PDFLoader(file_path)
else:
loader = TextFileLoader(file_path)
# Load and process the documents
documents = loader.load_documents()
texts = text_splitter.split_texts(documents)
return texts
class Question(BaseModel):
query: str
@api_router.post("/upload")
async def upload_file(file: UploadFile = File(...)):
global vector_db, qa_pipeline
# Create a temporary file
suffix = f".{file.filename.split('.')[-1]}"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
# Copy uploaded file content
content = await file.read()
temp_file.write(content)
temp_file.flush()
try:
# Process the file
texts = process_file(temp_file.name, file.filename)
# Create vector store
vector_db = VectorDatabase()
vector_db = await vector_db.abuild_from_list(texts)
# Initialize QA pipeline
chat_openai = ChatOpenAI()
qa_pipeline = RetrievalAugmentedQAPipeline(
vector_db_retriever=vector_db,
llm=chat_openai
)
return {"message": f"Successfully processed {file.filename}", "num_chunks": len(texts)}
finally:
# Clean up
os.unlink(temp_file.name)
@api_router.post("/ask")
async def ask_question(question: Question):
if not qa_pipeline:
return {"error": "Please upload a document first "}
result = await qa_pipeline.arun_pipeline(question.query)
return {
"answer": result["response"],
"context": [context[0] for context in result["context"]]
}
# Include API router
app.include_router(api_router)
# Mount static files last
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/", StaticFiles(directory="static", html=True), name="root")
print("CWD:", os.getcwd())
print("Static files:", os.listdir("static") if os.path.exists("static") else "No static directory")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) |