Elron commited on
Commit
fbd19c3
1 Parent(s): 429723e

Upload metric.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. metric.py +20 -2
metric.py CHANGED
@@ -42,7 +42,7 @@ from .processors import __file__ as _
42
  from .random_utils import __file__ as _
43
  from .recipe import __file__ as _
44
  from .register import __file__ as _
45
- from .register import register_all_artifacts
46
  from .renderers import __file__ as _
47
  from .schema import __file__ as _
48
  from .split_utils import __file__ as _
@@ -124,6 +124,8 @@ UNITXT_METRIC_SCHEMA = Features({"predictions": Value("string"), "references": d
124
 
125
 
126
  def _compute(predictions: List[str], references: Iterable, flatten: bool = False, split_name: str = "all"):
 
 
127
  recipe = MetricRecipe()
128
 
129
  multi_stream = recipe(predictions=predictions, references=references, split_name=split_name)
@@ -154,4 +156,20 @@ class Metric(evaluate.Metric):
154
  )
155
 
156
  def _compute(self, predictions: List[str], references: Iterable, flatten: bool = False, split_name: str = "all"):
157
- return _compute(predictions=predictions, references=references, flatten=flatten, split_name=split_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  from .random_utils import __file__ as _
43
  from .recipe import __file__ as _
44
  from .register import __file__ as _
45
+ from .register import _reset_env_local_catalogs, register_all_artifacts
46
  from .renderers import __file__ as _
47
  from .schema import __file__ as _
48
  from .split_utils import __file__ as _
 
124
 
125
 
126
  def _compute(predictions: List[str], references: Iterable, flatten: bool = False, split_name: str = "all"):
127
+ _reset_env_local_catalogs()
128
+ register_all_artifacts()
129
  recipe = MetricRecipe()
130
 
131
  multi_stream = recipe(predictions=predictions, references=references, split_name=split_name)
 
156
  )
157
 
158
  def _compute(self, predictions: List[str], references: Iterable, flatten: bool = False, split_name: str = "all"):
159
+ try:
160
+ from unitxt.dataset import (
161
+ get_dataset_artifact as get_dataset_artifact_installed,
162
+ )
163
+
164
+ unitxt_installed = True
165
+ except ImportError:
166
+ unitxt_installed = False
167
+
168
+ if unitxt_installed:
169
+ from unitxt.metric import _compute as _compute_installed
170
+
171
+ return _compute_installed(
172
+ predictions=predictions, references=references, flatten=flatten, split_name=split_name
173
+ )
174
+ else:
175
+ return _compute(predictions=predictions, references=references, flatten=flatten, split_name=split_name)