File size: 11,438 Bytes
74d4f11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
#!/usr/bin/env python
# coding: utf-8

# # Retrieval-Augmented QA Demo
# 
# This notebook builds a minimal RAG (Retrieval-Augmented Generation) pipeline with enhancements:
# 
# - Slimmed & deduplicated corpora
# - Chunking long passages
# - Persistent FAISS index & embeddings
# - Distance threshold to avoid hallucinations
# - Context-length control
# - Polished Gradio interface with separate contexts panel

# ## 1. Configuration & Imports
# 
# We detect device, print settings, and support loading saved index.

# In[2]:


import os
import pickle
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, CrossEncoder
import faiss
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from transformers import AutoTokenizer as _AutoTokenizer
import gradio as gr
import evaluate


# Settings
data_dir = os.path.join(os.getcwd(), "data")
os.makedirs(data_dir, exist_ok=True)
INDEX_PATH = os.path.join(data_dir, "faiss_index.faiss")
EMB_PATH   = os.path.join(data_dir, "embeddings.npy")
PCTX_PATH  = os.path.join(data_dir, "passages.pkl")

MODEL_NAME     = os.getenv("MODEL_NAME", "google/flan-t5-small")
EMBEDDER_MODEL = os.getenv("EMBEDDER_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
device         = 0 if torch.cuda.is_available() else -1
print(f"Using model: {MODEL_NAME}, embedder: {EMBEDDER_MODEL}, device: {'GPU' if device==0 else 'CPU'}")

# Threshold for maximum acceptable L2 distance
dist_threshold = 1.0  # tune as needed
# Max words per context snippet
max_context_words = 200


# ## Useful functions

def make_context_snippets(contexts, max_words=200):
    snippets = []
    for c in contexts:
        words = c.split()
        if len(words) > max_words:
            c = " ".join(words[:max_words]) + " ... [truncated]"
        snippets.append(c)
    return snippets


# ## 2. Load, Deduplicate & Chunk Corpora
# 
# For this demo we sample small slices and remove duplicates. We also chunk any passage >512 tokens.
# 


# tokenizer for chunking
chunk_tokenizer = _AutoTokenizer.from_pretrained(MODEL_NAME)
max_tokens = chunk_tokenizer.model_max_length

def chunk_text(text: str, max_tokens: int, stride: int = None) -> list[str]:
    """
    Split `text` into overlapping chunks of up to max_tokens words.
    By default uses 25% overlap (stride = max_tokens // 4).
    """
    words = text.split()
    if stride is None:
        stride = max_tokens // 4  # 25% overlap
    chunks = []
    start = 0
    while start < len(words):
        end = start + max_tokens
        chunk = " ".join(words[start:end])
        chunks.append(chunk)
        # advance by stride, not full window
        start += stride
    return chunks


# Load corpora
wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")
wiki_passages = wiki_ds["passage"]

squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]")
squad_passages = [ex["context"] for ex in squad_ds]

trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
trivia_passages = []
for ex in trivia_ds:
    for field in ("wiki_context", "search_context"):
        txt = ex.get(field) or ""
        if txt:
            trivia_passages.append(txt)

# Combine, dedupe, chunk
all_passages = wiki_passages + squad_passages + trivia_passages
unique_passages = list(dict.fromkeys(all_passages))
passages = []
for p in unique_passages:
    # count tokens without encoding to avoid warnings
    tokens = chunk_tokenizer.tokenize(p)
    if len(tokens) > max_tokens:
        passages.extend(chunk_text(p, max_tokens))
    else:
        passages.append(p)
print(f"Total passages after dedupe & chunk: {len(passages)}")

# Persist raw passages list
with open(PCTX_PATH, "wb") as f:
    pickle.dump(passages, f)


# ## 3. Build or Load FAISS Index & Embeddings
# 
# We save embeddings & index to disk to skip slow re-encoding.


# ── Initialize embedder and reranker ──
from sentence_transformers import SentenceTransformer
from torch import no_grad


embedder = SentenceTransformer(EMBEDDER_MODEL)
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

# ── Load or (re)build FAISS index with cosine similarity ──
if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH):
    print("Loading saved index and embeddings…")
    index = faiss.read_index(INDEX_PATH)
    embeddings = np.load(EMB_PATH)
else:
    print("Encoding passages (with overlap)…")
    embeddings = embedder.encode(
        passages,
        show_progress_bar=True,
        convert_to_numpy=True,
        batch_size=32
    )
    # Normalize to unit length so that inner‐product = cosine sim
    embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

    # Build a FAISS index over inner‐product (cosine) space
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(embeddings)

    # Persist to disk for faster reload
    faiss.write_index(index, INDEX_PATH)
    np.save(EMB_PATH, embeddings)
    print(f"Indexed {index.ntotal} vectors.")


# ## 4. Load QA Model & Pipeline


tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model     = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
qa_pipeline = pipeline(
    "text2text-generation",
    model=model,
    tokenizer=tokenizer,
    device=device,
    early_stopping=True
)
print("QA pipeline ready.")


# ## 5. Retrieval + Generation Functions
# 
# We bail out early if top distance > threshold to avoid hallucination.


def retrieve(question: str, k: int = 20, rerank_k: int = 5):
    # 1) dense‐search top k
    q_emb = embedder.encode([question], convert_to_numpy=True)
    distances, indices = index.search(q_emb, k)

    # 2) pull out those k contexts
    candidates = [passages[i] for i in indices[0]]

    # 3) score with cross‐encoder
    pairs = [[question, ctx] for ctx in candidates]
    scores = reranker.predict(pairs)

    # 4) pick top rerank_k
    top_idxs = np.argsort(scores)[-rerank_k:][::-1]
    final_ctxs = [candidates[i] for i in top_idxs]
    final_dist = [distances[0][i] for i in top_idxs]

    return final_ctxs, final_dist



