my-rag-qa / app.py
VictorTomas09's picture
Update app.py
728106c verified
#!/usr/bin/env python
# coding: utf-8
import os
import pickle
import argparse
import faiss
import numpy as np
import torch
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
pipeline as hf_pipeline,
)
import evaluate
# ── 1. Configuration ──
DATA_DIR = os.path.join(os.getcwd(), "data")
INDEX_PATH = os.path.join(DATA_DIR, "faiss_index.faiss")
EMB_PATH = os.path.join(DATA_DIR, "embeddings.npy")
PCTX_PATH = os.path.join(DATA_DIR, "passages.pkl")
MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small")
EMBEDDER_MODEL = os.getenv("EMBEDDER_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
DIST_THRESHOLD = float(os.getenv("DIST_THRESHOLD", 1.0))
MAX_CTX_WORDS = int(os.getenv("MAX_CTX_WORDS", 200))
DEVICE = 0 if torch.cuda.is_available() else -1
os.makedirs(DATA_DIR, exist_ok=True)
# ── 2. Helpers ──
def make_context_snippets(contexts, max_words=MAX_CTX_WORDS):
snippets = []
for c in contexts:
words = c.split()
if len(words) > max_words:
c = " ".join(words[:max_words]) + " ... [truncated]"
snippets.append(c)
return snippets
def chunk_text(text, max_tokens, stride=None):
words = text.split()
if stride is None:
stride = max_tokens // 4
chunks, start = [], 0
while start < len(words):
end = start + max_tokens
chunks.append(" ".join(words[start:end]))
start += stride
return chunks
# ── 3. Load & preprocess passages ──
def load_passages():
wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")
squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]")
trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
wiki_passages = wiki_ds["passage"]
squad_passages = [ex["context"] for ex in squad_ds]
trivia_passages = []
for ex in trivia_ds:
for fld in ("wiki_context", "search_context"):
txt = ex.get(fld) or ""
if txt:
trivia_passages.append(txt)
all_passages = list(dict.fromkeys(wiki_passages + squad_passages + trivia_passages))
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
max_tokens = tokenizer.model_max_length
chunks = []
for p in all_passages:
toks = tokenizer.tokenize(p)
if len(toks) > max_tokens:
chunks.extend(chunk_text(p, max_tokens))
else:
chunks.append(p)
print(f"[load_passages] total chunks: {len(chunks)}")
with open(PCTX_PATH, "wb") as f:
pickle.dump(chunks, f)
return chunks
# ── 4. Build or load FAISS ──
def load_faiss_index(passages):
embedder = SentenceTransformer(EMBEDDER_MODEL)
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH):
print("Loading FAISS index & embeddings…")
index = faiss.read_index(INDEX_PATH)
embeddings = np.load(EMB_PATH)
else:
print("Encoding passages & building FAISS index…")
embeddings = embedder.encode(
passages,
show_progress_bar=True,
convert_to_numpy=True,
batch_size=32
)
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
faiss.write_index(index, INDEX_PATH)
np.save(EMB_PATH, embeddings)
return embedder, reranker, index
# ── 5. Initialize RAG components ──
def setup_rag():
if os.path.exists(PCTX_PATH):
with open(PCTX_PATH, "rb") as f:
passages = pickle.load(f)
else:
passages = load_passages()
embedder, reranker, index = load_faiss_index(passages)
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
qa_pipe = hf_pipeline(
"text2text-generation",
model=model,
tokenizer=tok,
device=DEVICE,
truncation=True,
max_length=512,
num_beams=4,
early_stopping=True
)
return passages, embedder, reranker, index, qa_pipe
# ── 6. Retrieval & generation ──
def retrieve(question, passages, embedder, reranker, index, k=20, rerank_k=5):
q_emb = embedder.encode([question], convert_to_numpy=True)
distances, idxs = index.search(q_emb, k)
cands = [passages[i] for i in idxs[0]]
scores = reranker.predict([[question, c] for c in cands])
top = np.argsort(scores)[-rerank_k:][::-1]
return [cands[i] for i in top], [distances[0][i] for i in top]
def generate(question, contexts, qa_pipe):
lines = [
f"Context {i+1}: {s}"
for i, s in enumerate(make_context_snippets(contexts))
]
prompt = (
"You are a helpful assistant. Use ONLY the following contexts to answer. "
"If the answer is not contained, say 'Sorry, I don't know.'\n\n"
+ "\n".join(lines)
+ f"\n\nQuestion: {question}\nAnswer:"
)
return qa_pipe(prompt)[0]["generated_text"].strip()
def retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe):
contexts, dists = retrieve(question, passages, embedder, reranker, index)
if not contexts or dists[0] > DIST_THRESHOLD:
return "Sorry, I don't know.", []
return generate(question, contexts, qa_pipe), contexts
def answer_and_contexts(question, passages, embedder, reranker, index, qa_pipe):
ans, ctxs = retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe)
if not ctxs:
return ans, ""
snippets = [
f"Context {i+1}: {s}"
for i, s in enumerate(make_context_snippets(ctxs))
]
return ans, "\n\n---\n\n".join(snippets)
# ── 7. Evaluation routines ──
def retrieval_recall(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
hits = 0
for ex in dataset.select(range(num_samples)):
question = ex["question"]
gold_answers = ex["answers"]["text"]
if rerank_k:
ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
else:
q_emb = embedder.encode([question], convert_to_numpy=True)
distances, idxs = index.search(q_emb, k)
ctxs = [passages[i] for i in idxs[0]]
if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
hits += 1
recall = hits / num_samples
print(f"Retrieval Recall@{k} (rerank_k={rerank_k}): {recall:.3f} ({hits}/{num_samples})")
return recall
def retrieval_recall_answerable(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
hits, total = 0, 0
for ex in dataset.select(range(num_samples)):
gold = ex["answers"]["text"]
if not gold:
continue
total += 1
question = ex["question"]
if rerank_k:
ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
else:
q_emb = embedder.encode([question], convert_to_numpy=True)
distances, idxs = index.search(q_emb, k)
ctxs = [passages[i] for i in idxs[0]]
if any(any(ans in ctx for ctx in ctxs) for ans in gold):
hits += 1
recall = hits / total if total > 0 else 0.0
print(f"Retrieval Recall@{k} on answerable only (rerank_k={rerank_k}): {recall:.3f} ({hits}/{total})")
return recall
def qa_eval_answerable(dataset, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100):
squad_metric = evaluate.load("squad")
preds, refs = [], []
for ex in dataset.select(range(num_samples)):
gold = ex["answers"]["text"]
if not gold:
continue
qid = ex["id"]
answer, _ = retrieve_and_answer(ex["question"], passages, embedder, reranker, index, qa_pipe)
preds.append({"id": qid, "prediction_text": answer})
refs.append({"id": qid, "answers": ex["answers"]})
results = squad_metric.compute(predictions=preds, references=refs)
print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
return results
# ── 8. Main entry ──
def main():
passages, embedder, reranker, index, qa_pipe = setup_rag()
parser = argparse.ArgumentParser()
parser.add_argument(
"--eval", action="store_true",
help="Run retrieval/QA evaluations on SQuAD instead of launching the demo"
)
args = parser.parse_args()
if args.eval:
squad = load_dataset("rajpurkar/squad_v2", split="validation")
retrieval_recall(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
retrieval_recall_answerable(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
qa_eval_answerable(squad, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100)
else:
demo = gr.Interface(
fn=lambda q: answer_and_contexts(q, passages, embedder, reranker, index, qa_pipe),
inputs=gr.Textbox(lines=1, placeholder="Ask me anything…", label="Question"),
outputs=[gr.Textbox(label="Answer"), gr.Textbox(label="Contexts")],
title="πŸ” RAG QA Demo",
description="Retrieval-Augmented QA with threshold and context preview",
examples=[
"When was Abraham Lincoln inaugurated?",
"What is the capital of France?",
"Who wrote '1984'?"
],
allow_flagging="never",
)
demo.launch()
if __name__ == "__main__":
main()