Spaces:
Runtime error
Runtime error
from typing import Callable, Optional, Sequence, Union | |
import logging | |
from collections import defaultdict | |
from inspect import signature | |
from ..llm.client import LLMClient, get_default_client | |
from ..utils.analytics_collector import analytics | |
from .knowledge_base import KnowledgeBase | |
from .metrics import CorrectnessMetric, Metric | |
from .question_generators.utils import maybe_tqdm | |
from .recommendation import get_rag_recommendation | |
from .report import RAGReport | |
from .testset import QATestset | |
from .testset_generation import generate_testset | |
logger = logging.getLogger(__name__) | |
ANSWER_FN_HISTORY_PARAM = "history" | |
def evaluate( | |
answer_fn: Union[Callable, Sequence[str]], | |
testset: Optional[QATestset] = None, | |
knowledge_base: Optional[KnowledgeBase] = None, | |
llm_client: Optional[LLMClient] = None, | |
agent_description: str = "This agent is a chatbot that answers question from users.", | |
metrics: Optional[Sequence[Callable]] = None, | |
) -> RAGReport: | |
"""Evaluate an agent by comparing its answers on a QATestset. | |
Parameters | |
---------- | |
answers_fn : Union[Callable, Sequence[str]] | |
The prediction function of the agent to evaluate or a list of precalculated answers on the testset. | |
testset : QATestset, optional | |
The test set to evaluate the agent on. If not provided, a knowledge base must be provided and a default testset will be created from the knowledge base. | |
Note that if the answers_fn is a list of answers, the testset is required. | |
knowledge_base : KnowledgeBase, optional | |
The knowledge base of the agent to evaluate. If not provided, a testset must be provided. | |
llm_client : LLMClient, optional | |
The LLM client to use for the evaluation. If not provided, a default openai client will be used. | |
agent_description : str, optional | |
Description of the agent to be tested. | |
metrics : Optional[Sequence[Callable]], optional | |
Metrics to compute on the test set. | |
Returns | |
------- | |
RAGReport | |
The report of the evaluation. | |
""" | |
validate_inputs(answer_fn, knowledge_base, testset) | |
testset = testset or generate_testset(knowledge_base) | |
answers = retrieve_answers(answer_fn, testset) | |
llm_client = llm_client or get_default_client() | |
metrics = get_metrics(metrics, llm_client, agent_description) | |
metrics_results = compute_metrics(metrics, testset, answers) | |
report = get_report(testset, answers, metrics_results, knowledge_base) | |
add_recommendation(report, llm_client, metrics) | |
track_analytics(report, testset, knowledge_base, agent_description, metrics) | |
return report | |
def validate_inputs(answer_fn, knowledge_base, testset): | |
if testset is None: | |
if knowledge_base is None: | |
raise ValueError("At least one of testset or knowledge base must be provided to the evaluate function.") | |
if not isinstance(answer_fn, Sequence): | |
raise ValueError( | |
"If the testset is not provided, the answer_fn must be a list of answers to ensure the matching between questions and answers." | |
) | |
testset = generate_testset(knowledge_base) | |
# Check basic types, in case the user passed the params in the wrong order | |
if knowledge_base is not None and not isinstance(knowledge_base, KnowledgeBase): | |
raise ValueError( | |
f"knowledge_base must be a KnowledgeBase object (got {type(knowledge_base)} instead). Are you sure you passed the parameters in the right order?" | |
) | |
if testset is not None and not isinstance(testset, QATestset): | |
raise ValueError( | |
f"testset must be a QATestset object (got {type(testset)} instead). Are you sure you passed the parameters in the right order?" | |
) | |
def retrieve_answers(answer_fn, testset): | |
return answer_fn if isinstance(answer_fn, Sequence) else _compute_answers(answer_fn, testset) | |
def get_metrics(metrics, llm_client, agent_description): | |
metrics = list(metrics) if metrics is not None else [] | |
if not any(isinstance(metric, CorrectnessMetric) for metric in metrics): | |
# By default only correctness is computed as it is required to build the report | |
metrics.insert( | |
0, CorrectnessMetric(name="correctness", llm_client=llm_client, agent_description=agent_description) | |
) | |
return metrics | |
def compute_metrics(metrics, testset, answers): | |
metrics_results = defaultdict(dict) | |
for metric in metrics: | |
metric_name = getattr( | |
metric, "name", metric.__class__.__name__ if isinstance(metric, Metric) else metric.__name__ | |
) | |
for sample, answer in maybe_tqdm( | |
zip(testset.to_pandas().to_records(index=True), answers), | |
desc=f"{metric_name} evaluation", | |
total=len(answers), | |
): | |
metrics_results[sample["id"]].update(metric(sample, answer)) | |
return metrics_results | |
def get_report(testset, answers, metrics_results, knowledge_base): | |
return RAGReport(testset, answers, metrics_results, knowledge_base) | |
def add_recommendation(report, llm_client, metrics): | |
recommendation = get_rag_recommendation( | |
report.topics, | |
report.correctness_by_question_type().to_dict()[metrics[0].name], | |
report.correctness_by_topic().to_dict()[metrics[0].name], | |
llm_client, | |
) | |
report._recommendation = recommendation | |
def track_analytics(report, testset, knowledge_base, agent_description, metrics): | |
analytics.track( | |
"raget:evaluation", | |
{ | |
"testset_size": len(testset), | |
"knowledge_base_size": len(knowledge_base) if knowledge_base else -1, | |
"agent_description": agent_description, | |
"num_metrics": len(metrics), | |
"correctness": report.correctness, | |
}, | |
) | |
def _compute_answers(answer_fn, testset): | |
answers = [] | |
needs_history = ( | |
len(signature(answer_fn).parameters) > 1 and ANSWER_FN_HISTORY_PARAM in signature(answer_fn).parameters | |
) | |
for sample in maybe_tqdm(testset.samples, desc="Asking questions to the agent", total=len(testset)): | |
kwargs = {} | |
if needs_history: | |
kwargs[ANSWER_FN_HISTORY_PARAM] = sample.conversation_history | |
answers.append(answer_fn(sample.question, **kwargs)) | |
return answers |