|
from dataclasses import field |
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Union |
|
|
|
import datasets |
|
import evaluate |
|
from datasets import Features, Sequence, Value |
|
|
|
from .artifact import __file__ as _ |
|
from .blocks import __file__ as _ |
|
from .card import __file__ as _ |
|
from .catalog import __file__ as _ |
|
from .collections import __file__ as _ |
|
from .common import __file__ as _ |
|
from .dataclass import __file__ as _ |
|
from .file_utils import __file__ as _ |
|
from .fusion import __file__ as _ |
|
from .generator_utils import __file__ as _ |
|
from .instructions import __file__ as _ |
|
from .load import __file__ as _ |
|
from .loaders import __file__ as _ |
|
from .metrics import __file__ as _ |
|
from .normalizers import __file__ as _ |
|
from .operator import ( |
|
MultiStreamOperator, |
|
SequntialOperator, |
|
SequntialOperatorInitilizer, |
|
StreamInitializerOperator, |
|
) |
|
from .operator import __file__ as _ |
|
from .operators import ( |
|
ApplyStreamOperatorsField, |
|
ApplyValueOperatorsField, |
|
FlattenInstances, |
|
MergeStreams, |
|
SplitByValue, |
|
) |
|
from .operators import __file__ as _ |
|
from .processors import __file__ as _ |
|
from .recipe import __file__ as _ |
|
from .register import __file__ as _ |
|
from .register import register_all_artifacts |
|
from .schema import __file__ as _ |
|
from .split_utils import __file__ as _ |
|
from .splitters import __file__ as _ |
|
from .stream import MultiStream, Stream |
|
from .stream import __file__ as _ |
|
from .task import __file__ as _ |
|
from .templates import __file__ as _ |
|
from .text_utils import __file__ as _ |
|
from .utils import __file__ as _ |
|
from .validate import __file__ as _ |
|
|
|
|
|
class MultiStreamScoreMean(MultiStreamOperator): |
|
def aggegate_results(self, multi_stream: MultiStream): |
|
scores = [] |
|
for stream in multi_stream.values(): |
|
instance = stream.peak() |
|
scores.append(instance["score"]["global"]["score"]) |
|
|
|
from statistics import mean |
|
|
|
return mean(scores) |
|
|
|
def spread_results(self, stream: Stream, score: float): |
|
for instance in stream: |
|
instance["score"]["global"]["groups_mean_score"] = score |
|
yield instance |
|
|
|
def process(self, multi_stream: MultiStream) -> MultiStream: |
|
mean_score = self.aggegate_results(multi_stream) |
|
|
|
result = {} |
|
for stream_name, stream in multi_stream.items(): |
|
result[stream_name] = Stream(self.spread_results, gen_kwargs={"stream": stream, "score": mean_score}) |
|
|
|
return MultiStream(result) |
|
|
|
|
|
class FromPredictionsAndOriginalData(StreamInitializerOperator): |
|
def zip(self, predictions, references): |
|
for prediction, original in zip(predictions, references): |
|
yield {**original, "prediction": prediction} |
|
|
|
def process(self, predictions: List[str], references: Iterable, split_name: str = "all") -> MultiStream: |
|
return MultiStream( |
|
{split_name: Stream(self.zip, gen_kwargs={"predictions": predictions, "references": references})} |
|
) |
|
|
|
|
|
from .schema import UNITXT_DATASET_SCHEMA |
|
|
|
|
|
class MetricRecipe(SequntialOperatorInitilizer): |
|
def prepare(self): |
|
register_all_artifacts() |
|
self.steps = [ |
|
FromPredictionsAndOriginalData(), |
|
ApplyValueOperatorsField( |
|
value_field="prediction", operators_field="processors", default_operators=["processors.to_string"] |
|
), |
|
SplitByValue(["group"]), |
|
ApplyStreamOperatorsField( |
|
"metrics", |
|
reversed=True, |
|
), |
|
MultiStreamScoreMean(), |
|
MergeStreams(), |
|
] |
|
|
|
|
|
UNITXT_METRIC_SCHEMA = Features({"predictions": Value("string"), "references": dict(UNITXT_DATASET_SCHEMA)}) |
|
|
|
|
|
|
|
class Metric(evaluate.Metric): |
|
def _info(self): |
|
return evaluate.MetricInfo( |
|
description="_DESCRIPTION", |
|
citation="_CITATION", |
|
|
|
features=UNITXT_METRIC_SCHEMA, |
|
codebase_urls=["https://"], |
|
reference_urls=[ |
|
"https://", |
|
"https://", |
|
], |
|
) |
|
|
|
def _compute(self, predictions: List[str], references: Iterable, flatten: bool = False, split_name: str = "all"): |
|
recipe = MetricRecipe() |
|
|
|
multi_stream = recipe(predictions=predictions, references=references, split_name=split_name) |
|
|
|
if flatten: |
|
operator = FlattenInstances() |
|
multi_stream = operator(multi_stream) |
|
|
|
stream = multi_stream[split_name] |
|
|
|
return list(stream) |
|
|