Elron commited on
Commit
ff375eb
1 Parent(s): a350a45

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. metrics.py +7 -3
  2. task.py +10 -4
  3. version.py +1 -1
metrics.py CHANGED
@@ -1166,7 +1166,9 @@ class HuggingfaceBulkMetric(BulkInstanceMetric):
1166
 
1167
  def prepare(self):
1168
  super().prepare()
1169
- self.metric = evaluate.load(self.hf_metric_name)
 
 
1170
 
1171
  def compute(
1172
  self,
@@ -1213,7 +1215,7 @@ class F1(GlobalMetric):
1213
 
1214
  def prepare(self):
1215
  super().prepare()
1216
- self._metric = evaluate.load(self.metric)
1217
 
1218
  def get_str_id(self, str):
1219
  if str not in self.str_to_id:
@@ -1337,7 +1339,9 @@ class F1MultiLabel(GlobalMetric):
1337
 
1338
  def prepare(self):
1339
  super().prepare()
1340
- self._metric = evaluate.load(self.metric, "multilabel")
 
 
1341
 
1342
  def add_str_to_id(self, str):
1343
  if str not in self.str_to_id:
 
1166
 
1167
  def prepare(self):
1168
  super().prepare()
1169
+ self.metric = evaluate.load(
1170
+ self.hf_metric_name, experiment_id=str(uuid.uuid4())
1171
+ )
1172
 
1173
  def compute(
1174
  self,
 
1215
 
1216
  def prepare(self):
1217
  super().prepare()
1218
+ self._metric = evaluate.load(self.metric, experiment_id=str(uuid.uuid4()))
1219
 
1220
  def get_str_id(self, str):
1221
  if str not in self.str_to_id:
 
1339
 
1340
  def prepare(self):
1341
  super().prepare()
1342
+ self._metric = evaluate.load(
1343
+ self.metric, "multilabel", experiment_id=str(uuid.uuid4())
1344
+ )
1345
 
1346
  def add_str_to_id(self, str):
1347
  if str not in self.str_to_id:
task.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Any, Dict, List, Optional, Union
2
 
3
  from .artifact import fetch_artifact
@@ -75,11 +76,16 @@ class FormTask(Tasker, StreamInstanceOperator):
75
  augmentable_input in self.inputs
76
  ), f"augmentable_input {augmentable_input} is not part of {self.inputs}"
77
 
 
 
 
 
 
 
78
  def check_metrics_type(self) -> None:
79
  prediction_type = parse_type_string(self.prediction_type)
80
- for metric_name in self.metrics:
81
- metric = fetch_artifact(metric_name)[0]
82
- metric_prediction_type = metric.get_prediction_type()
83
 
84
  if (
85
  prediction_type == metric_prediction_type
@@ -93,7 +99,7 @@ class FormTask(Tasker, StreamInstanceOperator):
93
  continue
94
 
95
  raise ValueError(
96
- f"The task's prediction type ({prediction_type}) and '{metric_name}' "
97
  f"metric's prediction type ({metric_prediction_type}) are different."
98
  )
99
 
 
1
+ from functools import lru_cache
2
  from typing import Any, Dict, List, Optional, Union
3
 
4
  from .artifact import fetch_artifact
 
76
  augmentable_input in self.inputs
77
  ), f"augmentable_input {augmentable_input} is not part of {self.inputs}"
78
 
79
+ @staticmethod
80
+ @lru_cache(maxsize=None)
81
+ def get_metric_prediction_type(metric_id: str):
82
+ metric = fetch_artifact(metric_id)[0]
83
+ return metric.get_prediction_type()
84
+
85
  def check_metrics_type(self) -> None:
86
  prediction_type = parse_type_string(self.prediction_type)
87
+ for metric_id in self.metrics:
88
+ metric_prediction_type = FormTask.get_metric_prediction_type(metric_id)
 
89
 
90
  if (
91
  prediction_type == metric_prediction_type
 
99
  continue
100
 
101
  raise ValueError(
102
+ f"The task's prediction type ({prediction_type}) and '{metric_id}' "
103
  f"metric's prediction type ({metric_prediction_type}) are different."
104
  )
105
 
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.8.0"
 
1
+ version = "1.8.1"