File size: 4,611 Bytes
868e135
 
 
 
 
 
fdc14d4
868e135
512bf9c
 
 
fdc14d4
 
 
 
 
 
 
868e135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc14d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee6d6f4
 
 
 
 
fdc14d4
ee6d6f4
 
 
 
 
 
868e135
ee6d6f4
868e135
fdc14d4
868e135
fdc14d4
868e135
 
 
 
 
 
 
 
 
 
 
fdc14d4
868e135
 
fdc14d4
868e135
 
 
fdc14d4
868e135
 
fdc14d4
868e135
 
 
 
fdc14d4
 
868e135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc14d4
868e135
 
fdc14d4
868e135
 
 
908b273
fdc14d4
868e135
 
 
 
 
fdc14d4
868e135
 
fdc14d4
868e135
f1f1db3
 
fdc14d4
e09b558
 
 
868e135
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from fastapi import FastAPI, UploadFile, File, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import uvicorn
import tempfile
import os

print("CWD:", os.getcwd())
print("Static files:", os.listdir("static") if os.path.exists("static") else "No static directory")

from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
from aimakerspace.openai_utils.prompts import (
    UserRolePrompt,
    SystemRolePrompt,
)
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.chatmodel import ChatOpenAI

app = FastAPI()
api_router = APIRouter(prefix="/api")

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize global variables
text_splitter = CharacterTextSplitter()
vector_db = None
qa_pipeline = None

system_template = """\
Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
system_role_prompt = SystemRolePrompt(system_template)

user_prompt_template = """\
Context:
{context}

Question:
{question}
"""
user_role_prompt = UserRolePrompt(user_prompt_template)

class RetrievalAugmentedQAPipeline:
    def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
        self.llm = llm
        self.vector_db_retriever = vector_db_retriever

    async def arun_pipeline(self, user_query: str):
        context_list = self.vector_db_retriever.search_by_text(user_query, k=4)

        context_prompt = ""
        for context in context_list:
            context_prompt += context[0] + "\n"

        # Format the prompts as strings
        formatted_user_prompt = user_prompt_template.format(
            question=user_query,
            context=context_prompt
        )

        messages = [
            {"role": "system", "content": system_template},
            {"role": "user", "content": formatted_user_prompt}
        ]

        # Use astream for async streaming
        response_text = ""
        async for chunk in self.llm.astream(messages):
            response_text += chunk

        return {"response": response_text, "context": context_list}

def process_file(file_path: str, file_name: str):
    # Create appropriate loader
    if file_name.lower().endswith('.pdf'):
        loader = PDFLoader(file_path)
    else:
        loader = TextFileLoader(file_path)
        
    # Load and process the documents
    documents = loader.load_documents()
    texts = text_splitter.split_texts(documents)
    return texts

class Question(BaseModel):
    query: str

@api_router.post("/upload")
async def upload_file(file: UploadFile = File(...)):
    global vector_db, qa_pipeline
    
    # Create a temporary file
    suffix = f".{file.filename.split('.')[-1]}"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
        # Copy uploaded file content
        content = await file.read()
        temp_file.write(content)
        temp_file.flush()
        
        try:
            # Process the file
            texts = process_file(temp_file.name, file.filename)
            
            # Create vector store
            vector_db = VectorDatabase()
            vector_db = await vector_db.abuild_from_list(texts)
            
            # Initialize QA pipeline
            chat_openai = ChatOpenAI()
            qa_pipeline = RetrievalAugmentedQAPipeline(
                vector_db_retriever=vector_db,
                llm=chat_openai
            )
            
            return {"message": f"Successfully processed {file.filename}", "num_chunks": len(texts)}
        finally:
            # Clean up
            os.unlink(temp_file.name)

@api_router.post("/ask")
async def ask_question(question: Question):
    if not qa_pipeline:
        return {"error": "Please upload a document first  "}
    
    result = await qa_pipeline.arun_pipeline(question.query)
    return {
        "answer": result["response"],
        "context": [context[0] for context in result["context"]]
    }

# Include API router
app.include_router(api_router)

# Mount static files last
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/", StaticFiles(directory="static", html=True), name="root")

print("CWD:", os.getcwd())
print("Static files:", os.listdir("static") if os.path.exists("static") else "No static directory")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)