Elron commited on
Commit
18db0da
1 Parent(s): 7da4ddb

Upload metrics.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. metrics.py +76 -13
metrics.py CHANGED
@@ -16,7 +16,7 @@ from scipy.stats import bootstrap
16
  from scipy.stats._warnings_errors import DegenerateDataWarning
17
 
18
  from .artifact import Artifact
19
- from .dataclass import InternalField, OptionalField
20
  from .logging_utils import get_logger
21
  from .metric_utils import InstanceInput, MetricRequest, MetricResponse
22
  from .operator import (
@@ -58,6 +58,16 @@ def nan_mean(x):
58
  return np.nanmean(x)
59
 
60
 
 
 
 
 
 
 
 
 
 
 
61
  class UpdateStream(StreamInstanceOperator):
62
  update: dict
63
 
@@ -69,11 +79,7 @@ class UpdateStream(StreamInstanceOperator):
69
 
70
 
71
  class Metric(Artifact):
72
- @property
73
- @abstractmethod
74
- def main_score(self):
75
- pass
76
-
77
  # Override 'prediction_type' with the expected type of predictions
78
  # and references. Example: "List[str]", "List[Dict]"", "string".
79
  # If left with default None, a warning will be displayed.
@@ -229,6 +235,18 @@ class MetricWithConfidenceInterval(Metric):
229
  [instance["score"]["instance"][score_name] for instance in instances]
230
  )
231
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  @staticmethod
233
  def _all_instance_scores_equal(instances, score_name):
