ramailkk commited on
Commit
6cb3d7c
·
1 Parent(s): a865c33

proper evaluator changes

Browse files
Files changed (4) hide show
  1. config.yaml +1 -1
  2. main.py +9 -3
  3. retriever/evaluator.py +220 -40
  4. retriever/processor.py +2 -1
config.yaml CHANGED
@@ -35,7 +35,7 @@ generation:
35
  temperature: 0.1
36
  max_new_tokens: 512
37
  # The model used to Judge the others
38
- judge_model: "Llama-3-8B"
39
 
40
  # List of contestants in the tournament
41
  models:
 
35
  temperature: 0.1
36
  max_new_tokens: 512
37
  # The model used to Judge the others
38
+ judge_model: "llama-3.1-8b-instant"
39
 
40
  # List of contestants in the tournament
41
  models:
main.py CHANGED
@@ -71,8 +71,13 @@ def main():
71
  models = {name: MODEL_MAP[name](token=hf_token) for name in cfg.model_list}
72
 
73
  # Setup Evaluator with the designated Judge
74
- judge_llm = models[cfg.gen['judge_model']]
75
- evaluator = RAGEvaluator(judge_llm, proc.encoder)
 
 
 
 
 
76
  tournament_results = {}
77
 
78
  # 6. Tournament Loop
@@ -85,8 +90,9 @@ def main():
85
  temperature=cfg.gen['temperature']
86
  )
87
 
88
- # Batch Evaluation
89
  faith = evaluator.evaluate_faithfulness(answer, context_chunks)
 
90
  rel = evaluator.evaluate_relevancy(query, answer)
91
 
92
  tournament_results[name] = {
 
71
  models = {name: MODEL_MAP[name](token=hf_token) for name in cfg.model_list}
72
 
73
  # Setup Evaluator with the designated Judge
74
+
75
+ evaluator = RAGEvaluator(
76
+ judge_model=cfg.gen['judge_model'],
77
+ embedding_model=proc.encoder,
78
+ api_key=os.getenv("GROQ_API_KEY")
79
+ )
80
+
81
  tournament_results = {}
82
 
83
  # 6. Tournament Loop
 
90
  temperature=cfg.gen['temperature']
91
  )
92
 
93
+ # Faithfulness Evaluation
94
  faith = evaluator.evaluate_faithfulness(answer, context_chunks)
95
+ # Relevancy Evaluation
96
  rel = evaluator.evaluate_relevancy(query, answer)
97
 
