File size: 4,254 Bytes
e539f46
 
b27c40f
 
e539f46
 
 
 
 
22596e9
 
e539f46
af53cef
e539f46
 
 
af53cef
b27c40f
e539f46
af53cef
e539f46
 
af53cef
e539f46
 
 
 
 
af53cef
b27c40f
 
 
 
 
 
af53cef
79c6fc5
 
 
af53cef
22596e9
b27c40f
 
 
 
 
af53cef
e539f46
79c6fc5
e539f46
 
 
 
 
 
 
 
79c6fc5
b27c40f
 
e539f46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22596e9
e539f46
 
 
 
 
 
 
 
 
 
 
 
 
b27c40f
e539f46
 
b27c40f
22596e9
e539f46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22596e9
 
e539f46
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import os
import logging
from parser import parse_pdf, parse_text
from rag import RAG
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# App setup
app = FastAPI(title="NotebookLM-like Tool")

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Adjust for production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Serve static/index.html
app.mount("/static", StaticFiles(directory="static"), name="static")

@app.get("/")
def serve_index():
    return FileResponse("static/index.html")

# Load smaller Qwen model for Hugging Face CPU
logger.info("Loading Qwen 1.8B Chat model...")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-1_8B-Chat", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-1_8B-Chat", trust_remote_code=True)
logger.info("Model loaded.")

def generate_response(prompt: str) -> str:
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_new_tokens=300, do_sample=True)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# In-memory storage and RAG engine
rag = RAG()
documents = {}

class QueryRequest(BaseModel):
    question: str
    file_id: str

@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
    try:
        logger.info(f"Received file: {file.filename}")
        if file.content_type not in ["application/pdf", "text/plain"]:
            raise HTTPException(status_code=400, detail="Only PDF or TXT allowed")

        file_id = str(len(documents) + 1)
        safe_name = file.filename.replace("/", "_").replace("\\", "_")
        file_path = f"uploads/{file_id}_{safe_name}"

        os.makedirs("uploads", exist_ok=True)
        file_bytes = await file.read()

        with open(file_path, "wb") as f:
            f.write(file_bytes)

        text = parse_pdf(file_bytes) if file.filename.endswith(".pdf") else parse_text(file_bytes)
        if not text.strip():
            text = "No extractable text found in this file."

        documents[file_id] = {"file_path": file_path, "file_name": file.filename}
        await rag.embed_document(file_id, text)

        return {"file_id": file_id, "file_name": file.filename}
    except Exception as e:
        logger.error(f"Upload failed: {e}")
        raise HTTPException(status_code=500, detail="Failed to process file")

@app.post("/query")
async def query_file(request: QueryRequest):
    if request.file_id not in documents:
        raise HTTPException(status_code=404, detail="File not found")
    try:
        context = await rag.query_document(request.question, request.file_id)
        if not context:
            return {"answer": "No relevant info found in the document."}

        context_text = "\n".join(context)
        prompt = f"Using the context, answer the question: {request.question}\n\nContext:\n{context_text}"
        answer = generate_response(prompt)
        return {"answer": answer}
    except Exception as e:
        logger.error(f"Query failed: {e}")
        raise HTTPException(status_code=500, detail="Error answering question")

@app.post("/summarize")
async def summarize_file(file_id: str):
    if file_id not in documents:
        raise HTTPException(status_code=404, detail="File not found")
    try:
        with open(documents[file_id]["file_path"], "rb") as f:
            content = f.read()
            text = parse_pdf(content) if documents[file_id]["file_name"].endswith(".pdf") else parse_text(content)

        if not text.strip():
            raise ValueError("No text found to summarize.")

        prompt = f"Summarize this text in 100 words or less:\n\n{text[:5000]}"
        summary = generate_response(prompt)
        return {"summary": summary}
    except Exception as e:
        logger.error(f"Summarization failed: {e}")
        raise HTTPException(status_code=500, detail="Error generating summary")