Elron commited on
Commit
6e6d8af
1 Parent(s): a8234ba

Upload metric_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. metric_utils.py +146 -0
metric_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Iterable, List
2
+
3
+ from datasets import Features, Value
4
+
5
+ from .operator import (
6
+ MultiStreamOperator,
7
+ SequentialOperatorInitilizer,
8
+ StreamInitializerOperator,
9
+ )
10
+ from .operators import (
11
+ Apply,
12
+ ApplyMetric,
13
+ ApplyOperatorsField,
14
+ FlattenInstances,
15
+ MergeStreams,
16
+ SplitByValue,
17
+ )
18
+ from .register import _reset_env_local_catalogs, register_all_artifacts
19
+ from .schema import UNITXT_DATASET_SCHEMA
20
+ from .stream import MultiStream, Stream
21
+
22
+
23
+ class MultiStreamScoreMean(MultiStreamOperator):
24
+ def aggegate_results(self, multi_stream: MultiStream):
25
+ scores = []
26
+ for stream in multi_stream.values():
27
+ instance = stream.peek()
28
+ scores.append(instance["score"]["global"]["score"])
29
+
30
+ from statistics import mean
31
+
32
+ return mean(scores)
33
+
34
+ def spread_results(self, stream: Stream, score: float):
35
+ for instance in stream:
36
+ instance["score"]["global"]["groups_mean_score"] = score
37
+ yield instance
38
+
39
+ def spread_results_one_stream(self, stream: Stream):
40
+ for instance in stream:
41
+ instance["score"]["global"]["groups_mean_score"] = instance["score"][
42
+ "global"
43
+ ]["score"]
44
+ yield instance
45
+
46
+ def process(self, multi_stream: MultiStream) -> MultiStream:
47
+ result = {}
48
+
49
+ # optimization in to avoid double calculation of metrics
50
+ # when aggregating results, if there is only one stream.
51
+ if len(multi_stream) == 1:
52
+ for stream_name, stream in multi_stream.items():
53
+ result[stream_name] = Stream(
54
+ self.spread_results_one_stream, gen_kwargs={"stream": stream}
55
+ )
56
+ return MultiStream(result)
57
+
58
+ mean_score = self.aggegate_results(multi_stream)
59
+ result = {}
60
+ for stream_name, stream in multi_stream.items():
61
+ result[stream_name] = Stream(
62
+ self.spread_results, gen_kwargs={"stream": stream, "score": mean_score}
63
+ )
64
+
65
+ return MultiStream(result)
66
+
67
+
68
+ class FromPredictionsAndOriginalData(StreamInitializerOperator):
69
+ def zip(self, predictions, references):
70
+ for prediction, original in zip(predictions, references):
71
+ yield {**original, "prediction": prediction}
72
+
73
+ def process(
74
+ self, predictions: List[str], references: Iterable, split_name: str = "all"
75
+ ) -> MultiStream:
76
+ return MultiStream(
77
+ {
78
+ split_name: Stream(
79
+ self.zip,
80
+ gen_kwargs={"predictions": predictions, "references": references},
81
+ )
82
+ }
83
+ )
84
+
85
+
86
+ # The additional_inputs field in the schema is defined as
87
+ # Sequence({"key": Value(dtype="string"), "value": Value("string")})
88
+ # When receiving instances from this scheme, the keys and values are returned as two separate
89
+ # lists, and are converted to a dictionary.
90
+
91
+
92
+ def _from_key_value_pairs(key_value_list: Dict[str, list]) -> Dict[str, str]:
93
+ return dict(zip(key_value_list["key"], key_value_list["value"]))
94
+
95
+
96
+ class MetricRecipe(SequentialOperatorInitilizer):
97
+ calc_confidence_intervals: bool = True
98
+
99
+ def prepare(self):
100
+ register_all_artifacts()
101
+ self.steps = [
102
+ FromPredictionsAndOriginalData(),
103
+ Apply(
104
+ "additional_inputs",
105
+ function=_from_key_value_pairs,
106
+ to_field="additional_inputs",
107
+ ),
108
+ ApplyOperatorsField(
109
+ operators_field="postprocessors",
110
+ ),
111
+ SplitByValue(["group"]),
112
+ ApplyMetric(
113
+ "metrics",
114
+ calc_confidence_intervals=self.calc_confidence_intervals,
115
+ ),
116
+ MultiStreamScoreMean(),
117
+ MergeStreams(),
118
+ ]
119
+
120
+
121
+ UNITXT_METRIC_SCHEMA = Features(
122
+ {"predictions": Value("string"), "references": dict(UNITXT_DATASET_SCHEMA)}
123
+ )
124
+
125
+
126
+ def _compute(
127
+ predictions: List[str],
128
+ references: Iterable,
129
+ flatten: bool = False,
130
+ split_name: str = "all",
131
+ calc_confidence_intervals: bool = True,
132
+ ):
133
+ _reset_env_local_catalogs()
134
+ register_all_artifacts()
135
+ recipe = MetricRecipe(calc_confidence_intervals=calc_confidence_intervals)
136
+
137
+ multi_stream = recipe(
138
+ predictions=predictions, references=references, split_name=split_name
139
+ )
140
+
141
+ if flatten:
142
+ operator = FlattenInstances()
143
+ multi_stream = operator(multi_stream)
144
+
145
+ stream = multi_stream[split_name]
146
+ return list(stream)