Elron commited on
Commit
dc6018c
·
verified ·
1 Parent(s): 803d9a3

Upload metrics.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. metrics.py +356 -49
metrics.py CHANGED
@@ -1,4 +1,3 @@
1
- import logging
2
  import re
3
  import string
4
  import uuid
@@ -14,6 +13,7 @@ from scipy.stats import bootstrap
14
 
15
  from .artifact import Artifact
16
  from .dataclass import InternalField, OptionalField
 
17
  from .operator import (
18
  MultiStreamOperator,
19
  SingleStreamOperator,
@@ -23,7 +23,9 @@ from .operator import (
23
  from .operators import CopyFields
24
  from .random_utils import get_seed
25
  from .stream import MultiStream, Stream
 
26
 
 
27
  # The default number of resamples used to estimate the confidence intervals
28
  # global and instances metrics. Use None to disable confidence interval computation by default.
29
  _N_RESAMPLES_DEFAULT_FOR_INSTANCE_METRICS = 1000
@@ -61,6 +63,7 @@ class MetricWithConfidenceInterval(Metric):
61
  # Use None to disable confidence interval computation.
62
  n_resamples: int = None
63
  confidence_level: float = 0.95
 
64
 
65
  @staticmethod
66
  def new_random_generator():
@@ -79,7 +82,7 @@ class MetricWithConfidenceInterval(Metric):
79
  and num_predictions > 1
80
  )
81
 
82
- def score_based_confidence_interval(self, score_names: List[str], instances):
83
  """Compute confidence intervals based on existing scores, already computed on the input instances.
84
 
85
  score_names: List[str]
@@ -94,6 +97,10 @@ class MetricWithConfidenceInterval(Metric):
94
  if not self._can_compute_confidence_intervals(num_predictions=len(instances)):
95
  return result
96
 
 
 
 
 
97
  for score_name in score_names:
98
  scores = [
99
  instance["score"]["instance"][score_name] for instance in instances
@@ -131,7 +138,7 @@ class MetricWithConfidenceInterval(Metric):
131
  except Exception as e:
132
  # this happens in edge cases, for example, when the sampling creates a
133
  # sample where all strings are empty and this fails bleu.
134
- logging.info(f"Warning in {self.__class__.__name__}", e)
135
  return np.nan
136
 
137
  scores = numpy.apply_along_axis(
@@ -341,7 +348,7 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
341
  global_score["score_name"] = self.main_score
342
 
343
  confidence_interval = self.score_based_confidence_interval(
344
- score_names=[self.main_score], instances=instances
345
  )
346
  global_score.update(confidence_interval)
347
 
@@ -411,7 +418,7 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
411
  global_score["score_name"] = self.main_score
412
 
413
  confidence_interval = self.score_based_confidence_interval(
414
- score_names=[self.main_score], instances=instances
415
  )
416
  global_score.update(confidence_interval)
417
 
@@ -473,6 +480,23 @@ class Accuracy(InstanceMetric):
473
  return result
474
 
475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  class MetricPipeline(MultiStreamOperator, Metric):
477
  main_score: str = None
478
  preprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
@@ -512,9 +536,29 @@ class HuggingfaceMetric(GlobalMetric):
512
 
513
  scale: float = 1.0 # optional scaling of main results
514
  scaled_fields: list = None
 
515
  hf_compute_args: Dict[str, Any] = OptionalField(default_factory=dict)
 
 
 
 
 
 
 
516
  experiment_id: str = OptionalField(default_factory=lambda: str(uuid.uuid4()))
517
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  def prepare(self):
519
  super().prepare()
520
  self.metric = evaluate.load(
@@ -527,8 +571,36 @@ class HuggingfaceMetric(GlobalMetric):
527
  predictions: List[Any],
528
  additional_inputs: List[Dict],
529
  ) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  result = self.metric.compute(
531
- predictions=predictions, references=references, **self.hf_compute_args
 
 
 
532
  )
533
  if self.hf_main_score:
534
  result[self.main_score] = result[self.hf_main_score]
@@ -559,6 +631,7 @@ class HuggingfaceBulkMetric(BulkInstanceMetric):
559
 
560
  hf_metric_fields: List[str]
561
  hf_compute_args: dict = {}
 
562
 
563
  def prepare(self):
564
  super().prepare()
@@ -570,8 +643,23 @@ class HuggingfaceBulkMetric(BulkInstanceMetric):
570
  predictions: List[str],
571
  additional_inputs: List[Any],
572
  ) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
573
  scores = self.metric.compute(
574
- predictions=predictions, references=references, **self.hf_compute_args
 
 
 
575
  )
576
 
577
  # convert dict of lists to a list of dicts
@@ -656,10 +744,11 @@ class F1MultiLabel(GlobalMetric):
656
  main_score = "f1_macro"
657
  average = None # Report per class then aggregate by mean
658
  classes_to_ignore = ["none"]
 
659
 
660
  def prepare(self):
661
  super().prepare()
662
- self._metric = evaluate.load("f1", "multilabel")
663
 
664
  def add_str_to_id(self, str):
665
  if str not in self.str_to_id:
@@ -683,22 +772,10 @@ class F1MultiLabel(GlobalMetric):
683
  ) -> dict:
684
  self.str_to_id = {}
685
  self.id_to_str = {}
686
- assert all(
687
- len(reference) == 1 for reference in references
688
- ), "Only a single reference per prediction is allowed in F1 multi label metric"
689
 
 
690
  references = [reference[0] for reference in references]
691
 
692
- for reference in references:
693
- assert isinstance(
694
- references, list
695
- ), f"Each reference is expected to list of strings in F1 multi label metric. Received reference: {reference}"
696
-
697
- for prediction in predictions:
698
- assert isinstance(
699
- prediction, list
700
- ), f"Each prediction is expected to list of strings in F1 multi label metric. Received prediction: {prediction}"
701
-
702
  labels = [
703
  lbl
704
  for lbl in {label for reference in references for label in reference}
@@ -732,19 +809,60 @@ class F1MultiLabel(GlobalMetric):
732
  average=self.average,
733
  labels=labels_param,
734
  )
735
- if isinstance(result["f1"], numpy.ndarray):
736
  from statistics import mean
737
 
738
- assert len(result["f1"]) == len(
739
- labels
740
- ), f'F1 result ({result["f1"]}) has more entries than labels ({labels})'
741
- final_result = {self.main_score: mean(result["f1"])}
742
  for i, label in enumerate(labels):
743
- final_result["f1_" + label] = result["f1"][i]
744
  else:
745
- final_result = {self.main_score: result["f1"]}
746
  return final_result
747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
748
 
749
  class F1MicroMultiLabel(F1MultiLabel):
750
  main_score = "f1_micro"
@@ -868,27 +986,36 @@ class MatthewsCorrelation(HuggingfaceMetric):
868
 
869
  class CustomF1(GlobalMetric):
870
  main_score = "f1_micro"
871
- classes = None
872
  zero_division = 0.0
873
 
874
  @abstractmethod
875
- def get_element_group(self, element):
876
  pass
877
 
878
  @abstractmethod
879
- def get_element_representation(self, element):
880
  pass
881
 
882
- def group_elements(self, elements_list):
 
 
 
 
 
883
  return {
884
  k: Counter(
885
  [
886
- self.get_element_representation(value)
887
  for value in elements_list
888
- if self.get_element_group(value) == k
889
  ]
890
  )
891
- for k in {self.get_element_group(e) for e in elements_list}
 
 
 
 
892
  }
893
 
894
  def calculate_groups_ratio(self, actual_group, total_group):
@@ -910,30 +1037,46 @@ class CustomF1(GlobalMetric):
910
  except ZeroDivisionError:
911
  return self.zero_division
912
 
 
 
 
 
 
 
 
 
 
913
  def compute(
914
  self,
915
- references: List[Any],
916
  predictions: List[Any],
917
  additional_inputs: List[Dict],
918
  ) -> dict:
919
  # in case reference are List[List[List[Any]]] and predictions are List[List[Any]]:
920
- if isinstance(references[0], list) and isinstance(references[0][0], list):
 
 
 
 
921
  references = [element[0] for element in references]
922
 
923
  assert len(references) == len(predictions), (
924
  f"references size ({len(references)})"
925
  f" doesn't mach predictions sise ({len(references)})."
926
  )
927
- if self.classes is None:
928
- classes = {
929
- self.get_element_group(e) for sublist in references for e in sublist
930
- }
931
  else:
932
- classes = self.classes
933
  groups_statistics = {}
934
- for references_batch, predictions_batch in zip(references, predictions):
935
- grouped_references = self.group_elements(references_batch)
936
- grouped_predictions = self.group_elements(predictions_batch)
 
 
 
 
937
  all_groups = set(grouped_references.keys()).union(
938
  grouped_predictions.keys()
939
  )
@@ -976,7 +1119,7 @@ class CustomF1(GlobalMetric):
976
  rn_total + rn,
977
  rd_total + rd,
978
  )
979
- if group in classes:
980
  f1_result[f"f1_{group}"] = self.f1(pn, pd, rn, rd)
981
  recall_result[f"recall_{group}"] = self.recall(pn, pd, rn, rd)
982
  precision_result[f"precision_{group}"] = self.precision(pn, pd, rn, rd)
@@ -995,7 +1138,7 @@ class CustomF1(GlobalMetric):
995
  except ZeroDivisionError:
996
  result["f1_macro"] = self.zero_division
997
  result["recall_macro"] = self.zero_division
998
- result["micro_macro"] = self.zero_division
999
 
1000
  amount_of_predictions = pd_total
1001
  if amount_of_predictions == 0:
@@ -1013,10 +1156,10 @@ class CustomF1(GlobalMetric):
1013
 
1014
 
1015
  class NER(CustomF1):
1016
- def get_element_group(self, element):
1017
  return element[1]
1018
 
1019
- def get_element_representation(self, element):
1020
  return str(element)
1021
 
1022
 
@@ -1042,6 +1185,7 @@ def normalize_answer(s):
1042
  class TokenOverlap(InstanceMetric):
1043
  reduction_map = {"mean": ["f1", "precision", "recall"]}
1044
  main_score = "f1"
 
1045
 
1046
  def compute(
1047
  self, references: List[Any], prediction: Any, additional_inputs: List[Dict]
@@ -1075,6 +1219,7 @@ class BertScore(HuggingfaceBulkMetric):
1075
  main_score = "f1"
1076
  reduction_map = {"mean": ["f1", "precision", "recall"]}
1077
  hf_metric_fields = ["f1", "precision", "recall"]
 
1078
  model_name: str
1079
 
1080
  def prepare(self):
@@ -1223,3 +1368,165 @@ class NDCG(GlobalMetric):
1223
  ]
1224
  scores.append(self.eval([q_references], [q_predictions]))
1225
  return {self.main_score: mean(scores) if len(scores) > 0 else np.nan}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  import string
3
  import uuid
 
13
 
14
  from .artifact import Artifact
15
  from .dataclass import InternalField, OptionalField
16
+ from .logging_utils import get_logger
17
  from .operator import (
18
  MultiStreamOperator,
19
  SingleStreamOperator,
 
23
  from .operators import CopyFields
24
  from .random_utils import get_seed
25
  from .stream import MultiStream, Stream
26
+ from .type_utils import isoftype
27
 
28
+ logger = get_logger()
29
  # The default number of resamples used to estimate the confidence intervals
30
  # global and instances metrics. Use None to disable confidence interval computation by default.
31
  _N_RESAMPLES_DEFAULT_FOR_INSTANCE_METRICS = 1000
 
63
  # Use None to disable confidence interval computation.
64
  n_resamples: int = None
65
  confidence_level: float = 0.95
66
+ ci_scores: List[str] = None
67
 
68
  @staticmethod
69
  def new_random_generator():
 
82
  and num_predictions > 1
83
  )
84
 
85
+ def score_based_confidence_interval(self, instances):
86
  """Compute confidence intervals based on existing scores, already computed on the input instances.
87
 
88
  score_names: List[str]
 
97
  if not self._can_compute_confidence_intervals(num_predictions=len(instances)):
98
  return result
99
 
100
+ score_names = (
101
+ self.ci_scores if self.ci_scores is not None else [self.main_score]
102
+ )
103
+
104
  for score_name in score_names:
105
  scores = [
106
  instance["score"]["instance"][score_name] for instance in instances
 
138
  except Exception as e:
139
  # this happens in edge cases, for example, when the sampling creates a
140
  # sample where all strings are empty and this fails bleu.
141
+ logger.info(f"Warning in {self.__class__.__name__}", e)
142
  return np.nan
143
 
144
  scores = numpy.apply_along_axis(
 
348
  global_score["score_name"] = self.main_score
349
 
350
  confidence_interval = self.score_based_confidence_interval(
351
+ instances=instances
352
  )
353
  global_score.update(confidence_interval)
354
 
 
418
  global_score["score_name"] = self.main_score
419
 
420
  confidence_interval = self.score_based_confidence_interval(
421
+ instances=instances
422
  )
423
  global_score.update(confidence_interval)
424
 
 
480
  return result
481
 
482
 
483
+ class StringContainment(InstanceMetric):
484
+ reduction_map = {"mean": ["string_containment"]}
485
+ main_score = "string_containment"
486
+
487
+ def compute(
488
+ self, references: List[Any], prediction: Any, additional_inputs: List[Dict]
489
+ ) -> dict:
490
+ result = {
491
+ self.main_score: float(
492
+ any(str(reference) in prediction for reference in references)
493
+ )
494
+ }
495
+ result["score"] = result[self.main_score]
496
+ result["score_name"] = self.main_score
497
+ return result
498
+
499
+
500
  class MetricPipeline(MultiStreamOperator, Metric):
501
  main_score: str = None
502
  preprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
 
536
 
537
  scale: float = 1.0 # optional scaling of main results
538
  scaled_fields: list = None
539
+ # This are fixed arguments passed to compute method
540
  hf_compute_args: Dict[str, Any] = OptionalField(default_factory=dict)
541
+ # These are additional input fields passed to HF compute method (a list with one value per instance)
542
+ hf_additional_input_fields: List = OptionalField(default_factory=list)
543
+ # These are additional input fields that are passed as one value
544
+ hf_additional_input_fields_pass_one_value: List = OptionalField(
545
+ default_factory=list
546
+ )
547
+
548
  experiment_id: str = OptionalField(default_factory=lambda: str(uuid.uuid4()))
549
 
550
+ def verify(self):
551
+ assert (
552
+ self.hf_additional_input_fields is None
553
+ or isoftype(self.hf_additional_input_fields, List[str])
554
+ ), f"Argument hf_additional_input_fields should be either None or List[str]. It is now: {self.hf_additional_input_fields}."
555
+ assert (
556
+ self.hf_additional_input_fields_pass_one_value is None
557
+ or isoftype(self.hf_additional_input_fields_pass_one_value, List[str])
558
+ ), f"Argument hf_additional_input_fields_pass_one_value should be either None or List[str]. It is now: {self.hf_additional_input_fields_pass_one_value}."
559
+
560
+ return super().verify()
561
+
562
  def prepare(self):
563
  super().prepare()
564
  self.metric = evaluate.load(
 
571
  predictions: List[Any],
572
  additional_inputs: List[Dict],
573
  ) -> dict:
574
+ passed_additional_inputs = {}
575
+ for additional_input_field in self.hf_additional_input_fields:
576
+ assert (
577
+ additional_input_field in additional_inputs[0]
578
+ ), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in additional inputs: {additional_inputs[0]}"
579
+ passed_additional_inputs[additional_input_field] = [
580
+ additional_input[additional_input_field]
581
+ for additional_input in additional_inputs
582
+ ]
583
+ for additional_input_field in self.hf_additional_input_fields_pass_one_value:
584
+ assert (
585
+ additional_input_field in additional_inputs[0]
586
+ ), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in additional inputs: {additional_inputs[0]}"
587
+
588
+ values = {
589
+ additional_input[additional_input_field]
590
+ for additional_input in additional_inputs
591
+ }
592
+ assert (
593
+ len(values) == 1
594
+ ), f"Values of '{additional_input_field}' field required by {__class__.__name__} should all be the same, but have multiple values {values}"
595
+
596
+ passed_additional_inputs[additional_input_field] = next(iter(values))
597
+
598
+ # add check that all required fields in self.metrics are in passed_additional_inputs print(passed_additional_inputs)
599
  result = self.metric.compute(
600
+ predictions=predictions,
601
+ references=references,
602
+ **passed_additional_inputs,
603
+ **self.hf_compute_args,
604
  )
605
  if self.hf_main_score:
606
  result[self.main_score] = result[self.hf_main_score]
 
631
 
632
  hf_metric_fields: List[str]
633
  hf_compute_args: dict = {}
634
+ hf_additional_input_fields: List = OptionalField(default_factory=list)
635
 
636
  def prepare(self):
637
  super().prepare()
 
643
  predictions: List[str],
644
  additional_inputs: List[Any],
645
  ) -> List[Dict[str, Any]]:
646
+ passed_additional_inputs = {}
647
+ passed_additional_inputs = {}
648
+ for additional_input_field in self.hf_additional_input_fields:
649
+ assert (
650
+ additional_input_field in additional_inputs[0]
651
+ ), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in additional inputs: {additional_inputs[0]}"
652
+ passed_additional_inputs[additional_input_field] = [
653
+ additional_input[additional_input_field]
654
+ for additional_input in additional_inputs
655
+ ]
656
+ # add check that all required fields in self.metrics are in passed_additional_inputs
657
+
658
  scores = self.metric.compute(
659
+ predictions=predictions,
660
+ references=references,
661
+ **passed_additional_inputs,
662
+ **self.hf_compute_args,
663
  )
664
 
665
  # convert dict of lists to a list of dicts
 
744
  main_score = "f1_macro"
745
  average = None # Report per class then aggregate by mean
746
  classes_to_ignore = ["none"]
747
+ metric = "f1"
748
 
749
  def prepare(self):
750
  super().prepare()
751
+ self._metric = evaluate.load(self.metric, "multilabel")
752
 
753
  def add_str_to_id(self, str):
754
  if str not in self.str_to_id:
 
772
  ) -> dict:
773
  self.str_to_id = {}
774
  self.id_to_str = {}
 
 
 
775
 
776
+ self._validate_references_and_prediction(references, predictions)
777
  references = [reference[0] for reference in references]
778
 
 
 
 
 
 
 
 
 
 
 
779
  labels = [
780
  lbl
781
  for lbl in {label for reference in references for label in reference}
 
809
  average=self.average,
810
  labels=labels_param,
811
  )
812
+ if isinstance(result[self.metric], numpy.ndarray):
813
  from statistics import mean
814
 
815
+ assert (
816
+ len(result[self.metric]) == len(labels)
817
+ ), f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
818
+ final_result = {self.main_score: mean(result[self.metric])}
819
  for i, label in enumerate(labels):
820
+ final_result[self.metric + "_" + label] = result[self.metric][i]
821
  else:
822
+ final_result = {self.main_score: result[self.metric]}
823
  return final_result
824
 
825
+ def _validate_references_and_prediction(self, references, predictions):
826
+ for reference in references:
827
+ if not len(reference) == 1:
828
+ raise ValueError(
829
+ f"Only a single reference per prediction is allowed in F1 multi label metric. Received reference: {reference}"
830
+ )
831
+ if not isoftype(reference[0], List[str]):
832
+ raise ValueError(
833
+ f"Each reference is expected to be a list of strings in F1 multi label metric. Received reference: '{reference[0]}'"
834
+ )
835
+
836
+ for prediction in predictions:
837
+ if not isoftype(prediction, List[str]):
838
+ raise ValueError(
839
+ f"Each prediction is expected to be a list of strings in F1 multi label metric. Received prediction: '{prediction}'"
840
+ )
841
+
842
+
843
+ class PrecisionMacroMultiLabel(F1MultiLabel):
844
+ main_score = "precision_macro"
845
+ metric = "precision"
846
+ average = "macro"
847
+
848
+
849
+ class PrecisionMicroMultiLabel(F1MultiLabel):
850
+ main_score = "precision_micro"
851
+ metric = "precision"
852
+ average = "micro"
853
+
854
+
855
+ class RecallMacroMultiLabel(F1MultiLabel):
856
+ main_score = "recall_macro"
857
+ metric = "recall"
858
+ average = "macro"
859
+
860
+
861
+ class RecallMicroMultiLabel(F1MultiLabel):
862
+ main_score = "recall_micro"
863
+ metric = "recall"
864
+ average = "micro"
865
+
866
 
867
  class F1MicroMultiLabel(F1MultiLabel):
868
  main_score = "f1_micro"
 
986
 
987
  class CustomF1(GlobalMetric):
988
  main_score = "f1_micro"
989
+ groups = None
990
  zero_division = 0.0
991
 
992
  @abstractmethod
993
+ def get_element_group(self, element, additional_input):
994
  pass
995
 
996
  @abstractmethod
997
+ def get_element_representation(self, element, additional_input):
998
  pass
999
 
1000
+ def should_ignore_element(self, element, additional_input):
1001
+ return False
1002
+
1003
+ def group_elements(self, elements_list, additional_input):
1004
+ if not isinstance(elements_list, list):
1005
+ elements_list = [elements_list]
1006
  return {
1007
  k: Counter(
1008
  [
1009
+ self.get_element_representation(value, additional_input)
1010
  for value in elements_list
1011
+ if self.get_element_group(value, additional_input) == k
1012
  ]
1013
  )
1014
+ for k in {
1015
+ self.get_element_group(e, additional_input)
1016
+ for e in elements_list
1017
+ if not self.should_ignore_element(e, additional_input)
1018
+ }
1019
  }
1020
 
1021
  def calculate_groups_ratio(self, actual_group, total_group):
 
1037
  except ZeroDivisionError:
1038
  return self.zero_division
1039
 
1040
+ def get_groups(self, elements, additional_inputs):
1041
+ groups = set()
1042
+ for sublist, additional_input in zip(elements, additional_inputs):
1043
+ for e in sublist:
1044
+ if self.should_ignore_element(e, additional_input):
1045
+ continue
1046
+ groups.add(self.get_element_group(e, additional_input))
1047
+ return groups
1048
+
1049
  def compute(
1050
  self,
1051
+ references: List[List[Any]],
1052
  predictions: List[Any],
1053
  additional_inputs: List[Dict],
1054
  ) -> dict:
1055
  # in case reference are List[List[List[Any]]] and predictions are List[List[Any]]:
1056
+ if (
1057
+ isinstance(references[0], list)
1058
+ and len(references[0]) > 0
1059
+ and isinstance(references[0][0], list)
1060
+ ):
1061
  references = [element[0] for element in references]
1062
 
1063
  assert len(references) == len(predictions), (
1064
  f"references size ({len(references)})"
1065
  f" doesn't mach predictions sise ({len(references)})."
1066
  )
1067
+
1068
+ if self.groups is None:
1069
+ groups = self.get_groups(references, additional_inputs)
 
1070
  else:
1071
+ groups = self.groups
1072
  groups_statistics = {}
1073
+ for references_batch, predictions_batch, additional_input in zip(
1074
+ references, predictions, additional_inputs
1075
+ ):
1076
+ grouped_references = self.group_elements(references_batch, additional_input)
1077
+ grouped_predictions = self.group_elements(
1078
+ predictions_batch, additional_input
1079
+ )
1080
  all_groups = set(grouped_references.keys()).union(
1081
  grouped_predictions.keys()
1082
  )
 
1119
  rn_total + rn,
1120
  rd_total + rd,
1121
  )
