Elron commited on
Commit
3157b84
1 Parent(s): 50b5364

Upload metric_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. metric_utils.py +97 -9
metric_utils.py CHANGED
@@ -1,7 +1,9 @@
1
- from typing import Dict, Iterable, List
 
2
 
3
  from datasets import Features, Value
4
 
 
5
  from .operator import (
6
  MultiStreamOperator,
7
  SequentialOperatorInitilizer,
@@ -17,6 +19,7 @@ from .operators import (
17
  )
18
  from .register import _reset_env_local_catalogs, register_all_artifacts
19
  from .schema import UNITXT_DATASET_SCHEMA
 
20
  from .stream import MultiStream, Stream
21
 
22
 
@@ -83,16 +86,12 @@ class FromPredictionsAndOriginalData(StreamInitializerOperator):
83
  )
84
 
85
 
86
- # The additional_inputs field in the schema is defined as
87
  # Sequence({"key": Value(dtype="string"), "value": Value("string")})
88
  # When receiving instances from this scheme, the keys and values are returned as two separate
89
  # lists, and are converted to a dictionary.
90
 
91
 
92
- def _from_key_value_pairs(key_value_list: Dict[str, list]) -> Dict[str, str]:
93
- return dict(zip(key_value_list["key"], key_value_list["value"]))
94
-
95
-
96
  class MetricRecipe(SequentialOperatorInitilizer):
97
  calc_confidence_intervals: bool = True
98
 
@@ -101,9 +100,9 @@ class MetricRecipe(SequentialOperatorInitilizer):
101
  self.steps = [
102
  FromPredictionsAndOriginalData(),
103
  Apply(
104
- "additional_inputs",
105
- function=_from_key_value_pairs,
106
- to_field="additional_inputs",
107
  ),
108
  ApplyOperatorsField(
109
  operators_field="postprocessors",
@@ -144,3 +143,92 @@ def _compute(
144
 
145
  stream = multi_stream[split_name]
146
  return list(stream)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Any, Dict, Iterable, List, Optional
3
 
4
  from datasets import Features, Value
5
 
6
+ from .dataclass import Dataclass
7
  from .operator import (
8
  MultiStreamOperator,
9
  SequentialOperatorInitilizer,
 
19
  )
20
  from .register import _reset_env_local_catalogs, register_all_artifacts
21
  from .schema import UNITXT_DATASET_SCHEMA
22
+ from .settings_utils import get_settings
23
  from .stream import MultiStream, Stream
24
 
25
 
 
86
  )
87
 
88
 
89
+ # The task_data field in the schema is defined as
90
  # Sequence({"key": Value(dtype="string"), "value": Value("string")})
91
  # When receiving instances from this scheme, the keys and values are returned as two separate
92
  # lists, and are converted to a dictionary.
93
 
94
 
 
 
 
 
95
  class MetricRecipe(SequentialOperatorInitilizer):
96
  calc_confidence_intervals: bool = True
97
 
 
100
  self.steps = [
101
  FromPredictionsAndOriginalData(),
102
  Apply(
103
+ "task_data",
104
+ function="json.loads",
105
+ to_field="task_data",
106
  ),
107
  ApplyOperatorsField(
108
  operators_field="postprocessors",
 
143
 
144
  stream = multi_stream[split_name]
145
  return list(stream)
146
+
147
+
148
+ """
149
+ The API of a metric service:
150
+ - MetricRequest: A single input request to the metrics service.
151
+ - MetricResponse: A response returned from a metrics service.
152
+ """
153
+
154
+
155
+ class InstanceInput(Dataclass):
156
+ """A single instance inputted to a metric service."""
157
+
158
+ prediction: Any
159
+ references: List[Any]
160
+ additional_inputs: Optional[Dict] = None
161
+
162
+
163
+ class MetricRequest(Dataclass):
164
+ """A request to a metrics service, includes a list of input instances."""
165
+
166
+ instance_inputs: List[InstanceInput]
167
+
168
+
169
+ class MetricResponse(Dataclass):
170
+ """A response produced by a metrics service, includes the computed scores."""
171
+
172
+ # A list of instance score dictionaries. Each dictionary contains the
173
+ # score names and score values for a single instance.
174
+ instances_scores: List[Dict[str, Any]]
175
+ # The global scores dictionary, containing global score names and values.
176
+ # These are scores computed over the entire set of input instances, e.g.
177
+ # an average over a score computed per instance.
178
+ global_score: Dict[str, Any]
179
+
180
+
181
+ """
182
+ Functionality for loading the remote metrics configuration from local environment variables.
183
+ """
184
+
185
+ # A list of metrics to be executed remotely.
186
+ # For example: '["metrics.rag.context_relevance","metrics.rag.bert_k_precision"]'
187
+ # This value should be a valid json list
188
+ UNITXT_REMOTE_METRICS = "UNITXT_REMOTE_METRICS"
189
+
190
+ # The remote endpoint on which the remote metrics are available.
191
+ # For example, 'http://127.0.0.1:8000/compute'
192
+ UNITXT_REMOTE_METRICS_ENDPOINT = "UNITXT_REMOTE_METRICS_ENDPOINT"
193
+
194
+
195
+ def get_remote_metrics_names() -> List[str]:
196
+ """Load the remote metrics names from an environment variable.
197
+
198
+ Returns:
199
+ List[str] - names of metrics to be executed remotely.
200
+ """
201
+ settings = get_settings()
202
+ remote_metrics = settings.remote_metrics
203
+ if remote_metrics:
204
+ remote_metrics = json.loads(remote_metrics)
205
+ if not isinstance(remote_metrics, list):
206
+ raise RuntimeError(
207
+ f"Unexpected value {remote_metrics} for the '{UNITXT_REMOTE_METRICS}' environment variable. "
208
+ f"The value is expected to be a list of metric names in json format."
209
+ )
210
+ for remote_metric in remote_metrics:
211
+ if not isinstance(remote_metric, str):
212
+ raise RuntimeError(
213
+ f"Unexpected value {remote_metric} within the '{UNITXT_REMOTE_METRICS}' environment variable. "
214
+ f"The value is expected to be a string but its type is {type(remote_metric)}."
215
+ )
216
+ return remote_metrics
217
+
218
+
219
+ def get_remote_metrics_endpoint() -> str:
220
+ """Load the remote metrics endpoint from an environment variable.
221
+
222
+ Returns:
223
+ str - The remote endpoint on which the remote metrics are available.
224
+ """
225
+ settings = get_settings()
226
+ try:
227
+ remote_metrics_endpoint = settings.remote_metrics_endpoint
228
+ except AttributeError as e:
229
+ raise RuntimeError(
230
+ f"Unexpected None value for '{UNITXT_REMOTE_METRICS_ENDPOINT}'. "
231
+ f"Running remote metrics requires defining an "
232
+ f"endpoint in the environment variable '{UNITXT_REMOTE_METRICS_ENDPOINT}'."
233
+ ) from e
234
+ return remote_metrics_endpoint