def generate(question: str, contexts: list) -> str:
    """
    Build a RAG prompt from the retrieved contexts and generate
    an answer using the HF text2text pipeline.
    """
    # 1) Turn each context into a truncated snippet
    snippet_lines = [
        f"Context {i+1}: {s}"
        for i, s in enumerate(make_context_snippets(contexts, max_context_words))
    ]

    # 2) Build the full prompt
    prompt = (
        "You are a helpful assistant. Use ONLY the following contexts to answer. "
        "If the answer is not contained, say 'Sorry, I don't know.'\n\n"
        + "\n".join(snippet_lines)
        + f"\n\nQuestion: {question}\nAnswer:"
    )

    # 3) Call the pipeline (it handles tokenization + generation + decoding)
    result = qa_pipeline(prompt, truncation=True, max_new_tokens=200)[0]["generated_text"]
    return result.strip()


def retrieve_and_answer(question, k=5):
    contexts, distances = retrieve(question, k=20)
    if not contexts or distances[0] > dist_threshold:
        return "Sorry, I don't know.", []

    ans = generate(question, contexts)
    return ans, contexts


import random

print("Some sample passages:\n")
for p in random.sample(passages, 5):
    print(p, "\n" + "-"*80 + "\n")


# ## 6. Gradio Demo Interface
# 
# Separate panels for answer and contexts.

def answer_and_contexts(question: str):
    """
    Full end-to-end: retrieve, threshold-check, generate answer,
    and return both the answer and a formatted string of contexts.
    """
    answer, contexts = retrieve_and_answer(question)

    # If no valid contexts, just return the apology
    if not contexts:
        return answer, ""

    # Otherwise format each snippet for display
    ctx_snippets = [
        f"Context {i+1}: {s}"
        for i, s in enumerate(make_context_snippets(contexts, max_context_words))
    ]

    return answer, "\n\n---\n\n".join(ctx_snippets)



iface = gr.Interface(
    fn=answer_and_contexts,
    inputs=gr.Textbox(lines=1, placeholder="Enter your question here...", label="Question"),
    outputs=[
        gr.Textbox(label="Answer"),
        gr.Textbox(label="Retrieved Contexts")
    ],
    title="🔍 RAG QA Demo",
    description="Retrieval-Augmented QA with distance threshold and context preview"
)

iface.launch()


# # Test the Model

# load SQuAD v2 (we only need validation split)
squad = load_dataset("rajpurkar/squad_v2", split="validation")

# load the SQuAD metric (handles no-answer properly)
squad_metric = evaluate.load("squad")


def retrieval_recall(dataset, k=20, num_samples=100):
    hits = 0
    for ex in dataset.select(range(num_samples)):
        question = ex["question"]
        gold_answers = ex["answers"]["text"]  # list, empty if unanswerable
        
        # get your top-k contexts
        ctxs, _ = retrieve(question, k=k, rerank_k=k)  # or rerank_k smaller
        # check if any gold answer appears in any context
        if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
            hits += 1

    recall = hits / num_samples
    print(f"Retrieval Recall@{k}: {recall:.3f}")
    return recall


# ## Only answerable Questions


def retrieval_recall_answerable(dataset, k=20, num_samples=100):
    hits = 0
    total = 0
    for ex in dataset.select(range(num_samples)):
        if not ex["answers"]["text"]:
            continue   # skip unanswerable
        total += 1
        ctxs, _ = retrieve(ex["question"], k=k, rerank_k=k)
        if any(any(ans in ctx for ctx in ctxs) for ans in ex["answers"]["text"]):
            hits += 1
    recall = hits / total
    print(f"Retrieval Recall@{k} on answerable only: {recall:.3f} ({hits}/{total})")
    return recall

def qa_eval_all(dataset, num_samples=100, k=20):
    preds, refs = [], []
    for ex in dataset.select(range(num_samples)):
        qid  = ex["id"]
        gold = ex["answers"]
        # ensure metric has something to iterate over
        if not gold["text"]:
            gold = {"text":[""], "answer_start":[0]}

        ans, _ = retrieve_and_answer(ex["question"], k=k)
        # for metric purposes, treat our refusal as empty string
        pred_text = "" if ans.strip().lower().startswith("sorry") else ans

        preds.append({"id": qid, "prediction_text": pred_text})
        refs.append({"id": qid, "answers": gold})

    results = squad_metric.compute(predictions=preds, references=refs)
    print(f"Full QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
    return results

def qa_eval_answerable(dataset, num_samples=100, k=20):
    preds, refs = [], []
    for ex in dataset.select(range(num_samples)):
        if not ex["answers"]["text"]:
            continue                # skip unanswerable
        qid  = ex["id"]
        ans, _ = retrieve_and_answer(ex["question"], k=k)

        preds.append({"id": qid, "prediction_text": ans})
        refs.append({"id": qid, "answers": ex["answers"]})

    results = squad_metric.compute(predictions=preds, references=refs)
    print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
    return results


retrieval_recall(squad, k=2, num_samples=100)
retrieval_recall_answerable(squad, k=2, num_samples=100)
qa_eval_all(squad, num_samples=100, k=2)
qa_eval_answerable(squad, num_samples=100, k=2)