#!/usr/bin/env python # coding: utf-8 import os import pickle import argparse import faiss import numpy as np import torch import gradio as gr from datasets import load_dataset from sentence_transformers import SentenceTransformer, CrossEncoder from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, pipeline as hf_pipeline, ) import evaluate # ── 1. Configuration ── DATA_DIR = os.path.join(os.getcwd(), "data") INDEX_PATH = os.path.join(DATA_DIR, "faiss_index.faiss") EMB_PATH = os.path.join(DATA_DIR, "embeddings.npy") PCTX_PATH = os.path.join(DATA_DIR, "passages.pkl") MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small") EMBEDDER_MODEL = os.getenv("EMBEDDER_MODEL", "sentence-transformers/all-MiniLM-L6-v2") DIST_THRESHOLD = float(os.getenv("DIST_THRESHOLD", 1.0)) MAX_CTX_WORDS = int(os.getenv("MAX_CTX_WORDS", 200)) DEVICE = 0 if torch.cuda.is_available() else -1 os.makedirs(DATA_DIR, exist_ok=True) # ── 2. Helpers ── def make_context_snippets(contexts, max_words=MAX_CTX_WORDS): snippets = [] for c in contexts: words = c.split() if len(words) > max_words: c = " ".join(words[:max_words]) + " ... [truncated]" snippets.append(c) return snippets def chunk_text(text, max_tokens, stride=None): words = text.split() if stride is None: stride = max_tokens // 4 chunks, start = [], 0 while start < len(words): end = start + max_tokens chunks.append(" ".join(words[start:end])) start += stride return chunks # ── 3. Load & preprocess passages ── def load_passages(): wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages") squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]") trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]") wiki_passages = wiki_ds["passage"] squad_passages = [ex["context"] for ex in squad_ds] trivia_passages = [] for ex in trivia_ds: for fld in ("wiki_context", "search_context"): txt = ex.get(fld) or "" if txt: trivia_passages.append(txt) all_passages = list(dict.fromkeys(wiki_passages + squad_passages + trivia_passages)) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) max_tokens = tokenizer.model_max_length chunks = [] for p in all_passages: toks = tokenizer.tokenize(p) if len(toks) > max_tokens: chunks.extend(chunk_text(p, max_tokens)) else: chunks.append(p) print(f"[load_passages] total chunks: {len(chunks)}") with open(PCTX_PATH, "wb") as f: pickle.dump(chunks, f) return chunks # ── 4. Build or load FAISS ── def load_faiss_index(passages): embedder = SentenceTransformer(EMBEDDER_MODEL) reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH): print("Loading FAISS index & embeddings…") index = faiss.read_index(INDEX_PATH) embeddings = np.load(EMB_PATH) else: print("Encoding passages & building FAISS index…") embeddings = embedder.encode( passages, show_progress_bar=True, convert_to_numpy=True, batch_size=32 ) embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) dim = embeddings.shape[1] index = faiss.IndexFlatIP(dim) index.add(embeddings) faiss.write_index(index, INDEX_PATH) np.save(EMB_PATH, embeddings) return embedder, reranker, index # ── 5. Initialize RAG components ── def setup_rag(): if os.path.exists(PCTX_PATH): with open(PCTX_PATH, "rb") as f: passages = pickle.load(f) else: passages = load_passages() embedder, reranker, index = load_faiss_index(passages) tok = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) qa_pipe = hf_pipeline( "text2text-generation", model=model, tokenizer=tok, device=DEVICE, truncation=True, max_length=512, num_beams=4, early_stopping=True ) return passages, embedder, reranker, index, qa_pipe # ── 6. Retrieval & generation ── def retrieve(question, passages, embedder, reranker, index, k=20, rerank_k=5): q_emb = embedder.encode([question], convert_to_numpy=True) distances, idxs = index.search(q_emb, k) cands = [passages[i] for i in idxs[0]] scores = reranker.predict([[question, c] for c in cands]) top = np.argsort(scores)[-rerank_k:][::-1] return [cands[i] for i in top], [distances[0][i] for i in top] def generate(question, contexts, qa_pipe): lines = [ f"Context {i+1}: {s}" for i, s in enumerate(make_context_snippets(contexts)) ] prompt = ( "You are a helpful assistant. Use ONLY the following contexts to answer. " "If the answer is not contained, say 'Sorry, I don't know.'\n\n" + "\n".join(lines) + f"\n\nQuestion: {question}\nAnswer:" ) return qa_pipe(prompt)[0]["generated_text"].strip() def retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe): contexts, dists = retrieve(question, passages, embedder, reranker, index) if not contexts or dists[0] > DIST_THRESHOLD: return "Sorry, I don't know.", [] return generate(question, contexts, qa_pipe), contexts def answer_and_contexts(question, passages, embedder, reranker, index, qa_pipe): ans, ctxs = retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe) if not ctxs: return ans, "" snippets = [ f"Context {i+1}: {s}" for i, s in enumerate(make_context_snippets(ctxs)) ] return ans, "\n\n---\n\n".join(snippets) # ── 7. Evaluation routines ── def retrieval_recall(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100): hits = 0 for ex in dataset.select(range(num_samples)): question = ex["question"] gold_answers = ex["answers"]["text"] if rerank_k: ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k) else: q_emb = embedder.encode([question], convert_to_numpy=True) distances, idxs = index.search(q_emb, k) ctxs = [passages[i] for i in idxs[0]] if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers): hits += 1 recall = hits / num_samples print(f"Retrieval Recall@{k} (rerank_k={rerank_k}): {recall:.3f} ({hits}/{num_samples})") return recall def retrieval_recall_answerable(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100): hits, total = 0, 0 for ex in dataset.select(range(num_samples)): gold = ex["answers"]["text"] if not gold: continue total += 1 question = ex["question"] if rerank_k: ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k) else: q_emb = embedder.encode([question], convert_to_numpy=True) distances, idxs = index.search(q_emb, k) ctxs = [passages[i] for i in idxs[0]] if any(any(ans in ctx for ctx in ctxs) for ans in gold): hits += 1 recall = hits / total if total > 0 else 0.0 print(f"Retrieval Recall@{k} on answerable only (rerank_k={rerank_k}): {recall:.3f} ({hits}/{total})") return recall def qa_eval_answerable(dataset, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100): squad_metric = evaluate.load("squad") preds, refs = [], [] for ex in dataset.select(range(num_samples)): gold = ex["answers"]["text"] if not gold: continue qid = ex["id"] answer, _ = retrieve_and_answer(ex["question"], passages, embedder, reranker, index, qa_pipe) preds.append({"id": qid, "prediction_text": answer}) refs.append({"id": qid, "answers": ex["answers"]}) results = squad_metric.compute(predictions=preds, references=refs) print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}") return results # ── 8. Main entry ── def main(): passages, embedder, reranker, index, qa_pipe = setup_rag() parser = argparse.ArgumentParser() parser.add_argument( "--eval", action="store_true", help="Run retrieval/QA evaluations on SQuAD instead of launching the demo" ) args = parser.parse_args() if args.eval: squad = load_dataset("rajpurkar/squad_v2", split="validation") retrieval_recall(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100) retrieval_recall_answerable(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100) qa_eval_answerable(squad, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100) else: demo = gr.Interface( fn=lambda q: answer_and_contexts(q, passages, embedder, reranker, index, qa_pipe), inputs=gr.Textbox(lines=1, placeholder="Ask me anything…", label="Question"), outputs=[gr.Textbox(label="Answer"), gr.Textbox(label="Contexts")], title="🔍 RAG QA Demo", description="Retrieval-Augmented QA with threshold and context preview", examples=[ "When was Abraham Lincoln inaugurated?", "What is the capital of France?", "Who wrote '1984'?" ], allow_flagging="never", ) demo.launch() if __name__ == "__main__": main()