wu981526092's picture
Modulated code and cleaned up main function (#2)
38f10df verified
from typing import Callable, Optional, Sequence, Union
import logging
from collections import defaultdict
from inspect import signature
from ..llm.client import LLMClient, get_default_client
from ..utils.analytics_collector import analytics
from .knowledge_base import KnowledgeBase
from .metrics import CorrectnessMetric, Metric
from .question_generators.utils import maybe_tqdm
from .recommendation import get_rag_recommendation
from .report import RAGReport
from .testset import QATestset
from .testset_generation import generate_testset
logger = logging.getLogger(__name__)
ANSWER_FN_HISTORY_PARAM = "history"
def evaluate(
answer_fn: Union[Callable, Sequence[str]],
testset: Optional[QATestset] = None,
knowledge_base: Optional[KnowledgeBase] = None,
llm_client: Optional[LLMClient] = None,
agent_description: str = "This agent is a chatbot that answers question from users.",
metrics: Optional[Sequence[Callable]] = None,
) -> RAGReport:
"""Evaluate an agent by comparing its answers on a QATestset.
Parameters
----------
answers_fn : Union[Callable, Sequence[str]]
The prediction function of the agent to evaluate or a list of precalculated answers on the testset.
testset : QATestset, optional
The test set to evaluate the agent on. If not provided, a knowledge base must be provided and a default testset will be created from the knowledge base.
Note that if the answers_fn is a list of answers, the testset is required.
knowledge_base : KnowledgeBase, optional
The knowledge base of the agent to evaluate. If not provided, a testset must be provided.
llm_client : LLMClient, optional
The LLM client to use for the evaluation. If not provided, a default openai client will be used.
agent_description : str, optional
Description of the agent to be tested.
metrics : Optional[Sequence[Callable]], optional
Metrics to compute on the test set.
Returns
-------
RAGReport
The report of the evaluation.
"""
validate_inputs(answer_fn, knowledge_base, testset)
testset = testset or generate_testset(knowledge_base)
answers = retrieve_answers(answer_fn, testset)
llm_client = llm_client or get_default_client()
metrics = get_metrics(metrics, llm_client, agent_description)
metrics_results = compute_metrics(metrics, testset, answers)
report = get_report(testset, answers, metrics_results, knowledge_base)
add_recommendation(report, llm_client, metrics)
track_analytics(report, testset, knowledge_base, agent_description, metrics)
return report
def validate_inputs(answer_fn, knowledge_base, testset):
if testset is None:
if knowledge_base is None:
raise ValueError("At least one of testset or knowledge base must be provided to the evaluate function.")
if not isinstance(answer_fn, Sequence):
raise ValueError(
"If the testset is not provided, the answer_fn must be a list of answers to ensure the matching between questions and answers."
)
testset = generate_testset(knowledge_base)
# Check basic types, in case the user passed the params in the wrong order
if knowledge_base is not None and not isinstance(knowledge_base, KnowledgeBase):
raise ValueError(
f"knowledge_base must be a KnowledgeBase object (got {type(knowledge_base)} instead). Are you sure you passed the parameters in the right order?"
)
if testset is not None and not isinstance(testset, QATestset):
raise ValueError(
f"testset must be a QATestset object (got {type(testset)} instead). Are you sure you passed the parameters in the right order?"
)
def retrieve_answers(answer_fn, testset):
return answer_fn if isinstance(answer_fn, Sequence) else _compute_answers(answer_fn, testset)
def get_metrics(metrics, llm_client, agent_description):
metrics = list(metrics) if metrics is not None else []
if not any(isinstance(metric, CorrectnessMetric) for metric in metrics):
# By default only correctness is computed as it is required to build the report
metrics.insert(
0, CorrectnessMetric(name="correctness", llm_client=llm_client, agent_description=agent_description)
)
return metrics
def compute_metrics(metrics, testset, answers):
metrics_results = defaultdict(dict)
for metric in metrics:
metric_name = getattr(
metric, "name", metric.__class__.__name__ if isinstance(metric, Metric) else metric.__name__
)
for sample, answer in maybe_tqdm(
zip(testset.to_pandas().to_records(index=True), answers),
desc=f"{metric_name} evaluation",
total=len(answers),
):
metrics_results[sample["id"]].update(metric(sample, answer))
return metrics_results
def get_report(testset, answers, metrics_results, knowledge_base):
return RAGReport(testset, answers, metrics_results, knowledge_base)
def add_recommendation(report, llm_client, metrics):
recommendation = get_rag_recommendation(
report.topics,
report.correctness_by_question_type().to_dict()[metrics[0].name],
report.correctness_by_topic().to_dict()[metrics[0].name],
llm_client,
)
report._recommendation = recommendation
def track_analytics(report, testset, knowledge_base, agent_description, metrics):
analytics.track(
"raget:evaluation",
{
"testset_size": len(testset),
"knowledge_base_size": len(knowledge_base) if knowledge_base else -1,
"agent_description": agent_description,
"num_metrics": len(metrics),
"correctness": report.correctness,
},
)
def _compute_answers(answer_fn, testset):
answers = []
needs_history = (
len(signature(answer_fn).parameters) > 1 and ANSWER_FN_HISTORY_PARAM in signature(answer_fn).parameters
)
for sample in maybe_tqdm(testset.samples, desc="Asking questions to the agent", total=len(testset)):
kwargs = {}
if needs_history:
kwargs[ANSWER_FN_HISTORY_PARAM] = sample.conversation_history
answers.append(answer_fn(sample.question, **kwargs))
return answers