|
from typing import Dict, Iterable, List |
|
|
|
from datasets import Features, Value |
|
|
|
from .operator import ( |
|
MultiStreamOperator, |
|
SequentialOperatorInitilizer, |
|
StreamInitializerOperator, |
|
) |
|
from .operators import ( |
|
Apply, |
|
ApplyMetric, |
|
ApplyOperatorsField, |
|
FlattenInstances, |
|
MergeStreams, |
|
SplitByValue, |
|
) |
|
from .register import _reset_env_local_catalogs, register_all_artifacts |
|
from .schema import UNITXT_DATASET_SCHEMA |
|
from .stream import MultiStream, Stream |
|
|
|
|
|
class MultiStreamScoreMean(MultiStreamOperator): |
|
def aggegate_results(self, multi_stream: MultiStream): |
|
scores = [] |
|
for stream in multi_stream.values(): |
|
instance = stream.peek() |
|
scores.append(instance["score"]["global"]["score"]) |
|
|
|
from statistics import mean |
|
|
|
return mean(scores) |
|
|
|
def spread_results(self, stream: Stream, score: float): |
|
for instance in stream: |
|
instance["score"]["global"]["groups_mean_score"] = score |
|
yield instance |
|
|
|
def spread_results_one_stream(self, stream: Stream): |
|
for instance in stream: |
|
instance["score"]["global"]["groups_mean_score"] = instance["score"][ |
|
"global" |
|
]["score"] |
|
yield instance |
|
|
|
def process(self, multi_stream: MultiStream) -> MultiStream: |
|
result = {} |
|
|
|
|
|
|
|
if len(multi_stream) == 1: |
|
for stream_name, stream in multi_stream.items(): |
|
result[stream_name] = Stream( |
|
self.spread_results_one_stream, gen_kwargs={"stream": stream} |
|
) |
|
return MultiStream(result) |
|
|
|
mean_score = self.aggegate_results(multi_stream) |
|
result = {} |
|
for stream_name, stream in multi_stream.items(): |
|
result[stream_name] = Stream( |
|
self.spread_results, gen_kwargs={"stream": stream, "score": mean_score} |
|
) |
|
|
|
return MultiStream(result) |
|
|
|
|
|
class FromPredictionsAndOriginalData(StreamInitializerOperator): |
|
def zip(self, predictions, references): |
|
for prediction, original in zip(predictions, references): |
|
yield {**original, "prediction": prediction} |
|
|
|
def process( |
|
self, predictions: List[str], references: Iterable, split_name: str = "all" |
|
) -> MultiStream: |
|
return MultiStream( |
|
{ |
|
split_name: Stream( |
|
self.zip, |
|
gen_kwargs={"predictions": predictions, "references": references}, |
|
) |
|
} |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _from_key_value_pairs(key_value_list: Dict[str, list]) -> Dict[str, str]: |
|
return dict(zip(key_value_list["key"], key_value_list["value"])) |
|
|
|
|
|
class MetricRecipe(SequentialOperatorInitilizer): |
|
calc_confidence_intervals: bool = True |
|
|
|
def prepare(self): |
|
register_all_artifacts() |
|
self.steps = [ |
|
FromPredictionsAndOriginalData(), |
|
Apply( |
|
"additional_inputs", |
|
function=_from_key_value_pairs, |
|
to_field="additional_inputs", |
|
), |
|
ApplyOperatorsField( |
|
operators_field="postprocessors", |
|
), |
|
SplitByValue(["group"]), |
|
ApplyMetric( |
|
"metrics", |
|
calc_confidence_intervals=self.calc_confidence_intervals, |
|
), |
|
MultiStreamScoreMean(), |
|
MergeStreams(), |
|
] |
|
|
|
|
|
UNITXT_METRIC_SCHEMA = Features( |
|
{"predictions": Value("string"), "references": dict(UNITXT_DATASET_SCHEMA)} |
|
) |
|
|
|
|
|
def _compute( |
|
predictions: List[str], |
|
references: Iterable, |
|
flatten: bool = False, |
|
split_name: str = "all", |
|
calc_confidence_intervals: bool = True, |
|
): |
|
_reset_env_local_catalogs() |
|
register_all_artifacts() |
|
recipe = MetricRecipe(calc_confidence_intervals=calc_confidence_intervals) |
|
|
|
multi_stream = recipe( |
|
predictions=predictions, references=references, split_name=split_name |
|
) |
|
|
|
if flatten: |
|
operator = FlattenInstances() |
|
multi_stream = operator(multi_stream) |
|
|
|
stream = multi_stream[split_name] |
|
return list(stream) |
|
|