|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import FileResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from pydantic import BaseModel |
|
import os |
|
import logging |
|
from parser import parse_pdf, parse_text |
|
from rag import RAG |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI(title="NotebookLM-like Tool") |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
@app.get("/") |
|
def serve_index(): |
|
return FileResponse("static/index.html") |
|
|
|
|
|
logger.info("Loading Qwen 1.8B Chat model...") |
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-1_8B-Chat", trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-1_8B-Chat", trust_remote_code=True) |
|
logger.info("Model loaded.") |
|
|
|
def generate_response(prompt: str) -> str: |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
outputs = model.generate(**inputs, max_new_tokens=300, do_sample=True) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
rag = RAG() |
|
documents = {} |
|
|
|
class QueryRequest(BaseModel): |
|
question: str |
|
file_id: str |
|
|
|
@app.post("/upload") |
|
async def upload_file(file: UploadFile = File(...)): |
|
try: |
|
logger.info(f"Received file: {file.filename}") |
|
if file.content_type not in ["application/pdf", "text/plain"]: |
|
raise HTTPException(status_code=400, detail="Only PDF or TXT allowed") |
|
|
|
file_id = str(len(documents) + 1) |
|
safe_name = file.filename.replace("/", "_").replace("\\", "_") |
|
file_path = f"uploads/{file_id}_{safe_name}" |
|
|
|
os.makedirs("uploads", exist_ok=True) |
|
file_bytes = await file.read() |
|
|
|
with open(file_path, "wb") as f: |
|
f.write(file_bytes) |
|
|
|
text = parse_pdf(file_bytes) if file.filename.endswith(".pdf") else parse_text(file_bytes) |
|
if not text.strip(): |
|
text = "No extractable text found in this file." |
|
|
|
documents[file_id] = {"file_path": file_path, "file_name": file.filename} |
|
await rag.embed_document(file_id, text) |
|
|
|
return {"file_id": file_id, "file_name": file.filename} |
|
except Exception as e: |
|
logger.error(f"Upload failed: {e}") |
|
raise HTTPException(status_code=500, detail="Failed to process file") |
|
|
|
@app.post("/query") |
|
async def query_file(request: QueryRequest): |
|
if request.file_id not in documents: |
|
raise HTTPException(status_code=404, detail="File not found") |
|
try: |
|
context = await rag.query_document(request.question, request.file_id) |
|
if not context: |
|
return {"answer": "No relevant info found in the document."} |
|
|
|
context_text = "\n".join(context) |
|
prompt = f"Using the context, answer the question: {request.question}\n\nContext:\n{context_text}" |
|
answer = generate_response(prompt) |
|
return {"answer": answer} |
|
except Exception as e: |
|
logger.error(f"Query failed: {e}") |
|
raise HTTPException(status_code=500, detail="Error answering question") |
|
|
|
@app.post("/summarize") |
|
async def summarize_file(file_id: str): |
|
if file_id not in documents: |
|
raise HTTPException(status_code=404, detail="File not found") |
|
try: |
|
with open(documents[file_id]["file_path"], "rb") as f: |
|
content = f.read() |
|
text = parse_pdf(content) if documents[file_id]["file_name"].endswith(".pdf") else parse_text(content) |
|
|
|
if not text.strip(): |
|
raise ValueError("No text found to summarize.") |
|
|
|
prompt = f"Summarize this text in 100 words or less:\n\n{text[:5000]}" |
|
summary = generate_response(prompt) |
|
return {"summary": summary} |
|
except Exception as e: |
|
logger.error(f"Summarization failed: {e}") |
|
raise HTTPException(status_code=500, detail="Error generating summary") |
|
|