1122
+ if group in groups:
1123
  f1_result[f"f1_{group}"] = self.f1(pn, pd, rn, rd)
1124
  recall_result[f"recall_{group}"] = self.recall(pn, pd, rn, rd)
1125
  precision_result[f"precision_{group}"] = self.precision(pn, pd, rn, rd)
 
1138
  except ZeroDivisionError:
1139
  result["f1_macro"] = self.zero_division
1140
  result["recall_macro"] = self.zero_division
1141
+ result["precision_macro"] = self.zero_division
1142
 
1143
  amount_of_predictions = pd_total
1144
  if amount_of_predictions == 0:
 
1156
 
1157
 
1158
  class NER(CustomF1):
1159
+ def get_element_group(self, element, additional_input):
1160
  return element[1]
1161
 
1162
+ def get_element_representation(self, element, additional_input):
1163
  return str(element)
1164
 
1165
 
 
1185
  class TokenOverlap(InstanceMetric):
1186
  reduction_map = {"mean": ["f1", "precision", "recall"]}
1187
  main_score = "f1"
1188
+ ci_scores = ["f1", "precision", "recall"]
1189
 
1190
  def compute(
1191
  self, references: List[Any], prediction: Any, additional_inputs: List[Dict]
 
1219
  main_score = "f1"
1220
  reduction_map = {"mean": ["f1", "precision", "recall"]}
1221
  hf_metric_fields = ["f1", "precision", "recall"]
1222
+ ci_scores = ["f1", "precision", "recall"]
1223
  model_name: str
1224
 
1225
  def prepare(self):
 
1368
  ]
