eval_framework / judges /__init__.py
LCZZZZ's picture
Upload eval_framework source code
85b19cf verified
"""Judge stack: batch LLM evaluation.
Session: 2 calls (recall + correctness) + per-item calls for update/interference.
QA: 2 calls (answer + evidence).
"""
from __future__ import annotations
from eval_framework.judges.llm_client import llm_request_for_json
from eval_framework.judges.prompts import (
CORRECTNESS_BATCH_PROMPT,
EVIDENCE_BATCH_PROMPT,
INTERFERENCE_EVAL_PROMPT,
QA_EVALUATION_PROMPT,
RECALL_BATCH_PROMPT,
UPDATE_EVAL_PROMPT,
)
__all__ = [
"evaluate_recall_batch",
"evaluate_correctness_batch",
"evaluate_update_single",
"evaluate_interference_single",
"evaluate_evidence_batch",
"evaluate_qa_llm",
"llm_request_for_json",
]
def evaluate_recall_batch(
extracted_memories_str: str,
gold_points_tagged: list[str],
) -> dict[str, object]:
"""One LLM call: how many gold points are covered? Distinguishes update sub-score.
gold_points_tagged: list of "[normal] content" or "[update] content" strings.
Returns {covered_count, update_covered_count, total, update_total, reasoning}.
"""
if not extracted_memories_str.strip():
update_total = sum(1 for p in gold_points_tagged if p.startswith("[update]"))
return {
"covered_count": 0, "update_covered_count": 0,
"total": len(gold_points_tagged), "update_total": update_total,
"reasoning": "No extracted memories.",
}
if not gold_points_tagged:
return {
"covered_count": 0, "update_covered_count": 0,
"total": 0, "update_total": 0, "reasoning": "No gold points.",
}
numbered = "\n".join(f"[{i+1}] {p}" for i, p in enumerate(gold_points_tagged))
update_total = sum(1 for p in gold_points_tagged if p.startswith("[update]"))
prompt = RECALL_BATCH_PROMPT.format(memories=extracted_memories_str, gold_points=numbered)
try:
result = llm_request_for_json(prompt)
covered = int(result.get("covered_count", 0))
upd_covered = int(result.get("update_covered_count", 0))
return {
"covered_count": min(covered, len(gold_points_tagged)),
"update_covered_count": min(upd_covered, update_total),
"total": len(gold_points_tagged),
"update_total": update_total,
"reasoning": result.get("reasoning", ""),
}
except Exception as e:
return {
"covered_count": None, "update_covered_count": None,
"total": len(gold_points_tagged), "update_total": update_total,
"reasoning": f"LLM error: {e}",
}
def evaluate_correctness_batch(
snapshot_memories: list[str],
gold_points_tagged: list[str],
interference_total: int,
) -> dict[str, object]:
"""One LLM call: is each snapshot memory correct? Includes interference detection.
gold_points_tagged: list of "[normal] content", "[update] content", "[interference] content".
Returns {results: [{id, label}], interference_memorized_count, interference_total, reasoning}.
"""
if not snapshot_memories:
return {
"results": [],
"interference_memorized_count": 0,
"interference_total": interference_total,
"reasoning": "No snapshot memories.",
}
numbered_memories = "\n".join(f"[{i+1}] {m}" for i, m in enumerate(snapshot_memories))
numbered_golds = "\n".join(f"- {p}" for p in gold_points_tagged) if gold_points_tagged else "(no ground-truth)"
prompt = CORRECTNESS_BATCH_PROMPT.format(memories=numbered_memories, gold_points=numbered_golds)
try:
result = llm_request_for_json(prompt)
raw_results = result.get("results", [])
valid_labels = {"correct", "hallucination", "irrelevant"}
cleaned = []
for r in raw_results:
label = str(r.get("label", "irrelevant")).lower().strip()
if label not in valid_labels:
label = "irrelevant"
cleaned.append({"id": r.get("id"), "label": label})
interf_mem = int(result.get("interference_memorized_count", 0))
return {
"results": cleaned,
"interference_memorized_count": min(interf_mem, interference_total),
"interference_total": interference_total,
"reasoning": result.get("reasoning", ""),
}
except Exception as e:
return {
"results": [],
"interference_memorized_count": None,
"interference_total": interference_total,
"reasoning": f"LLM error: {e}",
}
def evaluate_update_single(
delta_memories_str: str,
new_content: str,
old_contents: list[str],
) -> dict[str, object]:
"""One LLM call: how did the system handle a single memory update?
Returns {label: "updated"|"both"|"outdated", reasoning}.
"""
old_str = "\n".join(f"- {o}" for o in old_contents) if old_contents else "(none)"
prompt = UPDATE_EVAL_PROMPT.format(
memories=delta_memories_str,
new_content=new_content,
old_contents=old_str,
)
try:
result = llm_request_for_json(prompt)
label = str(result.get("label", "outdated")).lower().strip()
if label not in ("updated", "both", "outdated"):
label = "outdated"
return {"label": label, "reasoning": result.get("reasoning", "")}
except Exception as e:
return {"label": None, "reasoning": f"LLM error: {e}"}
def evaluate_interference_single(
delta_memories_str: str,
interference_content: str,
) -> dict[str, object]:
"""One LLM call: did the system incorrectly memorize an interference point?
Returns {label: "rejected"|"memorized", reasoning}.
"""
prompt = INTERFERENCE_EVAL_PROMPT.format(
memories=delta_memories_str,
interference_content=interference_content,
)
try:
result = llm_request_for_json(prompt)
label = str(result.get("label", "memorized")).lower().strip()
if label not in ("rejected", "memorized"):
label = "memorized"
return {"label": label, "reasoning": result.get("reasoning", "")}
except Exception as e:
return {"label": None, "reasoning": f"LLM error: {e}"}
def evaluate_evidence_batch(
retrieved_memories_str: str,
evidence_points: list[str],
) -> dict[str, object]:
"""One LLM call: how many gold evidence points are covered by retrieval?"""
if not retrieved_memories_str.strip():
return {"covered_count": 0, "total": len(evidence_points), "reasoning": "No retrieved memories."}
if not evidence_points:
return {"covered_count": 0, "total": 0, "reasoning": "No evidence points."}
numbered = "\n".join(f"[{i+1}] {p}" for i, p in enumerate(evidence_points))
prompt = EVIDENCE_BATCH_PROMPT.format(retrieved_memories=retrieved_memories_str, gold_evidence_points=numbered)
try:
result = llm_request_for_json(prompt)
covered = int(result.get("covered_count", 0))
return {
"covered_count": min(covered, len(evidence_points)),
"total": len(evidence_points),
"reasoning": result.get("reasoning", ""),
}
except Exception as e:
return {"covered_count": None, "total": len(evidence_points), "reasoning": f"LLM error: {e}"}
def evaluate_qa_llm(
question: str,
reference_answer: str,
key_memory_points: str,
system_response: str,
) -> dict[str, object]:
"""LLM judge: classify the QA response as Correct/Hallucination/Omission."""
if not system_response.strip():
return {"evaluation_result": "Omission", "reasoning": "Empty system response."}
prompt = QA_EVALUATION_PROMPT.format(
question=question, reference_answer=reference_answer,
key_memory_points=key_memory_points, response=system_response,
)
try:
result = llm_request_for_json(prompt)
label = result.get("evaluation_result", "Omission")
if label not in ("Correct", "Hallucination", "Omission"):
label = "Omission"
return {"evaluation_result": label, "reasoning": result.get("reasoning", "")}
except Exception as e:
return {"evaluation_result": None, "reasoning": f"LLM judge error: {e}"}