|
|
import os |
|
|
import io |
|
|
import gradio as gr |
|
|
import faiss |
|
|
import numpy as np |
|
|
from pypdf import PdfReader |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
GEN_MODEL_NAME = "google/flan-t5-small" |
|
|
|
|
|
embedder = SentenceTransformer(EMBED_MODEL_NAME) |
|
|
generator = pipeline("text2text-generation", model=GEN_MODEL_NAME) |
|
|
|
|
|
|
|
|
def pdfs_to_texts(files): |
|
|
texts = [] |
|
|
for f in files: |
|
|
|
|
|
reader = PdfReader(io.BytesIO(f.read())) |
|
|
pages = [page.extract_text() or "" for page in reader.pages] |
|
|
texts.append("\n".join(pages)) |
|
|
return texts |
|
|
|
|
|
|
|
|
|
|
|
def chunk_text(text, chunk_size=600, overlap=120): |
|
|
words = text.split() |
|
|
chunks = [] |
|
|
i = 0 |
|
|
while i < len(words): |
|
|
chunk = words[i:i+chunk_size] |
|
|
chunks.append(" ".join(chunk)) |
|
|
i += chunk_size - overlap |
|
|
return chunks |
|
|
|
|
|
|
|
|
|
|
|
index = None |
|
|
corpus_chunks = [] |
|
|
|
|
|
def build_index(files, progress=gr.Progress()): |
|
|
global index, corpus_chunks |
|
|
texts = pdfs_to_texts(files) |
|
|
|
|
|
|
|
|
corpus_chunks = [] |
|
|
for t in texts: |
|
|
if not t.strip(): |
|
|
continue |
|
|
corpus_chunks += chunk_text(t) |
|
|
|
|
|
if not corpus_chunks: |
|
|
return "No text extracted from PDFs.", None |
|
|
|
|
|
progress(0.3, desc="Embedding chunks…") |
|
|
embeddings = embedder.encode(corpus_chunks, convert_to_numpy=True, show_progress_bar=False) |
|
|
d = embeddings.shape[1] |
|
|
|
|
|
progress(0.6, desc="Creating FAISS index…") |
|
|
index = faiss.IndexFlatIP(d) |
|
|
|
|
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-10 |
|
|
embeddings = embeddings / norms |
|
|
index.add(embeddings.astype(np.float32)) |
|
|
|
|
|
return f"Indexed {len(corpus_chunks)} chunks.", len(corpus_chunks) |
|
|
|
|
|
|
|
|
def answer_question(question, top_k=5, max_new_tokens=256): |
|
|
if index is None or not corpus_chunks: |
|
|
return "Index not built yet. Upload PDFs and click **Build Index** first." |
|
|
|
|
|
|
|
|
q = embedder.encode([question], convert_to_numpy=True) |
|
|
q = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-10) |
|
|
|
|
|
D, I = index.search(q.astype(np.float32), int(top_k)) |
|
|
retrieved = [corpus_chunks[i] for i in I[0] if i < len(corpus_chunks)] |
|
|
|
|
|
context = "\n\n".join(retrieved) |
|
|
prompt = ( |
|
|
"You are a helpful study assistant. Using ONLY the context, answer the question.\n" |
|
|
"If the answer isn't in the context, say you don't have enough information.\n\n" |
|
|
f"Context:\n{context}\n\nQuestion: {question}\nAnswer:" |
|
|
) |
|
|
out = generator(prompt, max_new_tokens=int(max_new_tokens), temperature=0.2) |
|
|
return out[0]["generated_text"].strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Group 5 Study Helper (RAG)") as demo: |
|
|
gr.Markdown("# Group 5 Study Helper (RAG)\nUpload PDFs → Build Index → Ask questions.") |
|
|
|
|
|
with gr.Row(): |
|
|
file_in = gr.Files(file_types=[".pdf"], label="Upload PDF files") |
|
|
with gr.Row(): |
|
|
build_btn = gr.Button("Build Index", variant="primary") |
|
|
status = gr.Markdown() |
|
|
chunk_count = gr.Number(label="Chunk count", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
question = gr.Textbox(label="Your question") |
|
|
with gr.Row(): |
|
|
topk = gr.Slider(1, 10, value=5, step=1, label="Top-K passages") |
|
|
max_tokens = gr.Slider(64, 512, value=256, step=16, label="Max new tokens") |
|
|
with gr.Row(): |
|
|
ask_btn = gr.Button("Ask", variant="primary") |
|
|
with gr.Row(): |
|
|
answer = gr.Markdown(label="Answer") |
|
|
|
|
|
def _build(files): |
|
|
msg, n = build_index(files) |
|
|
return msg, n or 0 |
|
|
|
|
|
build_btn.click(_build, inputs=[file_in], outputs=[status, chunk_count]) |
|
|
ask_btn.click(answer_question, inputs=[question, topk, max_tokens], outputs=[answer]) |
|
|
|
|
|
demo.launch() |
|
|
|