my-rag-qa / app.py
VictorTomas09's picture
Add app.py for RAG QA demo
74d4f11 verified
raw
history blame
11.4 kB
#!/usr/bin/env python
# coding: utf-8
# # Retrieval-Augmented QA Demo
#
# This notebook builds a minimal RAG (Retrieval-Augmented Generation) pipeline with enhancements:
#
# - Slimmed & deduplicated corpora
# - Chunking long passages
# - Persistent FAISS index & embeddings
# - Distance threshold to avoid hallucinations
# - Context-length control
# - Polished Gradio interface with separate contexts panel
# ## 1. Configuration & Imports
#
# We detect device, print settings, and support loading saved index.
# In[2]:
import os
import pickle
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, CrossEncoder
import faiss
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from transformers import AutoTokenizer as _AutoTokenizer
import gradio as gr
import evaluate
# Settings
data_dir = os.path.join(os.getcwd(), "data")
os.makedirs(data_dir, exist_ok=True)
INDEX_PATH = os.path.join(data_dir, "faiss_index.faiss")
EMB_PATH = os.path.join(data_dir, "embeddings.npy")
PCTX_PATH = os.path.join(data_dir, "passages.pkl")
MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small")
EMBEDDER_MODEL = os.getenv("EMBEDDER_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
device = 0 if torch.cuda.is_available() else -1
print(f"Using model: {MODEL_NAME}, embedder: {EMBEDDER_MODEL}, device: {'GPU' if device==0 else 'CPU'}")
# Threshold for maximum acceptable L2 distance
dist_threshold = 1.0 # tune as needed
# Max words per context snippet
max_context_words = 200
# ## Useful functions
def make_context_snippets(contexts, max_words=200):
snippets = []
for c in contexts:
words = c.split()
if len(words) > max_words:
c = " ".join(words[:max_words]) + " ... [truncated]"
snippets.append(c)
return snippets
# ## 2. Load, Deduplicate & Chunk Corpora
#
# For this demo we sample small slices and remove duplicates. We also chunk any passage >512 tokens.
#
# tokenizer for chunking
chunk_tokenizer = _AutoTokenizer.from_pretrained(MODEL_NAME)
max_tokens = chunk_tokenizer.model_max_length
def chunk_text(text: str, max_tokens: int, stride: int = None) -> list[str]:
"""
Split `text` into overlapping chunks of up to max_tokens words.
By default uses 25% overlap (stride = max_tokens // 4).
"""
words = text.split()
if stride is None:
stride = max_tokens // 4 # 25% overlap
chunks = []
start = 0
while start < len(words):
end = start + max_tokens
chunk = " ".join(words[start:end])
chunks.append(chunk)
# advance by stride, not full window
start += stride
return chunks
# Load corpora
wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")
wiki_passages = wiki_ds["passage"]
squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]")
squad_passages = [ex["context"] for ex in squad_ds]
trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
trivia_passages = []
for ex in trivia_ds:
for field in ("wiki_context", "search_context"):
txt = ex.get(field) or ""
if txt:
trivia_passages.append(txt)
# Combine, dedupe, chunk
all_passages = wiki_passages + squad_passages + trivia_passages
unique_passages = list(dict.fromkeys(all_passages))
passages = []
for p in unique_passages:
# count tokens without encoding to avoid warnings
tokens = chunk_tokenizer.tokenize(p)
if len(tokens) > max_tokens:
passages.extend(chunk_text(p, max_tokens))
else:
passages.append(p)
print(f"Total passages after dedupe & chunk: {len(passages)}")
# Persist raw passages list
with open(PCTX_PATH, "wb") as f:
pickle.dump(passages, f)
# ## 3. Build or Load FAISS Index & Embeddings
#
# We save embeddings & index to disk to skip slow re-encoding.
# ── Initialize embedder and reranker ──
from sentence_transformers import SentenceTransformer
from torch import no_grad
embedder = SentenceTransformer(EMBEDDER_MODEL)
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
# ── Load or (re)build FAISS index with cosine similarity ──
if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH):
print("Loading saved index and embeddings…")
index = faiss.read_index(INDEX_PATH)
embeddings = np.load(EMB_PATH)
else:
print("Encoding passages (with overlap)…")
embeddings = embedder.encode(
passages,
show_progress_bar=True,
convert_to_numpy=True,
batch_size=32
)
# Normalize to unit length so that inner‐product = cosine sim
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
# Build a FAISS index over inner‐product (cosine) space
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
# Persist to disk for faster reload
faiss.write_index(index, INDEX_PATH)
np.save(EMB_PATH, embeddings)
print(f"Indexed {index.ntotal} vectors.")
# ## 4. Load QA Model & Pipeline
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
qa_pipeline = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
device=device,
early_stopping=True
)
print("QA pipeline ready.")
# ## 5. Retrieval + Generation Functions
#
# We bail out early if top distance > threshold to avoid hallucination.
def retrieve(question: str, k: int = 20, rerank_k: int = 5):
# 1) dense‐search top k
q_emb = embedder.encode([question], convert_to_numpy=True)
distances, indices = index.search(q_emb, k)
# 2) pull out those k contexts
candidates = [passages[i] for i in indices[0]]
# 3) score with cross‐encoder
pairs = [[question, ctx] for ctx in candidates]
scores = reranker.predict(pairs)
# 4) pick top rerank_k
top_idxs = np.argsort(scores)[-rerank_k:][::-1]
final_ctxs = [candidates[i] for i in top_idxs]
final_dist = [distances[0][i] for i in top_idxs]
return final_ctxs, final_dist
def generate(question: str, contexts: list) -> str:
"""
Build a RAG prompt from the retrieved contexts and generate
an answer using the HF text2text pipeline.
"""
# 1) Turn each context into a truncated snippet
snippet_lines = [
f"Context {i+1}: {s}"
for i, s in enumerate(make_context_snippets(contexts, max_context_words))
]
# 2) Build the full prompt
prompt = (
"You are a helpful assistant. Use ONLY the following contexts to answer. "
"If the answer is not contained, say 'Sorry, I don't know.'\n\n"
+ "\n".join(snippet_lines)
+ f"\n\nQuestion: {question}\nAnswer:"
)
# 3) Call the pipeline (it handles tokenization + generation + decoding)
result = qa_pipeline(prompt, truncation=True, max_new_tokens=200)[0]["generated_text"]
return result.strip()
def retrieve_and_answer(question, k=5):
contexts, distances = retrieve(question, k=20)
if not contexts or distances[0] > dist_threshold:
return "Sorry, I don't know.", []
ans = generate(question, contexts)
return ans, contexts
import random
print("Some sample passages:\n")
for p in random.sample(passages, 5):
print(p, "\n" + "-"*80 + "\n")
# ## 6. Gradio Demo Interface
#
# Separate panels for answer and contexts.
def answer_and_contexts(question: str):
"""
Full end-to-end: retrieve, threshold-check, generate answer,
and return both the answer and a formatted string of contexts.
"""
answer, contexts = retrieve_and_answer(question)
# If no valid contexts, just return the apology
if not contexts:
return answer, ""
# Otherwise format each snippet for display
ctx_snippets = [
f"Context {i+1}: {s}"
for i, s in enumerate(make_context_snippets(contexts, max_context_words))
]
return answer, "\n\n---\n\n".join(ctx_snippets)
iface = gr.Interface(
fn=answer_and_contexts,
inputs=gr.Textbox(lines=1, placeholder="Enter your question here...", label="Question"),
outputs=[
gr.Textbox(label="Answer"),
gr.Textbox(label="Retrieved Contexts")
],
title="πŸ” RAG QA Demo",
description="Retrieval-Augmented QA with distance threshold and context preview"
)
iface.launch()
# # Test the Model
# load SQuAD v2 (we only need validation split)
squad = load_dataset("rajpurkar/squad_v2", split="validation")
# load the SQuAD metric (handles no-answer properly)
squad_metric = evaluate.load("squad")
def retrieval_recall(dataset, k=20, num_samples=100):
hits = 0
for ex in dataset.select(range(num_samples)):
question = ex["question"]
gold_answers = ex["answers"]["text"] # list, empty if unanswerable
# get your top-k contexts
ctxs, _ = retrieve(question, k=k, rerank_k=k) # or rerank_k smaller
# check if any gold answer appears in any context
if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
hits += 1
recall = hits / num_samples
print(f"Retrieval Recall@{k}: {recall:.3f}")
return recall
# ## Only answerable Questions
def retrieval_recall_answerable(dataset, k=20, num_samples=100):
hits = 0
total = 0
for ex in dataset.select(range(num_samples)):
if not ex["answers"]["text"]:
continue # skip unanswerable
total += 1
ctxs, _ = retrieve(ex["question"], k=k, rerank_k=k)
if any(any(ans in ctx for ctx in ctxs) for ans in ex["answers"]["text"]):
hits += 1
recall = hits / total
print(f"Retrieval Recall@{k} on answerable only: {recall:.3f} ({hits}/{total})")
return recall
def qa_eval_all(dataset, num_samples=100, k=20):
preds, refs = [], []
for ex in dataset.select(range(num_samples)):
qid = ex["id"]
gold = ex["answers"]
# ensure metric has something to iterate over
if not gold["text"]:
gold = {"text":[""], "answer_start":[0]}
ans, _ = retrieve_and_answer(ex["question"], k=k)
# for metric purposes, treat our refusal as empty string
pred_text = "" if ans.strip().lower().startswith("sorry") else ans
preds.append({"id": qid, "prediction_text": pred_text})
refs.append({"id": qid, "answers": gold})
results = squad_metric.compute(predictions=preds, references=refs)
print(f"Full QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
return results
def qa_eval_answerable(dataset, num_samples=100, k=20):
preds, refs = [], []
for ex in dataset.select(range(num_samples)):
if not ex["answers"]["text"]:
continue # skip unanswerable
qid = ex["id"]
ans, _ = retrieve_and_answer(ex["question"], k=k)
preds.append({"id": qid, "prediction_text": ans})
refs.append({"id": qid, "answers": ex["answers"]})
results = squad_metric.compute(predictions=preds, references=refs)
print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
return results
retrieval_recall(squad, k=2, num_samples=100)
retrieval_recall_answerable(squad, k=2, num_samples=100)
qa_eval_all(squad, num_samples=100, k=2)
qa_eval_answerable(squad, num_samples=100, k=2)