Spaces:
Running
Running
from typing import List, Dict, Any | |
from collections import defaultdict | |
import statistics | |
import datasets | |
import evaluate | |
from FLD_task import build_metrics | |
_DESCRIPTION = "" | |
_KWARGS_DESCRIPTION = "" | |
_CITATION = "" | |
class FLDMetrics(evaluate.Metric): | |
def __init__(self, *args, log_samples=False, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._metric_funcs = { | |
'strct': build_metrics('strict'), | |
'extr_stps': build_metrics('allow_extra_steps'), | |
} | |
self.log_samples = log_samples | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"predictions": datasets.Value("string"), | |
"references": datasets.Sequence(datasets.Value("string")), | |
"contexts": datasets.Value("string"), | |
} | |
), | |
# reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"], | |
) | |
def _compute(self, predictions, references, contexts): | |
if contexts is None: | |
contexts = [None] * len(predictions) | |
metrics: Dict[str, List[Any]] = defaultdict(list) | |
for pred, golds, context in zip(predictions, references, contexts): | |
for metric_type, calc_metrics in self._metric_funcs.items(): | |
_metrics = calc_metrics( | |
golds, | |
pred, | |
context=context, | |
) | |
for metric_name, metric_val in _metrics.items(): | |
metrics[f"{metric_type}.{metric_name}"].append(metric_val) | |
results = {} | |
for metric_name, metric_vals in metrics.items(): | |
results[f"{metric_name}"] = statistics.mean(metric_vals) | |
return results | |