VictorTomas09 commited on
Commit
8b017be
·
verified ·
1 Parent(s): 86211f8

Delete Evaluators

Browse files
Files changed (1) hide show
  1. Evaluators +0 -110
Evaluators DELETED
@@ -1,110 +0,0 @@
1
- import os
2
- import pickle
3
- import numpy as np
4
- import faiss
5
- import torch
6
- from datasets import load_dataset
7
- import evaluate
8
-
9
- # Import RAG setup and retrieval logic from app.py
10
- from app import setup_rag, retrieve, retrieve_and_answer
11
-
12
-
13
- def retrieval_recall(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
14
- """
15
- Compute raw Retrieval Recall@k on the first num_samples examples.
16
- If rerank_k is set, apply cross-encoder reranking via `retrieve`.
17
- Otherwise, use the FAISS index only (top-k) without reranking.
18
- """
19
- hits = 0
20
- for ex in dataset.select(range(num_samples)):
21
- question = ex["question"]
22
- gold_answers = ex["answers"]["text"]
23
-
24
- if rerank_k:
25
- # use two-stage retrieval (dense + rerank)
26
- ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
27
- else:
28
- # single-stage: FAISS only
29
- q_emb = embedder.encode([question], convert_to_numpy=True)
30
- distances, idxs = index.search(q_emb, k)
31
- ctxs = [passages[i] for i in idxs[0]]
32
-
33
- # check if any gold answer appears in any retrieved context
34
- if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
35
- hits += 1
36
-
37
- recall = hits / num_samples
38
- print(f"Retrieval Recall@{k} (rerank_k={rerank_k}): {recall:.3f} ({hits}/{num_samples})")
39
- return recall
40
-
41
-
42
- def retrieval_recall_answerable(dataset, passages, embedder, reranker, index, k=20, rerank_k=None, num_samples=100):
43
- """
44
- Retrieval Recall@k evaluated only on answerable questions (answers list non-empty).
45
- """
46
- hits = 0
47
- total = 0
48
- for ex in dataset.select(range(num_samples)):
49
- gold = ex["answers"]["text"]
50
- if not gold:
51
- continue
52
- total += 1
53
- question = ex["question"]
54
-
55
- if rerank_k:
56
- ctxs, _ = retrieve(question, passages, embedder, reranker, index, k=k, rerank_k=rerank_k)
57
- else:
58
- q_emb = embedder.encode([question], convert_to_numpy=True)
59
- distances, idxs = index.search(q_emb, k)
60
- ctxs = [passages[i] for i in idxs[0]]
61
-
62
- if any(any(ans in ctx for ctx in ctxs) for ans in gold):
63
- hits += 1
64
-
65
- recall = hits / total if total > 0 else 0.0
66
- print(f"Retrieval Recall@{k} on answerable only (rerank_k={rerank_k}): {recall:.3f} ({hits}/{total})")
67
- return recall
68
-
69
-
70
- def qa_eval_answerable(dataset, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100):
71
- """
72
- End-to-end QA EM/F1 on answerable subset using retrieve_and_answer.
73
- """
74
- squad_metric = evaluate.load("squad")
75
- preds = []
76
- refs = []
77
-
78
- for ex in dataset.select(range(num_samples)):
79
- gold = ex["answers"]["text"]
80
- if not gold:
81
- continue
82
- qid = ex["id"]
83
- # retrieve and generate
84
- answer, _ = retrieve_and_answer(
85
- ex["question"], passages, embedder, reranker, index, qa_pipe
86
- )
87
- preds.append({"id": qid, "prediction_text": answer})
88
- refs.append({"id": qid, "answers": ex["answers"]})
89
-
90
- results = squad_metric.compute(predictions=preds, references=refs)
91
- print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
92
- return results
93
-
94
-
95
- def main():
96
- # 1) Setup RAG components
97
- passages, embedder, reranker, index, qa_pipe = setup_rag()
98
-
99
- # 2) Load SQuAD v2 validation split
100
- squad = load_dataset("rajpurkar/squad_v2", split="validation")
101
-
102
- # 3) Run evaluations
103
- retrieval_recall(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
104
- retrieval_recall_answerable(squad, passages, embedder, reranker, index, k=20, rerank_k=5, num_samples=100)
105
- qa_eval_answerable(squad, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100)
106
-
107
-
108
- if __name__ == "__main__":
109
- main()
110
-