1369
  scores.append(self.eval([q_references], [q_predictions]))
1370
  return {self.main_score: mean(scores) if len(scores) > 0 else np.nan}
1371
+
1372
+
1373
+ class RetrievalMetric(InstanceMetric):
1374
+ def compute(
1375
+ self, references: List[Any], prediction: Any, additional_inputs: Dict
1376
+ ) -> dict:
1377
+ # digest input
1378
+ pred_ids: List[Any] = prediction
1379
+ ref_ids: List[Any] = list(dict.fromkeys(references))
1380
+
1381
+ # relevance_at_k: 1-based dictionary of indicators (0/1), telling whether
1382
+ # the doc id retrieved at position k (assuming it is 1-based, so k starts
1383
+ # from 1) is in the gold doc ids or not.
1384
+ # For example, assuming that in the retrieved docs we have correct predictions
1385
+ # at positions 2, 4 and 5 (1-based), the dict will look like:
1386
+ # {1: 0, 2: 1, 3: 0, 4: 1, 5: 1, ...}
1387
+ relevance_at_k = {
1388
+ k + 1: 1 if doc_id in ref_ids else 0 for k, doc_id in enumerate(pred_ids)
1389
+ }
1390
+
1391
+ # relevance_sum_at_k: 1-based dictionary of counts, where the value at k determines
1392
+ # how many gold doc ids have been observed up to index k.
1393
+ relevance_sum_at_k = {}
1394
+ for k, value in relevance_at_k.items():
1395
+ relevance_sum_at_k[k] = relevance_sum_at_k.get(k - 1, 0) + value
1396
+
1397
+ # precision_at_k: the precision of the top k retrieved documents. For example,
1398
+ # assuming that only 1 out of the first 4 retrieved documents is correct, the
1399
+ # value at 4 will be 1/4.
1400
+ precision_at_k = {k: value / k for k, value in relevance_sum_at_k.items()}
1401
+
1402
+ # recall_at_k: the recall of the top k retrieved documents. For example,
1403
+ # assuming that only 2 out of the 3 gold documents are in the top 5 results,
1404
+ # the value at 5 will be 2/3.
1405
+ n_refs = len(ref_ids)
1406
+ recall_at_k = {
1407
+ k: value / n_refs if n_refs > 0 else 0
1408
+ for k, value in relevance_sum_at_k.items()
1409
+ }
1410
+
1411
+ # rank - the 1-based index of the first hit of a gold doc id. So 1
1412
+ # means first position.
1413
+ rank = 0
1414
+ for k, relevance in relevance_at_k.items():
1415
+ if relevance == 1:
1416
+ rank = k
1417
+ break
1418
+
1419
+ # match_at_k: whether we have a match at the top k retrieved documents
1420
+ match_at_k = {
1421
+ k: 1.0 if value > 0 else 0.0 for k, value in relevance_sum_at_k.items()
1422
+ }
1423
+
1424
+ return self._compute(
1425
+ relevance_at_k,
1426
+ relevance_sum_at_k,
1427
+ precision_at_k,
1428
+ recall_at_k,
1429
+ match_at_k,
1430
+ rank,
1431
+ )
1432
+
1433
+ @abstractmethod
1434
+ def _compute(
1435
+ self,
1436
+ relevance_at_k,
1437
+ relevance_sum_at_k,
1438
+ precision_at_k,
1439
+ recall_at_k,
1440
+ match_at_k,
1441
+ rank,
1442
+ ) -> dict:
1443
+ pass
1444
+
1445
+
1446
+ class MRR(RetrievalMetric):
1447
+ reduction_map = {"mean": ["mrr"]}
1448
+ main_score = "mrr"
1449
+
1450
+ def _compute(
1451
+ self,
1452
+ relevance_at_k,
1453
+ relevance_sum_at_k,
1454
+ precision_at_k,
1455
+ recall_at_k,
1456
+ match_at_k,
1457
+ rank,
1458
+ ) -> dict:
1459
+ return {self.main_score: 1 / rank if rank > 0 else 0}
1460
+
1461
+
1462
+ class MAP(RetrievalMetric):
1463
+ reduction_map = {"mean": ["map"]}
1464
+ main_score = "map"
1465
+
1466
+ def _compute(
1467
+ self,
1468
+ relevance_at_k,
1469
+ relevance_sum_at_k,
1470
+ precision_at_k,
1471
+ recall_at_k,
1472
+ match_at_k,
1473
+ rank,
1474
+ ) -> dict:
1475
+ result = 0
1476
+ if len(relevance_at_k) > 0:
1477
+ total = sum(relevance_at_k.values())
1478
+ if total > 0:
1479
+ dot = sum(relevance_at_k[k] * precision_at_k[k] for k in relevance_at_k)
1480
+ result = dot / total
1481
+ return {self.main_score: result}
1482
+
1483
+
1484
+ class RetrievalAtK(RetrievalMetric):
1485
+ k_list: List[int]
1486
+ main_score: str = None
1487
+ reduction_map: Dict[str, List[str]] = None
1488
+
1489
+ def prepare(self):
1490
+ super().prepare()
1491
+ self.main_score = self.score_name("match", self.k_list[0])
1492
+ self.ci_scores = [
1493
+ self.score_name(measure, k)
1494
+ for measure in ["precision", "recall", "match"]
1495
+ for k in self.k_list
1496
+ ]
1497
+ self.reduction_map = {"mean": self.ci_scores}
1498
+
1499
+ @staticmethod
1500
+ def score_name(measure: str, k: int):
1501
+ return f"{measure}_at_{k}"
1502
+
1503
+ def _compute(
1504
+ self,
1505
+ relevance_at_k,
1506
+ relevance_sum_at_k,
1507
+ precision_at_k,
1508
+ recall_at_k,
1509
+ match_at_k,
1510
+ rank,
1511
+ ) -> dict:
1512
+ result = {}
1513
+ for measure_array, measure_name in [
1514
+ (precision_at_k, "precision"),
1515
+ (recall_at_k, "recall"),
1516
+ (match_at_k, "match"),
1517
+ ]:
1518
+ max_k = max(measure_array.keys())
1519
+ for k in self.k_list:
1520
+ result[self.score_name(measure_name, k)] = measure_array[min(k, max_k)]
1521
+ return result
1522
+
1523
+
1524
+ class KPA(CustomF1):
1525
+ def get_element_group(self, element, additional_input):
1526
+ return additional_input["keypoint"]
1527
+
1528
+ def get_element_representation(self, element, additional_input):
1529
+ return additional_input["keypoint"]
1530
+
1531
+ def should_ignore_element(self, element, additional_input):
1532
+ return element == "none"