ChatLM / main.py
Ah707's picture
Update main.py
af53cef verified
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import os
import logging
from parser import parse_pdf, parse_text
from rag import RAG
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# App setup
app = FastAPI(title="NotebookLM-like Tool")
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Serve static/index.html
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
def serve_index():
return FileResponse("static/index.html")
# Load smaller Qwen model for Hugging Face CPU
logger.info("Loading Qwen 1.8B Chat model...")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-1_8B-Chat", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-1_8B-Chat", trust_remote_code=True)
logger.info("Model loaded.")
def generate_response(prompt: str) -> str:
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=300, do_sample=True)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# In-memory storage and RAG engine
rag = RAG()
documents = {}
class QueryRequest(BaseModel):
question: str
file_id: str
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
try:
logger.info(f"Received file: {file.filename}")
if file.content_type not in ["application/pdf", "text/plain"]:
raise HTTPException(status_code=400, detail="Only PDF or TXT allowed")
file_id = str(len(documents) + 1)
safe_name = file.filename.replace("/", "_").replace("\\", "_")
file_path = f"uploads/{file_id}_{safe_name}"
os.makedirs("uploads", exist_ok=True)
file_bytes = await file.read()
with open(file_path, "wb") as f:
f.write(file_bytes)
text = parse_pdf(file_bytes) if file.filename.endswith(".pdf") else parse_text(file_bytes)
if not text.strip():
text = "No extractable text found in this file."
documents[file_id] = {"file_path": file_path, "file_name": file.filename}
await rag.embed_document(file_id, text)
return {"file_id": file_id, "file_name": file.filename}
except Exception as e:
logger.error(f"Upload failed: {e}")
raise HTTPException(status_code=500, detail="Failed to process file")
@app.post("/query")
async def query_file(request: QueryRequest):
if request.file_id not in documents:
raise HTTPException(status_code=404, detail="File not found")
try:
context = await rag.query_document(request.question, request.file_id)
if not context:
return {"answer": "No relevant info found in the document."}
context_text = "\n".join(context)
prompt = f"Using the context, answer the question: {request.question}\n\nContext:\n{context_text}"
answer = generate_response(prompt)
return {"answer": answer}
except Exception as e:
logger.error(f"Query failed: {e}")
raise HTTPException(status_code=500, detail="Error answering question")
@app.post("/summarize")
async def summarize_file(file_id: str):
if file_id not in documents:
raise HTTPException(status_code=404, detail="File not found")
try:
with open(documents[file_id]["file_path"], "rb") as f:
content = f.read()
text = parse_pdf(content) if documents[file_id]["file_name"].endswith(".pdf") else parse_text(content)
if not text.strip():
raise ValueError("No text found to summarize.")
prompt = f"Summarize this text in 100 words or less:\n\n{text[:5000]}"
summary = generate_response(prompt)
return {"summary": summary}
except Exception as e:
logger.error(f"Summarization failed: {e}")
raise HTTPException(status_code=500, detail="Error generating summary")