eval_framework / pipeline /qa_runner.py
LCZZZZ's picture
Upload eval_framework source code
85b19cf verified
"""Shared checkpoint QA: retrieval via adapter + answer from an injected callable.
``AnswerFn`` may return either a plain ``str`` (legacy) or a
``(str, list[str])`` tuple of ``(answer, cited_memories)``.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import Union
from eval_framework.datasets.domain_a_v2 import NormalizedCheckpoint, NormalizedCheckpointQuestion
from eval_framework.datasets.schemas import RetrievalRecord
from eval_framework.memory_adapters.base import MemoryAdapter
from eval_framework.pipeline.records import PipelineCheckpointQARecord
# answer_fn may return str (legacy) or (str, list[str])
AnswerResult = Union[str, tuple[str, list[str]]]
AnswerFn = Callable[[NormalizedCheckpointQuestion, RetrievalRecord], AnswerResult]
def run_checkpoint_qa_records(
adapter: MemoryAdapter,
*,
sample_id: str,
sample_uuid: str,
checkpoint: NormalizedCheckpoint,
top_k: int,
answer_fn: AnswerFn,
) -> tuple[PipelineCheckpointQARecord, ...]:
"""For each question, call ``retrieve`` then ``answer_fn`` (not ``adapter.answer``)."""
out: list[PipelineCheckpointQARecord] = []
for q in checkpoint.questions:
retrieval = adapter.retrieve(q.question, top_k)
result = answer_fn(q, retrieval)
if isinstance(result, tuple):
generated, cited = result
else:
generated, cited = result, []
out.append(
PipelineCheckpointQARecord(
sample_id=sample_id,
sample_uuid=sample_uuid,
checkpoint_id=checkpoint.checkpoint_id,
question=q.question,
gold_answer=q.gold_answer,
gold_evidence_memory_ids=q.gold_evidence_memory_ids,
gold_evidence_contents=q.gold_evidence_contents,
question_type=q.question_type,
question_type_abbrev=q.question_type_abbrev,
difficulty=q.difficulty,
retrieval=retrieval,
generated_answer=generated,
cited_memories=tuple(cited),
)
)
return tuple(out)