Upload metrics.py with huggingface_hub
Browse files- metrics.py +1338 -199
metrics.py
CHANGED
@@ -1,19 +1,24 @@
|
|
1 |
import re
|
2 |
import string
|
3 |
import uuid
|
|
|
4 |
from abc import ABC, abstractmethod
|
5 |
from collections import Counter
|
|
|
6 |
from dataclasses import field
|
|
|
7 |
from typing import Any, Dict, Generator, List, Optional, Tuple
|
8 |
|
9 |
import evaluate
|
10 |
import numpy
|
11 |
import numpy as np
|
12 |
from scipy.stats import bootstrap
|
|
|
13 |
|
14 |
from .artifact import Artifact
|
15 |
from .dataclass import InternalField, OptionalField
|
16 |
from .logging_utils import get_logger
|
|
|
17 |
from .operator import (
|
18 |
MultiStreamOperator,
|
19 |
SingleStreamOperator,
|
@@ -22,14 +27,17 @@ from .operator import (
|
|
22 |
)
|
23 |
from .operators import CopyFields
|
24 |
from .random_utils import get_seed
|
|
|
25 |
from .stream import MultiStream, Stream
|
26 |
-
from .type_utils import isoftype
|
27 |
|
28 |
logger = get_logger()
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
33 |
|
34 |
|
35 |
def abstract_factory():
|
@@ -40,6 +48,18 @@ def abstract_field():
|
|
40 |
return field(default_factory=abstract_factory)
|
41 |
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
class UpdateStream(StreamInstanceOperator):
|
44 |
update: dict
|
45 |
|
@@ -57,6 +77,48 @@ class Metric(Artifact):
|
|
57 |
def main_score(self):
|
58 |
pass
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
class MetricWithConfidenceInterval(Metric):
|
62 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
@@ -73,7 +135,12 @@ class MetricWithConfidenceInterval(Metric):
|
|
73 |
return np.random.default_rng(hash(get_seed()) & _max_32bit)
|
74 |
|
75 |
def disable_confidence_interval_calculation(self):
|
|
|
76 |
self.n_resamples = None
|
|
|
|
|
|
|
|
|
77 |
|
78 |
def _can_compute_confidence_intervals(self, num_predictions):
|
79 |
return (
|
@@ -82,45 +149,117 @@ class MetricWithConfidenceInterval(Metric):
|
|
82 |
and num_predictions > 1
|
83 |
)
|
84 |
|
85 |
-
|
86 |
-
|
|
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
The instances for which the confidence intervals are computed.
|
92 |
"""
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
|
|
|
|
|
|
95 |
result = {}
|
96 |
|
97 |
if not self._can_compute_confidence_intervals(num_predictions=len(instances)):
|
98 |
return result
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
104 |
for score_name in score_names:
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
ci = bootstrap(
|
109 |
-
(
|
110 |
-
statistic=
|
111 |
n_resamples=self.n_resamples,
|
112 |
confidence_level=self.confidence_level,
|
113 |
random_state=self.new_random_generator(),
|
114 |
).confidence_interval
|
115 |
-
|
116 |
-
result[f"{
|
|
|
117 |
if score_name == self.main_score:
|
118 |
result["score_ci_low"] = ci.low
|
119 |
result["score_ci_high"] = ci.high
|
120 |
return result
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
def compute_global_confidence_intervals(
|
123 |
-
self, references, predictions,
|
124 |
):
|
125 |
"""Computed confidence intervals for a set of references and predictions."""
|
126 |
random_gen = self.new_random_generator()
|
@@ -128,12 +267,12 @@ class MetricWithConfidenceInterval(Metric):
|
|
128 |
def statistic(arr, axis):
|
129 |
# arr is a 2d array where each row is a resampling, so we
|
130 |
# iterate over the rows and compute the metric on each resampling
|
131 |
-
def metric(sample_refs, sample_preds,
|
132 |
try:
|
133 |
return self._compute(
|
134 |
references=sample_refs,
|
135 |
predictions=sample_preds,
|
136 |
-
|
137 |
)["score"]
|
138 |
except Exception as e:
|
139 |
# this happens in edge cases, for example, when the sampling creates a
|
@@ -141,40 +280,21 @@ class MetricWithConfidenceInterval(Metric):
|
|
141 |
logger.info(f"Warning in {self.__class__.__name__}", e)
|
142 |
return np.nan
|
143 |
|
|
|
144 |
scores = numpy.apply_along_axis(
|
145 |
lambda x: metric(
|
146 |
sample_refs=[references[i] for i in x],
|
147 |
sample_preds=[predictions[i] for i in x],
|
148 |
-
|
149 |
),
|
150 |
axis=axis,
|
151 |
arr=arr,
|
152 |
)
|
153 |
|
154 |
-
#
|
155 |
-
#
|
156 |
-
|
157 |
-
|
158 |
-
# edge cases - for example, when the sample contains only empty strings.
|
159 |
-
# CI is about the distribution around the statistic (e.g. mean), it doesn't deal with
|
160 |
-
# cases in which the metric is not computable. Therefore, we ignore these edge cases
|
161 |
-
# as part of the computation of CI. The question is how to implement this policy.
|
162 |
-
# Options:
|
163 |
-
# 1. skip the errors and return a shorter array => this fails because Scipy demans
|
164 |
-
# this callback (i.e. the statistic() callback) to return an array of the same size
|
165 |
-
# as the number of resamples
|
166 |
-
# 2. Put np.nan for the errors => this fails because in such case the ci itself
|
167 |
-
# becomes np.nan. So one edge case can fail the whole CI computation.
|
168 |
-
# 3. Replace the errors with a sampling from the successful cases => this is what
|
169 |
-
# is implemented.
|
170 |
-
error_indices = numpy.isnan(scores)
|
171 |
-
n_errors = sum(error_indices)
|
172 |
-
if n_errors > 0:
|
173 |
-
new_scores = random_gen.choice(scores, n_errors, replace=True)
|
174 |
-
scores = scores[~error_indices]
|
175 |
-
scores = np.concatenate([scores, new_scores])
|
176 |
-
|
177 |
-
return scores
|
178 |
|
179 |
result = {}
|
180 |
num_predictions = len(predictions)
|
@@ -202,12 +322,15 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
202 |
need to be considered. Accuracy, on the other hand, is just an average of the accuracy of all the instances.
|
203 |
"""
|
204 |
|
205 |
-
n_resamples =
|
|
|
|
|
|
|
206 |
|
207 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
208 |
references = []
|
209 |
predictions = []
|
210 |
-
|
211 |
global_score = {}
|
212 |
|
213 |
instances = []
|
@@ -226,31 +349,40 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
226 |
predictions.append(instance_prediction)
|
227 |
instances.append(instance)
|
228 |
|
229 |
-
|
230 |
-
instance["
|
231 |
)
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
if isinstance(self.main_score, str):
|
243 |
-
instance_score[self.main_score] =
|
244 |
|
245 |
instance["score"]["instance"].update(instance_score)
|
246 |
|
247 |
-
result = self._compute(references, predictions,
|
248 |
|
249 |
global_score.update(result)
|
250 |
|
251 |
score_name = global_score["score_name"]
|
252 |
confidence_interval = self.compute_global_confidence_intervals(
|
253 |
-
references, predictions,
|
254 |
)
|
255 |
global_score.update(confidence_interval)
|
256 |
|
@@ -262,9 +394,9 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
262 |
self,
|
263 |
references: List[List[str]],
|
264 |
predictions: List[str],
|
265 |
-
|
266 |
) -> dict:
|
267 |
-
result = self.compute(references, predictions,
|
268 |
result["score"] = result[self.main_score]
|
269 |
result["score_name"] = self.main_score
|
270 |
return result
|
@@ -274,13 +406,25 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
274 |
self,
|
275 |
references: List[List[Any]],
|
276 |
predictions: List[Any],
|
277 |
-
|
278 |
) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
pass
|
280 |
|
281 |
|
282 |
class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
283 |
-
n_resamples =
|
|
|
|
|
284 |
main_score: str
|
285 |
reduction_map: Dict[str, List[str]]
|
286 |
|
@@ -301,8 +445,8 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
301 |
),
|
302 |
)
|
303 |
|
304 |
-
|
305 |
-
instance["
|
306 |
for instance in stream
|
307 |
]
|
308 |
|
@@ -310,7 +454,7 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
310 |
instance_scores = self.compute(
|
311 |
references=references,
|
312 |
predictions=predictions,
|
313 |
-
|
314 |
)
|
315 |
|
316 |
# add the score and score_name fields
|
@@ -334,8 +478,6 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
334 |
), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
|
335 |
|
336 |
if reduction == "mean":
|
337 |
-
from statistics import mean
|
338 |
-
|
339 |
for field_name in fields:
|
340 |
global_score[field_name] = mean(
|
341 |
[
|
@@ -347,8 +489,13 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
347 |
global_score["score"] = global_score[field_name]
|
348 |
global_score["score_name"] = self.main_score
|
349 |
|
|
|
|
|
|
|
|
|
|
|
350 |
confidence_interval = self.score_based_confidence_interval(
|
351 |
-
instances=instances
|
352 |
)
|
353 |
global_score.update(confidence_interval)
|
354 |
|
@@ -360,33 +507,217 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
360 |
self,
|
361 |
references: List[List[Any]],
|
362 |
predictions: List[Any],
|
363 |
-
|
364 |
) -> List[Dict[str, Any]]:
|
365 |
pass
|
366 |
|
367 |
|
368 |
class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
|
373 |
@property
|
374 |
@abstractmethod
|
375 |
def reduction_map(self) -> dict:
|
376 |
pass
|
377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
global_score = {}
|
380 |
instances = []
|
381 |
|
382 |
for instance in stream:
|
383 |
refs, pred = instance["references"], instance["prediction"]
|
384 |
-
|
385 |
-
instance["additional_inputs"] if "additional_inputs" in instance else {}
|
386 |
-
)
|
387 |
|
388 |
instance_score = self.compute(
|
389 |
-
references=refs, prediction=pred,
|
390 |
)
|
391 |
instance_score["score"] = instance_score[self.main_score]
|
392 |
instance_score["score_name"] = self.main_score
|
@@ -399,36 +730,100 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
399 |
|
400 |
instances.append(instance)
|
401 |
|
402 |
-
|
403 |
-
assert (
|
404 |
-
reduction in self.implemented_reductions
|
405 |
-
), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
|
406 |
|
407 |
-
|
408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
global_score[field_name] = mean(scores)
|
416 |
-
if field_name == self.main_score:
|
417 |
-
global_score["score"] = global_score[field_name]
|
418 |
-
global_score["score_name"] = self.main_score
|
419 |
|
420 |
-
|
421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
)
|
423 |
-
global_score.update(confidence_interval)
|
424 |
|
425 |
-
|
426 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
|
428 |
@abstractmethod
|
429 |
-
def compute(
|
430 |
-
self, references: List[Any], prediction: Any, additional_inputs: Dict
|
431 |
-
) -> dict:
|
432 |
pass
|
433 |
|
434 |
|
@@ -445,7 +840,7 @@ class Squad(GlobalMetric):
|
|
445 |
self,
|
446 |
references: List[List[str]],
|
447 |
predictions: List[str],
|
448 |
-
|
449 |
) -> dict:
|
450 |
ids = [str(uuid.uuid4()).replace("-", "") for _ in range(len(predictions))]
|
451 |
formatted_predictions = [
|
@@ -466,9 +861,10 @@ class Squad(GlobalMetric):
|
|
466 |
class Accuracy(InstanceMetric):
|
467 |
reduction_map = {"mean": ["accuracy"]}
|
468 |
main_score = "accuracy"
|
|
|
469 |
|
470 |
def compute(
|
471 |
-
self, references: List[Any], prediction: Any,
|
472 |
) -> dict:
|
473 |
result = {
|
474 |
self.main_score: float(
|
@@ -483,13 +879,14 @@ class Accuracy(InstanceMetric):
|
|
483 |
class StringContainment(InstanceMetric):
|
484 |
reduction_map = {"mean": ["string_containment"]}
|
485 |
main_score = "string_containment"
|
|
|
486 |
|
487 |
def compute(
|
488 |
-
self, references: List[Any], prediction: Any,
|
489 |
) -> dict:
|
490 |
result = {
|
491 |
self.main_score: float(
|
492 |
-
any(str(reference) in prediction for reference in references)
|
493 |
)
|
494 |
}
|
495 |
result["score"] = result[self.main_score]
|
@@ -505,6 +902,13 @@ class MetricPipeline(MultiStreamOperator, Metric):
|
|
505 |
)
|
506 |
metric: Metric = None
|
507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
def verify(self):
|
509 |
assert self.main_score is not None, "main_score is not set"
|
510 |
|
@@ -569,37 +973,37 @@ class HuggingfaceMetric(GlobalMetric):
|
|
569 |
self,
|
570 |
references: List[List[Any]],
|
571 |
predictions: List[Any],
|
572 |
-
|
573 |
) -> dict:
|
574 |
-
|
575 |
for additional_input_field in self.hf_additional_input_fields:
|
576 |
assert (
|
577 |
-
additional_input_field in
|
578 |
-
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in
|
579 |
-
|
580 |
additional_input[additional_input_field]
|
581 |
-
for additional_input in
|
582 |
]
|
583 |
for additional_input_field in self.hf_additional_input_fields_pass_one_value:
|
584 |
assert (
|
585 |
-
additional_input_field in
|
586 |
-
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in
|
587 |
|
588 |
values = {
|
589 |
additional_input[additional_input_field]
|
590 |
-
for additional_input in
|
591 |
}
|
592 |
assert (
|
593 |
len(values) == 1
|
594 |
), f"Values of '{additional_input_field}' field required by {__class__.__name__} should all be the same, but have multiple values {values}"
|
595 |
|
596 |
-
|
597 |
|
598 |
-
# add check that all required fields in self.metrics are in
|
599 |
result = self.metric.compute(
|
600 |
predictions=predictions,
|
601 |
references=references,
|
602 |
-
**
|
603 |
**self.hf_compute_args,
|
604 |
)
|
605 |
if self.hf_main_score:
|
@@ -641,23 +1045,23 @@ class HuggingfaceBulkMetric(BulkInstanceMetric):
|
|
641 |
self,
|
642 |
references: List[List[str]],
|
643 |
predictions: List[str],
|
644 |
-
|
645 |
) -> List[Dict[str, Any]]:
|
646 |
-
|
647 |
for additional_input_field in self.hf_additional_input_fields:
|
648 |
assert (
|
649 |
-
additional_input_field in
|
650 |
-
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in
|
651 |
-
|
652 |
additional_input[additional_input_field]
|
653 |
-
for additional_input in
|
654 |
]
|
655 |
-
# add check that all required fields in self.metrics are in
|
656 |
|
657 |
scores = self.metric.compute(
|
658 |
predictions=predictions,
|
659 |
references=references,
|
660 |
-
**
|
661 |
**self.hf_compute_args,
|
662 |
)
|
663 |
|
@@ -692,7 +1096,7 @@ class F1(GlobalMetric):
|
|
692 |
self,
|
693 |
references: List[List[str]],
|
694 |
predictions: List[str],
|
695 |
-
|
696 |
) -> dict:
|
697 |
assert all(
|
698 |
len(reference) == 1 for reference in references
|
@@ -714,8 +1118,6 @@ class F1(GlobalMetric):
|
|
714 |
average=self.average,
|
715 |
)
|
716 |
if isinstance(result["f1"], numpy.ndarray):
|
717 |
-
from statistics import mean
|
718 |
-
|
719 |
final_result = {self.main_score: mean(result["f1"])}
|
720 |
for i, label in enumerate(labels):
|
721 |
final_result["f1_" + self.id_to_str[label]] = result["f1"][i]
|
@@ -742,7 +1144,6 @@ class F1MultiLabel(GlobalMetric):
|
|
742 |
_metric = None
|
743 |
main_score = "f1_macro"
|
744 |
average = None # Report per class then aggregate by mean
|
745 |
-
classes_to_ignore = ["none"]
|
746 |
metric = "f1"
|
747 |
|
748 |
def prepare(self):
|
@@ -767,7 +1168,7 @@ class F1MultiLabel(GlobalMetric):
|
|
767 |
self,
|
768 |
references: List[List[str]],
|
769 |
predictions: List[List[str]],
|
770 |
-
|
771 |
) -> dict:
|
772 |
self.str_to_id = {}
|
773 |
self.id_to_str = {}
|
@@ -775,13 +1176,9 @@ class F1MultiLabel(GlobalMetric):
|
|
775 |
self._validate_references_and_prediction(references, predictions)
|
776 |
references = [reference[0] for reference in references]
|
777 |
|
778 |
-
labels =
|
779 |
-
|
780 |
-
for lbl in {label for reference in references for label in reference}
|
781 |
-
if lbl not in self.classes_to_ignore
|
782 |
-
]
|
783 |
# if no classes are left then F1 is not defined
|
784 |
-
# (e.g. only "none" in references)
|
785 |
if len(labels) == 0:
|
786 |
return {self.main_score: float("nan")}
|
787 |
|
@@ -809,8 +1206,6 @@ class F1MultiLabel(GlobalMetric):
|
|
809 |
labels=labels_param,
|
810 |
)
|
811 |
if isinstance(result[self.metric], numpy.ndarray):
|
812 |
-
from statistics import mean
|
813 |
-
|
814 |
assert (
|
815 |
len(result[self.metric]) == len(labels)
|
816 |
), f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
|
@@ -883,6 +1278,8 @@ class Rouge(HuggingfaceMetric):
|
|
883 |
|
884 |
sent_split_newline: bool = True
|
885 |
|
|
|
|
|
886 |
def prepare(self):
|
887 |
super().prepare()
|
888 |
|
@@ -895,7 +1292,7 @@ class Rouge(HuggingfaceMetric):
|
|
895 |
nltk.download("punkt")
|
896 |
self.sent_tokenize = nltk.sent_tokenize
|
897 |
|
898 |
-
def compute(self, references, predictions,
|
899 |
if self.sent_split_newline:
|
900 |
predictions = [
|
901 |
"\n".join(self.sent_tokenize(prediction.strip()))
|
@@ -905,13 +1302,16 @@ class Rouge(HuggingfaceMetric):
|
|
905 |
["\n".join(self.sent_tokenize(r.strip())) for r in reference]
|
906 |
for reference in references
|
907 |
]
|
908 |
-
return super().compute(references, predictions,
|
909 |
|
910 |
|
911 |
# Computes char edit distance, ignoring whitespace
|
912 |
class CharEditDistanceAccuracy(InstanceMetric):
|
913 |
reduction_map = {"mean": ["char_edit_dist_accuracy"]}
|
914 |
main_score = "char_edit_dist_accuracy"
|
|
|
|
|
|
|
915 |
|
916 |
def prepare(self):
|
917 |
super().prepare()
|
@@ -919,9 +1319,7 @@ class CharEditDistanceAccuracy(InstanceMetric):
|
|
919 |
|
920 |
self.eval = editdistance.eval
|
921 |
|
922 |
-
def compute(
|
923 |
-
self, references, prediction: str, additional_inputs: List[Dict]
|
924 |
-
) -> dict:
|
925 |
assert (
|
926 |
len(references) == 1
|
927 |
), f"Expected only one reference , but received: {references}"
|
@@ -939,11 +1337,13 @@ class Wer(HuggingfaceMetric):
|
|
939 |
hf_metric_name = "wer"
|
940 |
main_score = "wer"
|
941 |
|
|
|
|
|
942 |
def compute(
|
943 |
self,
|
944 |
references: List[List[str]],
|
945 |
predictions: List[str],
|
946 |
-
|
947 |
) -> dict:
|
948 |
assert all(
|
949 |
len(reference) == 1 for reference in references
|
@@ -955,6 +1355,43 @@ class Wer(HuggingfaceMetric):
|
|
955 |
return {self.main_score: result}
|
956 |
|
957 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
958 |
class MatthewsCorrelation(HuggingfaceMetric):
|
959 |
hf_metric_name = "matthews_correlation"
|
960 |
main_score = "matthews_correlation"
|
@@ -970,7 +1407,7 @@ class MatthewsCorrelation(HuggingfaceMetric):
|
|
970 |
self,
|
971 |
references: List[List[str]],
|
972 |
predictions: List[str],
|
973 |
-
|
974 |
) -> dict:
|
975 |
formatted_references = [
|
976 |
self.get_str_id(reference[0]) for reference in references
|
@@ -983,6 +1420,33 @@ class MatthewsCorrelation(HuggingfaceMetric):
|
|
983 |
)
|
984 |
|
985 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
986 |
class CustomF1(GlobalMetric):
|
987 |
main_score = "f1_micro"
|
988 |
groups = None
|
@@ -1036,9 +1500,9 @@ class CustomF1(GlobalMetric):
|
|
1036 |
except ZeroDivisionError:
|
1037 |
return self.zero_division
|
1038 |
|
1039 |
-
def get_groups(self, elements,
|
1040 |
groups = set()
|
1041 |
-
for sublist, additional_input in zip(elements,
|
1042 |
for e in sublist:
|
1043 |
if self.should_ignore_element(e, additional_input):
|
1044 |
continue
|
@@ -1049,7 +1513,7 @@ class CustomF1(GlobalMetric):
|
|
1049 |
self,
|
1050 |
references: List[List[Any]],
|
1051 |
predictions: List[Any],
|
1052 |
-
|
1053 |
) -> dict:
|
1054 |
# in case reference are List[List[List[Any]]] and predictions are List[List[Any]]:
|
1055 |
if (
|
@@ -1065,12 +1529,12 @@ class CustomF1(GlobalMetric):
|
|
1065 |
)
|
1066 |
|
1067 |
if self.groups is None:
|
1068 |
-
groups = self.get_groups(references,
|
1069 |
else:
|
1070 |
groups = self.groups
|
1071 |
groups_statistics = {}
|
1072 |
for references_batch, predictions_batch, additional_input in zip(
|
1073 |
-
references, predictions,
|
1074 |
):
|
1075 |
grouped_references = self.group_elements(references_batch, additional_input)
|
1076 |
grouped_predictions = self.group_elements(
|
@@ -1187,10 +1651,11 @@ class TokenOverlap(InstanceMetric):
|
|
1187 |
ci_scores = ["f1", "precision", "recall"]
|
1188 |
|
1189 |
def compute(
|
1190 |
-
self, references: List[Any], prediction: Any,
|
1191 |
) -> dict:
|
1192 |
results = [
|
1193 |
-
self._compute_single_ref(reference, prediction)
|
|
|
1194 |
]
|
1195 |
return {
|
1196 |
measure: max(r[i] for r in results)
|
@@ -1200,8 +1665,8 @@ class TokenOverlap(InstanceMetric):
|
|
1200 |
def _compute_single_ref(
|
1201 |
self, reference: Any, prediction: Any
|
1202 |
) -> Tuple[float, float, float]:
|
1203 |
-
prediction_tokens = normalize_answer(prediction).split()
|
1204 |
-
reference_tokens = normalize_answer(reference).split()
|
1205 |
common = Counter(prediction_tokens) & Counter(reference_tokens)
|
1206 |
num_same = sum(common.values())
|
1207 |
if num_same == 0:
|
@@ -1221,9 +1686,11 @@ class BertScore(HuggingfaceBulkMetric):
|
|
1221 |
ci_scores = ["f1", "precision", "recall"]
|
1222 |
model_name: str
|
1223 |
|
|
|
|
|
1224 |
def prepare(self):
|
1225 |
super().prepare()
|
1226 |
-
self.hf_compute_args = {"model_type": self.model_name}
|
1227 |
|
1228 |
|
1229 |
class SentenceBert(BulkInstanceMetric):
|
@@ -1233,19 +1700,23 @@ class SentenceBert(BulkInstanceMetric):
|
|
1233 |
|
1234 |
model_name: str
|
1235 |
|
|
|
|
|
1236 |
def prepare(self):
|
1237 |
super().prepare()
|
|
|
1238 |
from sentence_transformers import SentenceTransformer
|
1239 |
from sentence_transformers import util as sbert_util
|
1240 |
|
1241 |
-
self.
|
|
|
1242 |
self.util = sbert_util
|
1243 |
|
1244 |
def compute(
|
1245 |
self,
|
1246 |
references: List[List[Any]],
|
1247 |
predictions: List[Any],
|
1248 |
-
|
1249 |
) -> List[Dict[str, Any]]:
|
1250 |
scores = []
|
1251 |
|
@@ -1260,9 +1731,9 @@ class SentenceBert(BulkInstanceMetric):
|
|
1260 |
count += len(ref_group)
|
1261 |
|
1262 |
# compute s-bert embeddings
|
1263 |
-
preds_emb = self.model.encode(predictions)
|
1264 |
refs_emb = self.model.encode(
|
1265 |
-
[ref for ref_group in references for ref in ref_group]
|
1266 |
)
|
1267 |
|
1268 |
# for each candidate, pick the reference with the highest score
|
@@ -1280,17 +1751,23 @@ class Reward(BulkInstanceMetric):
|
|
1280 |
|
1281 |
model_name: str
|
1282 |
|
|
|
|
|
1283 |
def prepare(self):
|
1284 |
super().prepare()
|
|
|
1285 |
from transformers import pipeline
|
1286 |
|
1287 |
-
|
|
|
|
|
|
|
1288 |
|
1289 |
def compute(
|
1290 |
self,
|
1291 |
references: List[List[Any]],
|
1292 |
predictions: List[Any],
|
1293 |
-
|
1294 |
) -> List[Dict[str, Any]]:
|
1295 |
# treat the references as the questions and the predictions as answers
|
1296 |
# assume a single reference
|
@@ -1316,25 +1793,27 @@ class Perplexity(BulkInstanceMetric):
|
|
1316 |
batch_size: int = 32
|
1317 |
model_name: str
|
1318 |
|
|
|
|
|
1319 |
def compute(
|
1320 |
self,
|
1321 |
references: List[List[Any]],
|
1322 |
predictions: List[Any],
|
1323 |
-
|
1324 |
) -> List[Dict[str, Any]]:
|
1325 |
"""Computes the likelihood of generating text Y after text X - P(Y|X).
|
1326 |
|
1327 |
-
:param
|
1328 |
-
:param
|
1329 |
|
1330 |
-
:return: the likelihood of generating text Y_i after text
|
1331 |
"""
|
1332 |
sources = []
|
1333 |
targets = []
|
1334 |
for prediction, instance_references in zip(predictions, references):
|
1335 |
for instance_reference in instance_references:
|
1336 |
-
sources.append(f"{self.perplexity_prompt} {
|
1337 |
-
targets.append(
|
1338 |
|
1339 |
from transformers import AutoConfig
|
1340 |
|
@@ -1375,9 +1854,11 @@ class Perplexity(BulkInstanceMetric):
|
|
1375 |
from transformers import AutoTokenizer
|
1376 |
|
1377 |
self.model_name = model_name
|
|
|
|
|
|
|
|
|
1378 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
1379 |
-
self.model = self.model_class().from_pretrained(self.model_name)
|
1380 |
-
self.is_cuda = torch.cuda.is_available()
|
1381 |
|
1382 |
def compute_lm(
|
1383 |
self, source: List[str], target: List[str], batch_size: int
|
@@ -1470,16 +1951,9 @@ class Perplexity(BulkInstanceMetric):
|
|
1470 |
return AutoModelForSeq2SeqLM
|
1471 |
|
1472 |
def compute_batch(self, tokens_source, tokens_target):
|
1473 |
-
tokens_docs_ids = tokens_source["input_ids"]
|
1474 |
-
attention = tokens_source["attention_mask"]
|
1475 |
-
labels = tokens_target["input_ids"]
|
1476 |
-
|
1477 |
-
if self.is_cuda:
|
1478 |
-
tokens_docs_ids, attention, labels = (
|
1479 |
-
tokens_docs_ids.cuda(),
|
1480 |
-
attention.cuda(),
|
1481 |
-
labels.cuda(),
|
1482 |
-
)
|
1483 |
|
1484 |
logits = self.model(
|
1485 |
input_ids=tokens_docs_ids.long(),
|
@@ -1519,12 +1993,9 @@ class Perplexity(BulkInstanceMetric):
|
|
1519 |
# replace the padding token in the labels by -100
|
1520 |
labels[labels == self.tokenizer.pad_token_id] = -100
|
1521 |
|
1522 |
-
|
1523 |
-
|
1524 |
-
|
1525 |
-
attention.cuda(),
|
1526 |
-
labels.cuda(),
|
1527 |
-
)
|
1528 |
|
1529 |
# no need to pass labels as we calculate the loss below per document
|
1530 |
model_output = self.model(
|
@@ -1558,6 +2029,8 @@ class NDCG(GlobalMetric):
|
|
1558 |
|
1559 |
main_score = "nDCG"
|
1560 |
|
|
|
|
|
1561 |
def prepare(self):
|
1562 |
from sklearn.metrics import ndcg_score
|
1563 |
|
@@ -1568,15 +2041,12 @@ class NDCG(GlobalMetric):
|
|
1568 |
self,
|
1569 |
references: List[List[Any]],
|
1570 |
predictions: List[Any],
|
1571 |
-
|
1572 |
) -> dict:
|
1573 |
from collections import defaultdict
|
1574 |
-
from statistics import mean
|
1575 |
|
1576 |
query_to_predictions_and_references = defaultdict(lambda: [[], []])
|
1577 |
-
for reference, pred, inputs_dict in zip(
|
1578 |
-
references, predictions, additional_inputs
|
1579 |
-
):
|
1580 |
query = inputs_dict.get("query")
|
1581 |
query_to_predictions_and_references[query][0].append(pred)
|
1582 |
query_to_predictions_and_references[query][1].append(reference)
|
@@ -1606,9 +2076,7 @@ class NDCG(GlobalMetric):
|
|
1606 |
|
1607 |
|
1608 |
class RetrievalMetric(InstanceMetric):
|
1609 |
-
def compute(
|
1610 |
-
self, references: List[Any], prediction: Any, additional_inputs: Dict
|
1611 |
-
) -> dict:
|
1612 |
# digest input
|
1613 |
pred_ids: List[Any] = prediction
|
1614 |
ref_ids: List[Any] = list(dict.fromkeys(references))
|
@@ -1681,6 +2149,7 @@ class RetrievalMetric(InstanceMetric):
|
|
1681 |
class MRR(RetrievalMetric):
|
1682 |
reduction_map = {"mean": ["mrr"]}
|
1683 |
main_score = "mrr"
|
|
|
1684 |
|
1685 |
def _compute(
|
1686 |
self,
|
@@ -1697,6 +2166,7 @@ class MRR(RetrievalMetric):
|
|
1697 |
class MAP(RetrievalMetric):
|
1698 |
reduction_map = {"mean": ["map"]}
|
1699 |
main_score = "map"
|
|
|
1700 |
|
1701 |
def _compute(
|
1702 |
self,
|
@@ -1765,3 +2235,672 @@ class KPA(CustomF1):
|
|
1765 |
|
1766 |
def should_ignore_element(self, element, additional_input):
|
1767 |
return element == "none"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import re
|
2 |
import string
|
3 |
import uuid
|
4 |
+
import warnings
|
5 |
from abc import ABC, abstractmethod
|
6 |
from collections import Counter
|
7 |
+
from copy import deepcopy
|
8 |
from dataclasses import field
|
9 |
+
from statistics import mean
|
10 |
from typing import Any, Dict, Generator, List, Optional, Tuple
|
11 |
|
12 |
import evaluate
|
13 |
import numpy
|
14 |
import numpy as np
|
15 |
from scipy.stats import bootstrap
|
16 |
+
from scipy.stats._warnings_errors import DegenerateDataWarning
|
17 |
|
18 |
from .artifact import Artifact
|
19 |
from .dataclass import InternalField, OptionalField
|
20 |
from .logging_utils import get_logger
|
21 |
+
from .metric_utils import InstanceInput, MetricRequest, MetricResponse
|
22 |
from .operator import (
|
23 |
MultiStreamOperator,
|
24 |
SingleStreamOperator,
|
|
|
27 |
)
|
28 |
from .operators import CopyFields
|
29 |
from .random_utils import get_seed
|
30 |
+
from .settings_utils import get_settings
|
31 |
from .stream import MultiStream, Stream
|
32 |
+
from .type_utils import isoftype, to_float_or_default
|
33 |
|
34 |
logger = get_logger()
|
35 |
+
settings = get_settings()
|
36 |
+
|
37 |
+
warnings.filterwarnings("ignore", category=DegenerateDataWarning)
|
38 |
+
|
39 |
+
|
40 |
+
warnings.filterwarnings("ignore", category=DegenerateDataWarning)
|
41 |
|
42 |
|
43 |
def abstract_factory():
|
|
|
48 |
return field(default_factory=abstract_factory)
|
49 |
|
50 |
|
51 |
+
def nan_mean(x):
|
52 |
+
import warnings
|
53 |
+
|
54 |
+
with warnings.catch_warnings():
|
55 |
+
# final mean should be mean of scores, ignoring NaN, hence nanmean
|
56 |
+
# but if the group function values is NaN for ALL values, nanmean throws a
|
57 |
+
# RuntimeWarning that it is calculating the mean of an empty slice (with no non-Nans)
|
58 |
+
# this is the desired behavior, but we want to avoid the warning here
|
59 |
+
warnings.simplefilter("ignore", category=RuntimeWarning)
|
60 |
+
return np.nanmean(x)
|
61 |
+
|
62 |
+
|
63 |
class UpdateStream(StreamInstanceOperator):
|
64 |
update: dict
|
65 |
|
|
|
77 |
def main_score(self):
|
78 |
pass
|
79 |
|
80 |
+
def consume_stream(self, stream: Stream):
|
81 |
+
references = []
|
82 |
+
predictions = []
|
83 |
+
additional_inputs = []
|
84 |
+
instances = []
|
85 |
+
for instance in stream:
|
86 |
+
references.append(instance["references"])
|
87 |
+
predictions.append(instance["prediction"])
|
88 |
+
additional_inputs.append(
|
89 |
+
instance["additional_inputs"] if "additional_inputs" in instance else {}
|
90 |
+
)
|
91 |
+
instances.append(instance)
|
92 |
+
return predictions, references, additional_inputs, instances
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def update_instance_scores(instances, instances_scores: List[Dict[str, Any]]):
|
96 |
+
for instance, new_scores in zip(instances, instances_scores):
|
97 |
+
if "score" not in instance:
|
98 |
+
instance["score"] = {}
|
99 |
+
scores = instance["score"]
|
100 |
+
if "instance" not in scores:
|
101 |
+
scores["instance"] = {}
|
102 |
+
scores["instance"].update(new_scores)
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
def set_global_score(instances, global_score: Dict[str, Any]):
|
106 |
+
for instance in instances:
|
107 |
+
if "score" not in instance:
|
108 |
+
instance["score"] = {}
|
109 |
+
scores = instance["score"]
|
110 |
+
if "global" not in scores:
|
111 |
+
scores["global"] = {}
|
112 |
+
scores["global"] = global_score
|
113 |
+
|
114 |
+
@abstractmethod
|
115 |
+
def disable_confidence_interval_calculation(self):
|
116 |
+
pass
|
117 |
+
|
118 |
+
@abstractmethod
|
119 |
+
def set_n_resamples(self, n_resample):
|
120 |
+
pass
|
121 |
+
|
122 |
|
123 |
class MetricWithConfidenceInterval(Metric):
|
124 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
|
|
135 |
return np.random.default_rng(hash(get_seed()) & _max_32bit)
|
136 |
|
137 |
def disable_confidence_interval_calculation(self):
|
138 |
+
n = self.n_resamples
|
139 |
self.n_resamples = None
|
140 |
+
return n
|
141 |
+
|
142 |
+
def set_n_resamples(self, n_resamples):
|
143 |
+
self.n_resamples = n_resamples
|
144 |
|
145 |
def _can_compute_confidence_intervals(self, num_predictions):
|
146 |
return (
|
|
|
149 |
and num_predictions > 1
|
150 |
)
|
151 |
|
152 |
+
@staticmethod
|
153 |
+
def average_item_scores(instances: List[dict], score_name: str):
|
154 |
+
"""Calculate mean of a set of instance scores (given by score_name), omitting NaN values.
|
155 |
|
156 |
+
Args:
|
157 |
+
instances: list of dicts of each instance's instance scores.
|
158 |
+
score_name: score field names to compute the mean for.
|
|
|
159 |
"""
|
160 |
+
return nan_mean(
|
161 |
+
[instance["score"]["instance"][score_name] for instance in instances]
|
162 |
+
)
|
163 |
+
|
164 |
+
def score_based_confidence_interval(
|
165 |
+
self,
|
166 |
+
instances: List[dict],
|
167 |
+
score_names: List[str],
|
168 |
+
aggregation_func=None,
|
169 |
+
ci_score_prefix="",
|
170 |
+
):
|
171 |
+
"""Compute confidence intervals based on existing scores, already computed on the input instances.
|
172 |
+
|
173 |
+
Unlike GlobalMetric, this is simply a function of the instance scores (possibly taking into account task_data field),
|
174 |
+
so they don't need to be recomputed after every bootstrap draw.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
instances: The instances for which the confidence intervals are computed; should already have the relevant instance scores calculated.
|
178 |
+
score_names: List of instance score field names to compute a confidence interval for.
|
179 |
+
aggregation_func: A function with arguments instances, field_name; is applied on list of instances (which may include task_data
|
180 |
+
field, as well as the prediction and references), and the field_name; default is simply to take the mean field_name from
|
181 |
+
instances after resampling, if argument is None.
|
182 |
+
ci_score_prefix: An optional string prefix to the score_name in the CI. Useful in cases where the
|
183 |
+
aggregation_func is something other than the mean
|
184 |
|
185 |
+
Returns:
|
186 |
+
Dict of confidence interval values
|
187 |
+
"""
|
188 |
result = {}
|
189 |
|
190 |
if not self._can_compute_confidence_intervals(num_predictions=len(instances)):
|
191 |
return result
|
192 |
|
193 |
+
ci_score_prefix = str(ci_score_prefix)
|
194 |
+
if aggregation_func is None:
|
195 |
+
# if aggregation_func is None, we simply take the mean of the resampled instance scores
|
196 |
+
# otherwise, the aggregation_func needs to be applied AFTER resampling the instances;
|
197 |
+
# that is, re-form the groups, calculate the function, and take the mean of the group scores
|
198 |
+
aggregation_func = self.average_item_scores
|
199 |
for score_name in score_names:
|
200 |
+
# need to redefine the statistic function within the loop because score_name is a loop variable
|
201 |
+
def statistic(arr, axis, score_name=score_name):
|
202 |
+
# arr is a 2d array where each row is a resampling, so we
|
203 |
+
# iterate over the rows and compute the metric on each resampling
|
204 |
+
scores = numpy.apply_along_axis(
|
205 |
+
lambda resampled_instances: aggregation_func(
|
206 |
+
resampled_instances, score_name
|
207 |
+
),
|
208 |
+
axis=axis,
|
209 |
+
arr=arr,
|
210 |
+
)
|
211 |
+
return self.resample_from_non_nan(scores)
|
212 |
+
|
213 |
+
# apply bootstrap only on the relevant field
|
214 |
ci = bootstrap(
|
215 |
+
(instances,),
|
216 |
+
statistic=statistic,
|
217 |
n_resamples=self.n_resamples,
|
218 |
confidence_level=self.confidence_level,
|
219 |
random_state=self.new_random_generator(),
|
220 |
).confidence_interval
|
221 |
+
full_score_name = ci_score_prefix + score_name
|
222 |
+
result[f"{full_score_name}_ci_low"] = ci.low
|
223 |
+
result[f"{full_score_name}_ci_high"] = ci.high
|
224 |
if score_name == self.main_score:
|
225 |
result["score_ci_low"] = ci.low
|
226 |
result["score_ci_high"] = ci.high
|
227 |
return result
|
228 |
|
229 |
+
def resample_from_non_nan(self, values):
|
230 |
+
"""Given an array values, will replace any NaN values with elements resampled with replacement from the non-NaN ones.
|
231 |
+
|
232 |
+
here we deal with samples on which the metric could not be computed. These are
|
233 |
+
edge cases - for example, when the sample contains only empty strings.
|
234 |
+
CI is about the distribution around the statistic (e.g. mean), it doesn't deal with
|
235 |
+
cases in which the metric is not computable. Therefore, we ignore these edge cases
|
236 |
+
as part of the computation of CI.
|
237 |
+
|
238 |
+
In theory there would be several ways to deal with this:
|
239 |
+
1. skip the errors and return a shorter array => this fails because Scipy requires
|
240 |
+
this callback (i.e. the statistic() callback) to return an array of the same size
|
241 |
+
as the number of resamples
|
242 |
+
2. Put np.nan for the errors => this fails because in such case the ci itself
|
243 |
+
becomes np.nan. So one edge case can fail the whole CI computation.
|
244 |
+
3. Replace the errors with a sampling from the successful cases => this is what is implemented.
|
245 |
+
|
246 |
+
This resampling makes it so that, if possible, the bca confidence interval returned by bootstrap will not be NaN, since
|
247 |
+
bootstrap does not ignore NaNs. However, if there are 0 or 1 non-NaN values, or all non-NaN values are equal,
|
248 |
+
the resulting distribution will be degenerate (only one unique value) so the CI will still be NaN since there is
|
249 |
+
no variability. In this case, the CI is essentially an interval of length 0 equaling the mean itself.
|
250 |
+
"""
|
251 |
+
if values.size > 1:
|
252 |
+
error_indices = numpy.isnan(values)
|
253 |
+
n_errors = sum(error_indices)
|
254 |
+
if 0 < n_errors < values.size:
|
255 |
+
# replace NaN aggregate scores with random draws from non-NaN scores, so that confidence interval isn't NaN itself
|
256 |
+
values[error_indices] = self.new_random_generator().choice(
|
257 |
+
values[~error_indices], n_errors, replace=True
|
258 |
+
)
|
259 |
+
return values
|
260 |
+
|
261 |
def compute_global_confidence_intervals(
|
262 |
+
self, references, predictions, task_data, score_name
|
263 |
):
|
264 |
"""Computed confidence intervals for a set of references and predictions."""
|
265 |
random_gen = self.new_random_generator()
|
|
|
267 |
def statistic(arr, axis):
|
268 |
# arr is a 2d array where each row is a resampling, so we
|
269 |
# iterate over the rows and compute the metric on each resampling
|
270 |
+
def metric(sample_refs, sample_preds, sample_task_data):
|
271 |
try:
|
272 |
return self._compute(
|
273 |
references=sample_refs,
|
274 |
predictions=sample_preds,
|
275 |
+
task_data=sample_task_data,
|
276 |
)["score"]
|
277 |
except Exception as e:
|
278 |
# this happens in edge cases, for example, when the sampling creates a
|
|
|
280 |
logger.info(f"Warning in {self.__class__.__name__}", e)
|
281 |
return np.nan
|
282 |
|
283 |
+
# resample the instance scores, and then return the global score each time
|
284 |
scores = numpy.apply_along_axis(
|
285 |
lambda x: metric(
|
286 |
sample_refs=[references[i] for i in x],
|
287 |
sample_preds=[predictions[i] for i in x],
|
288 |
+
sample_task_data=[task_data[i] for i in x],
|
289 |
),
|
290 |
axis=axis,
|
291 |
arr=arr,
|
292 |
)
|
293 |
|
294 |
+
# in some resamplings of instances, the global score may be NaN since it cannot be computed;
|
295 |
+
# in these cases, the bca confidence interval will be NaN because it does not ignore these values,
|
296 |
+
# so we replace any NaN values with those resampled from the non-NaN ones.
|
297 |
+
return self.resample_from_non_nan(scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
|
299 |
result = {}
|
300 |
num_predictions = len(predictions)
|
|
|
322 |
need to be considered. Accuracy, on the other hand, is just an average of the accuracy of all the instances.
|
323 |
"""
|
324 |
|
325 |
+
n_resamples: int = OptionalField(
|
326 |
+
default_factory=lambda: settings.num_resamples_for_global_metrics
|
327 |
+
)
|
328 |
+
process_single_instances = True
|
329 |
|
330 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
331 |
references = []
|
332 |
predictions = []
|
333 |
+
task_data = []
|
334 |
global_score = {}
|
335 |
|
336 |
instances = []
|
|
|
349 |
predictions.append(instance_prediction)
|
350 |
instances.append(instance)
|
351 |
|
352 |
+
instance_task_data = (
|
353 |
+
instance["task_data"] if "task_data" in instance else {}
|
354 |
)
|
355 |
+
task_data.append(instance_task_data)
|
356 |
+
instance_score = None
|
357 |
+
# for backward compatibility
|
358 |
+
no_score_value = np.nan
|
359 |
+
if self.process_single_instances:
|
360 |
+
try:
|
361 |
+
instance_score = self._compute(
|
362 |
+
[instance_references],
|
363 |
+
[instance_prediction],
|
364 |
+
[instance_task_data],
|
365 |
+
)
|
366 |
+
except:
|
367 |
+
no_score_value = None
|
368 |
+
if not instance_score:
|
369 |
+
instance_score = {
|
370 |
+
"score": no_score_value,
|
371 |
+
"score_name": self.main_score,
|
372 |
+
}
|
373 |
|
374 |
if isinstance(self.main_score, str):
|
375 |
+
instance_score[self.main_score] = no_score_value
|
376 |
|
377 |
instance["score"]["instance"].update(instance_score)
|
378 |
|
379 |
+
result = self._compute(references, predictions, task_data)
|
380 |
|
381 |
global_score.update(result)
|
382 |
|
383 |
score_name = global_score["score_name"]
|
384 |
confidence_interval = self.compute_global_confidence_intervals(
|
385 |
+
references, predictions, task_data, score_name
|
386 |
)
|
387 |
global_score.update(confidence_interval)
|
388 |
|
|
|
394 |
self,
|
395 |
references: List[List[str]],
|
396 |
predictions: List[str],
|
397 |
+
task_data: List[Any],
|
398 |
) -> dict:
|
399 |
+
result = self.compute(references, predictions, task_data)
|
400 |
result["score"] = result[self.main_score]
|
401 |
result["score_name"] = self.main_score
|
402 |
return result
|
|
|
406 |
self,
|
407 |
references: List[List[Any]],
|
408 |
predictions: List[Any],
|
409 |
+
task_data: List[Any],
|
410 |
) -> dict:
|
411 |
+
"""Computes a scores dictionary on a list of references, predictions and input.
|
412 |
+
|
413 |
+
This function is called once per instance, and then another time
|
414 |
+
over all data instances.
|
415 |
+
|
416 |
+
Returns:
|
417 |
+
a dictionary of scores that is set as:
|
418 |
+
the instance scores when called on a single data instance
|
419 |
+
the global score when called on the all data instances
|
420 |
+
"""
|
421 |
pass
|
422 |
|
423 |
|
424 |
class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
425 |
+
n_resamples: int = OptionalField(
|
426 |
+
default_factory=lambda: settings.num_resamples_for_instance_metrics
|
427 |
+
)
|
428 |
main_score: str
|
429 |
reduction_map: Dict[str, List[str]]
|
430 |
|
|
|
445 |
),
|
446 |
)
|
447 |
|
448 |
+
task_data = [
|
449 |
+
instance["task_data"] if "task_data" in instance else {}
|
450 |
for instance in stream
|
451 |
]
|
452 |
|
|
|
454 |
instance_scores = self.compute(
|
455 |
references=references,
|
456 |
predictions=predictions,
|
457 |
+
task_data=task_data,
|
458 |
)
|
459 |
|
460 |
# add the score and score_name fields
|
|
|
478 |
), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
|
479 |
|
480 |
if reduction == "mean":
|
|
|
|
|
481 |
for field_name in fields:
|
482 |
global_score[field_name] = mean(
|
483 |
[
|
|
|
489 |
global_score["score"] = global_score[field_name]
|
490 |
global_score["score_name"] = self.main_score
|
491 |
|
492 |
+
ci_fields = (
|
493 |
+
list(set(self.ci_scores))
|
494 |
+
if self.ci_scores is not None
|
495 |
+
else [self.main_score]
|
496 |
+
)
|
497 |
confidence_interval = self.score_based_confidence_interval(
|
498 |
+
instances=instances, score_names=ci_fields
|
499 |
)
|
500 |
global_score.update(confidence_interval)
|
501 |
|
|
|
507 |
self,
|
508 |
references: List[List[Any]],
|
509 |
predictions: List[Any],
|
510 |
+
task_data: List[Dict],
|
511 |
) -> List[Dict[str, Any]]:
|
512 |
pass
|
513 |
|
514 |
|
515 |
class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
516 |
+
"""Class for metrics for which a global score can be calculated by aggregating the instance scores (possibly with additional instance inputs).
|
517 |
+
|
518 |
+
InstanceMetric currently allows two reductions:
|
519 |
+
1. 'mean', which calculates the mean of instance scores,
|
520 |
+
2. 'group_mean', which first applies an aggregation function specified in the reduction_map
|
521 |
+
to instance scores grouped by the field grouping_field (which must not be None), and returns the mean
|
522 |
+
of the group scores; if grouping_field is None, grouping is disabled.
|
523 |
+
See _validate_group_mean_reduction for formatting instructions.
|
524 |
+
"""
|
525 |
|
526 |
+
n_resamples: int = OptionalField(
|
527 |
+
default_factory=lambda: settings.num_resamples_for_instance_metrics
|
528 |
+
)
|
529 |
+
|
530 |
+
# some group_mean aggregation functions (3rd element of "agg_func" list in the reduction)
|
531 |
+
# only require a list of instance scores (e.g., mean, median, etc.). Others aggregation functions
|
532 |
+
# require an additional column (e.g., a subgroup identifier) by which the instance scores will be grouped
|
533 |
+
# if subgroup_column is not None, a column by the specified name will be required in task_data
|
534 |
+
subgroup_column = None
|
535 |
+
implemented_reductions: List[str] = field(
|
536 |
+
default_factory=lambda: ["mean", "group_mean"]
|
537 |
+
)
|
538 |
|
539 |
@property
|
540 |
@abstractmethod
|
541 |
def reduction_map(self) -> dict:
|
542 |
pass
|
543 |
|
544 |
+
def _validate_group_mean_reduction(self, instances: List[dict]):
|
545 |
+
"""Ensure that group_mean reduction_map is properly formatted.
|
546 |
+
|
547 |
+
Example: Apply the variance (np.var) to group Accuracy instance scores. This class would be specified as follows:
|
548 |
+
|
549 |
+
class GroupVarianceAccuracy(Accuracy):
|
550 |
+
reduction_map = {'group_mean': {'agg_func': ['variance', np.var, True]}}
|
551 |
+
|
552 |
+
reduction_map must be a dict with values containing
|
553 |
+
- an 'agg_func' field with value being a 3-element list where
|
554 |
+
- 1st element is a string name of the aggregation function (used in naming the CI report)
|
555 |
+
- 2nd element is the callable aggregation function
|
556 |
+
- 3rd element is a Boolean indicator of whether, during boostrap CI calculation, the groups are to be sampled as single units.
|
557 |
+
If True, the group scores are calculated and then resampled. This treats the group units as the unit of
|
558 |
+
interest for which the CI is being compared.
|
559 |
+
If False, the instances are resampled individually, and the groups determined
|
560 |
+
(meaning the groups may be of slightly different size or composition from the original
|
561 |
+
depending on the resampling of the instances).
|
562 |
+
- Optional: 'score_fields' key with list value containing the string names of fields to apply the aggregation to
|
563 |
+
- If not present, the parent class main_score is used.
|
564 |
+
|
565 |
+
The aggregation function (2nd element of agg_func) can be one of two types:
|
566 |
+
1. simple: calculate a summary statistic from a single group of values (e.g. mean, median, etc.).
|
567 |
+
This is best suited for cases where the instances are independent of each other, other than belonging to the same group
|
568 |
+
2. comparison: requires subgroup_column to be specified. This function conducts
|
569 |
+
a comparison between scores for differing values of subgroup_column (e.g., 'original' vs 'paraphrase').
|
570 |
+
An example is where the original instance is a question, and the others are various paraphrases
|
571 |
+
or perturbations of this question. Here, the function would return, say, a comparison of the instance accuracies
|
572 |
+
rather than, say, the average instance accuracy.
|
573 |
+
In these cases, we recommend setting the 3rd parameter to be True so that the groups are resampled together.
|
574 |
+
|
575 |
+
Example:
|
576 |
+
class GroupVsBaselineDiffAccuracy(Accuracy):
|
577 |
+
subgroup_column = 'variant_type'
|
578 |
+
reduction_map = {'group_mean': {'agg_func': ['accuracy_diff', accuracy_diff, True],}}
|
579 |
+
|
580 |
+
# where the function is defined as
|
581 |
+
def accuracy_diff(subgroup_scores_dict, expected_subgroup_types=['original', 'paraphrase']):
|
582 |
+
validate_subgroup_types(subgroup_scores_dict, expected_subgroup_types)
|
583 |
+
from statistics import mean
|
584 |
+
return mean(subgroup_scores_dict['paraphrase']) - mean(subgroup_scores_dict['original'])
|
585 |
+
The input dataset should look like:
|
586 |
+
|
587 |
+
'group_id' 'question' 'variant_type'
|
588 |
+
1 'How do you fix a car engine?' 'original'
|
589 |
+
1 'What is the best way to fix an engine?' 'paraphrase'
|
590 |
+
1 'How do you repair a car engine?' 'paraphrase'
|
591 |
+
1 'How do I repair my engine?' 'paraphrase'
|
592 |
+
2 'Why are ants eating my food?' 'original'
|
593 |
+
"""
|
594 |
+
# instances need to all have task_data field with field group_id
|
595 |
+
assert all(
|
596 |
+
"task_data" in instance for instance in instances
|
597 |
+
), "each instance must have an task_data field"
|
598 |
+
assert all(
|
599 |
+
isinstance(instance["task_data"], dict) for instance in instances
|
600 |
+
), "each instance must have an task_data field that is a dict"
|
601 |
+
assert all(
|
602 |
+
"group_id" in instance["task_data"] for instance in instances
|
603 |
+
), "each instance task_data dict must have a key group_id"
|
604 |
+
|
605 |
+
# validate the reduction_map
|
606 |
+
assert (
|
607 |
+
"group_mean" in self.reduction_map
|
608 |
+
), "reduction_map must have a 'group_mean' key"
|
609 |
+
fields = self.reduction_map["group_mean"]
|
610 |
+
# for group_mean, expects a dict
|
611 |
+
assert isinstance(fields, dict)
|
612 |
+
assert (
|
613 |
+
"agg_func" in fields
|
614 |
+
), "fields should have a key 'agg_func' whose value is a 3-element list of a function name, function definition, and a boolean indicator"
|
615 |
+
assert isinstance(
|
616 |
+
fields["agg_func"], list
|
617 |
+
), "fields['agg_func'] should be a list"
|
618 |
+
assert (
|
619 |
+
len(fields["agg_func"]) == 3
|
620 |
+
), "fields['agg_func'] should be a 3-element list"
|
621 |
+
assert isinstance(
|
622 |
+
fields["agg_func"][0], str
|
623 |
+
), "first item in fields['agg_func'] should be a string name of a function"
|
624 |
+
assert callable(
|
625 |
+
fields["agg_func"][1]
|
626 |
+
), "second item in fields['agg_func'] should be a callable function"
|
627 |
+
assert isinstance(
|
628 |
+
fields["agg_func"][2], bool
|
629 |
+
), "third item in fields['agg_func'] should be a boolean value"
|
630 |
+
if "score_fields" in fields:
|
631 |
+
assert isinstance(fields["score_fields"], list)
|
632 |
+
|
633 |
+
# for aggregation functions that use the subgroup_column (expect a dict of lists), check that
|
634 |
+
# this field exists
|
635 |
+
if self.subgroup_column is not None:
|
636 |
+
assert all(
|
637 |
+
self.subgroup_column in instance["task_data"] for instance in instances
|
638 |
+
), f"each instance task_data dict must have a key {self.subgroup_column}"
|
639 |
+
|
640 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
641 |
+
instances, global_score = self.compute_instance_scores(stream)
|
642 |
+
|
643 |
+
for reduction_type, reduction_params in self.reduction_map.items():
|
644 |
+
assert (
|
645 |
+
reduction_type in self.implemented_reductions
|
646 |
+
), f"Reduction {reduction_type} is not implemented, use one of {self.implemented_reductions}"
|
647 |
+
|
648 |
+
field_name_full_prefix = ""
|
649 |
+
# used for passing to the bootstrapping, depends on whether the groups are fixed or not
|
650 |
+
aggregation_function = self.average_item_scores
|
651 |
+
if reduction_type == "mean":
|
652 |
+
reduction_fields = list(set(reduction_params))
|
653 |
+
# no group reduction, so resample instances individually
|
654 |
+
scores_to_resample = instances
|
655 |
+
elif reduction_type == "group_mean":
|
656 |
+
self._validate_group_mean_reduction(instances=instances)
|
657 |
+
reduction_fields = (
|
658 |
+
[self.main_score]
|
659 |
+
if "score_fields" not in reduction_params
|
660 |
+
else list(set(reduction_params["score_fields"]))
|
661 |
+
)
|
662 |
+
aggregation_function_name = str(reduction_params["agg_func"][0])
|
663 |
+
field_name_full_prefix = "group_" + aggregation_function_name + "_"
|
664 |
+
do_resample_as_group = reduction_params["agg_func"][2]
|
665 |
+
if do_resample_as_group:
|
666 |
+
# append fixed_ to name because resamples the groups as fixed units
|
667 |
+
field_name_full_prefix = "fixed_" + field_name_full_prefix
|
668 |
+
(
|
669 |
+
scores_to_resample,
|
670 |
+
aggregation_function,
|
671 |
+
) = self._set_up_group_mean_aggregation(
|
672 |
+
instances, reduction_params, reduction_fields
|
673 |
+
)
|
674 |
+
else:
|
675 |
+
raise ValueError(
|
676 |
+
f"Reduction {reduction_type} is not supported, please specify a valid reduction method in reduction_map {self.reduction_map}."
|
677 |
+
)
|
678 |
+
|
679 |
+
# calculate global scores for each reduction field
|
680 |
+
for field_name in reduction_fields:
|
681 |
+
field_name_full = field_name_full_prefix + field_name
|
682 |
+
# if group resampling (3rd element of agg_func parameter) is True, then
|
683 |
+
# 1. scores_to_resample are the group scores, and
|
684 |
+
# 2. aggregation_function is to take the raw mean
|
685 |
+
# if no group resampling (3rd element of agg_func parameter) is False, then
|
686 |
+
# 1. scores_to_resample are the original instance scores, and
|
687 |
+
# 2. aggregation_function is to apply the group aggregation from the instance scores
|
688 |
+
# either way, the application of aggregation_function to scores_to_resample yields the global score
|
689 |
+
global_score[field_name_full] = aggregation_function(
|
690 |
+
scores_to_resample, field_name
|
691 |
+
)
|
692 |
+
if field_name == self.main_score:
|
693 |
+
global_score["score"] = global_score[field_name_full]
|
694 |
+
global_score["score_name"] = field_name_full
|
695 |
+
|
696 |
+
# need to specify which fields should have CIs calculated for them through ci_scores
|
697 |
+
# (will not automatically calculate CIs for fields in reduction map)
|
698 |
+
if self.ci_scores is not None:
|
699 |
+
confidence_interval = self.score_based_confidence_interval(
|
700 |
+
instances=scores_to_resample,
|
701 |
+
score_names=list(set(self.ci_scores)),
|
702 |
+
ci_score_prefix=field_name_full_prefix,
|
703 |
+
aggregation_func=aggregation_function,
|
704 |
+
)
|
705 |
+
global_score.update(confidence_interval)
|
706 |
+
|
707 |
+
yield from instances
|
708 |
+
|
709 |
+
def compute_instance_scores(
|
710 |
+
self, stream: Stream, stream_name: Optional[str] = None
|
711 |
+
):
|
712 |
global_score = {}
|
713 |
instances = []
|
714 |
|
715 |
for instance in stream:
|
716 |
refs, pred = instance["references"], instance["prediction"]
|
717 |
+
task_data = instance["task_data"] if "task_data" in instance else {}
|
|
|
|
|
718 |
|
719 |
instance_score = self.compute(
|
720 |
+
references=refs, prediction=pred, task_data=task_data
|
721 |
)
|
722 |
instance_score["score"] = instance_score[self.main_score]
|
723 |
instance_score["score_name"] = self.main_score
|
|
|
730 |
|
731 |
instances.append(instance)
|
732 |
|
733 |
+
return instances, global_score
|
|
|
|
|
|
|
734 |
|
735 |
+
def get_group_scores(
|
736 |
+
self, instances: List[dict], score_names: List[str], group_aggregation_func
|
737 |
+
):
|
738 |
+
"""Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
|
739 |
+
|
740 |
+
Args:
|
741 |
+
instances: List of observation instances with instance-level scores (fields) computed.
|
742 |
+
score_names: List of instance score names in each instance to apply the aggregation function.
|
743 |
+
group_aggregation_func: Callable aggregation function accepting a list of numeric scores;
|
744 |
+
or, if self.subgroup_column is not None, a dict of subgroup types scores by subgroup_column value.
|
745 |
+
callable function returns a single score for the group
|
746 |
+
|
747 |
+
Returns:
|
748 |
+
List of dicts, each corresponding to a group of instances (defined by 'group_id'),
|
749 |
+
with an aggregate group score for each score_name
|
750 |
+
"""
|
751 |
+
from collections import defaultdict
|
752 |
|
753 |
+
# three-level defaultdict:
|
754 |
+
# first is the grouping, second is the field name, the third is the subgroup_type (by default 'default')
|
755 |
+
group_to_instance_scores = defaultdict(
|
756 |
+
lambda: defaultdict(lambda: defaultdict(list))
|
757 |
+
)
|
|
|
|
|
|
|
|
|
758 |
|
759 |
+
# check if function has fields for subgroup_column
|
760 |
+
uses_subgroups = self.subgroup_column is not None
|
761 |
+
default_subgroup_name = "default"
|
762 |
+
# loop through the instances and group the scores
|
763 |
+
for instance in instances:
|
764 |
+
task_data = instance["task_data"]
|
765 |
+
group_key = task_data["group_id"]
|
766 |
+
# for functions that do comparisons between subgroup_column groups
|
767 |
+
# if function doesn't use subgroup_column, or none is present, set "default" as default value, and pass all scores
|
768 |
+
subgroup_type = (
|
769 |
+
task_data[self.subgroup_column]
|
770 |
+
if uses_subgroups
|
771 |
+
else default_subgroup_name
|
772 |
+
)
|
773 |
+
for score_name in score_names:
|
774 |
+
group_to_instance_scores[group_key][score_name][subgroup_type].append(
|
775 |
+
instance["score"]["instance"][score_name]
|
776 |
)
|
|
|
777 |
|
778 |
+
# if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores
|
779 |
+
return [
|
780 |
+
{
|
781 |
+
"score": {
|
782 |
+
"instance": {
|
783 |
+
score_name: group_aggregation_func(
|
784 |
+
score_dict
|
785 |
+
if uses_subgroups
|
786 |
+
else score_dict[default_subgroup_name]
|
787 |
+
)
|
788 |
+
for score_name, score_dict in group_scores.items()
|
789 |
+
}
|
790 |
+
}
|
791 |
+
}
|
792 |
+
for group_scores in group_to_instance_scores.values()
|
793 |
+
]
|
794 |
+
|
795 |
+
def _set_up_group_mean_aggregation(
|
796 |
+
self, instances, reduction_params, reduction_fields
|
797 |
+
):
|
798 |
+
group_aggregation_func = reduction_params["agg_func"][1]
|
799 |
+
# if treat groups as units
|
800 |
+
do_resample_as_group = reduction_params["agg_func"][2]
|
801 |
+
if do_resample_as_group:
|
802 |
+
# pass the group aggregate---not instance---scores to resample as usual
|
803 |
+
aggregation_function = self.average_item_scores
|
804 |
+
scores_to_resample = self.get_group_scores(
|
805 |
+
instances, reduction_fields, group_aggregation_func
|
806 |
+
)
|
807 |
+
else:
|
808 |
+
# pass the instance scores to resample, and calculate the group aggregation on the resamplings
|
809 |
+
scores_to_resample = instances
|
810 |
+
|
811 |
+
def aggregation_function(
|
812 |
+
instances,
|
813 |
+
field_name,
|
814 |
+
group_aggregation_func=group_aggregation_func,
|
815 |
+
):
|
816 |
+
group_scores = self.get_group_scores(
|
817 |
+
instances, [field_name], group_aggregation_func
|
818 |
+
)
|
819 |
+
return nan_mean(
|
820 |
+
[group["score"]["instance"][field_name] for group in group_scores]
|
821 |
+
)
|
822 |
+
|
823 |
+
return scores_to_resample, aggregation_function
|
824 |
|
825 |
@abstractmethod
|
826 |
+
def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
|
|
|
|
|
827 |
pass
|
828 |
|
829 |
|
|
|
840 |
self,
|
841 |
references: List[List[str]],
|
842 |
predictions: List[str],
|
843 |
+
task_data: List[Dict],
|
844 |
) -> dict:
|
845 |
ids = [str(uuid.uuid4()).replace("-", "") for _ in range(len(predictions))]
|
846 |
formatted_predictions = [
|
|
|
861 |
class Accuracy(InstanceMetric):
|
862 |
reduction_map = {"mean": ["accuracy"]}
|
863 |
main_score = "accuracy"
|
864 |
+
ci_scores = ["accuracy"]
|
865 |
|
866 |
def compute(
|
867 |
+
self, references: List[Any], prediction: Any, task_data: List[Dict]
|
868 |
) -> dict:
|
869 |
result = {
|
870 |
self.main_score: float(
|
|
|
879 |
class StringContainment(InstanceMetric):
|
880 |
reduction_map = {"mean": ["string_containment"]}
|
881 |
main_score = "string_containment"
|
882 |
+
ci_scores = ["string_containment"]
|
883 |
|
884 |
def compute(
|
885 |
+
self, references: List[Any], prediction: Any, task_data: List[Dict]
|
886 |
) -> dict:
|
887 |
result = {
|
888 |
self.main_score: float(
|
889 |
+
any(str(reference) in str(prediction) for reference in references)
|
890 |
)
|
891 |
}
|
892 |
result["score"] = result[self.main_score]
|
|
|
902 |
)
|
903 |
metric: Metric = None
|
904 |
|
905 |
+
def disable_confidence_interval_calculation(self):
|
906 |
+
return self.metric.disable_confidence_interval_calculation()
|
907 |
+
|
908 |
+
def set_n_resamples(self, n_resample):
|
909 |
+
if isinstance(self.metric, MetricWithConfidenceInterval):
|
910 |
+
self.metric.set_n_resamples(n_resample)
|
911 |
+
|
912 |
def verify(self):
|
913 |
assert self.main_score is not None, "main_score is not set"
|
914 |
|
|
|
973 |
self,
|
974 |
references: List[List[Any]],
|
975 |
predictions: List[Any],
|
976 |
+
task_data: List[Dict],
|
977 |
) -> dict:
|
978 |
+
passed_task_data = {}
|
979 |
for additional_input_field in self.hf_additional_input_fields:
|
980 |
assert (
|
981 |
+
additional_input_field in task_data[0]
|
982 |
+
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
|
983 |
+
passed_task_data[additional_input_field] = [
|
984 |
additional_input[additional_input_field]
|
985 |
+
for additional_input in task_data
|
986 |
]
|
987 |
for additional_input_field in self.hf_additional_input_fields_pass_one_value:
|
988 |
assert (
|
989 |
+
additional_input_field in task_data[0]
|
990 |
+
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
|
991 |
|
992 |
values = {
|
993 |
additional_input[additional_input_field]
|
994 |
+
for additional_input in task_data
|
995 |
}
|
996 |
assert (
|
997 |
len(values) == 1
|
998 |
), f"Values of '{additional_input_field}' field required by {__class__.__name__} should all be the same, but have multiple values {values}"
|
999 |
|
1000 |
+
passed_task_data[additional_input_field] = next(iter(values))
|
1001 |
|
1002 |
+
# add check that all required fields in self.metrics are in passed_task_data print(passed_task_data)
|
1003 |
result = self.metric.compute(
|
1004 |
predictions=predictions,
|
1005 |
references=references,
|
1006 |
+
**passed_task_data,
|
1007 |
**self.hf_compute_args,
|
1008 |
)
|
1009 |
if self.hf_main_score:
|
|
|
1045 |
self,
|
1046 |
references: List[List[str]],
|
1047 |
predictions: List[str],
|
1048 |
+
task_data: List[Any],
|
1049 |
) -> List[Dict[str, Any]]:
|
1050 |
+
passed_task_data = {}
|
1051 |
for additional_input_field in self.hf_additional_input_fields:
|
1052 |
assert (
|
1053 |
+
additional_input_field in task_data[0]
|
1054 |
+
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
|
1055 |
+
passed_task_data[additional_input_field] = [
|
1056 |
additional_input[additional_input_field]
|
1057 |
+
for additional_input in task_data
|
1058 |
]
|
1059 |
+
# add check that all required fields in self.metrics are in passed_task_data
|
1060 |
|
1061 |
scores = self.metric.compute(
|
1062 |
predictions=predictions,
|
1063 |
references=references,
|
1064 |
+
**passed_task_data,
|
1065 |
**self.hf_compute_args,
|
1066 |
)
|
1067 |
|
|
|
1096 |
self,
|
1097 |
references: List[List[str]],
|
1098 |
predictions: List[str],
|
1099 |
+
task_data: List[Dict],
|
1100 |
) -> dict:
|
1101 |
assert all(
|
1102 |
len(reference) == 1 for reference in references
|
|
|
1118 |
average=self.average,
|
1119 |
)
|
1120 |
if isinstance(result["f1"], numpy.ndarray):
|
|
|
|
|
1121 |
final_result = {self.main_score: mean(result["f1"])}
|
1122 |
for i, label in enumerate(labels):
|
1123 |
final_result["f1_" + self.id_to_str[label]] = result["f1"][i]
|
|
|
1144 |
_metric = None
|
1145 |
main_score = "f1_macro"
|
1146 |
average = None # Report per class then aggregate by mean
|
|
|
1147 |
metric = "f1"
|
1148 |
|
1149 |
def prepare(self):
|
|
|
1168 |
self,
|
1169 |
references: List[List[str]],
|
1170 |
predictions: List[List[str]],
|
1171 |
+
task_data: List[Dict],
|
1172 |
) -> dict:
|
1173 |
self.str_to_id = {}
|
1174 |
self.id_to_str = {}
|
|
|
1176 |
self._validate_references_and_prediction(references, predictions)
|
1177 |
references = [reference[0] for reference in references]
|
1178 |
|
1179 |
+
labels = list({label for reference in references for label in reference})
|
1180 |
+
|
|
|
|
|
|
|
1181 |
# if no classes are left then F1 is not defined
|
|
|
1182 |
if len(labels) == 0:
|
1183 |
return {self.main_score: float("nan")}
|
1184 |
|
|
|
1206 |
labels=labels_param,
|
1207 |
)
|
1208 |
if isinstance(result[self.metric], numpy.ndarray):
|
|
|
|
|
1209 |
assert (
|
1210 |
len(result[self.metric]) == len(labels)
|
1211 |
), f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
|
|
|
1278 |
|
1279 |
sent_split_newline: bool = True
|
1280 |
|
1281 |
+
_requirements_list: List[str] = ["nltk", "rouge_score"]
|
1282 |
+
|
1283 |
def prepare(self):
|
1284 |
super().prepare()
|
1285 |
|
|
|
1292 |
nltk.download("punkt")
|
1293 |
self.sent_tokenize = nltk.sent_tokenize
|
1294 |
|
1295 |
+
def compute(self, references, predictions, task_data: List[Dict]):
|
1296 |
if self.sent_split_newline:
|
1297 |
predictions = [
|
1298 |
"\n".join(self.sent_tokenize(prediction.strip()))
|
|
|
1302 |
["\n".join(self.sent_tokenize(r.strip())) for r in reference]
|
1303 |
for reference in references
|
1304 |
]
|
1305 |
+
return super().compute(references, predictions, task_data)
|
1306 |
|
1307 |
|
1308 |
# Computes char edit distance, ignoring whitespace
|
1309 |
class CharEditDistanceAccuracy(InstanceMetric):
|
1310 |
reduction_map = {"mean": ["char_edit_dist_accuracy"]}
|
1311 |
main_score = "char_edit_dist_accuracy"
|
1312 |
+
ci_scores = ["char_edit_dist_accuracy"]
|
1313 |
+
|
1314 |
+
_requirements_list: List[str] = ["editdistance"]
|
1315 |
|
1316 |
def prepare(self):
|
1317 |
super().prepare()
|
|
|
1319 |
|
1320 |
self.eval = editdistance.eval
|
1321 |
|
1322 |
+
def compute(self, references, prediction: str, task_data: List[Dict]) -> dict:
|
|
|
|
|
1323 |
assert (
|
1324 |
len(references) == 1
|
1325 |
), f"Expected only one reference , but received: {references}"
|
|
|
1337 |
hf_metric_name = "wer"
|
1338 |
main_score = "wer"
|
1339 |
|
1340 |
+
_requirements_list: List[str] = ["jiwer"]
|
1341 |
+
|
1342 |
def compute(
|
1343 |
self,
|
1344 |
references: List[List[str]],
|
1345 |
predictions: List[str],
|
1346 |
+
task_data: List[Dict],
|
1347 |
) -> dict:
|
1348 |
assert all(
|
1349 |
len(reference) == 1 for reference in references
|
|
|
1355 |
return {self.main_score: result}
|
1356 |
|
1357 |
|
1358 |
+
class Spearmanr(HuggingfaceMetric):
|
1359 |
+
hf_metric_name = "spearmanr"
|
1360 |
+
main_score = "spearmanr"
|
1361 |
+
process_single_instances = False
|
1362 |
+
|
1363 |
+
|
1364 |
+
class KendallTauMetric(GlobalMetric):
|
1365 |
+
main_score = "kendalltau_b"
|
1366 |
+
variant = "b"
|
1367 |
+
process_single_instances = False
|
1368 |
+
|
1369 |
+
_requirements_list: List[str] = ["scipy"]
|
1370 |
+
|
1371 |
+
def prepare(self):
|
1372 |
+
from scipy.stats import kendalltau
|
1373 |
+
|
1374 |
+
self.kendalltau = kendalltau
|
1375 |
+
|
1376 |
+
def compute(
|
1377 |
+
self,
|
1378 |
+
references: List[List[str]],
|
1379 |
+
predictions: List[str],
|
1380 |
+
task_data: List[Dict],
|
1381 |
+
) -> dict:
|
1382 |
+
if isinstance(references[0], list):
|
1383 |
+
references = [reference[0] for reference in references]
|
1384 |
+
references = [to_float_or_default(r) for r in references]
|
1385 |
+
predictions = [to_float_or_default(p) for p in predictions]
|
1386 |
+
|
1387 |
+
kendall_results = self.kendalltau(references, predictions, variant=self.variant)
|
1388 |
+
corr = kendall_results.correlation
|
1389 |
+
return {
|
1390 |
+
self.main_score: corr,
|
1391 |
+
f"{self.main_score}_p_val": kendall_results.pvalue,
|
1392 |
+
}
|
1393 |
+
|
1394 |
+
|
1395 |
class MatthewsCorrelation(HuggingfaceMetric):
|
1396 |
hf_metric_name = "matthews_correlation"
|
1397 |
main_score = "matthews_correlation"
|
|
|
1407 |
self,
|
1408 |
references: List[List[str]],
|
1409 |
predictions: List[str],
|
1410 |
+
task_data: List[Dict],
|
1411 |
) -> dict:
|
1412 |
formatted_references = [
|
1413 |
self.get_str_id(reference[0]) for reference in references
|
|
|
1420 |
)
|
1421 |
|
1422 |
|
1423 |
+
class RocAuc(GlobalMetric):
|
1424 |
+
main_score = "roc_auc"
|
1425 |
+
process_single_instances = False
|
1426 |
+
_requirements_list: List[str] = ["sklearn"]
|
1427 |
+
|
1428 |
+
def prepare(self):
|
1429 |
+
from sklearn import metrics
|
1430 |
+
|
1431 |
+
self.roc_curve = metrics.roc_curve
|
1432 |
+
self.auc = metrics.auc
|
1433 |
+
|
1434 |
+
def compute(
|
1435 |
+
self,
|
1436 |
+
references: List[List[str]],
|
1437 |
+
predictions: List[str],
|
1438 |
+
task_data: List[Dict],
|
1439 |
+
) -> dict:
|
1440 |
+
if isinstance(references[0], list):
|
1441 |
+
references = [reference[0] for reference in references]
|
1442 |
+
references = [to_float_or_default(r) for r in references]
|
1443 |
+
predictions = [to_float_or_default(p) for p in predictions]
|
1444 |
+
|
1445 |
+
fpr, tpr, thrs = self.roc_curve(y_true=references, y_score=predictions)
|
1446 |
+
roc_auc = self.auc(fpr, tpr)
|
1447 |
+
return {self.main_score: roc_auc}
|
1448 |
+
|
1449 |
+
|
1450 |
class CustomF1(GlobalMetric):
|
1451 |
main_score = "f1_micro"
|
1452 |
groups = None
|
|
|
1500 |
except ZeroDivisionError:
|
1501 |
return self.zero_division
|
1502 |
|
1503 |
+
def get_groups(self, elements, task_data):
|
1504 |
groups = set()
|
1505 |
+
for sublist, additional_input in zip(elements, task_data):
|
1506 |
for e in sublist:
|
1507 |
if self.should_ignore_element(e, additional_input):
|
1508 |
continue
|
|
|
1513 |
self,
|
1514 |
references: List[List[Any]],
|
1515 |
predictions: List[Any],
|
1516 |
+
task_data: List[Dict],
|
1517 |
) -> dict:
|
1518 |
# in case reference are List[List[List[Any]]] and predictions are List[List[Any]]:
|
1519 |
if (
|
|
|
1529 |
)
|
1530 |
|
1531 |
if self.groups is None:
|
1532 |
+
groups = self.get_groups(references, task_data)
|
1533 |
else:
|
1534 |
groups = self.groups
|
1535 |
groups_statistics = {}
|
1536 |
for references_batch, predictions_batch, additional_input in zip(
|
1537 |
+
references, predictions, task_data
|
1538 |
):
|
1539 |
grouped_references = self.group_elements(references_batch, additional_input)
|
1540 |
grouped_predictions = self.group_elements(
|
|
|
1651 |
ci_scores = ["f1", "precision", "recall"]
|
1652 |
|
1653 |
def compute(
|
1654 |
+
self, references: List[Any], prediction: Any, task_data: List[Dict]
|
1655 |
) -> dict:
|
1656 |
results = [
|
1657 |
+
self._compute_single_ref(str(reference), str(prediction))
|
1658 |
+
for reference in references
|
1659 |
]
|
1660 |
return {
|
1661 |
measure: max(r[i] for r in results)
|
|
|
1665 |
def _compute_single_ref(
|
1666 |
self, reference: Any, prediction: Any
|
1667 |
) -> Tuple[float, float, float]:
|
1668 |
+
prediction_tokens = normalize_answer(str(prediction)).split()
|
1669 |
+
reference_tokens = normalize_answer(str(reference)).split()
|
1670 |
common = Counter(prediction_tokens) & Counter(reference_tokens)
|
1671 |
num_same = sum(common.values())
|
1672 |
if num_same == 0:
|
|
|
1686 |
ci_scores = ["f1", "precision", "recall"]
|
1687 |
model_name: str
|
1688 |
|
1689 |
+
_requirements_list: List[str] = ["bert_score"]
|
1690 |
+
|
1691 |
def prepare(self):
|
1692 |
super().prepare()
|
1693 |
+
self.hf_compute_args = {"model_type": self.model_name, "batch_size": 16}
|
1694 |
|
1695 |
|
1696 |
class SentenceBert(BulkInstanceMetric):
|
|
|
1700 |
|
1701 |
model_name: str
|
1702 |
|
1703 |
+
_requirements_list: List[str] = ["sentence_transformers"]
|
1704 |
+
|
1705 |
def prepare(self):
|
1706 |
super().prepare()
|
1707 |
+
import torch
|
1708 |
from sentence_transformers import SentenceTransformer
|
1709 |
from sentence_transformers import util as sbert_util
|
1710 |
|
1711 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
1712 |
+
self.model = SentenceTransformer(self.model_name, device=self.device)
|
1713 |
self.util = sbert_util
|
1714 |
|
1715 |
def compute(
|
1716 |
self,
|
1717 |
references: List[List[Any]],
|
1718 |
predictions: List[Any],
|
1719 |
+
task_data: List[Dict],
|
1720 |
) -> List[Dict[str, Any]]:
|
1721 |
scores = []
|
1722 |
|
|
|
1731 |
count += len(ref_group)
|
1732 |
|
1733 |
# compute s-bert embeddings
|
1734 |
+
preds_emb = self.model.encode(predictions, device=self.device)
|
1735 |
refs_emb = self.model.encode(
|
1736 |
+
[ref for ref_group in references for ref in ref_group], device=self.device
|
1737 |
)
|
1738 |
|
1739 |
# for each candidate, pick the reference with the highest score
|
|
|
1751 |
|
1752 |
model_name: str
|
1753 |
|
1754 |
+
_requirements_list: List[str] = ["transformers"]
|
1755 |
+
|
1756 |
def prepare(self):
|
1757 |
super().prepare()
|
1758 |
+
import torch
|
1759 |
from transformers import pipeline
|
1760 |
|
1761 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
1762 |
+
self.pipe = pipeline(
|
1763 |
+
"text-classification", model=self.model_name, device=device
|
1764 |
+
)
|
1765 |
|
1766 |
def compute(
|
1767 |
self,
|
1768 |
references: List[List[Any]],
|
1769 |
predictions: List[Any],
|
1770 |
+
task_data: List[Dict],
|
1771 |
) -> List[Dict[str, Any]]:
|
1772 |
# treat the references as the questions and the predictions as answers
|
1773 |
# assume a single reference
|
|
|
1793 |
batch_size: int = 32
|
1794 |
model_name: str
|
1795 |
|
1796 |
+
_requirements_list: List[str] = ["transformers"]
|
1797 |
+
|
1798 |
def compute(
|
1799 |
self,
|
1800 |
references: List[List[Any]],
|
1801 |
predictions: List[Any],
|
1802 |
+
task_data: List[Dict],
|
1803 |
) -> List[Dict[str, Any]]:
|
1804 |
"""Computes the likelihood of generating text Y after text X - P(Y|X).
|
1805 |
|
1806 |
+
:param predictions: the list of Y texts = the targets of the generation
|
1807 |
+
:param references: the list of list of X texts = the sources of the generation
|
1808 |
|
1809 |
+
:return: the likelihood of generating text Y_i after each text X_i_j = P(Y_i|X_i_1), ..., P(Y_i|X_i_n) for every i.
|
1810 |
"""
|
1811 |
sources = []
|
1812 |
targets = []
|
1813 |
for prediction, instance_references in zip(predictions, references):
|
1814 |
for instance_reference in instance_references:
|
1815 |
+
sources.append(f"{self.perplexity_prompt} {instance_reference}")
|
1816 |
+
targets.append(prediction)
|
1817 |
|
1818 |
from transformers import AutoConfig
|
1819 |
|
|
|
1854 |
from transformers import AutoTokenizer
|
1855 |
|
1856 |
self.model_name = model_name
|
1857 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
1858 |
+
self.model = (
|
1859 |
+
self.model_class().from_pretrained(self.model_name).to(self.device)
|
1860 |
+
)
|
1861 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
|
|
|
1862 |
|
1863 |
def compute_lm(
|
1864 |
self, source: List[str], target: List[str], batch_size: int
|
|
|
1951 |
return AutoModelForSeq2SeqLM
|
1952 |
|
1953 |
def compute_batch(self, tokens_source, tokens_target):
|
1954 |
+
tokens_docs_ids = tokens_source["input_ids"].to(self.device)
|
1955 |
+
attention = tokens_source["attention_mask"].to(self.device)
|
1956 |
+
labels = tokens_target["input_ids"].to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1957 |
|
1958 |
logits = self.model(
|
1959 |
input_ids=tokens_docs_ids.long(),
|
|
|
1993 |
# replace the padding token in the labels by -100
|
1994 |
labels[labels == self.tokenizer.pad_token_id] = -100
|
1995 |
|
1996 |
+
tokens = tokens.to(self.device)
|
1997 |
+
attention = attention.to(self.device)
|
1998 |
+
labels = labels.to(self.device)
|
|
|
|
|
|
|
1999 |
|
2000 |
# no need to pass labels as we calculate the loss below per document
|
2001 |
model_output = self.model(
|
|
|
2029 |
|
2030 |
main_score = "nDCG"
|
2031 |
|
2032 |
+
_requirements_list: List[str] = ["sklearn"]
|
2033 |
+
|
2034 |
def prepare(self):
|
2035 |
from sklearn.metrics import ndcg_score
|
2036 |
|
|
|
2041 |
self,
|
2042 |
references: List[List[Any]],
|
2043 |
predictions: List[Any],
|
2044 |
+
task_data: List[Any],
|
2045 |
) -> dict:
|
2046 |
from collections import defaultdict
|
|
|
2047 |
|
2048 |
query_to_predictions_and_references = defaultdict(lambda: [[], []])
|
2049 |
+
for reference, pred, inputs_dict in zip(references, predictions, task_data):
|
|
|
|
|
2050 |
query = inputs_dict.get("query")
|
2051 |
query_to_predictions_and_references[query][0].append(pred)
|
2052 |
query_to_predictions_and_references[query][1].append(reference)
|
|
|
2076 |
|
2077 |
|
2078 |
class RetrievalMetric(InstanceMetric):
|
2079 |
+
def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
|
|
|
|
|
2080 |
# digest input
|
2081 |
pred_ids: List[Any] = prediction
|
2082 |
ref_ids: List[Any] = list(dict.fromkeys(references))
|
|
|
2149 |
class MRR(RetrievalMetric):
|
2150 |
reduction_map = {"mean": ["mrr"]}
|
2151 |
main_score = "mrr"
|
2152 |
+
ci_scores = ["mrr"]
|
2153 |
|
2154 |
def _compute(
|
2155 |
self,
|
|
|
2166 |
class MAP(RetrievalMetric):
|
2167 |
reduction_map = {"mean": ["map"]}
|
2168 |
main_score = "map"
|
2169 |
+
ci_scores = ["map"]
|
2170 |
|
2171 |
def _compute(
|
2172 |
self,
|
|
|
2235 |
|
2236 |
def should_ignore_element(self, element, additional_input):
|
2237 |
return element == "none"
|
2238 |
+
|
2239 |
+
|
2240 |
+
class RemoteMetric(SingleStreamOperator, Metric):
|
2241 |
+
"""A metric that runs another metric remotely.
|
2242 |
+
|
2243 |
+
main_score: the score updated by this metric.
|
2244 |
+
endpoint: the remote host that supports the remote metric execution.
|
2245 |
+
metric_name: the name of the metric that is executed remotely.
|
2246 |
+
api_key: optional, passed to the remote metric with the input, allows secure authentication.
|
2247 |
+
"""
|
2248 |
+
|
2249 |
+
main_score: str = None
|
2250 |
+
endpoint: str
|
2251 |
+
metric_name: str
|
2252 |
+
api_key: str = None
|
2253 |
+
|
2254 |
+
@staticmethod
|
2255 |
+
def wrap_inner_metric_pipeline_metric(
|
2256 |
+
metric_pipeline: MetricPipeline, remote_metrics_endpoint: str
|
2257 |
+
) -> MetricPipeline:
|
2258 |
+
"""Wrap the inner metric in a MetricPipeline with a RemoteMetric.
|
2259 |
+
|
2260 |
+
When executing the returned MetricPipeline, the inner metric will be computed
|
2261 |
+
remotely (pre and post processing steps in the MetricPipeline will be computed locally).
|
2262 |
+
"""
|
2263 |
+
local_inner_metric = metric_pipeline.metric
|
2264 |
+
metric_pipeline = deepcopy(
|
2265 |
+
metric_pipeline
|
2266 |
+
) # To avoid unintentional changes to the catalog contents
|
2267 |
+
metric_pipeline.metric = RemoteMetric(
|
2268 |
+
main_score=local_inner_metric.main_score,
|
2269 |
+
metric_name=local_inner_metric.artifact_identifier,
|
2270 |
+
endpoint=remote_metrics_endpoint,
|
2271 |
+
)
|
2272 |
+
return metric_pipeline
|
2273 |
+
|
2274 |
+
def get_metric_url(self) -> str:
|
2275 |
+
return f"{self.endpoint}/{self.metric_name}"
|
2276 |
+
|
2277 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
2278 |
+
predictions, references, additional_inputs, instances = self.consume_stream(
|
2279 |
+
stream
|
2280 |
+
)
|
2281 |
+
metric_request = self.create_metric_request(
|
2282 |
+
predictions, references, additional_inputs
|
2283 |
+
)
|
2284 |
+
metric_response = self.get_metric_response(metric_request)
|
2285 |
+
self.update_instance_scores(instances, metric_response.instances_scores)
|
2286 |
+
self.set_global_score(instances, metric_response.global_score)
|
2287 |
+
yield from instances
|
2288 |
+
|
2289 |
+
@staticmethod
|
2290 |
+
def create_metric_request(predictions, references, additional_inputs):
|
2291 |
+
instance_inputs = [
|
2292 |
+
InstanceInput(
|
2293 |
+
prediction=prediction,
|
2294 |
+
references=reference,
|
2295 |
+
additional_inputs=additional_input,
|
2296 |
+
)
|
2297 |
+
for prediction, reference, additional_input in zip(
|
2298 |
+
predictions, references, additional_inputs
|
2299 |
+
)
|
2300 |
+
]
|
2301 |
+
return MetricRequest(instance_inputs=instance_inputs)
|
2302 |
+
|
2303 |
+
def get_metric_response(self, metric_request: MetricRequest) -> MetricResponse:
|
2304 |
+
import requests
|
2305 |
+
|
2306 |
+
response = requests.post(
|
2307 |
+
url=self.get_metric_url(),
|
2308 |
+
json=metric_request.to_dict(),
|
2309 |
+
headers={"Authorization": f"Bearer {self.api_key}"},
|
2310 |
+
)
|
2311 |
+
response.raise_for_status()
|
2312 |
+
response_json = response.json()
|
2313 |
+
return MetricResponse(**response_json)
|
2314 |
+
|
2315 |
+
def disable_confidence_interval_calculation(self):
|
2316 |
+
"""Confidence intervals are always disabled for RemoteMetric.
|
2317 |
+
|
2318 |
+
No need to do anything.
|
2319 |
+
"""
|
2320 |
+
pass
|
2321 |
+
|
2322 |
+
def set_n_resamples(self, n_resample):
|
2323 |
+
"""Since confidence intervals are always disabled for remote metrics, this is a no-op."""
|
2324 |
+
pass
|
2325 |
+
|
2326 |
+
|
2327 |
+
def validate_subgroup_types(
|
2328 |
+
subgroup_scores_dict: Dict[str, List],
|
2329 |
+
control_subgroup_types: List[str],
|
2330 |
+
comparison_subgroup_types: List[str],
|
2331 |
+
):
|
2332 |
+
"""Validate a dict of subgroup type instance score lists, and subgroup type lists.
|
2333 |
+
|
2334 |
+
Args:
|
2335 |
+
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
2336 |
+
control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
|
2337 |
+
comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
|
2338 |
+
to be compared to the control group.
|
2339 |
+
|
2340 |
+
Returns:
|
2341 |
+
dict with all NaN scores removed; control_subgroup_types and comparison_subgroup_types will have non-unique elements removed
|
2342 |
+
"""
|
2343 |
+
# note: subgroup_scores_dict is already a defaultdict of lists, so don't need to check that keys in control_ and comparison_subgroup_types exist in it
|
2344 |
+
# remove any NaNs
|
2345 |
+
subgroup_scores_dict.update(
|
2346 |
+
{
|
2347 |
+
subgroup_name: [score for score in score_list if not np.isnan(score)]
|
2348 |
+
for subgroup_name, score_list in subgroup_scores_dict.items()
|
2349 |
+
}
|
2350 |
+
)
|
2351 |
+
assert isinstance(
|
2352 |
+
control_subgroup_types, list
|
2353 |
+
), "control_subgroup_types must be a list"
|
2354 |
+
assert isinstance(
|
2355 |
+
comparison_subgroup_types, list
|
2356 |
+
), "comparison_subgroup_types must be a list"
|
2357 |
+
# make sure each list is unique, so that labels aren't double-counted
|
2358 |
+
control_subgroup_types = list(set(control_subgroup_types))
|
2359 |
+
comparison_subgroup_types = list(set(comparison_subgroup_types))
|
2360 |
+
|
2361 |
+
return subgroup_scores_dict, control_subgroup_types, comparison_subgroup_types
|
2362 |
+
|
2363 |
+
|
2364 |
+
def performance_drop_rate(
|
2365 |
+
subgroup_scores_dict: Dict[str, List],
|
2366 |
+
control_subgroup_types: List[str],
|
2367 |
+
comparison_subgroup_types: List[str],
|
2368 |
+
):
|
2369 |
+
"""Percentage decrease of mean performance on test elements relative to that on a baseline (control).
|
2370 |
+
|
2371 |
+
from https://arxiv.org/pdf/2306.04528.pdf.
|
2372 |
+
|
2373 |
+
Args:
|
2374 |
+
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
2375 |
+
control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
|
2376 |
+
comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
|
2377 |
+
to be compared to the control group.
|
2378 |
+
|
2379 |
+
Returns:
|
2380 |
+
numeric PDR metric.
|
2381 |
+
If only one element (no test set) or the first is 0 (percentage change is undefined) return NaN
|
2382 |
+
otherwise, calculate PDR
|
2383 |
+
"""
|
2384 |
+
(
|
2385 |
+
subgroup_scores_dict,
|
2386 |
+
control_subgroup_types,
|
2387 |
+
comparison_subgroup_types,
|
2388 |
+
) = validate_subgroup_types(
|
2389 |
+
subgroup_scores_dict, control_subgroup_types, comparison_subgroup_types
|
2390 |
+
)
|
2391 |
+
|
2392 |
+
# combine all scores from each label (if there are more than 1 in each group) into a list
|
2393 |
+
group_scores_list = [
|
2394 |
+
np.concatenate(
|
2395 |
+
[subgroup_scores_dict[subgroup_name] for subgroup_name in name_list]
|
2396 |
+
)
|
2397 |
+
for name_list in [control_subgroup_types, comparison_subgroup_types]
|
2398 |
+
]
|
2399 |
+
if any(len(scores) == 0 for scores in group_scores_list):
|
2400 |
+
# no comparison can be made since there is not at least one score per type
|
2401 |
+
return np.nan
|
2402 |
+
control_mean = mean(group_scores_list[0])
|
2403 |
+
comparison_mean = mean(group_scores_list[1])
|
2404 |
+
if control_mean == 0:
|
2405 |
+
# return 0 if comparison is also 0
|
2406 |
+
if comparison_mean == 0:
|
2407 |
+
return 0
|
2408 |
+
return np.nan
|
2409 |
+
# otherwise, take the percentage change (which may also be 0)
|
2410 |
+
return 1 - comparison_mean / control_mean
|
2411 |
+
|
2412 |
+
|
2413 |
+
def interpret_effect_size(x: float):
|
2414 |
+
"""Return a string rule-of-thumb interpretation of an effect size value, as defined by Cohen/Sawilowsky.
|
2415 |
+
|
2416 |
+
See https://en.wikipedia.org/wiki/Effect_size;
|
2417 |
+
Cohen, Jacob (1988). Statistical Power Analysis for the Behavioral Sciences; and
|
2418 |
+
Sawilowsky, S (2009). "New effect size rules of thumb". Journal of Modern Applied Statistical Methods. 8 (2): 467-474.
|
2419 |
+
|
2420 |
+
Value has interpretation of
|
2421 |
+
- essentially 0 if |x| < 0.01
|
2422 |
+
- very small if 0.01 <= |x| < 0.2
|
2423 |
+
- small difference if 0.2 <= |x| < 0.5
|
2424 |
+
- a medium difference if 0.5 <= |x| < 0.8
|
2425 |
+
- a large difference if 0.8 <= |x| < 1.2
|
2426 |
+
- a very large difference if 1.2 <= |x| < 2.0
|
2427 |
+
- a huge difference if 2.0 <= |x|
|
2428 |
+
|
2429 |
+
Args:
|
2430 |
+
x: float effect size value
|
2431 |
+
|
2432 |
+
Returns:
|
2433 |
+
string interpretation
|
2434 |
+
"""
|
2435 |
+
import pandas as pd
|
2436 |
+
|
2437 |
+
# assign a label according to threshold of the absolute value
|
2438 |
+
return pd.cut(
|
2439 |
+
x=[np.abs(x)],
|
2440 |
+
right=False,
|
2441 |
+
bins=[-1, 0.01, 0.2, 0.5, 0.8, 1.2, 2.0, np.Inf],
|
2442 |
+
labels=[
|
2443 |
+
"essentially zero",
|
2444 |
+
"very small",
|
2445 |
+
"small",
|
2446 |
+
"medium",
|
2447 |
+
"large",
|
2448 |
+
"very large",
|
2449 |
+
"huge",
|
2450 |
+
],
|
2451 |
+
)[0]
|
2452 |
+
|
2453 |
+
|
2454 |
+
def normalized_cohens_h(
|
2455 |
+
subgroup_scores_dict: Dict[str, List],
|
2456 |
+
control_subgroup_types: List[str],
|
2457 |
+
comparison_subgroup_types: List[str],
|
2458 |
+
interpret=False,
|
2459 |
+
):
|
2460 |
+
"""Cohen's h effect size between two proportions, normalized to interval [-1,1].
|
2461 |
+
|
2462 |
+
Allows for change-type metric when the baseline is 0 (percentage change, and thus PDR, is undefined)
|
2463 |
+
https://en.wikipedia.org/wiki/Cohen%27s_h
|
2464 |
+
|
2465 |
+
Cohen's h effect size metric between two proportions p2 and p1 is 2 * (arcsin(sqrt(p2)) - arcsin(sqrt(p1))).
|
2466 |
+
h in -pi, pi, with +/-pi representing the largest increase/decrease (p1=0, p2=1), or (p1=1, p2=0).
|
2467 |
+
h=0 is no change. Unlike percentage change, h is defined even if the baseline (p1) is 0.
|
2468 |
+
Assumes the scores are in [0,1], either continuous or binary; hence taking the average of a group of scores yields a proportion..
|
2469 |
+
Calculates the change in the average of the other_scores relative to the average of the baseline_scores. We rescale this to [-1,1] from [-pi,pi] for clarity, where +- 1 are the most extreme changes, and 0 is no change
|
2470 |
+
|
2471 |
+
Interpretation: the original unscaled Cohen's h can be interpreted according to function interpret_effect_size
|
2472 |
+
|
2473 |
+
Thus, the rule of interpreting the effect of the normalized value is to use the same thresholds divided by pi
|
2474 |
+
- essentially 0 if |norm h| < 0.0031831
|
2475 |
+
- very small if 0.0031831 <= |norm h| < 0.06366198
|
2476 |
+
- small difference if 0.06366198 <= |norm h| < 0.15915494
|
2477 |
+
- a medium difference if 0.15915494 <= |norm h| < 0.25464791
|
2478 |
+
- a large difference if 0.25464791 <= |norm h| < 0.38197186
|
2479 |
+
- a very large difference if 0.38197186 <= |norm h| < 0.63661977
|
2480 |
+
- a huge difference if 0.63661977 <= |norm h|
|
2481 |
+
Args:
|
2482 |
+
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
2483 |
+
control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
|
2484 |
+
comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
|
2485 |
+
to be compared to the control group.
|
2486 |
+
interpret: boolean, whether to interpret the significance of the score or not
|
2487 |
+
Returns:
|
2488 |
+
float score between -1 and 1, and a string interpretation if interpret=True
|
2489 |
+
"""
|
2490 |
+
(
|
2491 |
+
subgroup_scores_dict,
|
2492 |
+
control_subgroup_types,
|
2493 |
+
comparison_subgroup_types,
|
2494 |
+
) = validate_subgroup_types(
|
2495 |
+
subgroup_scores_dict, control_subgroup_types, comparison_subgroup_types
|
2496 |
+
)
|
2497 |
+
|
2498 |
+
# requires scores to be in [0,1]
|
2499 |
+
for subgroup_name, score_list in subgroup_scores_dict.items():
|
2500 |
+
assert all(
|
2501 |
+
0 <= score <= 1 for score in score_list
|
2502 |
+
), f"all {subgroup_name} scores must be in [0,1]"
|
2503 |
+
|
2504 |
+
# combine all scores from each label (if there are more than 1 in each group) into a list
|
2505 |
+
group_scores_list = [
|
2506 |
+
np.concatenate(
|
2507 |
+
[subgroup_scores_dict[subgroup_name] for subgroup_name in name_list]
|
2508 |
+
)
|
2509 |
+
for name_list in [control_subgroup_types, comparison_subgroup_types]
|
2510 |
+
]
|
2511 |
+
|
2512 |
+
if any(len(scores) == 0 for scores in group_scores_list):
|
2513 |
+
# no comparison can be made since there is not at least one score per type
|
2514 |
+
h, norm_h = np.nan, np.nan
|
2515 |
+
else:
|
2516 |
+
control_mean = mean(group_scores_list[0])
|
2517 |
+
comparison_mean = mean(group_scores_list[1])
|
2518 |
+
h = 2 * (np.arcsin(np.sqrt(comparison_mean)) - np.arcsin(np.sqrt(control_mean)))
|
2519 |
+
norm_h = np.clip(a=h / np.pi, a_min=-1, a_max=1)
|
2520 |
+
|
2521 |
+
if not interpret:
|
2522 |
+
return norm_h
|
2523 |
+
|
2524 |
+
return norm_h, interpret_effect_size(h)
|
2525 |
+
|
2526 |
+
|
2527 |
+
def normalized_hedges_g(
|
2528 |
+
subgroup_scores_dict: Dict[str, List[float]],
|
2529 |
+
control_subgroup_types: List[str],
|
2530 |
+
comparison_subgroup_types: List[str],
|
2531 |
+
interpret=False,
|
2532 |
+
):
|
2533 |
+
"""Hedge's g effect size between mean of two samples, normalized to interval [-1,1]. Better than Cohen's d for small sample sizes.
|
2534 |
+
|
2535 |
+
Takes into account the variances within the samples, not just the means.
|
2536 |
+
|
2537 |
+
Args:
|
2538 |
+
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
2539 |
+
control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
|
2540 |
+
comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
|
2541 |
+
to be compared to the control group.
|
2542 |
+
interpret: boolean, whether to interpret the significance of the score or not
|
2543 |
+
Returns:
|
2544 |
+
float score between -1 and 1, and a string interpretation if interpret=True
|
2545 |
+
"""
|
2546 |
+
(
|
2547 |
+
subgroup_scores_dict,
|
2548 |
+
control_subgroup_types,
|
2549 |
+
comparison_subgroup_types,
|
2550 |
+
) = validate_subgroup_types(
|
2551 |
+
subgroup_scores_dict, control_subgroup_types, comparison_subgroup_types
|
2552 |
+
)
|
2553 |
+
|
2554 |
+
# combine all scores from each label (if there are more than 1 in each group) into a list
|
2555 |
+
group_scores_list = [
|
2556 |
+
np.concatenate(
|
2557 |
+
[subgroup_scores_dict[subgroup_name] for subgroup_name in name_list]
|
2558 |
+
)
|
2559 |
+
for name_list in [control_subgroup_types, comparison_subgroup_types]
|
2560 |
+
]
|
2561 |
+
|
2562 |
+
group_n = [len(scores) for scores in group_scores_list]
|
2563 |
+
if any(nn == 0 for nn in group_n) or all(nn <= 1 for nn in group_n):
|
2564 |
+
# if at least one sample size is 0 for one type, no comparison can be made at all
|
2565 |
+
# if both sample sizes are 1, then the denominator is undefined since divide by n1 + n2 - 2
|
2566 |
+
# so require at least one sample to have > 1 observation, and both to have >= 1.
|
2567 |
+
g, norm_g = np.nan, np.nan
|
2568 |
+
else:
|
2569 |
+
# otherwise, calculate the variances
|
2570 |
+
group_mean = [mean(scores) for scores in group_scores_list]
|
2571 |
+
# sample variance with 1 degree of freedom (denominator n-1); if n=1, return 0 since otherwise throws an error
|
2572 |
+
group_var = [
|
2573 |
+
0.0 if nn == 1 else np.var(scores, ddof=1)
|
2574 |
+
for scores, nn in zip(group_scores_list, group_n)
|
2575 |
+
]
|
2576 |
+
var_total = sum([(nn - 1) * vv for vv, nn in zip(group_var, group_n)])
|
2577 |
+
pooled_sd = np.sqrt(var_total / (sum(group_n) - 2))
|
2578 |
+
|
2579 |
+
max_absolute_value = 5
|
2580 |
+
gmd = float(group_mean[1] - group_mean[0])
|
2581 |
+
|
2582 |
+
if gmd == 0:
|
2583 |
+
# if exactly the same, return 0
|
2584 |
+
g = 0.0
|
2585 |
+
else:
|
2586 |
+
try:
|
2587 |
+
g = gmd / pooled_sd
|
2588 |
+
except ZeroDivisionError:
|
2589 |
+
# return a large effect size to avoid explosion if there is zero variance
|
2590 |
+
g = np.sign(gmd) * max_absolute_value
|
2591 |
+
|
2592 |
+
n = sum(group_n)
|
2593 |
+
if 3 < n < 50:
|
2594 |
+
# small sample adjustment see https://www.itl.nist.gov/div898/software/dataplot/refman2/auxillar/hedgeg.htm
|
2595 |
+
# the multiplier is 0 if n <= 3
|
2596 |
+
g *= ((n - 3) / (n - 2.25)) * np.sqrt((n - 2) / n)
|
2597 |
+
# clip it at a very large value so it doesn't become infinite if the variance (denominator) is very small or 0
|
2598 |
+
g = float(np.clip(a=g, a_min=-1 * max_absolute_value, a_max=max_absolute_value))
|
2599 |
+
norm_g = g / max_absolute_value
|
2600 |
+
|
2601 |
+
if not interpret:
|
2602 |
+
return norm_g
|
2603 |
+
return norm_g, interpret_effect_size(g)
|
2604 |
+
|
2605 |
+
|
2606 |
+
def mean_subgroup_score(
|
2607 |
+
subgroup_scores_dict: Dict[str, List], subgroup_types: List[str]
|
2608 |
+
):
|
2609 |
+
"""Return the mean instance score for a subset (possibly a single type) of variants (not a comparison).
|
2610 |
+
|
2611 |
+
Args:
|
2612 |
+
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
2613 |
+
subgroup_types: the keys (subgroup types) for which the average will be computed.
|
2614 |
+
|
2615 |
+
Returns:
|
2616 |
+
float score
|
2617 |
+
"""
|
2618 |
+
subgroup_scores_dict, subgroup_types, _ = validate_subgroup_types(
|
2619 |
+
subgroup_scores_dict, subgroup_types, []
|
2620 |
+
)
|
2621 |
+
|
2622 |
+
# combine all desired subgroup scores
|
2623 |
+
score_list = np.concatenate(
|
2624 |
+
[subgroup_scores_dict[subgroup_name] for subgroup_name in subgroup_types]
|
2625 |
+
)
|
2626 |
+
if len(score_list) == 0:
|
2627 |
+
# no scores to use
|
2628 |
+
return np.nan
|
2629 |
+
return mean(score_list)
|
2630 |
+
|
2631 |
+
|
2632 |
+
# metrics using mean reduction
|
2633 |
+
class GroupMeanAccuracy(Accuracy):
|
2634 |
+
reduction_map = {"group_mean": {"agg_func": ["mean", nan_mean, False]}}
|
2635 |
+
|
2636 |
+
|
2637 |
+
class FixedGroupMeanAccuracy(Accuracy):
|
2638 |
+
# the same as GroupMeanAccuracy, except the groups are fixed and are resampled together
|
2639 |
+
reduction_map = {"group_mean": {"agg_func": ["mean", nan_mean, True]}}
|
2640 |
+
|
2641 |
+
|
2642 |
+
# same as above, now using StringContainment
|
2643 |
+
class GroupMeanStringContainment(StringContainment):
|
2644 |
+
reduction_map = {"group_mean": {"agg_func": ["mean", nan_mean, False]}}
|
2645 |
+
|
2646 |
+
|
2647 |
+
class FixedGroupMeanStringContainment(StringContainment):
|
2648 |
+
# the same as GroupMeanStringContainment, except the groups are fixed and are resampled together
|
2649 |
+
reduction_map = {"group_mean": {"agg_func": ["mean", nan_mean, True]}}
|
2650 |
+
|
2651 |
+
|
2652 |
+
# take only the (fixed) group mean of baseline or other (paraphrases) scores
|
2653 |
+
class FixedGroupMeanBaselineAccuracy(Accuracy):
|
2654 |
+
subgroup_column = "variant_type"
|
2655 |
+
# take mean of "original" variants only
|
2656 |
+
reduction_map = {
|
2657 |
+
"group_mean": {
|
2658 |
+
"agg_func": [
|
2659 |
+
"mean_baseline",
|
2660 |
+
lambda scd: mean_subgroup_score(
|
2661 |
+
subgroup_scores_dict=scd, subgroup_types=["original"]
|
2662 |
+
),
|
2663 |
+
True,
|
2664 |
+
],
|
2665 |
+
}
|
2666 |
+
}
|
2667 |
+
|
2668 |
+
|
2669 |
+
class FixedGroupMeanParaphraseAccuracy(Accuracy):
|
2670 |
+
subgroup_column = "variant_type"
|
2671 |
+
# take mean of "paraphrase" variants only
|
2672 |
+
reduction_map = {
|
2673 |
+
"group_mean": {
|
2674 |
+
"agg_func": [
|
2675 |
+
"mean_paraphrase",
|
2676 |
+
lambda scd: mean_subgroup_score(
|
2677 |
+
subgroup_scores_dict=scd, subgroup_types=["paraphrase"]
|
2678 |
+
),
|
2679 |
+
True,
|
2680 |
+
],
|
2681 |
+
}
|
2682 |
+
}
|
2683 |
+
|
2684 |
+
|
2685 |
+
# same as above but using StringContainment
|
2686 |
+
class FixedGroupMeanBaselineStringContainment(StringContainment):
|
2687 |
+
subgroup_column = "variant_type"
|
2688 |
+
# take mean of "original" variants only
|
2689 |
+
reduction_map = {
|
2690 |
+
"group_mean": {
|
2691 |
+
"agg_func": [
|
2692 |
+
"mean_baseline",
|
2693 |
+
lambda scd: mean_subgroup_score(
|
2694 |
+
subgroup_scores_dict=scd, subgroup_types=["original"]
|
2695 |
+
),
|
2696 |
+
True,
|
2697 |
+
],
|
2698 |
+
}
|
2699 |
+
}
|
2700 |
+
|
2701 |
+
|
2702 |
+
class FixedGroupMeanParaphraseStringContainment(StringContainment):
|
2703 |
+
subgroup_column = "variant_type"
|
2704 |
+
# take mean of "paraphrase" variants only
|
2705 |
+
reduction_map = {
|
2706 |
+
"group_mean": {
|
2707 |
+
"agg_func": [
|
2708 |
+
"mean_paraphrase",
|
2709 |
+
lambda scd: mean_subgroup_score(
|
2710 |
+
subgroup_scores_dict=scd, subgroup_types=["paraphrase"]
|
2711 |
+
),
|
2712 |
+
True,
|
2713 |
+
],
|
2714 |
+
}
|
2715 |
+
}
|
2716 |
+
|
2717 |
+
|
2718 |
+
# using PDR
|
2719 |
+
class FixedGroupPDRParaphraseAccuracy(Accuracy):
|
2720 |
+
subgroup_column = "variant_type"
|
2721 |
+
reduction_map = {
|
2722 |
+
"group_mean": {
|
2723 |
+
"agg_func": [
|
2724 |
+
"pdr_paraphrase",
|
2725 |
+
lambda scd: performance_drop_rate(
|
2726 |
+
subgroup_scores_dict=scd,
|
2727 |
+
control_subgroup_types=["original"],
|
2728 |
+
comparison_subgroup_types=["paraphrase"],
|
2729 |
+
),
|
2730 |
+
True,
|
2731 |
+
],
|
2732 |
+
}
|
2733 |
+
}
|
2734 |
+
|
2735 |
+
|
2736 |
+
class FixedGroupPDRParaphraseStringContainment(StringContainment):
|
2737 |
+
subgroup_column = "variant_type"
|
2738 |
+
reduction_map = {
|
2739 |
+
"group_mean": {
|
2740 |
+
"agg_func": [
|
2741 |
+
"pdr_paraphrase",
|
2742 |
+
lambda scd: performance_drop_rate(
|
2743 |
+
subgroup_scores_dict=scd,
|
2744 |
+
control_subgroup_types=["original"],
|
2745 |
+
comparison_subgroup_types=["paraphrase"],
|
2746 |
+
),
|
2747 |
+
True,
|
2748 |
+
],
|
2749 |
+
}
|
2750 |
+
}
|
2751 |
+
|
2752 |
+
|
2753 |
+
class GroupMeanTokenOverlap(TokenOverlap):
|
2754 |
+
reduction_map = {
|
2755 |
+
"group_mean": {
|
2756 |
+
"agg_func": ["mean", nan_mean, False],
|
2757 |
+
"score_fields": ["f1", "precision", "recall"],
|
2758 |
+
}
|
2759 |
+
}
|
2760 |
+
|
2761 |
+
|
2762 |
+
# using Cohens's h for proportions
|
2763 |
+
class FixedGroupNormCohensHParaphraseAccuracy(Accuracy):
|
2764 |
+
subgroup_column = "variant_type"
|
2765 |
+
reduction_map = {
|
2766 |
+
"group_mean": {
|
2767 |
+
"agg_func": [
|
2768 |
+
"norm_cohens_h_paraphrase",
|
2769 |
+
lambda scd: normalized_cohens_h(
|
2770 |
+
subgroup_scores_dict=scd,
|
2771 |
+
control_subgroup_types=["original"],
|
2772 |
+
comparison_subgroup_types=["paraphrase"],
|
2773 |
+
),
|
2774 |
+
True,
|
2775 |
+
],
|
2776 |
+
}
|
2777 |
+
}
|
2778 |
+
|
2779 |
+
|
2780 |
+
class FixedGroupNormCohensHParaphraseStringContainment(StringContainment):
|
2781 |
+
subgroup_column = "variant_type"
|
2782 |
+
reduction_map = {
|
2783 |
+
"group_mean": {
|
2784 |
+
"agg_func": [
|
2785 |
+
"norm_cohens_h_paraphrase",
|
2786 |
+
lambda scd: normalized_cohens_h(
|
2787 |
+
subgroup_scores_dict=scd,
|
2788 |
+
control_subgroup_types=["original"],
|
2789 |
+
comparison_subgroup_types=["paraphrase"],
|
2790 |
+
),
|
2791 |
+
True,
|
2792 |
+
],
|
2793 |
+
}
|
2794 |
+
}
|
2795 |
+
|
2796 |
+
|
2797 |
+
# using Hedges' g (takes into account internal variation in group scores)
|
2798 |
+
class FixedGroupNormHedgesGParaphraseAccuracy(Accuracy):
|
2799 |
+
subgroup_column = "variant_type"
|
2800 |
+
reduction_map = {
|
2801 |
+
"group_mean": {
|
2802 |
+
"agg_func": [
|
2803 |
+
"norm_hedges_g_paraphrase",
|
2804 |
+
lambda scd: normalized_hedges_g(
|
2805 |
+
subgroup_scores_dict=scd,
|
2806 |
+
control_subgroup_types=["original"],
|
2807 |
+
comparison_subgroup_types=["paraphrase"],
|
2808 |
+
),
|
2809 |
+
True,
|
2810 |
+
],
|
2811 |
+
}
|
2812 |
+
}
|
2813 |
+
|
2814 |
+
|
2815 |
+
class FixedGroupNormHedgesGParaphraseStringContainment(StringContainment):
|
2816 |
+
subgroup_column = "variant_type"
|
2817 |
+
reduction_map = {
|
2818 |
+
"group_mean": {
|
2819 |
+
"agg_func": [
|
2820 |
+
"norm_hedges_g_paraphrase",
|
2821 |
+
lambda scd: normalized_hedges_g(
|
2822 |
+
subgroup_scores_dict=scd,
|
2823 |
+
control_subgroup_types=["original"],
|
2824 |
+
comparison_subgroup_types=["paraphrase"],
|
2825 |
+
),
|
2826 |
+
True,
|
2827 |
+
],
|
2828 |
+
}
|
2829 |
+
}
|
2830 |
+
|
2831 |
+
|
2832 |
+
# for above metrics, take absolute value of group score first; this measures variation in either direction
|
2833 |
+
class FixedGroupAbsvalNormCohensHParaphraseAccuracy(Accuracy):
|
2834 |
+
subgroup_column = "variant_type"
|
2835 |
+
reduction_map = {
|
2836 |
+
"group_mean": {
|
2837 |
+
"agg_func": [
|
2838 |
+
"absval_norm_cohens_h_paraphrase",
|
2839 |
+
lambda scd: np.abs(
|
2840 |
+
normalized_cohens_h(
|
2841 |
+
subgroup_scores_dict=scd,
|
2842 |
+
control_subgroup_types=["original"],
|
2843 |
+
comparison_subgroup_types=["paraphrase"],
|
2844 |
+
)
|
2845 |
+
),
|
2846 |
+
True,
|
2847 |
+
],
|
2848 |
+
}
|
2849 |
+
}
|
2850 |
+
|
2851 |
+
|
2852 |
+
class FixedGroupAbsvalNormCohensHParaphraseStringContainment(StringContainment):
|
2853 |
+
subgroup_column = "variant_type"
|
2854 |
+
reduction_map = {
|
2855 |
+
"group_mean": {
|
2856 |
+
"agg_func": [
|
2857 |
+
"absval_norm_cohens_h_paraphrase",
|
2858 |
+
lambda scd: np.abs(
|
2859 |
+
normalized_cohens_h(
|
2860 |
+
subgroup_scores_dict=scd,
|
2861 |
+
control_subgroup_types=["original"],
|
2862 |
+
comparison_subgroup_types=["paraphrase"],
|
2863 |
+
)
|
2864 |
+
),
|
2865 |
+
True,
|
2866 |
+
],
|
2867 |
+
}
|
2868 |
+
}
|
2869 |
+
|
2870 |
+
|
2871 |
+
class FixedGroupAbsvalNormHedgesGParaphraseAccuracy(Accuracy):
|
2872 |
+
subgroup_column = "variant_type"
|
2873 |
+
reduction_map = {
|
2874 |
+
"group_mean": {
|
2875 |
+
"agg_func": [
|
2876 |
+
"absval_norm_hedges_g_paraphrase",
|
2877 |
+
lambda scd: np.abs(
|
2878 |
+
normalized_hedges_g(
|
2879 |
+
subgroup_scores_dict=scd,
|
2880 |
+
control_subgroup_types=["original"],
|
2881 |
+
comparison_subgroup_types=["paraphrase"],
|
2882 |
+
)
|
2883 |
+
),
|
2884 |
+
True,
|
2885 |
+
],
|
2886 |
+
}
|
2887 |
+
}
|
2888 |
+
|
2889 |
+
|
2890 |
+
class FixedGroupAbsvalNormHedgesGParaphraseStringContainment(StringContainment):
|
2891 |
+
subgroup_column = "variant_type"
|
2892 |
+
reduction_map = {
|
2893 |
+
"group_mean": {
|
2894 |
+
"agg_func": [
|
2895 |
+
"absval_norm_hedges_g_paraphrase",
|
2896 |
+
lambda scd: np.abs(
|
2897 |
+
normalized_hedges_g(
|
2898 |
+
subgroup_scores_dict=scd,
|
2899 |
+
control_subgroup_types=["original"],
|
2900 |
+
comparison_subgroup_types=["paraphrase"],
|
2901 |
+
)
|
2902 |
+
),
|
2903 |
+
True,
|
2904 |
+
],
|
2905 |
+
}
|
2906 |
+
}
|