98
  tournament_results[name] = {
retriever/evaluator.py CHANGED
@@ -1,105 +1,285 @@
 
1
  import numpy as np
2
  from sklearn.metrics.pairwise import cosine_similarity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  class RAGEvaluator:
5
- def __init__(self, judge_model, embedding_model, verbose=True):
6
  """
7
- judge_model: An instance of an LLM class.
8
- embedding_model: The proc.encoder for similarity checks.
9
- verbose: If True, uses internal printer functions to show progress.
 
 
10
  """
11
- self.judge = judge_model
12
  self.encoder = embedding_model
13
  self.verbose = verbose
14
 
15
  # ------------------------------------------------------------------
16
  # 1. FAITHFULNESS: Claim Extraction & Verification
17
  # ------------------------------------------------------------------
18
- def evaluate_faithfulness(self, answer, context_list):
 
 
 
 
 
 
 
19
  if self.verbose:
20
- self._print_extraction_header(len(answer))
21
 
22
  # --- Step A: Extraction ---
23
- extraction_prompt = f"Extract a list of independent factual claims from the following answer. Respond ONLY with the claims, one per line. Do not include any introductory text.\nAnswer: {answer}"
 
 
 
 
 
 
 
 
 
 
24
  raw_claims = self.judge.generate(extraction_prompt)
25
- claims = [c.strip() for c in raw_claims.split('\n') if len(c.strip()) > 5]
26
 
27
- if not claims:
 
 
 
 
 
 
28
  return {"score": 0, "details": []}
29
 
30
- # --- Step B: Batch Verification ---
31
- combined_context = "\n".join(context_list)
32
- claims_formatted = "\n".join([f"{i+1}. {c}" for i, c in enumerate(claims)])
33
-
34
- batch_prompt = f"Context: {combined_context}\nClaims: {claims_formatted}\nRespond YES/NO for each."
35
- raw_verdicts = self.judge.generate(batch_prompt)
36
- verdict_lines = [v.strip().upper() for v in raw_verdicts.split('\n') if v.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # --- Step C: Scoring & Details ---
39
  verified_count = 0
40
  details = []
41
  for i, claim in enumerate(claims):
42
- is_supported = "YES" in verdict_lines[i] if i < len(verdict_lines) else False
43
- if is_supported: verified_count += 1
44
-
45
  details.append({
46
- "claim": claim,
47
  "verdict": "Supported" if is_supported else "Not Supported"
48
  })
49
 
50
  score = (verified_count / len(claims)) * 100
51
-
52
  if self.verbose:
53
  self._print_faithfulness_results(claims, details, score)
54
-
55
  return {"score": score, "details": details}
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # ------------------------------------------------------------------
58
  # 2. RELEVANCY: Alternate Query Generation
59
  # ------------------------------------------------------------------
60
- def evaluate_relevancy(self, query, answer):
 
61
  if self.verbose:
62
  self._print_relevancy_header()
63
 
64
  # --- Step A: Generation ---
65
- gen_prompt = f"Generate 3 distinct questions this answer addresses.\nAnswer: {answer}"
 
 
 
 
 
 
 
 
66
  raw_gen = self.judge.generate(gen_prompt)
67
- gen_queries = [q.strip() for q in raw_gen.split('\n') if '?' in q][:3]
68
 
69
- if not gen_queries:
 
 
 
 
 
 
70
  return {"score": 0, "queries": []}
71
 
72
- # --- Step B: Similarity Logic ---
73
- original_vec = self.encoder.encode([query])
74
- generated_vecs = self.encoder.encode(gen_queries)
 
 
75
  similarities = cosine_similarity(original_vec, generated_vecs)[0]
76
- avg_score = np.mean(similarities)
77
-
78
- if self.verbose:
79
  self._print_relevancy_results(query, gen_queries, similarities, avg_score)
80
-
81
  return {"score": avg_score, "queries": gen_queries}
82
 
83
  # ------------------------------------------------------------------
84
- # 3. PRINT HELPERS (Keep the logic above clean)
85
  # ------------------------------------------------------------------
86
- def _print_extraction_header(self, length):
87
- print(f"\n[EVAL] Analyzing Faithfulness...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  print(f" - Extracting claims from answer ({length} chars)")
89
 
90
  def _print_faithfulness_results(self, claims, details, score):
91
  print(f" - Verifying {len(claims)} claims against context...")
92
  for i, detail in enumerate(details):
93
- status = "✅" if "Supported" in detail['verdict'] else "❌"
94
  print(f" {status} Claim {i+1}: {detail['claim'][:75]}...")
95
  print(f" 🎯 Faithfulness Score: {score:.1f}%")
96
 
97
  def _print_relevancy_header(self):
98
  print(f"\n[EVAL] Analyzing Relevancy...")
99
- print(f" - Generating 3 sample questions addressed by the answer")
100
 
101
  def _print_relevancy_results(self, query, gen_queries, similarities, avg):
102
  print(f" - Comparing to original query: '{query}'")
103
  for i, (q, sim) in enumerate(zip(gen_queries, similarities)):
104
  print(f" Q{i+1}: {q} (Sim: {sim:.2f})")
105
- print(f" 🎯 Average Relevancy: {avg:.2f}")
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
  import numpy as np
3
  from sklearn.metrics.pairwise import cosine_similarity
4
+ from groq import Groq
5
+
6
+
7
+ # ------------------------------------------------------------------
8
+ # Groq Judge Wrapper
9
+ # ------------------------------------------------------------------
10
+
11
+ class GroqJudge:
12
+ def __init__(self, api_key: str, model: str = "llama-3.1-8b-instant"):
13
+ """
14
+ Wraps Groq's chat completions to match the .generate(prompt) interface
15
+ expected by RAGEvaluator.
16
+
17
+ Args:
18
+ api_key: Your Groq API key (https://console.groq.com)
19
+ model: Groq model to use. Free tier options:
20
+ - "llama-3.1-8b-instant" (fastest)
21
+ - "llama-3.3-70b-versatile" (more capable, slower)
22
+ - "gemma2-9b-it"
23
+ """
24
+ self.client = Groq(api_key=api_key)
25
+ self.model = model
26
+
27
+ def generate(self, prompt: str) -> str:
28
+ response = self.client.chat.completions.create(
29
+ model=self.model,
30
+ messages=[{"role": "user", "content": prompt}],
31
+ temperature=0.0, # deterministic for evaluation
32
+ max_tokens=1024,
33
+ )
34
+ return response.choices[0].message.content.strip()
35
+
36
+
37
+ # ------------------------------------------------------------------
38
+ # RAG Evaluator
39
+ # ------------------------------------------------------------------
40
 
41
  class RAGEvaluator:
42
+ def __init__(self, judge_model: str, embedding_model, api_key: str, verbose=True):
43
  """
44
+ judge_model: Model name string passed to GroqJudge, must match cfg.gen['judge_model']
45
+ e.g. "llama-3.1-8b-instant", "llama-3.3-70b-versatile", "gemma2-9b-it"
46
+ embedding_model: The proc.encoder (SentenceTransformer) for similarity checks
47
+ api_key: Groq API key (https://console.groq.com)
48
+ verbose: If True, prints progress via internal helpers
49
  """
50
+ self.judge = GroqJudge(api_key=api_key, model=judge_model)
51
  self.encoder = embedding_model
52
  self.verbose = verbose
53
 
54
  # ------------------------------------------------------------------
55
  # 1. FAITHFULNESS: Claim Extraction & Verification
56
  # ------------------------------------------------------------------
57
+
58
+ def evaluate_faithfulness(self, answer: str, context_list: list[str], strict: bool = True) -> dict:
59
+ """
60
+ Args:
61
+ strict: If True, verifies each claim against chunks individually
62
+ (more API calls but catches vague batch verdicts).
63
+ If False, uses single batched verification call.
64
+ """
65
  if self.verbose:
66
+ self._print_extraction_header(len(answer), strict=strict)
67
 
68
  # --- Step A: Extraction ---
69
+ extraction_prompt = (
70
+ "Extract a list of independent factual claims from the following answer.\n"
71
+ "Rules:\n"
72
+ "- Each claim must be specific and verifiable — include numbers, names, or concrete details where present\n"
73
+ "- Vague claims like 'the model performs well' or 'this improves results' are NOT acceptable\n"
74
+ "- Do NOT include claims about what the context does or does not contain\n"
75
+ "- Do NOT include introductory text, numbering, or bullet points\n"
76
+ "- Do NOT rephrase or merge claims\n"
77
+ "- One claim per line only\n\n"
78
+ f"Answer: {answer}"
79
+ )
80
  raw_claims = self.judge.generate(extraction_prompt)
 
81
 
82
+ # Filter out short lines, preamble, and lines ending with ':'
83
+ claims = [
84
+ c.strip() for c in raw_claims.split('\n')
85
+ if len(c.strip()) > 20 and not c.strip().endswith(':')
86
+ ]
87
+
88
+ if not claims:
89
  return {"score": 0, "details": []}
90
 
91
+ # --- Step B: Verification ---
92
+ if strict:
93
+ # Per-chunk: claim must be explicitly supported by at least one chunk
94
+ verdicts = {i: self._verify_claim_against_chunks(claim, context_list)
95
+ for i, claim in enumerate(claims)}
96
+ else:
97
+ # Batch: all chunks joined, strict burden-of-proof prompt
98
+ combined_context = "\n".join(context_list)
99
+ if len(combined_context) > 6000:
100
+ combined_context = combined_context[:6000]
101
+
102
+ claims_formatted = "\n".join([f"{i+1}. {c}" for i, c in enumerate(claims)])
103
+
104
+ batch_prompt = (
105
+ f"Context:\n{combined_context}\n\n"
106
+ f"For each claim, respond YES only if the claim is EXPLICITLY and DIRECTLY "
107
+ f"supported by the context above. Respond NO if the claim is inferred, assumed, "
108
+ f"or not clearly stated in the context.\n\n"
109
+ f"Format strictly as:\n"
110
+ f"1: YES\n"
111
+ f"2: NO\n\n"
112
+ f"Claims:\n{claims_formatted}"
113
+ )
114
+ raw_verdicts = self.judge.generate(batch_prompt)
115
+
116
+ verdicts = {}
117
+ for line in raw_verdicts.split('\n'):
118
+ match = re.match(r'(\d+)\s*:\s*(YES|NO)', line.strip().upper())
119
+ if match:
120
+ verdicts[int(match.group(1)) - 1] = match.group(2) == "YES"
121
 
122
  # --- Step C: Scoring & Details ---
123
  verified_count = 0
124
  details = []
125
  for i, claim in enumerate(claims):
126
+ is_supported = verdicts.get(i, False)
127
+ if is_supported:
128
+ verified_count += 1
129
  details.append({
130
+ "claim": claim,
131
  "verdict": "Supported" if is_supported else "Not Supported"
132
  })
133
 
134
  score = (verified_count / len(claims)) * 100
135
+
136
  if self.verbose:
137
  self._print_faithfulness_results(claims, details, score)
138
+
139
  return {"score": score, "details": details}
140
 
141
+ def _verify_claim_against_chunks(self, claim: str, context_list: list[str]) -> bool:
142
+ """Verify a single claim against each chunk individually. Returns True if any chunk supports it."""
143
+ for chunk in context_list:
144
+ prompt = (
145
+ f"Context:\n{chunk}\n\n"
146
+ f"Claim: {claim}\n\n"
147
+ f"Is this claim EXPLICITLY and DIRECTLY stated in the context above? "
148
+ f"Do not infer or assume. Respond with YES or NO only."
149
+ )
150
+ result = self.judge.generate(prompt)
151
+ if "YES" in result.upper():
152
+ return True
153
+ return False
154
+
155
  # ------------------------------------------------------------------
156
  # 2. RELEVANCY: Alternate Query Generation
157
  # ------------------------------------------------------------------
158
+
159
+ def evaluate_relevancy(self, query: str, answer: str) -> dict:
160
  if self.verbose:
161
  self._print_relevancy_header()
162
 
163
  # --- Step A: Generation ---
164
+ # Explicitly ask the judge NOT to rephrase the original query
165
+ gen_prompt = (
166
+ f"Generate 3 distinct questions that the following answer addresses.\n"
167
+ f"Rules:\n"
168
+ f"- Do NOT rephrase or repeat this question: '{query}'\n"
169
+ f"- Each question must end with a '?'\n"
170
+ f"- One question per line, no numbering or bullet points\n\n"
171
+ f"Answer: {answer}"
172
+ )
173
  raw_gen = self.judge.generate(gen_prompt)
 
174
 
175
+ # Filter by length rather than just '?' presence
176
+ gen_queries = [
177
+ q.strip() for q in raw_gen.split('\n')
178
+ if len(q.strip()) > 10
179
+ ][:3]
180
+
181
+ if not gen_queries:
182
  return {"score": 0, "queries": []}
183
 
184
+ # --- Step B: Similarity (single batched encode call) ---
185
+ all_vecs = self.encoder.encode([query] + gen_queries)
186
+ original_vec = all_vecs[0:1]
187
+ generated_vecs = all_vecs[1:]
188
+
189
  similarities = cosine_similarity(original_vec, generated_vecs)[0]
190
+ avg_score = float(np.mean(similarities))
191
+
192
+ if self.verbose:
193
  self._print_relevancy_results(query, gen_queries, similarities, avg_score)
194
+
195
  return {"score": avg_score, "queries": gen_queries}
196
 
197
  # ------------------------------------------------------------------
198
+ # 3. DATASET-LEVEL EVALUATION
199
  # ------------------------------------------------------------------
200
+
201
+ def evaluate_dataset(self, test_cases: list[dict], strict: bool = False) -> dict:
202
+ """
203
+ Runs faithfulness + relevancy over a full test set and aggregates results.
204
+
205
+ Args:
206
+ test_cases: List of dicts, each with keys:
207
+ - "query": str
208
+ - "answer": str
209
+ - "contexts": List[str]
210
+ strict: If True, passes strict=True to evaluate_faithfulness
211
+ (per-chunk verification, more API calls, harder to pass)
212
+
213
+ Returns:
214
+ {
215
+ "avg_faithfulness": float,
216
+ "avg_relevancy": float,
217
+ "per_query": List[dict]
218
+ }
219
+ """
220
+ faithfulness_scores = []
221
+ relevancy_scores = []
222
+ per_query = []
223
+
224
+ for i, case in enumerate(test_cases):
225
+ if self.verbose:
226
+ print(f"\n{'='*60}")
227
+ print(f"Query {i+1}/{len(test_cases)}: {case['query']}")
228
+ print('='*60)
229
+
230
+ f_result = self.evaluate_faithfulness(case['answer'], case['contexts'], strict=strict)
231
+ r_result = self.evaluate_relevancy(case['query'], case['answer'])
232
+
233
+ faithfulness_scores.append(f_result['score'])
234
+ relevancy_scores.append(r_result['score'])
235
+ per_query.append({
236
+ "query": case['query'],
237
+ "faithfulness": f_result,
238
+ "relevancy": r_result,
239
+ })
240
+
241
+ results = {
242
+ "avg_faithfulness": float(np.mean(faithfulness_scores)),
243
+ "avg_relevancy": float(np.mean(relevancy_scores)),
244
+ "per_query": per_query,
245
+ }
246
+
247
+ if self.verbose:
248
+ self._print_dataset_summary(results)
249
+
250
+ return results
251
+
252
+ # ------------------------------------------------------------------
253
+ # 4. PRINT HELPERS
254
+ # ------------------------------------------------------------------
255
+
256
+ def _print_extraction_header(self, length, strict=False):
257
+ mode = "strict per-chunk" if strict else "batch"
258
+ print(f"\n[EVAL] Analyzing Faithfulness ({mode})...")
259
  print(f" - Extracting claims from answer ({length} chars)")
260
 
261
  def _print_faithfulness_results(self, claims, details, score):
262
  print(f" - Verifying {len(claims)} claims against context...")
263
  for i, detail in enumerate(details):
264
+ status = "✅" if "Yes" in detail['verdict'] else "❌"
265
  print(f" {status} Claim {i+1}: {detail['claim'][:75]}...")
266
  print(f" 🎯 Faithfulness Score: {score:.1f}%")
267
 
268
  def _print_relevancy_header(self):
269
  print(f"\n[EVAL] Analyzing Relevancy...")
270
+ print(f" - Generating 3 distinct questions addressed by the answer")
271
 
272
  def _print_relevancy_results(self, query, gen_queries, similarities, avg):
273
  print(f" - Comparing to original query: '{query}'")
274
  for i, (q, sim) in enumerate(zip(gen_queries, similarities)):
275
  print(f" Q{i+1}: {q} (Sim: {sim:.2f})")
276
+ print(f" 🎯 Average Relevancy: {avg:.2f}")
277
+
278
+ def _print_dataset_summary(self, results):
279
+ print(f"\n{'='*60}")
280
+ print(f" DATASET EVALUATION SUMMARY")
281
+ print(f"{'='*60}")
282
+ print(f" Avg Faithfulness : {results['avg_faithfulness']:.1f}%")
283
+ print(f" Avg Relevancy : {results['avg_relevancy']:.2f}")
284
+ print(f" Queries Evaluated: {len(results['per_query'])}")
285
+ print(f"{'='*60}")
retriever/processor.py CHANGED
@@ -74,7 +74,8 @@ class ChunkProcessor:
74
  return SemanticChunker(
75
  self.hf_embeddings,
76
  breakpoint_threshold_type=kwargs.get('breakpoint_threshold_type', "percentile"),
77
- breakpoint_threshold_amount=kwargs.get('breakpoint_threshold_amount', 95)
 
78
  )
79
 
80
  else:
 
74
  return SemanticChunker(
75
  self.hf_embeddings,
76
  breakpoint_threshold_type=kwargs.get('breakpoint_threshold_type', "percentile"),
77
+ # Using 70 because 95 was giving way too big chunks
78
+ breakpoint_threshold_amount=kwargs.get('breakpoint_threshold_amount', 70)
79
  )
80
 
81
  else: