from typing import List, Dict, Any from collections import defaultdict import statistics import datasets import evaluate from FLD_task import build_metrics _DESCRIPTION = "" _KWARGS_DESCRIPTION = "" _CITATION = "" @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class FLDMetrics(evaluate.Metric): def __init__(self, *args, log_samples=False, **kwargs): super().__init__(*args, **kwargs) self._metric_funcs = { 'strct': build_metrics('strict'), 'extr_stps': build_metrics('allow_extra_steps'), } self.log_samples = log_samples def _info(self): return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, features=datasets.Features( { "predictions": datasets.Value("string"), "references": datasets.Sequence(datasets.Value("string")), "contexts": datasets.Value("string"), } ), # reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"], ) def _compute(self, predictions, references, contexts): if contexts is None: contexts = [None] * len(predictions) metrics: Dict[str, List[Any]] = defaultdict(list) for pred, golds, context in zip(predictions, references, contexts): for metric_type, calc_metrics in self._metric_funcs.items(): _metrics = calc_metrics( golds, pred, context=context, ) for metric_name, metric_val in _metrics.items(): metrics[f"{metric_type}.{metric_name}"].append(metric_val) results = {} for metric_name, metric_vals in metrics.items(): results[f"{metric_name}"] = statistics.mean(metric_vals) return results