from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel import os import logging from parser import parse_pdf, parse_text from rag import RAG from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # App setup app = FastAPI(title="NotebookLM-like Tool") # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], # Adjust for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Serve static/index.html app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/") def serve_index(): return FileResponse("static/index.html") # Load smaller Qwen model for Hugging Face CPU logger.info("Loading Qwen 1.8B Chat model...") tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-1_8B-Chat", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-1_8B-Chat", trust_remote_code=True) logger.info("Model loaded.") def generate_response(prompt: str) -> str: inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate(**inputs, max_new_tokens=300, do_sample=True) return tokenizer.decode(outputs[0], skip_special_tokens=True) # In-memory storage and RAG engine rag = RAG() documents = {} class QueryRequest(BaseModel): question: str file_id: str @app.post("/upload") async def upload_file(file: UploadFile = File(...)): try: logger.info(f"Received file: {file.filename}") if file.content_type not in ["application/pdf", "text/plain"]: raise HTTPException(status_code=400, detail="Only PDF or TXT allowed") file_id = str(len(documents) + 1) safe_name = file.filename.replace("/", "_").replace("\\", "_") file_path = f"uploads/{file_id}_{safe_name}" os.makedirs("uploads", exist_ok=True) file_bytes = await file.read() with open(file_path, "wb") as f: f.write(file_bytes) text = parse_pdf(file_bytes) if file.filename.endswith(".pdf") else parse_text(file_bytes) if not text.strip(): text = "No extractable text found in this file." documents[file_id] = {"file_path": file_path, "file_name": file.filename} await rag.embed_document(file_id, text) return {"file_id": file_id, "file_name": file.filename} except Exception as e: logger.error(f"Upload failed: {e}") raise HTTPException(status_code=500, detail="Failed to process file") @app.post("/query") async def query_file(request: QueryRequest): if request.file_id not in documents: raise HTTPException(status_code=404, detail="File not found") try: context = await rag.query_document(request.question, request.file_id) if not context: return {"answer": "No relevant info found in the document."} context_text = "\n".join(context) prompt = f"Using the context, answer the question: {request.question}\n\nContext:\n{context_text}" answer = generate_response(prompt) return {"answer": answer} except Exception as e: logger.error(f"Query failed: {e}") raise HTTPException(status_code=500, detail="Error answering question") @app.post("/summarize") async def summarize_file(file_id: str): if file_id not in documents: raise HTTPException(status_code=404, detail="File not found") try: with open(documents[file_id]["file_path"], "rb") as f: content = f.read() text = parse_pdf(content) if documents[file_id]["file_name"].endswith(".pdf") else parse_text(content) if not text.strip(): raise ValueError("No text found to summarize.") prompt = f"Summarize this text in 100 words or less:\n\n{text[:5000]}" summary = generate_response(prompt) return {"summary": summary} except Exception as e: logger.error(f"Summarization failed: {e}") raise HTTPException(status_code=500, detail="Error generating summary")