234
  instance_scores = [
@@ -625,13 +643,10 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
625
  # if subgroup_column is not None, a column by the specified name will be required in task_data
626
  subgroup_column = None
627
  implemented_reductions: List[str] = field(
628
- default_factory=lambda: ["mean", "group_mean"]
629
  )
630
 
631
- @property
632
- @abstractmethod
633
- def reduction_map(self) -> dict:
634
- pass
635
 
636
  def _validate_group_mean_reduction(self, instances: List[dict]):
637
  """Ensure that group_mean reduction_map is properly formatted.
@@ -739,12 +754,19 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
739
 
740
  field_name_full_prefix = ""
741
  # used for passing to the bootstrapping, depends on whether the groups are fixed or not
742
- aggregation_function = self.average_item_scores
743
  if reduction_type == "mean":
 
 
 
 
 
 
744
  reduction_fields = list(set(reduction_params))
745
  # no group reduction, so resample instances individually
746
  scores_to_resample = instances
747
  elif reduction_type == "group_mean":
 
748
  self._validate_group_mean_reduction(instances=instances)
749
  reduction_fields = (
750
  [self.main_score]
@@ -941,6 +963,12 @@ class Accuracy(InstanceMetric):
941
  return result
942
 
943
 
 
 
 
 
 
 
944
  class UnsortedListExactMatch(InstanceMetric):
945
  reduction_map = {"mean": ["unsorted_list_exact_match"]}
946
  main_score = "unsorted_list_exact_match"
@@ -988,7 +1016,15 @@ class MetricPipeline(MultiStreamOperator, Metric):
988
  self.metric.disable_confidence_interval_calculation()
989
 
990
  def verify(self):
991
- assert self.main_score is not None, "main_score is not set"
 
 
 
 
 
 
 
 
992
 
993
  def prepare(self):
994
  super().prepare()
@@ -3266,3 +3302,30 @@ class BinaryMaxAccuracy(GlobalMetric):
3266
  best_thr = thr
3267
 
3268
  return {self.main_score: best_acc, "best_thr_max_acc": best_thr}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from scipy.stats._warnings_errors import DegenerateDataWarning
17
 
18
  from .artifact import Artifact
19
+ from .dataclass import AbstractField, InternalField, OptionalField
20
  from .logging_utils import get_logger
21
  from .metric_utils import InstanceInput, MetricRequest, MetricResponse
22
  from .operator import (
 
58
  return np.nanmean(x)
59
 
60
 
61
+ def nan_max(x):
62
+ with warnings.catch_warnings():
63
+ # final mean should be mean of scores, ignoring NaN, hence nanmax
64
+ # but if the group function values is NaN for ALL values, nanmean throws a
65
+ # RuntimeWarning that it is calculating the mean of an empty slice (with no non-Nans)
66
+ # this is the desired behavior, but we want to avoid the warning here
67
+ warnings.simplefilter("ignore", category=RuntimeWarning)
68
+ return np.nanmax(x)
69
+
70
+
71
  class UpdateStream(StreamInstanceOperator):
72
  update: dict
73
 
 
79
 
80
 
81
  class Metric(Artifact):
82
+ main_score: str = AbstractField()
 
 
 
 
83
  # Override 'prediction_type' with the expected type of predictions
84
  # and references. Example: "List[str]", "List[Dict]"", "string".
85
  # If left with default None, a warning will be displayed.
 
235
  [instance["score"]["instance"][score_name] for instance in instances]
236
  )
237
 
238
+ @staticmethod
239
+ def max_item_scores(instances: List[dict], score_name: str):
240
+ """Calculate max of a set of instance scores (given by score_name), omitting NaN values.
241
+
242
+ Args:
243
+ instances: list of dicts of each instance's instance scores.
244
+ score_name: score field names to compute the mean for.
245
+ """
246
+ return nan_max(
247
+ [instance["score"]["instance"][score_name] for instance in instances]
248
+ )
249
+
250
  @staticmethod
251
  def _all_instance_scores_equal(instances, score_name):
252
  instance_scores = [
 
643
  # if subgroup_column is not None, a column by the specified name will be required in task_data
644
  subgroup_column = None
645
  implemented_reductions: List[str] = field(
646
+ default_factory=lambda: ["mean", "group_mean", "max"]
647
  )
648
 
649
+ reduction_map: Dict[str, List[str]] = AbstractField()
 
 
 
650
 
651
  def _validate_group_mean_reduction(self, instances: List[dict]):
652
  """Ensure that group_mean reduction_map is properly formatted.
 
754
 
755
  field_name_full_prefix = ""
756
  # used for passing to the bootstrapping, depends on whether the groups are fixed or not
757
+ aggregation_function = None
758
  if reduction_type == "mean":
759
+ aggregation_function = self.average_item_scores
760
+ reduction_fields = list(set(reduction_params))
761
+ # no group reduction, so resample instances individually
762
+ scores_to_resample = instances
763
+ elif reduction_type == "max":
764
+ aggregation_function = self.max_item_scores
765
  reduction_fields = list(set(reduction_params))
766
  # no group reduction, so resample instances individually
767
  scores_to_resample = instances
768
  elif reduction_type == "group_mean":
769
+ aggregation_function = self.average_item_scores
770
  self._validate_group_mean_reduction(instances=instances)
771
  reduction_fields = (
772
  [self.main_score]
 
963
  return result
964
 
965
 
966
+ class MaxAccuracy(Accuracy):
967
+ """Calculate the maximal accuracy over all instances as the global score."""
968
+
969
+ reduction_map = {"max": ["accuracy"]}
970
+
971
+
972
  class UnsortedListExactMatch(InstanceMetric):
973
  reduction_map = {"mean": ["unsorted_list_exact_match"]}
974
  main_score = "unsorted_list_exact_match"
 
1016
  self.metric.disable_confidence_interval_calculation()
1017
 
1018
  def verify(self):
1019
+ assert (
1020
+ self.metric is not None
1021
+ ), f"'metric' is not set in {self.get_metric_name()}"
1022
+ assert (
1023
+ self.main_score is not None
1024
+ ), f"'main_score' is not set in {self.get_metric_name()}"
1025
+ assert isinstance(
1026
+ self.metric, Metric
1027
+ ), f"'metric' is not set to a Metric class in {self.get_metric_name()} (type{self.metric})"
1028
 
1029
  def prepare(self):
1030
  super().prepare()
 
3302
  best_thr = thr
3303
 
3304
  return {self.main_score: best_acc, "best_thr_max_acc": best_thr}
3305
+
3306
+
3307
+ KO_ERROR_MESSAGE = """
3308
+
3309
+ Additional dependencies required. To install them, run:
3310
+ `pip install "sacrebleu[ko]"`.
3311
+
3312
+ For MacOS: If error on 'mecab-config' show up during installation ], one should run:
3313
+
3314
+ `brew install mecab`
3315
+ `pip install "sacrebleu[ko]"`
3316
+
3317
+ """
3318
+
3319
+
3320
+ class NormalizedSacrebleu(HuggingfaceMetric):
3321
+ hf_metric_name = "sacrebleu"
3322
+ hf_main_score = "score"
3323
+ prediction_type = "str"
3324
+ main_score = "sacrebleu"
3325
+ scale = 100.0
3326
+ scaled_fields = ["sacrebleu", "precisions"]
3327
+ hf_additional_input_fields_pass_one_value = ["tokenize"]
3328
+ _requirements_list = {
3329
+ "mecab_ko": KO_ERROR_MESSAGE,
3330
+ "mecab_ko_dic": KO_ERROR_MESSAGE,
3331
+ }