File size: 4,254 Bytes
e539f46 b27c40f e539f46 22596e9 e539f46 af53cef e539f46 af53cef b27c40f e539f46 af53cef e539f46 af53cef e539f46 af53cef b27c40f af53cef 79c6fc5 af53cef 22596e9 b27c40f af53cef e539f46 79c6fc5 e539f46 79c6fc5 b27c40f e539f46 22596e9 e539f46 b27c40f e539f46 b27c40f 22596e9 e539f46 22596e9 e539f46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import os
import logging
from parser import parse_pdf, parse_text
from rag import RAG
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# App setup
app = FastAPI(title="NotebookLM-like Tool")
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Serve static/index.html
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
def serve_index():
return FileResponse("static/index.html")
# Load smaller Qwen model for Hugging Face CPU
logger.info("Loading Qwen 1.8B Chat model...")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-1_8B-Chat", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-1_8B-Chat", trust_remote_code=True)
logger.info("Model loaded.")
def generate_response(prompt: str) -> str:
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=300, do_sample=True)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# In-memory storage and RAG engine
rag = RAG()
documents = {}
class QueryRequest(BaseModel):
question: str
file_id: str
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
try:
logger.info(f"Received file: {file.filename}")
if file.content_type not in ["application/pdf", "text/plain"]:
raise HTTPException(status_code=400, detail="Only PDF or TXT allowed")
file_id = str(len(documents) + 1)
safe_name = file.filename.replace("/", "_").replace("\\", "_")
file_path = f"uploads/{file_id}_{safe_name}"
os.makedirs("uploads", exist_ok=True)
file_bytes = await file.read()
with open(file_path, "wb") as f:
f.write(file_bytes)
text = parse_pdf(file_bytes) if file.filename.endswith(".pdf") else parse_text(file_bytes)
if not text.strip():
text = "No extractable text found in this file."
documents[file_id] = {"file_path": file_path, "file_name": file.filename}
await rag.embed_document(file_id, text)
return {"file_id": file_id, "file_name": file.filename}
except Exception as e:
logger.error(f"Upload failed: {e}")
raise HTTPException(status_code=500, detail="Failed to process file")
@app.post("/query")
async def query_file(request: QueryRequest):
if request.file_id not in documents:
raise HTTPException(status_code=404, detail="File not found")
try:
context = await rag.query_document(request.question, request.file_id)
if not context:
return {"answer": "No relevant info found in the document."}
context_text = "\n".join(context)
prompt = f"Using the context, answer the question: {request.question}\n\nContext:\n{context_text}"
answer = generate_response(prompt)
return {"answer": answer}
except Exception as e:
logger.error(f"Query failed: {e}")
raise HTTPException(status_code=500, detail="Error answering question")
@app.post("/summarize")
async def summarize_file(file_id: str):
if file_id not in documents:
raise HTTPException(status_code=404, detail="File not found")
try:
with open(documents[file_id]["file_path"], "rb") as f:
content = f.read()
text = parse_pdf(content) if documents[file_id]["file_name"].endswith(".pdf") else parse_text(content)
if not text.strip():
raise ValueError("No text found to summarize.")
prompt = f"Summarize this text in 100 words or less:\n\n{text[:5000]}"
summary = generate_response(prompt)
return {"summary": summary}
except Exception as e:
logger.error(f"Summarization failed: {e}")
raise HTTPException(status_code=500, detail="Error generating summary")
|