from typing import Callable, Optional, Sequence, Union import logging from collections import defaultdict from inspect import signature from ..llm.client import LLMClient, get_default_client from ..utils.analytics_collector import analytics from .knowledge_base import KnowledgeBase from .metrics import CorrectnessMetric, Metric from .question_generators.utils import maybe_tqdm from .recommendation import get_rag_recommendation from .report import RAGReport from .testset import QATestset from .testset_generation import generate_testset logger = logging.getLogger(__name__) ANSWER_FN_HISTORY_PARAM = "history" def evaluate( answer_fn: Union[Callable, Sequence[str]], testset: Optional[QATestset] = None, knowledge_base: Optional[KnowledgeBase] = None, llm_client: Optional[LLMClient] = None, agent_description: str = "This agent is a chatbot that answers question from users.", metrics: Optional[Sequence[Callable]] = None, ) -> RAGReport: """Evaluate an agent by comparing its answers on a QATestset. Parameters ---------- answers_fn : Union[Callable, Sequence[str]] The prediction function of the agent to evaluate or a list of precalculated answers on the testset. testset : QATestset, optional The test set to evaluate the agent on. If not provided, a knowledge base must be provided and a default testset will be created from the knowledge base. Note that if the answers_fn is a list of answers, the testset is required. knowledge_base : KnowledgeBase, optional The knowledge base of the agent to evaluate. If not provided, a testset must be provided. llm_client : LLMClient, optional The LLM client to use for the evaluation. If not provided, a default openai client will be used. agent_description : str, optional Description of the agent to be tested. metrics : Optional[Sequence[Callable]], optional Metrics to compute on the test set. Returns ------- RAGReport The report of the evaluation. """ validate_inputs(answer_fn, knowledge_base, testset) testset = testset or generate_testset(knowledge_base) answers = retrieve_answers(answer_fn, testset) llm_client = llm_client or get_default_client() metrics = get_metrics(metrics, llm_client, agent_description) metrics_results = compute_metrics(metrics, testset, answers) report = get_report(testset, answers, metrics_results, knowledge_base) add_recommendation(report, llm_client, metrics) track_analytics(report, testset, knowledge_base, agent_description, metrics) return report def validate_inputs(answer_fn, knowledge_base, testset): if testset is None: if knowledge_base is None: raise ValueError("At least one of testset or knowledge base must be provided to the evaluate function.") if not isinstance(answer_fn, Sequence): raise ValueError( "If the testset is not provided, the answer_fn must be a list of answers to ensure the matching between questions and answers." ) testset = generate_testset(knowledge_base) # Check basic types, in case the user passed the params in the wrong order if knowledge_base is not None and not isinstance(knowledge_base, KnowledgeBase): raise ValueError( f"knowledge_base must be a KnowledgeBase object (got {type(knowledge_base)} instead). Are you sure you passed the parameters in the right order?" ) if testset is not None and not isinstance(testset, QATestset): raise ValueError( f"testset must be a QATestset object (got {type(testset)} instead). Are you sure you passed the parameters in the right order?" ) def retrieve_answers(answer_fn, testset): return answer_fn if isinstance(answer_fn, Sequence) else _compute_answers(answer_fn, testset) def get_metrics(metrics, llm_client, agent_description): metrics = list(metrics) if metrics is not None else [] if not any(isinstance(metric, CorrectnessMetric) for metric in metrics): # By default only correctness is computed as it is required to build the report metrics.insert( 0, CorrectnessMetric(name="correctness", llm_client=llm_client, agent_description=agent_description) ) return metrics def compute_metrics(metrics, testset, answers): metrics_results = defaultdict(dict) for metric in metrics: metric_name = getattr( metric, "name", metric.__class__.__name__ if isinstance(metric, Metric) else metric.__name__ ) for sample, answer in maybe_tqdm( zip(testset.to_pandas().to_records(index=True), answers), desc=f"{metric_name} evaluation", total=len(answers), ): metrics_results[sample["id"]].update(metric(sample, answer)) return metrics_results def get_report(testset, answers, metrics_results, knowledge_base): return RAGReport(testset, answers, metrics_results, knowledge_base) def add_recommendation(report, llm_client, metrics): recommendation = get_rag_recommendation( report.topics, report.correctness_by_question_type().to_dict()[metrics[0].name], report.correctness_by_topic().to_dict()[metrics[0].name], llm_client, ) report._recommendation = recommendation def track_analytics(report, testset, knowledge_base, agent_description, metrics): analytics.track( "raget:evaluation", { "testset_size": len(testset), "knowledge_base_size": len(knowledge_base) if knowledge_base else -1, "agent_description": agent_description, "num_metrics": len(metrics), "correctness": report.correctness, }, ) def _compute_answers(answer_fn, testset): answers = [] needs_history = ( len(signature(answer_fn).parameters) > 1 and ANSWER_FN_HISTORY_PARAM in signature(answer_fn).parameters ) for sample in maybe_tqdm(testset.samples, desc="Asking questions to the agent", total=len(testset)): kwargs = {} if needs_history: kwargs[ANSWER_FN_HISTORY_PARAM] = sample.conversation_history answers.append(answer_fn(sample.question, **kwargs)) return answers