Upload metrics.py with huggingface_hub
Browse files- metrics.py +265 -35
metrics.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import re
|
2 |
import string
|
3 |
import uuid
|
@@ -49,8 +50,6 @@ def abstract_field():
|
|
49 |
|
50 |
|
51 |
def nan_mean(x):
|
52 |
-
import warnings
|
53 |
-
|
54 |
with warnings.catch_warnings():
|
55 |
# final mean should be mean of scores, ignoring NaN, hence nanmean
|
56 |
# but if the group function values is NaN for ALL values, nanmean throws a
|
@@ -70,7 +69,6 @@ class UpdateStream(StreamInstanceOperator):
|
|
70 |
return instance
|
71 |
|
72 |
|
73 |
-
# TODO: currently we have two classes with this name. metric.Metric and matrics.Metric...
|
74 |
class Metric(Artifact):
|
75 |
@property
|
76 |
@abstractmethod
|
@@ -115,10 +113,6 @@ class Metric(Artifact):
|
|
115 |
def disable_confidence_interval_calculation(self):
|
116 |
pass
|
117 |
|
118 |
-
@abstractmethod
|
119 |
-
def set_n_resamples(self, n_resample):
|
120 |
-
pass
|
121 |
-
|
122 |
|
123 |
class MetricWithConfidenceInterval(Metric):
|
124 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
@@ -135,12 +129,7 @@ class MetricWithConfidenceInterval(Metric):
|
|
135 |
return np.random.default_rng(hash(get_seed()) & _max_32bit)
|
136 |
|
137 |
def disable_confidence_interval_calculation(self):
|
138 |
-
n = self.n_resamples
|
139 |
self.n_resamples = None
|
140 |
-
return n
|
141 |
-
|
142 |
-
def set_n_resamples(self, n_resamples):
|
143 |
-
self.n_resamples = n_resamples
|
144 |
|
145 |
def _can_compute_confidence_intervals(self, num_predictions):
|
146 |
return (
|
@@ -161,6 +150,17 @@ class MetricWithConfidenceInterval(Metric):
|
|
161 |
[instance["score"]["instance"][score_name] for instance in instances]
|
162 |
)
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
def score_based_confidence_interval(
|
165 |
self,
|
166 |
instances: List[dict],
|
@@ -197,6 +197,11 @@ class MetricWithConfidenceInterval(Metric):
|
|
197 |
# that is, re-form the groups, calculate the function, and take the mean of the group scores
|
198 |
aggregation_func = self.average_item_scores
|
199 |
for score_name in score_names:
|
|
|
|
|
|
|
|
|
|
|
200 |
# need to redefine the statistic function within the loop because score_name is a loop variable
|
201 |
def statistic(arr, axis, score_name=score_name):
|
202 |
# arr is a 2d array where each row is a resampling, so we
|
@@ -300,13 +305,18 @@ class MetricWithConfidenceInterval(Metric):
|
|
300 |
num_predictions = len(predictions)
|
301 |
if self._can_compute_confidence_intervals(num_predictions=num_predictions):
|
302 |
identifiers = list(range(num_predictions))
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
|
|
|
|
|
|
|
|
|
|
310 |
result["score_ci_low"] = ci.low
|
311 |
result["score_ci_high"] = ci.high
|
312 |
result[f"{score_name}_ci_low"] = ci.low
|
@@ -553,7 +563,7 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
553 |
- an 'agg_func' field with value being a 3-element list where
|
554 |
- 1st element is a string name of the aggregation function (used in naming the CI report)
|
555 |
- 2nd element is the callable aggregation function
|
556 |
-
- 3rd element is a Boolean indicator of whether, during
|
557 |
If True, the group scores are calculated and then resampled. This treats the group units as the unit of
|
558 |
interest for which the CI is being compared.
|
559 |
If False, the instances are resampled individually, and the groups determined
|
@@ -903,11 +913,7 @@ class MetricPipeline(MultiStreamOperator, Metric):
|
|
903 |
metric: Metric = None
|
904 |
|
905 |
def disable_confidence_interval_calculation(self):
|
906 |
-
|
907 |
-
|
908 |
-
def set_n_resamples(self, n_resample):
|
909 |
-
if isinstance(self.metric, MetricWithConfidenceInterval):
|
910 |
-
self.metric.set_n_resamples(n_resample)
|
911 |
|
912 |
def verify(self):
|
913 |
assert self.main_score is not None, "main_score is not set"
|
@@ -1092,6 +1098,11 @@ class F1(GlobalMetric):
|
|
1092 |
self.id_to_str[id] = str
|
1093 |
return self.str_to_id[str]
|
1094 |
|
|
|
|
|
|
|
|
|
|
|
1095 |
def compute(
|
1096 |
self,
|
1097 |
references: List[List[str]],
|
@@ -1101,6 +1112,9 @@ class F1(GlobalMetric):
|
|
1101 |
assert all(
|
1102 |
len(reference) == 1 for reference in references
|
1103 |
), "Only a single reference per prediction is allowed in F1 metric"
|
|
|
|
|
|
|
1104 |
self.str_to_id = {}
|
1105 |
self.id_to_str = {}
|
1106 |
formatted_references = [
|
@@ -1111,18 +1125,21 @@ class F1(GlobalMetric):
|
|
1111 |
self.get_str_id(prediction) for prediction in predictions
|
1112 |
]
|
1113 |
labels = list(set(formatted_references))
|
|
|
1114 |
result = self._metric.compute(
|
1115 |
predictions=formatted_predictions,
|
1116 |
references=formatted_references,
|
1117 |
labels=labels,
|
1118 |
average=self.average,
|
1119 |
)
|
1120 |
-
if isinstance(result[
|
1121 |
-
final_result = {self.main_score: mean(result[
|
1122 |
for i, label in enumerate(labels):
|
1123 |
-
final_result["
|
|
|
|
|
1124 |
else:
|
1125 |
-
final_result = {self.main_score: result[
|
1126 |
return final_result
|
1127 |
|
1128 |
|
@@ -1131,6 +1148,40 @@ class F1Micro(F1):
|
|
1131 |
average = "micro"
|
1132 |
|
1133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1134 |
class F1Macro(F1):
|
1135 |
main_score = "f1_macro"
|
1136 |
|
@@ -1442,8 +1493,10 @@ class RocAuc(GlobalMetric):
|
|
1442 |
references = [to_float_or_default(r) for r in references]
|
1443 |
predictions = [to_float_or_default(p) for p in predictions]
|
1444 |
|
1445 |
-
|
1446 |
-
|
|
|
|
|
1447 |
return {self.main_score: roc_auc}
|
1448 |
|
1449 |
|
@@ -1525,7 +1578,7 @@ class CustomF1(GlobalMetric):
|
|
1525 |
|
1526 |
assert len(references) == len(predictions), (
|
1527 |
f"references size ({len(references)})"
|
1528 |
-
f" doesn't mach predictions
|
1529 |
)
|
1530 |
|
1531 |
if self.groups is None:
|
@@ -1700,7 +1753,7 @@ class SentenceBert(BulkInstanceMetric):
|
|
1700 |
|
1701 |
model_name: str
|
1702 |
|
1703 |
-
_requirements_list: List[str] = ["sentence_transformers"]
|
1704 |
|
1705 |
def prepare(self):
|
1706 |
super().prepare()
|
@@ -1751,7 +1804,7 @@ class Reward(BulkInstanceMetric):
|
|
1751 |
|
1752 |
model_name: str
|
1753 |
|
1754 |
-
_requirements_list: List[str] = ["transformers"]
|
1755 |
|
1756 |
def prepare(self):
|
1757 |
super().prepare()
|
@@ -1782,6 +1835,134 @@ class Reward(BulkInstanceMetric):
|
|
1782 |
return self.pipe(inputs, batch_size=self.batch_size)
|
1783 |
|
1784 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1785 |
class Perplexity(BulkInstanceMetric):
|
1786 |
"""Computes the likelihood of generating text Y after text X - P(Y|X)."""
|
1787 |
|
@@ -1793,7 +1974,7 @@ class Perplexity(BulkInstanceMetric):
|
|
1793 |
batch_size: int = 32
|
1794 |
model_name: str
|
1795 |
|
1796 |
-
_requirements_list: List[str] = ["transformers"]
|
1797 |
|
1798 |
def compute(
|
1799 |
self,
|
@@ -2904,3 +3085,52 @@ class FixedGroupAbsvalNormHedgesGParaphraseStringContainment(StringContainment):
|
|
2904 |
],
|
2905 |
}
|
2906 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
import re
|
3 |
import string
|
4 |
import uuid
|
|
|
50 |
|
51 |
|
52 |
def nan_mean(x):
|
|
|
|
|
53 |
with warnings.catch_warnings():
|
54 |
# final mean should be mean of scores, ignoring NaN, hence nanmean
|
55 |
# but if the group function values is NaN for ALL values, nanmean throws a
|
|
|
69 |
return instance
|
70 |
|
71 |
|
|
|
72 |
class Metric(Artifact):
|
73 |
@property
|
74 |
@abstractmethod
|
|
|
113 |
def disable_confidence_interval_calculation(self):
|
114 |
pass
|
115 |
|
|
|
|
|
|
|
|
|
116 |
|
117 |
class MetricWithConfidenceInterval(Metric):
|
118 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
|
|
129 |
return np.random.default_rng(hash(get_seed()) & _max_32bit)
|
130 |
|
131 |
def disable_confidence_interval_calculation(self):
|
|
|
132 |
self.n_resamples = None
|
|
|
|
|
|
|
|
|
133 |
|
134 |
def _can_compute_confidence_intervals(self, num_predictions):
|
135 |
return (
|
|
|
150 |
[instance["score"]["instance"][score_name] for instance in instances]
|
151 |
)
|
152 |
|
153 |
+
@staticmethod
|
154 |
+
def _all_instance_scores_equal(instances, score_name):
|
155 |
+
instance_scores = [
|
156 |
+
instance["score"]["instance"][score_name] for instance in instances
|
157 |
+
]
|
158 |
+
non_nan_instance_scores = [
|
159 |
+
score for score in instance_scores if score is not np.nan
|
160 |
+
]
|
161 |
+
num_unique_scores = len(set(non_nan_instance_scores))
|
162 |
+
return num_unique_scores == 1
|
163 |
+
|
164 |
def score_based_confidence_interval(
|
165 |
self,
|
166 |
instances: List[dict],
|
|
|
197 |
# that is, re-form the groups, calculate the function, and take the mean of the group scores
|
198 |
aggregation_func = self.average_item_scores
|
199 |
for score_name in score_names:
|
200 |
+
# If all computed instance level scores are the same, there is no point in computing
|
201 |
+
# confidence intervals. So skip to the next score.
|
202 |
+
if self._all_instance_scores_equal(instances, score_name):
|
203 |
+
continue
|
204 |
+
|
205 |
# need to redefine the statistic function within the loop because score_name is a loop variable
|
206 |
def statistic(arr, axis, score_name=score_name):
|
207 |
# arr is a 2d array where each row is a resampling, so we
|
|
|
305 |
num_predictions = len(predictions)
|
306 |
if self._can_compute_confidence_intervals(num_predictions=num_predictions):
|
307 |
identifiers = list(range(num_predictions))
|
308 |
+
|
309 |
+
with warnings.catch_warnings():
|
310 |
+
# Avoid RuntimeWarning in bootstrap computation. This happens on small datasets where
|
311 |
+
# the value of the computed global metric is the same on all resamplings.
|
312 |
+
warnings.simplefilter("ignore", category=RuntimeWarning)
|
313 |
+
ci = bootstrap(
|
314 |
+
(identifiers,),
|
315 |
+
statistic=statistic,
|
316 |
+
n_resamples=self.n_resamples,
|
317 |
+
confidence_level=self.confidence_level,
|
318 |
+
random_state=random_gen,
|
319 |
+
).confidence_interval
|
320 |
result["score_ci_low"] = ci.low
|
321 |
result["score_ci_high"] = ci.high
|
322 |
result[f"{score_name}_ci_low"] = ci.low
|
|
|
563 |
- an 'agg_func' field with value being a 3-element list where
|
564 |
- 1st element is a string name of the aggregation function (used in naming the CI report)
|
565 |
- 2nd element is the callable aggregation function
|
566 |
+
- 3rd element is a Boolean indicator of whether, during bootstrap CI calculation, the groups are to be sampled as single units.
|
567 |
If True, the group scores are calculated and then resampled. This treats the group units as the unit of
|
568 |
interest for which the CI is being compared.
|
569 |
If False, the instances are resampled individually, and the groups determined
|
|
|
913 |
metric: Metric = None
|
914 |
|
915 |
def disable_confidence_interval_calculation(self):
|
916 |
+
self.metric.disable_confidence_interval_calculation()
|
|
|
|
|
|
|
|
|
917 |
|
918 |
def verify(self):
|
919 |
assert self.main_score is not None, "main_score is not set"
|
|
|
1098 |
self.id_to_str[id] = str
|
1099 |
return self.str_to_id[str]
|
1100 |
|
1101 |
+
def _labels_match_average_format(
|
1102 |
+
self, references: List[List[str]], predictions: List[str]
|
1103 |
+
):
|
1104 |
+
return True
|
1105 |
+
|
1106 |
def compute(
|
1107 |
self,
|
1108 |
references: List[List[str]],
|
|
|
1112 |
assert all(
|
1113 |
len(reference) == 1 for reference in references
|
1114 |
), "Only a single reference per prediction is allowed in F1 metric"
|
1115 |
+
if not self._labels_match_average_format(references, predictions):
|
1116 |
+
return {self.main_score: np.nan}
|
1117 |
+
|
1118 |
self.str_to_id = {}
|
1119 |
self.id_to_str = {}
|
1120 |
formatted_references = [
|
|
|
1125 |
self.get_str_id(prediction) for prediction in predictions
|
1126 |
]
|
1127 |
labels = list(set(formatted_references))
|
1128 |
+
|
1129 |
result = self._metric.compute(
|
1130 |
predictions=formatted_predictions,
|
1131 |
references=formatted_references,
|
1132 |
labels=labels,
|
1133 |
average=self.average,
|
1134 |
)
|
1135 |
+
if isinstance(result[self.metric], numpy.ndarray):
|
1136 |
+
final_result = {self.main_score: mean(result[self.metric])}
|
1137 |
for i, label in enumerate(labels):
|
1138 |
+
final_result[f"{self.metric}_" + self.id_to_str[label]] = result[
|
1139 |
+
self.metric
|
1140 |
+
][i]
|
1141 |
else:
|
1142 |
+
final_result = {self.main_score: result[self.metric]}
|
1143 |
return final_result
|
1144 |
|
1145 |
|
|
|
1148 |
average = "micro"
|
1149 |
|
1150 |
|
1151 |
+
class F1Binary(F1):
|
1152 |
+
process_single_instances = False
|
1153 |
+
main_score = "f1_binary"
|
1154 |
+
average = "binary"
|
1155 |
+
pos_classes = {"1", "1.0", "yes", "true"}
|
1156 |
+
|
1157 |
+
def get_str_id(self, str):
|
1158 |
+
if str.lower() in self.pos_classes:
|
1159 |
+
return 1
|
1160 |
+
return 0
|
1161 |
+
|
1162 |
+
# References and predictions must include up to 2 unique values, one of them in pos_classes
|
1163 |
+
def _labels_match_average_format(
|
1164 |
+
self, references: List[List[str]], predictions: List[str]
|
1165 |
+
):
|
1166 |
+
classes = set(predictions + list(itertools.chain(*references)))
|
1167 |
+
n_classes = len(classes)
|
1168 |
+
if n_classes > 2:
|
1169 |
+
return False
|
1170 |
+
if n_classes == 2 and len(set(classes).difference(self.pos_classes)) == 0:
|
1171 |
+
return False
|
1172 |
+
return True
|
1173 |
+
|
1174 |
+
|
1175 |
+
class RecallBinary(F1Binary):
|
1176 |
+
main_score = "recall_binary"
|
1177 |
+
metric = "recall"
|
1178 |
+
|
1179 |
+
|
1180 |
+
class PrecisionBinary(F1Binary):
|
1181 |
+
main_score = "precision_binary"
|
1182 |
+
metric = "precision"
|
1183 |
+
|
1184 |
+
|
1185 |
class F1Macro(F1):
|
1186 |
main_score = "f1_macro"
|
1187 |
|
|
|
1493 |
references = [to_float_or_default(r) for r in references]
|
1494 |
predictions = [to_float_or_default(p) for p in predictions]
|
1495 |
|
1496 |
+
false_positive_rates, true_positive_rates, _ = self.roc_curve(
|
1497 |
+
y_true=references, y_score=predictions
|
1498 |
+
)
|
1499 |
+
roc_auc = self.auc(false_positive_rates, true_positive_rates)
|
1500 |
return {self.main_score: roc_auc}
|
1501 |
|
1502 |
|
|
|
1578 |
|
1579 |
assert len(references) == len(predictions), (
|
1580 |
f"references size ({len(references)})"
|
1581 |
+
f" doesn't mach predictions size ({len(references)})."
|
1582 |
)
|
1583 |
|
1584 |
if self.groups is None:
|
|
|
1753 |
|
1754 |
model_name: str
|
1755 |
|
1756 |
+
_requirements_list: List[str] = ["sentence_transformers", "torch", "transformers"]
|
1757 |
|
1758 |
def prepare(self):
|
1759 |
super().prepare()
|
|
|
1804 |
|
1805 |
model_name: str
|
1806 |
|
1807 |
+
_requirements_list: List[str] = ["transformers", "torch"]
|
1808 |
|
1809 |
def prepare(self):
|
1810 |
super().prepare()
|
|
|
1835 |
return self.pipe(inputs, batch_size=self.batch_size)
|
1836 |
|
1837 |
|
1838 |
+
class LlamaIndexCorrectness(InstanceMetric):
|
1839 |
+
"""LlamaIndex based metric class for evaluating correctness.
|
1840 |
+
|
1841 |
+
Attributes:
|
1842 |
+
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
1843 |
+
main_score (str): The main score used for evaluation.
|
1844 |
+
_requirements_list (List[str]): A list specifying any additional requirements for the metric.
|
1845 |
+
|
1846 |
+
Methods:
|
1847 |
+
prepare(self): Initialization method for the metric.
|
1848 |
+
compute(self, references, predictions, additional_inputs): Method to compute the metric.
|
1849 |
+
|
1850 |
+
Usage:
|
1851 |
+
metric = LlamaIndexCorrectnessMetric()
|
1852 |
+
scores = metric.compute(references, prediction, additional_inputs)
|
1853 |
+
"""
|
1854 |
+
|
1855 |
+
model_name: str = ""
|
1856 |
+
main_score: str = ""
|
1857 |
+
|
1858 |
+
reduction_map: Dict[str, List[str]] = None
|
1859 |
+
openai_models: List[str] = ["gpt-3.5-turbo"]
|
1860 |
+
anthropic_models: List[
|
1861 |
+
str
|
1862 |
+
] = [] # this is here for the sake of documentation for future models
|
1863 |
+
mock_models: List[str] = ["mock"]
|
1864 |
+
external_api_models = openai_models + anthropic_models
|
1865 |
+
|
1866 |
+
_requirements_list: List[str] = ["llama_index"]
|
1867 |
+
|
1868 |
+
@staticmethod
|
1869 |
+
def _custom_parser(eval_response: str):
|
1870 |
+
"""Default parser function for evaluation response.
|
1871 |
+
|
1872 |
+
Args:
|
1873 |
+
eval_response (str): The response string from the evaluation.
|
1874 |
+
|
1875 |
+
Returns:
|
1876 |
+
Tuple[float, str]: A tuple containing the score as a float and the reasoning as a string.
|
1877 |
+
"""
|
1878 |
+
score_str = eval_response.split("\n")[0]
|
1879 |
+
reasoning_str = "\n".join(eval_response.split("\n")[1:])
|
1880 |
+
score = float(score_str)
|
1881 |
+
reasoning = reasoning_str.lstrip("\n")
|
1882 |
+
return score, reasoning
|
1883 |
+
|
1884 |
+
def _model_using_extrnal_api(self):
|
1885 |
+
return self.model_name in self.external_api_models
|
1886 |
+
|
1887 |
+
def prepare(self):
|
1888 |
+
"""Initialization method for the metric. Initializes the CorrectnessEvaluator with the OpenAI model."""
|
1889 |
+
super().prepare()
|
1890 |
+
|
1891 |
+
self.model_name_normalized = self.model_name.replace(".", "_").replace("-", "_")
|
1892 |
+
self.main_score: str = (
|
1893 |
+
f"correctness_llama_index_by_{self.model_name_normalized}_judge"
|
1894 |
+
)
|
1895 |
+
|
1896 |
+
self.reduction_map: Dict[str, List[str]] = {"mean": [self.main_score]}
|
1897 |
+
|
1898 |
+
from llama_index.core.evaluation import CorrectnessEvaluator
|
1899 |
+
|
1900 |
+
if self.model_name in self.openai_models:
|
1901 |
+
from llama_index.llms.openai import OpenAI
|
1902 |
+
|
1903 |
+
llm = OpenAI("gpt-3.5-turbo")
|
1904 |
+
elif self.model_name in self.mock_models:
|
1905 |
+
from llama_index.core.llms.mock import MockLLM
|
1906 |
+
|
1907 |
+
llm = MockLLM(system_prompt="5") # perfect score
|
1908 |
+
else:
|
1909 |
+
raise NotImplementedError(
|
1910 |
+
f"LlamaIndexCorrectnessMetric does not support {self.model_name}, currently only gpt-3.5-turbo is supported"
|
1911 |
+
)
|
1912 |
+
|
1913 |
+
self.evaluator = CorrectnessEvaluator(
|
1914 |
+
llm=llm, parser_function=self._custom_parser
|
1915 |
+
)
|
1916 |
+
|
1917 |
+
def compute(
|
1918 |
+
self,
|
1919 |
+
references: List[str],
|
1920 |
+
prediction: str,
|
1921 |
+
task_data: Dict,
|
1922 |
+
) -> Dict[str, Any]:
|
1923 |
+
"""Method to compute the correctness metric.
|
1924 |
+
|
1925 |
+
Args:
|
1926 |
+
references (List[str]): List of reference instances.
|
1927 |
+
prediction (str): List of predicted instances.
|
1928 |
+
task_data (Dict): List of additional input data.
|
1929 |
+
|
1930 |
+
Returns:
|
1931 |
+
Dict[str, Any]: List of computed scores and feedback.
|
1932 |
+
|
1933 |
+
Raises:
|
1934 |
+
AssertionError: If the input does not meet the expected format.
|
1935 |
+
"""
|
1936 |
+
# treat the references as the questions and the predictions as answers
|
1937 |
+
# assume a single reference
|
1938 |
+
|
1939 |
+
assert (
|
1940 |
+
not self._model_using_extrnal_api()
|
1941 |
+
or settings.allow_passing_data_to_remote_api
|
1942 |
+
), f"Cannot run send data to remote APIs ({self.model_name}) when unitxt.settings.allow_passing_data_to_remote_api=False. Set UNITXT_ALLOW_PASSING_DATA_TO_REMOTE_API environment variable, if you want to allow this."
|
1943 |
+
|
1944 |
+
query = task_data["question"]
|
1945 |
+
contexts = task_data["contexts"]
|
1946 |
+
|
1947 |
+
per_reference_results = []
|
1948 |
+
for reference_response in references:
|
1949 |
+
per_reference_results.append(
|
1950 |
+
self.evaluator.evaluate(
|
1951 |
+
query=query,
|
1952 |
+
response=prediction,
|
1953 |
+
contexts=contexts,
|
1954 |
+
reference=reference_response,
|
1955 |
+
)
|
1956 |
+
)
|
1957 |
+
result = max([results.score for results in per_reference_results])
|
1958 |
+
|
1959 |
+
return {
|
1960 |
+
self.main_score: result / 5,
|
1961 |
+
# "score_name": self.main_score,
|
1962 |
+
# "feedback": result.feedback, # removed since this cannot be tested
|
1963 |
+
}
|
1964 |
+
|
1965 |
+
|
1966 |
class Perplexity(BulkInstanceMetric):
|
1967 |
"""Computes the likelihood of generating text Y after text X - P(Y|X)."""
|
1968 |
|
|
|
1974 |
batch_size: int = 32
|
1975 |
model_name: str
|
1976 |
|
1977 |
+
_requirements_list: List[str] = ["transformers", "torch"]
|
1978 |
|
1979 |
def compute(
|
1980 |
self,
|
|
|
3085 |
],
|
3086 |
}
|
3087 |
}
|
3088 |
+
|
3089 |
+
|
3090 |
+
class BinaryMaxF1(F1Binary):
|
3091 |
+
main_score = "max_f1_binary"
|
3092 |
+
|
3093 |
+
def compute(
|
3094 |
+
self,
|
3095 |
+
references: List[List[str]],
|
3096 |
+
predictions: List[List[str]],
|
3097 |
+
task_data: List[Dict],
|
3098 |
+
) -> dict:
|
3099 |
+
assert all(
|
3100 |
+
len(reference) == 1 for reference in references
|
3101 |
+
), "Only a single reference per prediction is allowed in F1 metric"
|
3102 |
+
classes = set(itertools.chain(*references))
|
3103 |
+
n_clases = len(classes)
|
3104 |
+
assert len(classes) <= 2, "References of BinaryMaxF1 must be binary"
|
3105 |
+
pos_classes = classes.intersection(self.pos_classes)
|
3106 |
+
neg_classes = classes.difference(self.pos_classes)
|
3107 |
+
n_pos_classes = len(pos_classes)
|
3108 |
+
if n_clases == 2:
|
3109 |
+
assert (
|
3110 |
+
n_pos_classes == 1
|
3111 |
+
), "Only one positive class is allowed in BinaryMaxF1"
|
3112 |
+
pos_class = next(iter(pos_classes)) if n_pos_classes > 0 else "1.0"
|
3113 |
+
neg_class = next(iter(neg_classes)) if len(neg_classes) > 0 else "0.0"
|
3114 |
+
|
3115 |
+
float_predictions = []
|
3116 |
+
for prediction in predictions:
|
3117 |
+
try:
|
3118 |
+
float_predictions.append(float(prediction))
|
3119 |
+
except Exception:
|
3120 |
+
float_predictions.append(0)
|
3121 |
+
|
3122 |
+
best_thr = -1
|
3123 |
+
best_f1 = -1
|
3124 |
+
for thr in set(float_predictions):
|
3125 |
+
new_predictions = [
|
3126 |
+
pos_class if float_prediction >= thr else neg_class
|
3127 |
+
for float_prediction in float_predictions
|
3128 |
+
]
|
3129 |
+
f1 = super().compute(references, new_predictions, task_data)[
|
3130 |
+
self.main_score
|
3131 |
+
]
|
3132 |
+
if f1 > best_f1:
|
3133 |
+
best_f1 = f1
|
3134 |
+
best_thr = thr
|
3135 |
+
|
3136 |
+
return {self.main_score: best_f1, "best_thr_maxf1": best_thr}
|