Elron commited on
Commit
0a1b314
1 Parent(s): b462f85

Upload folder using huggingface_hub

Browse files
artifact.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  import pkgutil
6
  from abc import abstractmethod
7
  from copy import deepcopy
8
- from typing import Dict, List, Optional, Union, final
9
 
10
  from .dataclass import (
11
  AbstractField,
@@ -129,6 +129,10 @@ class Artifact(Dataclass):
129
  )
130
  __id__: str = InternalField(default=None, required=False, also_positional=False)
131
 
 
 
 
 
132
  @classmethod
133
  def is_artifact_dict(cls, d):
134
  return isinstance(d, dict) and "type" in d
@@ -226,6 +230,11 @@ class Artifact(Dataclass):
226
  new_artifact.__id__ = artifact_identifier
227
  return new_artifact
228
 
 
 
 
 
 
229
  def prepare(self):
230
  pass
231
 
@@ -236,6 +245,20 @@ class Artifact(Dataclass):
236
  def __pre_init__(self, **kwargs):
237
  self._init_dict = get_raw(kwargs)
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  @final
240
  def __post_init__(self):
241
  self.type = self.register_class(self.__class__)
@@ -248,6 +271,7 @@ class Artifact(Dataclass):
248
  value = map_values_in_place(value, maybe_recover_artifact)
249
  setattr(self, field.name, value)
250
 
 
251
  if not settings.skip_artifacts_prepare_and_verify:
252
  self.prepare()
253
  self.verify()
@@ -259,6 +283,76 @@ class Artifact(Dataclass):
259
  data = self.to_dict()
260
  save_json(path, data)
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  def get_raw(obj):
264
  if isinstance(obj, Artifact):
@@ -367,3 +461,53 @@ def register_all_artifacts(path):
367
  # Make sure the class is a subclass of Artifact (but not Artifact itself)
368
  if issubclass(obj, Artifact) and obj is not Artifact:
369
  logger.info(obj)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import pkgutil
6
  from abc import abstractmethod
7
  from copy import deepcopy
8
+ from typing import Any, Dict, List, Optional, Union, final
9
 
10
  from .dataclass import (
11
  AbstractField,
 
129
  )
130
  __id__: str = InternalField(default=None, required=False, also_positional=False)
131
 
132
+ data_classification_policy: List[str] = NonPositionalField(
133
+ default=None, required=False, also_positional=False
134
+ )
135
+
136
  @classmethod
137
  def is_artifact_dict(cls, d):
138
  return isinstance(d, dict) and "type" in d
 
230
  new_artifact.__id__ = artifact_identifier
231
  return new_artifact
232
 
233
+ def get_pretty_print_name(self):
234
+ if self.__id__ is not None:
235
+ return self.__id__
236
+ return self.__class__.__name__
237
+
238
  def prepare(self):
239
  pass
240
 
 
245
  def __pre_init__(self, **kwargs):
246
  self._init_dict = get_raw(kwargs)
247
 
248
+ @final
249
+ def verify_data_classification_policy(self):
250
+ if self.data_classification_policy is not None:
251
+ if not isinstance(self.data_classification_policy, list) or not all(
252
+ isinstance(data_classification, str)
253
+ for data_classification in self.data_classification_policy
254
+ ):
255
+ raise ValueError(
256
+ f"The 'data_classification_policy' of {self.get_pretty_print_name()} "
257
+ f"must be either None - in case when no policy applies - or a list of "
258
+ f"strings, for example: ['public']. However, '{self.data_classification_policy}' "
259
+ f"of type {type(self.data_classification_policy)} was provided instead."
260
+ )
261
+
262
  @final
263
  def __post_init__(self):
264
  self.type = self.register_class(self.__class__)
 
271
  value = map_values_in_place(value, maybe_recover_artifact)
272
  setattr(self, field.name, value)
273
 
274
+ self.verify_data_classification_policy()
275
  if not settings.skip_artifacts_prepare_and_verify:
276
  self.prepare()
277
  self.verify()
 
283
  data = self.to_dict()
284
  save_json(path, data)
285
 
286
+ def verify_instance(
287
+ self, instance: Dict[str, Any], name: Optional[str] = None
288
+ ) -> Dict[str, Any]:
289
+ """Checks if data classifications of an artifact and instance are compatible.
290
+
291
+ Raises an error if an artifact's data classification policy does not include that of
292
+ processed data. The purpose is to ensure that any sensitive data is handled in a
293
+ proper way (for example when sending it to some external services).
294
+
295
+ Args:
296
+ instance (Dict[str, Any]): data which should contain its allowed data
297
+ classification policies under key 'data_classification_policy'.
298
+ name (Optional[str]): name of artifact which should be used to retrieve
299
+ data classification from env. If not specified, then either __id__ or
300
+ __class__.__name__, are used instead, respectively.
301
+
302
+ Returns:
303
+ Dict[str, Any]: unchanged instance.
304
+
305
+ Examples:
306
+ instance = {"x": "some_text", "data_classification_policy": ["pii"]}
307
+
308
+ # Will raise an error as "pii" is not included policy
309
+ metric = Accuracy(data_classification_policy=["public"])
310
+ metric.verify_instance(instance)
311
+
312
+ # Will not raise an error
313
+ template = SpanLabelingTemplate(data_classification_policy=["pii", "propriety"])
314
+ template.verify_instance(instance)
315
+
316
+ # Will not raise an error since the policy was specified in environment variable:
317
+ UNITXT_DATA_CLASSIFICATION_POLICY = json.dumps({"metrics.accuracy": ["pii"]})
318
+ metric = fetch_artifact("metrics.accuracy")
319
+ metric.verify_instance(instance)
320
+ """
321
+ name = name or self.get_pretty_print_name()
322
+ data_classification_policy = get_artifacts_data_classification(name)
323
+ if not data_classification_policy:
324
+ data_classification_policy = self.data_classification_policy
325
+
326
+ if not data_classification_policy:
327
+ return instance
328
+
329
+ instance_data_classification = instance.get("data_classification_policy")
330
+ if not instance_data_classification:
331
+ get_logger().warning(
332
+ f"The data does not provide information if it can be used by "
333
+ f"'{name}' with the following data classification policy "
334
+ f"'{data_classification_policy}'. This may lead to sending of undesired "
335
+ f"data to external service. Set the 'data_classification_policy' "
336
+ f"of the data to ensure a proper handling of sensitive information."
337
+ )
338
+ return instance
339
+
340
+ if not any(
341
+ data_classification in data_classification_policy
342
+ for data_classification in instance_data_classification
343
+ ):
344
+ raise ValueError(
345
+ f"The instance '{instance} 'has the following data classification policy "
346
+ f"'{instance_data_classification}', however, the artifact '{name}' "
347
+ f"is only configured to support the data with classification "
348
+ f"'{data_classification_policy}'. To enable this either change "
349
+ f"the 'data_classification_policy' attribute of the artifact, "
350
+ f"or modify the environment variable "
351
+ f"'UNITXT_DATA_CLASSIFICATION_POLICY' accordingly."
352
+ )
353
+
354
+ return instance
355
+
356
 
357
  def get_raw(obj):
358
  if isinstance(obj, Artifact):
 
461
  # Make sure the class is a subclass of Artifact (but not Artifact itself)
462
  if issubclass(obj, Artifact) and obj is not Artifact:
463
  logger.info(obj)
464
+
465
+
466
+ def get_artifacts_data_classification(artifact: str) -> Optional[List[str]]:
467
+ """Loads given artifact's data classification policy from an environment variable.
468
+
469
+ Args:
470
+ artifact (str): Name of the artifact which the data classification policy
471
+ should be retrieved for. For example "metrics.accuracy".
472
+
473
+ Returns:
474
+ Optional[List[str]] - Data classification policies for the specified artifact
475
+ if they were found, or None otherwise.
476
+ """
477
+ data_classification = settings.data_classification_policy
478
+ if data_classification is None:
479
+ return None
480
+
481
+ error_msg = (
482
+ f"If specified, the value of 'UNITXT_DATA_CLASSIFICATION_POLICY' "
483
+ f"should be a valid json dictionary. Got '{data_classification}' "
484
+ f"instead."
485
+ )
486
+
487
+ try:
488
+ data_classification = json.loads(data_classification)
489
+ except json.decoder.JSONDecodeError as e:
490
+ raise RuntimeError(error_msg) from e
491
+
492
+ if not isinstance(data_classification, dict):
493
+ raise RuntimeError(error_msg)
494
+
495
+ for artifact_name, artifact_data_classifications in data_classification.items():
496
+ if (
497
+ not isinstance(artifact_name, str)
498
+ or not isinstance(artifact_data_classifications, list)
499
+ or not all(
500
+ isinstance(artifact_data_classification, str)
501
+ for artifact_data_classification in artifact_data_classifications
502
+ )
503
+ ):
504
+ raise RuntimeError(
505
+ "'UNITXT_DATA_CLASSIFICATION_POLICY' should be of type "
506
+ "'Dict[str, List[str]]', where a artifact's name is a key, and a "
507
+ "value is a list of data classifications used by that artifact."
508
+ )
509
+
510
+ if artifact not in data_classification.keys():
511
+ return None
512
+
513
+ return data_classification.get(artifact)
blocks.py CHANGED
@@ -31,7 +31,7 @@ from .struct_data_operators import (
31
  TruncateTableCells,
32
  TruncateTableRows,
33
  )
34
- from .task import Task
35
  from .templates import (
36
  InputOutputTemplate,
37
  MultiLabelTemplate,
 
31
  TruncateTableCells,
32
  TruncateTableRows,
33
  )
34
+ from .task import FormTask, Task
35
  from .templates import (
36
  InputOutputTemplate,
37
  MultiLabelTemplate,
collections_operators.py CHANGED
@@ -1,7 +1,7 @@
1
  from copy import deepcopy
2
  from typing import Any, Generator, List, Optional
3
 
4
- from .operators import FieldOperator, SingleStreamOperator
5
  from .stream import Stream
6
 
7
 
@@ -58,7 +58,7 @@ class Get(FieldOperator):
58
  return collection[self.item]
59
 
60
 
61
- class DuplicateByList(SingleStreamOperator):
62
  field: str
63
  to_field: Optional[str] = None
64
  use_deep_copy: bool = False
@@ -80,7 +80,7 @@ class DuplicateByList(SingleStreamOperator):
80
  yield instance_copy
81
 
82
 
83
- class DuplicateBySubLists(SingleStreamOperator):
84
  field: str
85
  to_field: Optional[str] = None
86
  use_deep_copy: bool = False
 
1
  from copy import deepcopy
2
  from typing import Any, Generator, List, Optional
3
 
4
+ from .operators import FieldOperator, StreamOperator
5
  from .stream import Stream
6
 
7
 
 
58
  return collection[self.item]
59
 
60
 
61
+ class DuplicateByList(StreamOperator):
62
  field: str
63
  to_field: Optional[str] = None
64
  use_deep_copy: bool = False
 
80
  yield instance_copy
81
 
82
 
83
+ class DuplicateBySubLists(StreamOperator):
84
  field: str
85
  to_field: Optional[str] = None
86
  use_deep_copy: bool = False
dict_utils.py CHANGED
@@ -1,6 +1,8 @@
1
  import re
2
  from typing import Any, List, Tuple
3
 
 
 
4
  indx = re.compile(r"^(\d+)$")
5
  name = re.compile(r"^[\w. -]+$")
6
 
@@ -395,22 +397,20 @@ def dict_get(
395
  if len(components) > 1:
396
  try:
397
  success, values = get_values(dic, components, -1 * len(components))
398
- if not success:
399
- if not_exist_ok:
400
- return default
401
- raise ValueError(
402
- f'query "{query}" did not match any item in dict: {dic}'
403
- )
404
-
405
- return values
406
-
407
  except Exception as e:
408
- if not_exist_ok:
409
- return default
410
  raise ValueError(
411
- f'query "{query}" did not match any item in dict: {dic}'
412
  ) from e
413
 
 
 
 
 
 
 
 
414
  # len(components) == 1
415
  if components[0] in dic:
416
  return dic[components[0]]
@@ -418,7 +418,9 @@ def dict_get(
418
  if not_exist_ok:
419
  return default
420
 
421
- raise ValueError(f'query "{query}" did not match any item in dict: {dic}')
 
 
422
 
423
 
424
  # dict_set sets a value, 'value', which by itself, can be a dict or list or scalar, into 'dic', to become the value of
 
1
  import re
2
  from typing import Any, List, Tuple
3
 
4
+ from .text_utils import construct_dict_str
5
+
6
  indx = re.compile(r"^(\d+)$")
7
  name = re.compile(r"^[\w. -]+$")
8
 
 
397
  if len(components) > 1:
398
  try:
399
  success, values = get_values(dic, components, -1 * len(components))
400
+ if success:
401
+ return values
 
 
 
 
 
 
 
402
  except Exception as e:
 
 
403
  raise ValueError(
404
+ f'query "{query}" did not match any item in dict:\n{construct_dict_str(dic)}'
405
  ) from e
406
 
407
+ if not_exist_ok:
408
+ return default
409
+
410
+ raise ValueError(
411
+ f'query "{query}" did not match any item in dict:\n{construct_dict_str(dic)}'
412
+ )
413
+
414
  # len(components) == 1
415
  if components[0] in dic:
416
  return dic[components[0]]
 
418
  if not_exist_ok:
419
  return default
420
 
421
+ raise ValueError(
422
+ f'query "{query}" did not match any item in dict:\n{construct_dict_str(dic)}'
423
+ )
424
 
425
 
426
  # dict_set sets a value, 'value', which by itself, can be a dict or list or scalar, into 'dic', to become the value of
formats.py CHANGED
@@ -7,11 +7,11 @@ from typing import (
7
  )
8
 
9
  from .dataclass import OptionalField
10
- from .operator import StreamInstanceOperator
11
  from .type_utils import isoftype
12
 
13
 
14
- class Format(StreamInstanceOperator):
15
  pass
16
 
17
 
 
7
  )
8
 
9
  from .dataclass import OptionalField
10
+ from .operator import InstanceOperator
11
  from .type_utils import isoftype
12
 
13
 
14
+ class Format(InstanceOperator):
15
  pass
16
 
17
 
inference.py CHANGED
@@ -5,24 +5,20 @@ from typing import Any, Dict, List, Literal, Optional, Union
5
 
6
  from .artifact import Artifact
7
  from .operator import PackageRequirementsMixin
8
- from .settings_utils import get_settings
9
 
10
 
11
  class InferenceEngine(abc.ABC, Artifact):
12
  """Abstract base class for inference."""
13
 
14
  @abc.abstractmethod
15
- def infer(self, dataset):
16
  """Perform inference on the input dataset."""
17
  pass
18
 
19
- @staticmethod
20
- def _assert_allow_passing_data_to_remote_api(remote_api_label: str):
21
- assert get_settings().allow_passing_data_to_remote_api, (
22
- f"LlmAsJudge metric cannot run send data to remote APIs ({remote_api_label}) when"
23
- f" unitxt.settings.allow_passing_data_to_remote_api=False."
24
- f" Set UNITXT_ALLOW_PASSING_DATA_TO_REMOTE_API environment variable, if you want to allow this. "
25
- )
26
 
27
 
28
  class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
@@ -73,7 +69,7 @@ class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
73
  model=self.model_name, trust_remote_code=True, **model_args
74
  )
75
 
76
- def infer(self, dataset):
77
  outputs = []
78
  for output in self.model([instance["source"] for instance in dataset]):
79
  if isinstance(output, list):
@@ -88,7 +84,7 @@ class MockInferenceEngine(InferenceEngine):
88
  def prepare(self):
89
  return
90
 
91
- def infer(self, dataset):
92
  return ["[[10]]" for instance in dataset]
93
 
94
 
@@ -114,6 +110,7 @@ class IbmGenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
114
  _requirement = {
115
  "genai": "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai"
116
  }
 
117
 
118
  def prepare(self):
119
  from genai import Client, Credentials
@@ -128,9 +125,7 @@ class IbmGenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
128
  credentials = Credentials(api_key=api_key, api_endpoint=api_endpoint)
129
  self.client = Client(credentials=credentials)
130
 
131
- self._assert_allow_passing_data_to_remote_api(self.label)
132
-
133
- def infer(self, dataset):
134
  from genai.schema import TextGenerationParameters
135
 
136
  genai_params = TextGenerationParameters(
@@ -186,9 +181,8 @@ class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
186
  )
187
 
188
  self.client = OpenAI(api_key=api_key)
189
- self._assert_allow_passing_data_to_remote_api(self.label)
190
 
191
- def infer(self, dataset):
192
  return [
193
  self.client.chat.completions.create(
194
  messages=[
 
5
 
6
  from .artifact import Artifact
7
  from .operator import PackageRequirementsMixin
 
8
 
9
 
10
  class InferenceEngine(abc.ABC, Artifact):
11
  """Abstract base class for inference."""
12
 
13
  @abc.abstractmethod
14
+ def _infer(self, dataset):
15
  """Perform inference on the input dataset."""
16
  pass
17
 
18
+ def infer(self, dataset):
19
+ """Verifies instances of a dataset and performs inference."""
20
+ [self.verify_instance(instance) for instance in dataset]
21
+ return self._infer(dataset)
 
 
 
22
 
23
 
24
  class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
 
69
  model=self.model_name, trust_remote_code=True, **model_args
70
  )
71
 
72
+ def _infer(self, dataset):
73
  outputs = []
74
  for output in self.model([instance["source"] for instance in dataset]):
75
  if isinstance(output, list):
 
84
  def prepare(self):
85
  return
86
 
87
+ def _infer(self, dataset):
88
  return ["[[10]]" for instance in dataset]
89
 
90
 
 
110
  _requirement = {
111
  "genai": "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai"
112
  }
113
+ data_classification_policy = ["public", "proprietary"]
114
 
115
  def prepare(self):
116
  from genai import Client, Credentials
 
125
  credentials = Credentials(api_key=api_key, api_endpoint=api_endpoint)
126
  self.client = Client(credentials=credentials)
127
 
128
+ def _infer(self, dataset):
 
 
129
  from genai.schema import TextGenerationParameters
130
 
131
  genai_params = TextGenerationParameters(
 
181
  )
182
 
183
  self.client = OpenAI(api_key=api_key)
 
184
 
185
+ def _infer(self, dataset):
186
  return [
187
  self.client.chat.completions.create(
188
  messages=[
loaders.py CHANGED
@@ -15,16 +15,26 @@ Unitxt catalog contains several loaders for the most popular datasource formats.
15
  All these loaders inherit from Loader, and hence, implementing a loader to expand over a new type of datasource, is
16
  straight forward.
17
 
18
- Operators in Unitxt catalog:
19
- LoadHF : loads from Huggingface dataset.
20
- LoadCSV: loads from csv (comma separated value) files
21
- LoadFromKaggle: loads datasets from the kaggle.com community site
22
- LoadFromIBMCloud: loads a dataset from the IBM cloud.
 
 
 
 
 
 
 
 
23
  ------------------------
24
  """
25
  import itertools
26
  import os
27
  import tempfile
 
 
28
  from pathlib import Path
29
  from tempfile import TemporaryDirectory
30
  from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
@@ -37,6 +47,7 @@ from .dataclass import InternalField, OptionalField
37
  from .fusion import FixedFusion
38
  from .logging_utils import get_logger
39
  from .operator import SourceOperator
 
40
  from .settings_utils import get_settings
41
  from .stream import GeneratorStream, MultiStream
42
 
@@ -45,12 +56,22 @@ settings = get_settings()
45
 
46
 
47
  class Loader(SourceOperator):
48
- # The loader_limit an optional parameter used to control the maximum number of instances to load from the the source.
49
- # It is usually provided to the loader via the recipe (see standard.py)
50
- # The loader can use this value to limit the amount of data downloaded from the source
51
- # to reduce loading time. However, this may not always be possible, so the
52
- # loader may ignore this. In any case, the recipe, will limit the number of instances in the returned
53
- # stream, after load is complete.
 
 
 
 
 
 
 
 
 
 
54
  loader_limit: int = None
55
  streaming: bool = False
56
 
@@ -75,8 +96,66 @@ class Loader(SourceOperator):
75
  f"\nLoading limited to {self.get_limit()} instances by setting {self.get_limiter()};"
76
  )
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  class LoadHF(Loader):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  path: str
81
  name: Optional[str] = None
82
  data_dir: Optional[str] = None
@@ -187,7 +266,15 @@ class LoadHF(Loader):
187
  }
188
  )
189
 
190
- def process(self):
 
 
 
 
 
 
 
 
191
  try:
192
  dataset = self.stream_dataset()
193
  except (
@@ -202,6 +289,25 @@ class LoadHF(Loader):
202
 
203
 
204
  class LoadCSV(Loader):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  files: Dict[str, str]
206
  chunksize: int = 1000
207
  _cache = InternalField(default_factory=dict)
@@ -236,7 +342,10 @@ class LoadCSV(Loader):
236
 
237
  yield from self._cache[file]
238
 
239
- def process(self):
 
 
 
240
  if self.streaming:
241
  return MultiStream(
242
  {
@@ -258,8 +367,25 @@ class LoadCSV(Loader):
258
 
259
 
260
  class LoadFromSklearn(Loader):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  dataset_name: str
262
  splits: List[str] = ["train", "test"]
 
263
 
264
  _requirements_list: List[str] = ["sklearn", "pandas"]
265
 
@@ -275,7 +401,7 @@ class LoadFromSklearn(Loader):
275
 
276
  self.downloader = getattr(sklearn_datatasets, f"fetch_{self.dataset_name}")
277
 
278
- def process(self):
279
  with TemporaryDirectory() as temp_directory:
280
  for split in self.splits:
281
  split_data = self.downloader(subset=split)
@@ -293,8 +419,25 @@ class MissingKaggleCredentialsError(ValueError):
293
 
294
 
295
  class LoadFromKaggle(Loader):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  url: str
 
297
  _requirements_list: List[str] = ["opendatasets"]
 
298
 
299
  def verify(self):
300
  super().verify()
@@ -312,7 +455,7 @@ class LoadFromKaggle(Loader):
312
 
313
  self.downloader = download
314
 
315
- def process(self):
316
  with TemporaryDirectory() as temp_directory:
317
  self.downloader(self.url, temp_directory)
318
  dataset = hf_load_dataset(temp_directory, streaming=False)
@@ -321,18 +464,47 @@ class LoadFromKaggle(Loader):
321
 
322
 
323
  class LoadFromIBMCloud(Loader):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  endpoint_url_env: str
325
  aws_access_key_id_env: str
326
  aws_secret_access_key_env: str
327
  bucket_name: str
328
  data_dir: str = None
329
 
330
- # Can be either:
331
- # 1. a list of file names, the split of each file is determined by the file name pattern
332
- # 2. Mapping: split -> file_name, e.g. {"test" : "test.json", "train": "train.json"}
333
- # 3. Mapping: split -> file_names, e.g. {"test" : ["test1.json", "test2.json"], "train": ["train.json"]}
334
  data_files: Union[Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
335
  caching: bool = True
 
 
336
  _requirements_list: List[str] = ["ibm_boto3"]
337
 
338
  def _download_from_cos(self, cos, bucket_name, item_name, local_file):
@@ -400,7 +572,10 @@ class LoadFromIBMCloud(Loader):
400
  if self.streaming:
401
  raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
402
 
403
- def process(self):
 
 
 
404
  import ibm_boto3
405
 
406
  cos = ibm_boto3.resource(
@@ -458,23 +633,37 @@ class LoadFromIBMCloud(Loader):
458
 
459
 
460
  class MultipleSourceLoader(Loader):
461
- """Allow loading data from multiple sources.
 
 
 
462
 
463
  Examples:
464
- 1) Loading the train split from Huggingface hub and the test set from a local file:
 
 
 
 
465
 
466
- MultipleSourceLoader(loaders = [ LoadHF(path="public/data",split="train"), LoadCSV({"test": "mytest.csv"}) ])
467
 
468
- 2) Loading a test set combined from two files
469
 
470
- MultipleSourceLoader(loaders = [ LoadCSV({"test": "mytest1.csv"}, LoadCSV({"test": "mytest2.csv"}) ])
471
 
 
472
 
 
473
  """
474
 
475
  sources: List[Loader]
476
 
477
- def process(self):
 
 
 
 
 
 
 
478
  return FixedFusion(
479
  origins=self.sources, max_instances_per_origin_split=self.get_limit()
480
  ).process()
@@ -485,19 +674,138 @@ class LoadFromDictionary(Loader):
485
 
486
  The loader can be used, for example, when debugging or working with small datasets.
487
 
488
- Attributes:
489
  data (Dict[str, List[Dict[str, Any]]]): a dictionary of constants from which the data will be loaded
490
 
491
- Examples:
492
- data = {
493
- "train": {"input": "SomeInput1", "output": "SomeResult1"},
494
- "test": {"input": "SomeInput2", "output": "SomeResult2"},
495
- }
496
- loader = LoadFromDictionary(data=data)
497
- multi_stream = loader.process()
 
 
 
498
  """
499
 
500
  data: Dict[str, List[Dict[str, Any]]]
501
 
502
- def process(self) -> MultiStream:
503
- return MultiStream.from_iterables(self.data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  All these loaders inherit from Loader, and hence, implementing a loader to expand over a new type of datasource, is
16
  straight forward.
17
 
18
+ Available Loaders Overview:
19
+ - :ref:`LoadHF <unitxt.loaders.LoadHF>` - Loads data from Huggingface datasets.
20
+ - :ref:`LoadCSV <unitxt.loaders.LoadCSV>` - Imports data from CSV (Comma-Separated Values) files.
21
+ - :ref:`LoadFromKaggle <unitxt.loaders.LoadFromKaggle>` - Retrieves datasets from the Kaggle community site.
22
+ - :ref:`LoadFromIBMCloud <unitxt.loaders.LoadFromIBMCloud>` - Fetches datasets hosted on IBM Cloud.
23
+ - :ref:`LoadFromSklearn <unitxt.loaders.LoadFromSklearn>` - Loads datasets available through the sklearn library.
24
+ - :ref:`MultipleSourceLoader <unitxt.loaders.MultipleSourceLoader>` - Combines data from multiple different sources.
25
+ - :ref:`LoadFromDictionary <unitxt.loaders.LoadFromDictionary>` - Loads data from a user-defined Python dictionary.
26
+ - :ref:`LoadFromHFSpace <unitxt.loaders.LoadFromHFSpace>` - Downloads and loads data from Huggingface Spaces.
27
+
28
+
29
+
30
+
31
  ------------------------
32
  """
33
  import itertools
34
  import os
35
  import tempfile
36
+ from abc import abstractmethod
37
+ from copy import deepcopy
38
  from pathlib import Path
39
  from tempfile import TemporaryDirectory
40
  from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
 
47
  from .fusion import FixedFusion
48
  from .logging_utils import get_logger
49
  from .operator import SourceOperator
50
+ from .operators import AddFields
51
  from .settings_utils import get_settings
52
  from .stream import GeneratorStream, MultiStream
53
 
 
56
 
57
 
58
  class Loader(SourceOperator):
59
+ """A base class for all loaders.
60
+
61
+ A loader is the first component in the Unitxt Recipe,
62
+ responsible for loading data from various sources and preparing it as a MultiStream for processing.
63
+ The loader_limit an optional parameter used to control the maximum number of instances to load from the data source. It is applied for each split separately.
64
+ It is usually provided to the loader via the recipe (see standard.py)
65
+ The loader can use this value to limit the amount of data downloaded from the source
66
+ to reduce loading time. However, this may not always be possible, so the
67
+ loader may ignore this. In any case, the recipe, will limit the number of instances in the returned
68
+ stream, after load is complete.
69
+
70
+ Args:
71
+ loader_limit: Optional integer to specify a limit on the number of records to load.
72
+ streaming: Bool indicating if streaming should be used.
73
+ """
74
+
75
  loader_limit: int = None
76
  streaming: bool = False
77
 
 
96
  f"\nLoading limited to {self.get_limit()} instances by setting {self.get_limiter()};"
97
  )
98
 
99
+ def add_data_classification(self, multi_stream: MultiStream) -> MultiStream:
100
+ if self.data_classification_policy is None:
101
+ get_logger().warning(
102
+ f"The {self.get_pretty_print_name()} loader does not set the `data_classification_policy`. "
103
+ f"This may lead to sending of undesired data to external services.\n"
104
+ f"Set it to a list of classification identifiers. \n"
105
+ f"For example:\n"
106
+ f"data_classification_policy = ['public']\n"
107
+ f" or \n"
108
+ f"data_classification_policy =['confidential','pii'])\n"
109
+ )
110
+
111
+ operator = AddFields(
112
+ fields={"data_classification_policy": self.data_classification_policy}
113
+ )
114
+ return operator(multi_stream)
115
+
116
+ def sef_default_data_classification(
117
+ self, default_data_classification_policy, additional_info
118
+ ):
119
+ if self.data_classification_policy is None:
120
+ logger.info(
121
+ f"{self.get_pretty_print_name()} sets 'data_classification_policy' to "
122
+ f"{default_data_classification_policy} by default {additional_info}.\n"
123
+ "To use a different value or remove this message, explicitly set the "
124
+ "`data_classification_policy` attribute of the loader.\n"
125
+ )
126
+ self.data_classification_policy = default_data_classification_policy
127
+
128
+ @abstractmethod
129
+ def load_data(self):
130
+ pass
131
+
132
+ def process(self) -> MultiStream:
133
+ return self.add_data_classification(self.load_data())
134
+
135
 
136
  class LoadHF(Loader):
137
+ """Loads datasets from the Huggingface Hub.
138
+
139
+ It supports loading with or without streaming,
140
+ and can filter datasets upon loading.
141
+
142
+ Args:
143
+ path: The path or identifier of the dataset on the Huggingface Hub.
144
+ name: An optional dataset name.
145
+ data_dir: Optional directory to store downloaded data.
146
+ split: Optional specification of which split to load.
147
+ data_files: Optional specification of particular data files to load.
148
+ streaming: Bool indicating if streaming should be used.
149
+ filtering_lambda: A lambda function for filtering the data after loading.
150
+
151
+ Example:
152
+ Loading glue's mrpc dataset
153
+
154
+ .. code-block:: python
155
+
156
+ load_hf = LoadHF(path='glue', name='mrpc')
157
+ """
158
+
159
  path: str
160
  name: Optional[str] = None
161
  data_dir: Optional[str] = None
 
266
  }
267
  )
268
 
269
+ def load_data(self):
270
+ if os.path.exists(self.path):
271
+ self.sef_default_data_classification(
272
+ ["proprietary"], "when loading from local files"
273
+ )
274
+ else:
275
+ self.sef_default_data_classification(
276
+ ["public"], "when loading from Huggingface hub"
277
+ )
278
  try:
279
  dataset = self.stream_dataset()
280
  except (
 
289
 
290
 
291
  class LoadCSV(Loader):
292
+ """Loads data from CSV files.
293
+
294
+ Supports streaming and can handle large files by loading them in chunks.
295
+
296
+ Args:
297
+ files (Dict[str, str]): A dictionary mapping names to file paths.
298
+ chunksize : Size of the chunks to load at a time.
299
+ loader_limit: Optional integer to specify a limit on the number of records to load.
300
+ streaming: Bool indicating if streaming should be used.
301
+ sep: String specifying the separator used in the CSV files.
302
+
303
+ Example:
304
+ Loading csv
305
+
306
+ .. code-block:: python
307
+
308
+ load_csv = LoadCSV(files={'train': 'path/to/train.csv'}, chunksize=100)
309
+ """
310
+
311
  files: Dict[str, str]
312
  chunksize: int = 1000
313
  _cache = InternalField(default_factory=dict)
 
342
 
343
  yield from self._cache[file]
344
 
345
+ def load_data(self):
346
+ self.sef_default_data_classification(
347
+ ["proprietary"], "when loading from local files"
348
+ )
349
  if self.streaming:
350
  return MultiStream(
351
  {
 
367
 
368
 
369
  class LoadFromSklearn(Loader):
370
+ """Loads datasets from the sklearn library.
371
+
372
+ This loader does not support streaming and is intended for use with sklearn's dataset fetch functions.
373
+
374
+ Args:
375
+ dataset_name: The name of the sklearn dataset to fetch.
376
+ splits: A list of data splits to load, e.g., ['train', 'test'].
377
+
378
+ Example:
379
+ Loading form sklearn
380
+
381
+ .. code-block:: python
382
+
383
+ load_sklearn = LoadFromSklearn(dataset_name='iris', splits=['train', 'test'])
384
+ """
385
+
386
  dataset_name: str
387
  splits: List[str] = ["train", "test"]
388
+ data_classification_policy = ["public"]
389
 
390
  _requirements_list: List[str] = ["sklearn", "pandas"]
391
 
 
401
 
402
  self.downloader = getattr(sklearn_datatasets, f"fetch_{self.dataset_name}")
403
 
404
+ def load_data(self):
405
  with TemporaryDirectory() as temp_directory:
406
  for split in self.splits:
407
  split_data = self.downloader(subset=split)
 
419
 
420
 
421
  class LoadFromKaggle(Loader):
422
+ """Loads datasets from Kaggle.
423
+
424
+ Requires Kaggle API credentials and does not support streaming.
425
+
426
+ Args:
427
+ url: URL to the Kaggle dataset.
428
+
429
+ Example:
430
+ Loading from kaggle
431
+
432
+ .. code-block:: python
433
+
434
+ load_kaggle = LoadFromKaggle(url='kaggle.com/dataset/example')
435
+ """
436
+
437
  url: str
438
+
439
  _requirements_list: List[str] = ["opendatasets"]
440
+ data_classification_policy = ["public"]
441
 
442
  def verify(self):
443
  super().verify()
 
455
 
456
  self.downloader = download
457
 
458
+ def load_data(self):
459
  with TemporaryDirectory() as temp_directory:
460
  self.downloader(self.url, temp_directory)
461
  dataset = hf_load_dataset(temp_directory, streaming=False)
 
464
 
465
 
466
  class LoadFromIBMCloud(Loader):
467
+ """Loads data from IBM Cloud Object Storage.
468
+
469
+ Does not support streaming and requires AWS-style access keys.
470
+ data_dir Can be either:
471
+ 1. a list of file names, the split of each file is determined by the file name pattern
472
+ 2. Mapping: split -> file_name, e.g. {"test" : "test.json", "train": "train.json"}
473
+ 3. Mapping: split -> file_names, e.g. {"test" : ["test1.json", "test2.json"], "train": ["train.json"]}
474
+
475
+ Args:
476
+ endpoint_url_env: Environment variable name for the IBM Cloud endpoint URL.
477
+ aws_access_key_id_env: Environment variable name for the AWS access key ID.
478
+ aws_secret_access_key_env: Environment variable name for the AWS secret access key.
479
+ bucket_name: Name of the S3 bucket from which to load data.
480
+ data_dir: Optional directory path within the bucket.
481
+ data_files: Union type allowing either a list of file names or a mapping of splits to file names.
482
+ caching: Bool indicating if caching is enabled to avoid re-downloading data.
483
+
484
+ Example:
485
+ Loading from IBM Cloud
486
+
487
+ .. code-block:: python
488
+
489
+ load_ibm_cloud = LoadFromIBMCloud(
490
+ endpoint_url_env='IBM_CLOUD_ENDPOINT',
491
+ aws_access_key_id_env='IBM_AWS_ACCESS_KEY_ID',
492
+ aws_secret_access_key_env='IBM_AWS_SECRET_ACCESS_KEY',
493
+ bucket_name='my-bucket'
494
+ )
495
+ multi_stream = load_ibm_cloud.process()
496
+ """
497
+
498
  endpoint_url_env: str
499
  aws_access_key_id_env: str
500
  aws_secret_access_key_env: str
501
  bucket_name: str
502
  data_dir: str = None
503
 
 
 
 
 
504
  data_files: Union[Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
505
  caching: bool = True
506
+ data_classification_policy = ["proprietary"]
507
+
508
  _requirements_list: List[str] = ["ibm_boto3"]
509
 
510
  def _download_from_cos(self, cos, bucket_name, item_name, local_file):
 
572
  if self.streaming:
573
  raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
574
 
575
+ def load_data(self):
576
+ self.sef_default_data_classification(
577
+ ["proprietary"], "when loading from IBM COS"
578
+ )
579
  import ibm_boto3
580
 
581
  cos = ibm_boto3.resource(
 
633
 
634
 
635
  class MultipleSourceLoader(Loader):
636
+ """Allows loading data from multiple sources, potentially mixing different types of loaders.
637
+
638
+ Args:
639
+ sources: A list of loaders that will be combined to form a unified dataset.
640
 
641
  Examples:
642
+ 1) Loading the train split from Huggingface hub and the test set from a local file:
643
+
644
+ .. code-block:: python
645
+
646
+ MultipleSourceLoader(loaders = [ LoadHF(path="public/data",split="train"), LoadCSV({"test": "mytest.csv"}) ])
647
 
 
648
 
 
649
 
650
+ 2) Loading a test set combined from two files
651
 
652
+ .. code-block:: python
653
 
654
+ MultipleSourceLoader(loaders = [ LoadCSV({"test": "mytest1.csv"}, LoadCSV({"test": "mytest2.csv"}) ])
655
  """
656
 
657
  sources: List[Loader]
658
 
659
+ # MultipleSourceLoaders uses the the data classification from source loaders,
660
+ # so only need to add it, if explicitly requested to override.
661
+ def add_data_classification(self, multi_stream: MultiStream) -> MultiStream:
662
+ if self.data_classification_policy is None:
663
+ return multi_stream
664
+ return super().add_data_classification(multi_stream)
665
+
666
+ def load_data(self):
667
  return FixedFusion(
668
  origins=self.sources, max_instances_per_origin_split=self.get_limit()
669
  ).process()
 
674
 
675
  The loader can be used, for example, when debugging or working with small datasets.
676
 
677
+ Args:
678
  data (Dict[str, List[Dict[str, Any]]]): a dictionary of constants from which the data will be loaded
679
 
680
+ Example:
681
+ Loading dictionary
682
+
683
+ .. code-block:: python
684
+
685
+ data = {
686
+ "train": {"input": "SomeInput1", "output": "SomeResult1"},
687
+ "test": {"input": "SomeInput2", "output": "SomeResult2"},
688
+ }
689
+ loader = LoadFromDictionary(data=data)
690
  """
691
 
692
  data: Dict[str, List[Dict[str, Any]]]
693
 
694
+ def load_data(self) -> MultiStream:
695
+ self.sef_default_data_classification(
696
+ ["proprietary"], "when loading from python dictionary"
697
+ )
698
+ return MultiStream.from_iterables(deepcopy(self.data))
699
+
700
+
701
+ class LoadFromHFSpace(LoadHF):
702
+ """Used to load data from Huggingface spaces.
703
+
704
+ Loaders firstly tries to download all files specified in the 'data_files' parameter
705
+ from the given space and then reads them as a Huggingface dataset.
706
+
707
+ Args:
708
+ space_name (str): Name of the Huggingface space to be accessed to.
709
+ data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]): Relative
710
+ paths to files within a given repository. If given as a mapping, paths should
711
+ be values, while keys should represent the type of respective files
712
+ (training, testing etc.).
713
+ path (str, optional): Absolute path to a directory where data should be downloaded to.
714
+ revision (str, optional): ID of a Git branch or commit to be used. By default, it is
715
+ set to None, thus data is downloaded from the main branch of the accessed
716
+ repository.
717
+ use_token (bool, optional): Whether token used for authentication when accessing
718
+ the Huggingface space - if necessary - should be read from the Huggingface
719
+ config folder.
720
+ token_env (str, optional): Key of an env variable which value will be used for
721
+ authentication when accessing the Huggingface space - if necessary.
722
+
723
+ Example:
724
+ Loading from Huggingface Space
725
+
726
+ .. code-block:: python
727
+
728
+ loader = LoadFromHFSpace(
729
+ space_name="lmsys/mt-bench",
730
+ data_files={
731
+ "train": [
732
+ "data/mt_bench/model_answer/gpt-3.5-turbo.jsonl",
733
+ "data/mt_bench/model_answer/gpt-4.jsonl",
734
+ ],
735
+ "test": "data/mt_bench/model_answer/tulu-30b.jsonl",
736
+ },
737
+ )
738
+ """
739
+
740
+ space_name: str
741
+ data_files: Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
742
+ path: Optional[str] = None
743
+ revision: Optional[str] = None
744
+ use_token: Optional[bool] = None
745
+ token_env: Optional[str] = None
746
+ requirements_list: List[str] = ["huggingface_hub"]
747
+
748
+ def _get_token(self) -> Optional[Union[bool, str]]:
749
+ if self.token_env:
750
+ token = os.getenv(self.token_env)
751
+ if not token:
752
+ get_logger().warning(
753
+ f"The 'token_env' parameter was specified as '{self.token_env}', "
754
+ f"however, no environment variable under such a name was found. "
755
+ f"Therefore, the loader will not use any tokens for authentication."
756
+ )
757
+ return token
758
+ return self.use_token
759
+
760
+ def _download_file_from_space(self, filename: str) -> str:
761
+ from huggingface_hub import hf_hub_download
762
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
763
+
764
+ token = self._get_token()
765
+
766
+ try:
767
+ file_path = hf_hub_download(
768
+ repo_id=self.space_name,
769
+ filename=filename,
770
+ repo_type="space",
771
+ token=token,
772
+ revision=self.revision,
773
+ local_dir=self.path,
774
+ )
775
+ except EntryNotFoundError as e:
776
+ raise ValueError(
777
+ f"The file '{filename}' was not found in the space '{self.space_name}'. "
778
+ f"Please check if the filename is correct, or if it exists in that "
779
+ f"Huggingface space."
780
+ ) from e
781
+ except RepositoryNotFoundError as e:
782
+ raise ValueError(
783
+ f"The Huggingface space '{self.space_name}' was not found. "
784
+ f"Please check if the name is correct and you have access to the space."
785
+ ) from e
786
+
787
+ return file_path
788
+
789
+ def _download_data(self) -> str:
790
+ if isinstance(self.data_files, str):
791
+ data_files = [self.data_files]
792
+ elif isinstance(self.data_files, Mapping):
793
+ data_files = list(self.data_files.values())
794
+ else:
795
+ data_files = self.data_files
796
+
797
+ for files in data_files:
798
+ if isinstance(files, str):
799
+ files = [files]
800
+ # All files - within the same space - are downloaded into the same base directory:
801
+ paths = [self._download_file_from_space(file) for file in files]
802
+ dir_path = paths[0].replace(files[0], "")
803
+
804
+ return dir_path
805
+
806
+ def load_data(self):
807
+ self.sef_default_data_classification(
808
+ ["public"], "when loading from Huggingface spaces"
809
+ )
810
+ self.path = self._download_data()
811
+ return super().load_data()
metrics.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import re
2
  import string
3
  import uuid
@@ -6,6 +7,7 @@ from abc import ABC, abstractmethod
6
  from collections import Counter, defaultdict
7
  from copy import deepcopy
8
  from dataclasses import field
 
9
  from statistics import mean
10
  from typing import Any, Dict, Generator, List, Optional, Tuple
11
 
@@ -20,10 +22,10 @@ from .dataclass import AbstractField, InternalField, NonPositionalField, Optiona
20
  from .logging_utils import get_logger
21
  from .metric_utils import InstanceInput, MetricRequest, MetricResponse
22
  from .operator import (
 
23
  MultiStreamOperator,
24
- SingleStreamOperator,
25
  StreamingOperator,
26
- StreamInstanceOperator,
27
  )
28
  from .operators import CopyFields
29
  from .random_utils import get_seed
@@ -68,7 +70,7 @@ def nan_max(x):
68
  return np.nanmax(x)
69
 
70
 
71
- class UpdateStream(StreamInstanceOperator):
72
  update: dict
73
 
74
  def process(
@@ -94,6 +96,28 @@ class Metric(Artifact):
94
  # parsing on every use
95
  _parsed_prediction_type = None
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def _validate_references_and_prediction(self, references, predictions):
98
  if not isoftype(predictions, List[Any]):
99
  raise ValueError(
@@ -151,7 +175,7 @@ class Metric(Artifact):
151
  self._parsed_prediction_type = parse_type_string(self.prediction_type)
152
  except ValueError:
153
  raise ValueError(
154
- "Could convert prediction type '{self.prediction_type}' in {self.get_metric_name()} to known type. To enable type checking for this prediction type, open unitxt issue with this message. Alternatively, set the metric's prediction_type to 'Any'"
155
  ) from None
156
  return self._parsed_prediction_type
157
 
@@ -166,6 +190,7 @@ class Metric(Artifact):
166
  additional_inputs = []
167
  instances = []
168
  for instance in stream:
 
169
  references.append(instance["references"])
170
  predictions.append(instance["prediction"])
171
  additional_inputs.append(
@@ -421,7 +446,7 @@ class MetricWithConfidenceInterval(Metric):
421
  return result
422
 
423
 
424
- class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
425
  """A class for computing metrics that require joint calculations over all instances and are not just aggregation of scores of individuals instances.
426
 
427
  For example, macro_F1 requires
@@ -445,15 +470,16 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
445
  instances = []
446
 
447
  for instance in stream:
 
 
448
  if "score" not in instance:
449
- instance["score"] = {"global": global_score, "instance": {}}
450
- else:
451
- global_score = instance["score"]["global"]
452
 
453
  instance_references, instance_prediction = (
454
  instance["references"],
455
  instance["prediction"],
456
  )
 
457
  references.append(instance_references)
458
  predictions.append(instance_prediction)
459
  instances.append(instance)
@@ -463,6 +489,7 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
463
  )
464
  task_data.append(instance_task_data)
465
  instance_score = None
 
466
  # for backward compatibility
467
  no_score_value = np.nan
468
  if self.process_single_instances:
@@ -483,13 +510,14 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
483
  if isinstance(self.main_score, str):
484
  instance_score[self.main_score] = no_score_value
485
 
486
- instance["score"]["instance"].update(instance_score)
 
 
487
  self._validate_references_and_prediction(references, predictions)
488
 
489
  result = self._compute(references, predictions, task_data)
490
 
491
- global_score.update(result)
492
-
493
  score_name = global_score["score_name"]
494
  confidence_interval = self.compute_global_confidence_intervals(
495
  references, predictions, task_data, score_name
@@ -497,7 +525,7 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
497
  global_score.update(confidence_interval)
498
 
499
  for instance in instances:
500
- instance["score"]["global"] = global_score
501
  yield instance
502
 
503
  def _compute(
@@ -531,11 +559,12 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
531
  pass
532
 
533
 
534
- class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
535
  n_resamples: int = OptionalField(
536
  default_factory=lambda: settings.num_resamples_for_instance_metrics
537
  )
538
  main_score: str
 
539
  reduction_map: Dict[str, List[str]]
540
 
541
  implemented_reductions: List[str] = field(default_factory=lambda: ["mean"])
@@ -549,7 +578,9 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
549
  list,
550
  zip(
551
  *[
552
- (instance["references"], instance["prediction"])
 
 
553
  for instance in stream
554
  ]
555
  ),
@@ -574,12 +605,11 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
574
 
575
  for instance, score in zip(stream, instance_scores):
576
  if "score" not in instance:
577
- instance["score"] = {"global": global_score, "instance": {}}
578
- else:
579
- global_score = instance["score"]["global"]
580
-
581
- instance["score"]["instance"].update(score)
582
 
 
 
 
583
  instances.append(instance)
584
 
585
  for reduction, fields in self.reduction_map.items():
@@ -589,27 +619,32 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
589
 
590
  if reduction == "mean":
591
  for field_name in fields:
592
- global_score[field_name] = mean(
 
593
  [
594
- instance["score"]["instance"][field_name]
595
  for instance in instances
596
  ]
597
  )
598
  if field_name == self.main_score:
599
- global_score["score"] = global_score[field_name]
600
- global_score["score_name"] = self.main_score
601
 
602
  ci_fields = (
603
  list(set(self.ci_scores))
604
  if self.ci_scores is not None
605
  else [self.main_score]
606
  )
 
 
 
607
  confidence_interval = self.score_based_confidence_interval(
608
- instances=instances, score_names=ci_fields
609
  )
610
  global_score.update(confidence_interval)
611
 
612
  for instance in instances:
 
613
  yield instance
614
 
615
  @abstractmethod
@@ -622,7 +657,7 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
622
  pass
623
 
624
 
625
- class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
626
  """Class for metrics for which a global score can be calculated by aggregating the instance scores (possibly with additional instance inputs).
627
 
628
  InstanceMetric currently allows two reductions:
@@ -748,8 +783,8 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
748
  ), f"each instance task_data dict must have a key {self.subgroup_column}"
749
 
750
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
751
- instances, global_score = self.compute_instance_scores(stream)
752
-
753
  for reduction_type, reduction_params in self.reduction_map.items():
754
  assert (
755
  reduction_type in self.implemented_reductions
@@ -795,7 +830,9 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
795
 
796
  # calculate global scores for each reduction field
797
  for field_name in reduction_fields:
798
- field_name_full = field_name_full_prefix + field_name
 
 
799
  # if group resampling (3rd element of agg_func parameter) is True, then
800
  # 1. scores_to_resample are the group scores, and
801
  # 2. aggregation_function is to take the raw mean
@@ -804,7 +841,7 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
804
  # 2. aggregation_function is to apply the group aggregation from the instance scores
805
  # either way, the application of aggregation_function to scores_to_resample yields the global score
806
  global_score[field_name_full] = aggregation_function(
807
- scores_to_resample, field_name
808
  )
809
  if field_name == self.main_score:
810
  global_score["score"] = global_score[field_name_full]
@@ -815,21 +852,26 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
815
  if self.ci_scores is not None:
816
  confidence_interval = self.score_based_confidence_interval(
817
  instances=scores_to_resample,
818
- score_names=list(set(self.ci_scores)),
 
 
819
  ci_score_prefix=field_name_full_prefix,
820
  aggregation_func=aggregation_function,
821
  )
822
  global_score.update(confidence_interval)
823
 
 
 
824
  yield from instances
825
 
826
  def compute_instance_scores(
827
  self, stream: Stream, stream_name: Optional[str] = None
828
  ):
829
- global_score = {}
830
  instances = []
831
 
832
  for instance in stream:
 
 
833
  task_data = instance["task_data"] if "task_data" in instance else {}
834
 
835
  if self.reference_field == "references":
@@ -849,18 +891,19 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
849
  instance_score = self.compute(
850
  references=refs, prediction=pred, task_data=task_data
851
  )
 
852
  instance_score["score"] = instance_score[self.main_score]
853
  instance_score["score_name"] = self.main_score
854
  if "score" not in instance:
855
- instance["score"] = {"global": global_score, "instance": {}}
856
- else:
857
- global_score = instance["score"]["global"]
858
 
859
- instance["score"]["instance"].update(instance_score)
 
 
860
 
861
  instances.append(instance)
862
 
863
- return instances, global_score
864
 
865
  def get_group_scores(
866
  self, instances: List[dict], score_names: List[str], group_aggregation_func
@@ -1082,8 +1125,14 @@ class MetricPipeline(MultiStreamOperator, Metric):
1082
  super().prepare()
1083
  self.prepare_score = CopyFields(
1084
  field_to_field=[
1085
- [f"score/instance/{self.main_score}", "score/instance/score"],
1086
- [f"score/global/{self.main_score}", "score/global/score"],
 
 
 
 
 
 
1087
  ],
1088
  )
1089
 
@@ -2098,6 +2147,7 @@ class LlamaIndexCorrectness(InstanceMetric):
2098
  ] = [] # this is here for the sake of documentation for future models
2099
  mock_models: List[str] = ["mock"]
2100
  external_api_models = openai_models + anthropic_models
 
2101
 
2102
  _requirements_list: List[str] = ["llama_index"]
2103
 
@@ -2179,11 +2229,6 @@ class LlamaIndexCorrectness(InstanceMetric):
2179
  # treat the references as the questions and the predictions as answers
2180
  # assume a single reference
2181
 
2182
- assert (
2183
- not self._model_using_extrnal_api()
2184
- or settings.allow_passing_data_to_remote_api
2185
- ), f"Cannot run send data to remote APIs ({self.model_name}) when unitxt.settings.allow_passing_data_to_remote_api=False. Set UNITXT_ALLOW_PASSING_DATA_TO_REMOTE_API environment variable, if you want to allow this."
2186
-
2187
  query = task_data["question"]
2188
 
2189
  contexts = None
@@ -2733,7 +2778,7 @@ class KPA(CustomF1):
2733
  return element == "none"
2734
 
2735
 
2736
- class RemoteMetric(SingleStreamOperator, Metric):
2737
  """A metric that runs another metric remotely.
2738
 
2739
  main_score: the score updated by this metric.
@@ -2746,10 +2791,12 @@ class RemoteMetric(SingleStreamOperator, Metric):
2746
  endpoint: str
2747
  metric_name: str
2748
  api_key: str = None
 
2749
 
2750
  @staticmethod
2751
  def wrap_inner_metric_pipeline_metric(
2752
- metric_pipeline: MetricPipeline, remote_metrics_endpoint: str
 
2753
  ) -> MetricPipeline:
2754
  """Wrap the inner metric in a MetricPipeline with a RemoteMetric.
2755
 
@@ -3662,3 +3709,40 @@ class NormalizedSacrebleu(HuggingfaceMetric):
3662
  "mecab_ko": KO_ERROR_MESSAGE,
3663
  "mecab_ko_dic": KO_ERROR_MESSAGE,
3664
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
  import re
3
  import string
4
  import uuid
 
7
  from collections import Counter, defaultdict
8
  from copy import deepcopy
9
  from dataclasses import field
10
+ from operator import itemgetter
11
  from statistics import mean
12
  from typing import Any, Dict, Generator, List, Optional, Tuple
13
 
 
22
  from .logging_utils import get_logger
23
  from .metric_utils import InstanceInput, MetricRequest, MetricResponse
24
  from .operator import (
25
+ InstanceOperator,
26
  MultiStreamOperator,
 
27
  StreamingOperator,
28
+ StreamOperator,
29
  )
30
  from .operators import CopyFields
31
  from .random_utils import get_seed
 
70
  return np.nanmax(x)
71
 
72
 
73
+ class UpdateStream(InstanceOperator):
74
  update: dict
75
 
76
  def process(
 
96
  # parsing on every use
97
  _parsed_prediction_type = None
98
 
99
+ #
100
+ # Used to add a prefix to all score, except the "score_name" and "score" fields.
101
+ # This is used to distinguish two scores of the same metrics, operating on different fields of the task
102
+ #
103
+ score_prefix: str = ""
104
+
105
+ def _add_score_prefix(self, score_name):
106
+ return (
107
+ self.score_prefix + score_name
108
+ if score_name not in ["score", "score_name"]
109
+ else score_name
110
+ )
111
+
112
+ def _add_score_prefixes_to_score_dict(self, scores: Dict[str, Any]):
113
+ new_scores = {}
114
+ for score_name, score in scores.items():
115
+ score_with_prefix = self._add_score_prefix(score_name)
116
+ new_scores[score_with_prefix] = (
117
+ score if score_name not in ["score_name"] else self.score_prefix + score
118
+ )
119
+ return new_scores
120
+
121
  def _validate_references_and_prediction(self, references, predictions):
122
  if not isoftype(predictions, List[Any]):
123
  raise ValueError(
 
175
  self._parsed_prediction_type = parse_type_string(self.prediction_type)
176
  except ValueError:
177
  raise ValueError(
178
+ f"Could convert prediction type '{self.prediction_type}' in {self.get_metric_name()} to known type. To enable type checking for this prediction type, open unitxt issue with this message. Alternatively, set the metric's prediction_type to 'Any'"
179
  ) from None
180
  return self._parsed_prediction_type
181
 
 
190
  additional_inputs = []
191
  instances = []
192
  for instance in stream:
193
+ instance = self.verify_instance(instance)
194
  references.append(instance["references"])
195
  predictions.append(instance["prediction"])
196
  additional_inputs.append(
 
446
  return result
447
 
448
 
449
+ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
450
  """A class for computing metrics that require joint calculations over all instances and are not just aggregation of scores of individuals instances.
451
 
452
  For example, macro_F1 requires
 
470
  instances = []
471
 
472
  for instance in stream:
473
+ instance = self.verify_instance(instance)
474
+
475
  if "score" not in instance:
476
+ instance["score"] = {"global": {}, "instance": {}}
 
 
477
 
478
  instance_references, instance_prediction = (
479
  instance["references"],
480
  instance["prediction"],
481
  )
482
+
483
  references.append(instance_references)
484
  predictions.append(instance_prediction)
485
  instances.append(instance)
 
489
  )
490
  task_data.append(instance_task_data)
491
  instance_score = None
492
+
493
  # for backward compatibility
494
  no_score_value = np.nan
495
  if self.process_single_instances:
 
510
  if isinstance(self.main_score, str):
511
  instance_score[self.main_score] = no_score_value
512
 
513
+ instance["score"]["instance"].update(
514
+ self._add_score_prefixes_to_score_dict(instance_score)
515
+ )
516
  self._validate_references_and_prediction(references, predictions)
517
 
518
  result = self._compute(references, predictions, task_data)
519
 
520
+ global_score.update(self._add_score_prefixes_to_score_dict(result))
 
521
  score_name = global_score["score_name"]
522
  confidence_interval = self.compute_global_confidence_intervals(
523
  references, predictions, task_data, score_name
 
525
  global_score.update(confidence_interval)
526
 
527
  for instance in instances:
528
+ instance["score"]["global"].update(global_score)
529
  yield instance
530
 
531
  def _compute(
 
559
  pass
560
 
561
 
562
+ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
563
  n_resamples: int = OptionalField(
564
  default_factory=lambda: settings.num_resamples_for_instance_metrics
565
  )
566
  main_score: str
567
+
568
  reduction_map: Dict[str, List[str]]
569
 
570
  implemented_reductions: List[str] = field(default_factory=lambda: ["mean"])
 
578
  list,
579
  zip(
580
  *[
581
+ itemgetter("references", "prediction")(
582
+ self.verify_instance(instance)
583
+ )
584
  for instance in stream
585
  ]
586
  ),
 
605
 
606
  for instance, score in zip(stream, instance_scores):
607
  if "score" not in instance:
608
+ instance["score"] = {"global": {}, "instance": {}}
 
 
 
 
609
 
610
+ instance["score"]["instance"].update(
611
+ self._add_score_prefixes_to_score_dict(score)
612
+ )
613
  instances.append(instance)
614
 
615
  for reduction, fields in self.reduction_map.items():
 
619
 
620
  if reduction == "mean":
621
  for field_name in fields:
622
+ field_name_with_prefix = self._add_score_prefix(field_name)
623
+ global_score[field_name_with_prefix] = mean(
624
  [
625
+ instance["score"]["instance"][field_name_with_prefix]
626
  for instance in instances
627
  ]
628
  )
629
  if field_name == self.main_score:
630
+ global_score["score"] = global_score[field_name_with_prefix]
631
+ global_score["score_name"] = self.score_prefix + self.main_score
632
 
633
  ci_fields = (
634
  list(set(self.ci_scores))
635
  if self.ci_scores is not None
636
  else [self.main_score]
637
  )
638
+ ci_fields_with_prefix = [
639
+ self._add_score_prefix(ci_field) for ci_field in ci_fields
640
+ ]
641
  confidence_interval = self.score_based_confidence_interval(
642
+ instances=instances, score_names=ci_fields_with_prefix
643
  )
644
  global_score.update(confidence_interval)
645
 
646
  for instance in instances:
647
+ instance["score"]["global"].update(global_score)
648
  yield instance
649
 
650
  @abstractmethod
 
657
  pass
658
 
659
 
660
+ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
661
  """Class for metrics for which a global score can be calculated by aggregating the instance scores (possibly with additional instance inputs).
662
 
663
  InstanceMetric currently allows two reductions:
 
783
  ), f"each instance task_data dict must have a key {self.subgroup_column}"
784
 
785
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
786
+ instances = self.compute_instance_scores(stream)
787
+ global_score = {}
788
  for reduction_type, reduction_params in self.reduction_map.items():
789
  assert (
790
  reduction_type in self.implemented_reductions
 
830
 
831
  # calculate global scores for each reduction field
832
  for field_name in reduction_fields:
833
+ field_name_full = (
834
+ field_name_full_prefix + self.score_prefix + field_name
835
+ )
836
  # if group resampling (3rd element of agg_func parameter) is True, then
837
  # 1. scores_to_resample are the group scores, and
838
  # 2. aggregation_function is to take the raw mean
 
841
  # 2. aggregation_function is to apply the group aggregation from the instance scores
842
  # either way, the application of aggregation_function to scores_to_resample yields the global score
843
  global_score[field_name_full] = aggregation_function(
844
+ scores_to_resample, self.score_prefix + field_name
845
  )
846
  if field_name == self.main_score:
847
  global_score["score"] = global_score[field_name_full]
 
852
  if self.ci_scores is not None:
853
  confidence_interval = self.score_based_confidence_interval(
854
  instances=scores_to_resample,
855
+ score_names=[
856
+ self.score_prefix + ci_score for ci_score in set(self.ci_scores)
857
+ ],
858
  ci_score_prefix=field_name_full_prefix,
859
  aggregation_func=aggregation_function,
860
  )
861
  global_score.update(confidence_interval)
862
 
863
+ for instance in instances:
864
+ instance["score"]["global"].update(global_score)
865
  yield from instances
866
 
867
  def compute_instance_scores(
868
  self, stream: Stream, stream_name: Optional[str] = None
869
  ):
 
870
  instances = []
871
 
872
  for instance in stream:
873
+ instance = self.verify_instance(instance)
874
+
875
  task_data = instance["task_data"] if "task_data" in instance else {}
876
 
877
  if self.reference_field == "references":
 
891
  instance_score = self.compute(
892
  references=refs, prediction=pred, task_data=task_data
893
  )
894
+
895
  instance_score["score"] = instance_score[self.main_score]
896
  instance_score["score_name"] = self.main_score
897
  if "score" not in instance:
898
+ instance["score"] = {"global": {}, "instance": {}}
 
 
899
 
900
+ instance["score"]["instance"].update(
901
+ self._add_score_prefixes_to_score_dict(instance_score)
902
+ )
903
 
904
  instances.append(instance)
905
 
906
+ return instances
907
 
908
  def get_group_scores(
909
  self, instances: List[dict], score_names: List[str], group_aggregation_func
 
1125
  super().prepare()
1126
  self.prepare_score = CopyFields(
1127
  field_to_field=[
1128
+ [
1129
+ f"score/instance/{self.metric._add_score_prefix(self.main_score)}",
1130
+ "score/instance/score",
1131
+ ],
1132
+ [
1133
+ f"score/global/{self.metric._add_score_prefix(self.main_score)}",
1134
+ "score/global/score",
1135
+ ],
1136
  ],
1137
  )
1138
 
 
2147
  ] = [] # this is here for the sake of documentation for future models
2148
  mock_models: List[str] = ["mock"]
2149
  external_api_models = openai_models + anthropic_models
2150
+ data_classification_policy = ["public"]
2151
 
2152
  _requirements_list: List[str] = ["llama_index"]
2153
 
 
2229
  # treat the references as the questions and the predictions as answers
2230
  # assume a single reference
2231
 
 
 
 
 
 
2232
  query = task_data["question"]
2233
 
2234
  contexts = None
 
2778
  return element == "none"
2779
 
2780
 
2781
+ class RemoteMetric(StreamOperator, Metric):
2782
  """A metric that runs another metric remotely.
2783
 
2784
  main_score: the score updated by this metric.
 
2791
  endpoint: str
2792
  metric_name: str
2793
  api_key: str = None
2794
+ data_classification_policy = ["public", "proprietary"]
2795
 
2796
  @staticmethod
2797
  def wrap_inner_metric_pipeline_metric(
2798
+ metric_pipeline: MetricPipeline,
2799
+ remote_metrics_endpoint: str,
2800
  ) -> MetricPipeline:
2801
  """Wrap the inner metric in a MetricPipeline with a RemoteMetric.
2802
 
 
3709
  "mecab_ko": KO_ERROR_MESSAGE,
3710
  "mecab_ko_dic": KO_ERROR_MESSAGE,
3711
  }
3712
+
3713
+
3714
+ class CustomF1Fuzzy(CustomF1):
3715
+ def calculate_groups_ratio(self, actual_group, total_group):
3716
+ from fuzzywuzzy import fuzz
3717
+
3718
+ tmp = []
3719
+ for actual_key in actual_group.keys():
3720
+ max_score = self.fuzz_ratio
3721
+ best_total_key = None
3722
+
3723
+ for total_key in total_group.keys():
3724
+ tup_ac = ast.literal_eval(actual_key)
3725
+ tup_to = ast.literal_eval(total_key)
3726
+
3727
+ if tup_ac[1] == tup_to[1]:
3728
+ score = fuzz.ratio(tup_ac[0], tup_to[0])
3729
+ if score > max_score:
3730
+ max_score = score
3731
+ best_total_key = total_key
3732
+
3733
+ if best_total_key is not None:
3734
+ tmp.append(min(actual_group[actual_key], total_group[best_total_key]))
3735
+ else:
3736
+ tmp.append(min(actual_group[actual_key], 0))
3737
+ return sum(tmp), sum(actual_group.values())
3738
+
3739
+
3740
+ class FuzzyNer(CustomF1Fuzzy):
3741
+ prediction_type = "List[Tuple[str,str]]"
3742
+ fuzz_ratio = 75
3743
+
3744
+ def get_element_group(self, element, additional_input):
3745
+ return element[1]
3746
+
3747
+ def get_element_representation(self, element, additional_input):
3748
+ return str(element)
normalizers.py CHANGED
@@ -1,9 +1,9 @@
1
  from typing import Any, Dict, List, Optional
2
 
3
- from .operator import StreamInstanceOperator
4
 
5
 
6
- class NormalizeListFields(StreamInstanceOperator):
7
  fields: List[str]
8
  key_prefix: str = ""
9
  empty_value: str = ""
 
1
  from typing import Any, Dict, List, Optional
2
 
3
+ from .operator import InstanceOperator
4
 
5
 
6
+ class NormalizeListFields(InstanceOperator):
7
  fields: List[str]
8
  key_prefix: str = ""
9
  empty_value: str = ""
operator.py CHANGED
@@ -1,4 +1,3 @@
1
- import re
2
  from abc import abstractmethod
3
  from dataclasses import field
4
  from typing import Any, Dict, Generator, List, Optional, Union
@@ -208,12 +207,13 @@ class MultiStreamOperator(StreamingOperator):
208
  pass
209
 
210
  def process_instance(self, instance, stream_name="tmp"):
 
211
  multi_stream = MultiStream({stream_name: stream_single(instance)})
212
  processed_multi_stream = self(multi_stream)
213
  return next(iter(processed_multi_stream[stream_name]))
214
 
215
 
216
- class SingleStreamOperator(MultiStreamOperator):
217
  """A class representing a single-stream operator in the streaming system.
218
 
219
  A single-stream operator is a type of `MultiStreamOperator` that operates on individual
@@ -236,9 +236,7 @@ class SingleStreamOperator(MultiStreamOperator):
236
  stream = self._process_single_stream(stream, stream_name)
237
  else:
238
  stream = stream
239
- assert isinstance(
240
- stream, Stream
241
- ), "SingleStreamOperator must return a Stream"
242
  result[stream_name] = stream
243
 
244
  return MultiStream(result)
@@ -279,16 +277,21 @@ class SingleStreamOperator(MultiStreamOperator):
279
  pass
280
 
281
  def process_instance(self, instance, stream_name="tmp"):
 
282
  processed_stream = self._process_single_stream(
283
  stream_single(instance), stream_name
284
  )
285
  return next(iter(processed_stream))
286
 
287
 
288
- class PagedStreamOperator(SingleStreamOperator):
 
 
 
 
289
  """A class representing a paged-stream operator in the streaming system.
290
 
291
- A paged-stream operator is a type of `SingleStreamOperator` that operates on a page of instances
292
  in a `Stream` at a time, where a page is a subset of instances.
293
  The `process` method should be implemented by subclasses to define the specific operations
294
  to be performed on each page.
@@ -320,6 +323,7 @@ class PagedStreamOperator(SingleStreamOperator):
320
  pass
321
 
322
  def process_instance(self, instance, stream_name="tmp"):
 
323
  processed_stream = self._process_page([instance], stream_name)
324
  return next(iter(processed_stream))
325
 
@@ -343,10 +347,10 @@ class SingleStreamReducer(StreamingOperator):
343
  pass
344
 
345
 
346
- class StreamInstanceOperator(SingleStreamOperator):
347
  """A class representing a stream instance operator in the streaming system.
348
 
349
- A stream instance operator is a type of `SingleStreamOperator` that operates on individual instances within a `Stream`. It iterates through each instance in the `Stream` and applies the `process` method. The `process` method should be implemented by subclasses to define the specific operations to be performed on each instance.
350
  """
351
 
352
  def _process_stream(
@@ -367,6 +371,7 @@ class StreamInstanceOperator(SingleStreamOperator):
367
  def _process_instance(
368
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
369
  ) -> Dict[str, Any]:
 
370
  return self.process(instance, stream_name)
371
 
372
  @abstractmethod
@@ -379,10 +384,10 @@ class StreamInstanceOperator(SingleStreamOperator):
379
  return self._process_instance(instance, stream_name)
380
 
381
 
382
- class StreamInstanceOperatorValidator(StreamInstanceOperator):
383
  """A class representing a stream instance operator validator in the streaming system.
384
 
385
- A stream instance operator validator is a type of `StreamInstanceOperator` that includes a validation step. It operates on individual instances within a `Stream` and validates the result of processing each instance.
386
  """
387
 
388
  @abstractmethod
@@ -405,20 +410,6 @@ class StreamInstanceOperatorValidator(StreamInstanceOperator):
405
  )
406
 
407
 
408
- class InstanceOperator(Artifact):
409
- """A class representing an instance operator in the streaming system.
410
-
411
- An instance operator is a type of `Artifact` that operates on a single instance (represented as a dict) at a time. It takes an instance as input and produces a transformed instance as output.
412
- """
413
-
414
- def __call__(self, data: dict) -> dict:
415
- return self.process(data)
416
-
417
- @abstractmethod
418
- def process(self, data: dict) -> dict:
419
- pass
420
-
421
-
422
  class BaseFieldOperator(Artifact):
423
  """A class representing a field operator in the streaming system.
424
 
@@ -426,6 +417,7 @@ class BaseFieldOperator(Artifact):
426
  """
427
 
428
  def __call__(self, data: Dict[str, Any], field: str) -> dict:
 
429
  value = self.process(data[field])
430
  data[field] = value
431
  return data
@@ -456,7 +448,10 @@ class InstanceOperatorWithMultiStreamAccess(StreamingOperator):
456
  return MultiStream(result)
457
 
458
  def generator(self, stream, multi_stream):
459
- yield from (self.process(instance, multi_stream) for instance in stream)
 
 
 
460
 
461
  @abstractmethod
462
  def process(self, instance: dict, multi_stream: MultiStream) -> dict:
@@ -488,8 +483,7 @@ class SequentialOperator(MultiStreamOperator):
488
  last_step = (
489
  self.max_steps - 1 if self.max_steps is not None else len(self.steps) - 1
490
  )
491
- description = str(self.steps[last_step])
492
- return re.sub(r"\w+=None, ", "", description)
493
 
494
  def _get_max_steps(self):
495
  return self.max_steps if self.max_steps is not None else len(self.steps)
 
 
1
  from abc import abstractmethod
2
  from dataclasses import field
3
  from typing import Any, Dict, Generator, List, Optional, Union
 
207
  pass
208
 
209
  def process_instance(self, instance, stream_name="tmp"):
210
+ instance = self.verify_instance(instance)
211
  multi_stream = MultiStream({stream_name: stream_single(instance)})
212
  processed_multi_stream = self(multi_stream)
213
  return next(iter(processed_multi_stream[stream_name]))
214
 
215
 
216
+ class StreamOperator(MultiStreamOperator):
217
  """A class representing a single-stream operator in the streaming system.
218
 
219
  A single-stream operator is a type of `MultiStreamOperator` that operates on individual
 
236
  stream = self._process_single_stream(stream, stream_name)
237
  else:
238
  stream = stream
239
+ assert isinstance(stream, Stream), "StreamOperator must return a Stream"
 
 
240
  result[stream_name] = stream
241
 
242
  return MultiStream(result)
 
277
  pass
278
 
279
  def process_instance(self, instance, stream_name="tmp"):
280
+ instance = self.verify_instance(instance)
281
  processed_stream = self._process_single_stream(
282
  stream_single(instance), stream_name
283
  )
284
  return next(iter(processed_stream))
285
 
286
 
287
+ class SingleStreamOperator(StreamOperator):
288
+ pass
289
+
290
+
291
+ class PagedStreamOperator(StreamOperator):
292
  """A class representing a paged-stream operator in the streaming system.
293
 
294
+ A paged-stream operator is a type of `StreamOperator` that operates on a page of instances
295
  in a `Stream` at a time, where a page is a subset of instances.
296
  The `process` method should be implemented by subclasses to define the specific operations
297
  to be performed on each page.
 
323
  pass
324
 
325
  def process_instance(self, instance, stream_name="tmp"):
326
+ instance = self.verify_instance(instance)
327
  processed_stream = self._process_page([instance], stream_name)
328
  return next(iter(processed_stream))
329
 
 
347
  pass
348
 
349
 
350
+ class InstanceOperator(StreamOperator):
351
  """A class representing a stream instance operator in the streaming system.
352
 
353
+ A stream instance operator is a type of `StreamOperator` that operates on individual instances within a `Stream`. It iterates through each instance in the `Stream` and applies the `process` method. The `process` method should be implemented by subclasses to define the specific operations to be performed on each instance.
354
  """
355
 
356
  def _process_stream(
 
371
  def _process_instance(
372
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
373
  ) -> Dict[str, Any]:
374
+ instance = self.verify_instance(instance)
375
  return self.process(instance, stream_name)
376
 
377
  @abstractmethod
 
384
  return self._process_instance(instance, stream_name)
385
 
386
 
387
+ class InstanceOperatorValidator(InstanceOperator):
388
  """A class representing a stream instance operator validator in the streaming system.
389
 
390
+ A stream instance operator validator is a type of `InstanceOperator` that includes a validation step. It operates on individual instances within a `Stream` and validates the result of processing each instance.
391
  """
392
 
393
  @abstractmethod
 
410
  )
411
 
412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  class BaseFieldOperator(Artifact):
414
  """A class representing a field operator in the streaming system.
415
 
 
417
  """
418
 
419
  def __call__(self, data: Dict[str, Any], field: str) -> dict:
420
+ data = self.verify_instance(data)
421
  value = self.process(data[field])
422
  data[field] = value
423
  return data
 
448
  return MultiStream(result)
449
 
450
  def generator(self, stream, multi_stream):
451
+ yield from (
452
+ self.process(self.verify_instance(instance), multi_stream)
453
+ for instance in stream
454
+ )
455
 
456
  @abstractmethod
457
  def process(self, instance: dict, multi_stream: MultiStream) -> dict:
 
483
  last_step = (
484
  self.max_steps - 1 if self.max_steps is not None else len(self.steps) - 1
485
  )
486
+ return self.steps[last_step].__description__
 
487
 
488
  def _get_max_steps(self):
489
  return self.max_steps if self.max_steps is not None else len(self.steps)
operators.py CHANGED
@@ -29,9 +29,10 @@ Other specelized operators are used by unitxt internally:
29
 
30
  The rest of this section is dedicated for general operators.
31
 
32
- General Operaotrs List:
33
  ------------------------
34
  """
 
35
  import copy
36
  import operator
37
  import uuid
@@ -60,18 +61,18 @@ from .artifact import Artifact, fetch_artifact
60
  from .dataclass import DeprecatedField, NonPositionalField, OptionalField
61
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
62
  from .operator import (
 
63
  MultiStream,
64
  MultiStreamOperator,
65
  PackageRequirementsMixin,
66
  PagedStreamOperator,
67
  SequentialOperator,
68
  SideEffectOperator,
69
- SingleStreamOperator,
70
  SingleStreamReducer,
71
  SourceOperator,
72
  StreamingOperator,
73
  StreamInitializerOperator,
74
- StreamInstanceOperator,
75
  )
76
  from .random_utils import new_random_generator
77
  from .settings_utils import get_settings
@@ -116,10 +117,10 @@ class IterableSource(SourceOperator):
116
  return MultiStream.from_iterables(self.iterables)
117
 
118
 
119
- class MapInstanceValues(StreamInstanceOperator):
120
  """A class used to map instance values into other values.
121
 
122
- This class is a type of StreamInstanceOperator,
123
  it maps values of instances in a stream using predefined mappers.
124
 
125
  Attributes:
@@ -138,7 +139,7 @@ class MapInstanceValues(StreamInstanceOperator):
138
  replaces '1' with 'hi' and '2' with 'bye' in field 'a' in all instances of all streams:
139
  instance {"a":"1", "b": 2} becomes {"a":"hi", "b": 2}.
140
 
141
- MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, process_every_element=True)
142
  Assuming field 'a' is a list of values, potentially including "1"-s and "2"-s, this replaces
143
  each such "1" with "hi" and "2" -- with "bye" in all instances of all streams:
144
  instance {"a": ["1", "2"], "b": 2} becomes {"a": ["hi", "bye"], "b": 2}.
@@ -204,7 +205,7 @@ class MapInstanceValues(StreamInstanceOperator):
204
  return val
205
 
206
 
207
- class FlattenInstances(StreamInstanceOperator):
208
  """Flattens each instance in a stream, making nested dictionary entries into top-level entries.
209
 
210
  Args:
@@ -221,7 +222,7 @@ class FlattenInstances(StreamInstanceOperator):
221
  return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
222
 
223
 
224
- class AddFields(StreamInstanceOperator):
225
  """Adds specified fields to each instance in a given stream or all streams (default) If fields exist, updates them.
226
 
227
  Args:
@@ -264,7 +265,7 @@ class AddFields(StreamInstanceOperator):
264
  return instance
265
 
266
 
267
- class RemoveFields(StreamInstanceOperator):
268
  """Remove specified fields from each instance in a stream.
269
 
270
  Args:
@@ -281,7 +282,7 @@ class RemoveFields(StreamInstanceOperator):
281
  return instance
282
 
283
 
284
- class InstanceFieldOperator(StreamInstanceOperator):
285
  """A general stream instance operator that processes the values of a field (or multiple ones).
286
 
287
  Args:
@@ -393,6 +394,11 @@ class InstanceFieldOperator(StreamInstanceOperator):
393
  def process(
394
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
395
  ) -> Dict[str, Any]:
 
 
 
 
 
396
  for from_field, to_field in self._field_to_field:
397
  try:
398
  old_value = dict_get(
@@ -485,7 +491,7 @@ class AddConstant(FieldOperator):
485
  return self.add + value
486
 
487
 
488
- class Augmentor(StreamInstanceOperator):
489
  """A stream operator that augments the values of either the task input fields before rendering with the template, or the input passed to the model after rendering of the template.
490
 
491
  Args:
@@ -732,7 +738,7 @@ class JoinStr(FieldOperator):
732
  return self.separator.join(str(x) for x in value)
733
 
734
 
735
- class Apply(StreamInstanceOperator):
736
  """A class used to apply a python function and store the result in a field.
737
 
738
  Args:
@@ -802,7 +808,7 @@ class Apply(StreamInstanceOperator):
802
  return instance
803
 
804
 
805
- class ListFieldValues(StreamInstanceOperator):
806
  """Concatenates values of multiple fields into a list, and assigns it to a new field."""
807
 
808
  fields: List[str]
@@ -824,7 +830,7 @@ class ListFieldValues(StreamInstanceOperator):
824
  return instance
825
 
826
 
827
- class ZipFieldValues(StreamInstanceOperator):
828
  """Zips values of multiple fields in a given instance, similar to list(zip(*fields)).
829
 
830
  The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
@@ -860,7 +866,7 @@ class ZipFieldValues(StreamInstanceOperator):
860
  return instance
861
 
862
 
863
- class InterleaveListsToDialogOperator(StreamInstanceOperator):
864
  """Interleaves two lists, one of user dialog turns and one of assistant dialog turns, into a single list of tuples, alternating between "user" and "assistant".
865
 
866
  The list of tuples if of format (role, turn_content), where the role label is specified by
@@ -905,7 +911,7 @@ class InterleaveListsToDialogOperator(StreamInstanceOperator):
905
  return instance
906
 
907
 
908
- class IndexOf(StreamInstanceOperator):
909
  """For a given instance, finds the offset of value of field 'index_of', within the value of field 'search_in'."""
910
 
911
  search_in: str
@@ -927,7 +933,7 @@ class IndexOf(StreamInstanceOperator):
927
  return instance
928
 
929
 
930
- class TakeByField(StreamInstanceOperator):
931
  """From field 'field' of a given instance, select the member indexed by field 'index', and store to field 'to_field'."""
932
 
933
  field: str
@@ -1034,7 +1040,7 @@ class GetItemByIndex(FieldOperator):
1034
  return self.items_list[value]
1035
 
1036
 
1037
- class AddID(StreamInstanceOperator):
1038
  """Stores a unique id value in the designated 'id_field_name' field of the given instance."""
1039
 
1040
  id_field_name: str = "id"
@@ -1046,7 +1052,7 @@ class AddID(StreamInstanceOperator):
1046
  return instance
1047
 
1048
 
1049
- class CastFields(StreamInstanceOperator):
1050
  """Casts specified fields to specified types.
1051
 
1052
  Args:
@@ -1106,7 +1112,7 @@ class CastFields(StreamInstanceOperator):
1106
  return instance
1107
 
1108
 
1109
- class DivideAllFieldsBy(StreamInstanceOperator):
1110
  """Recursively reach down to all fields that are float, and divide each by 'divisor'.
1111
 
1112
  The given instance is viewed as a tree whose internal nodes are dictionaries and lists, and
@@ -1165,7 +1171,7 @@ class ArtifactFetcherMixin:
1165
  return cls.cache[artifact_identifier]
1166
 
1167
 
1168
- class ApplyOperatorsField(StreamInstanceOperator):
1169
  """Applies value operators to each instance in a stream based on specified fields.
1170
 
1171
  Args:
@@ -1206,7 +1212,7 @@ class ApplyOperatorsField(StreamInstanceOperator):
1206
  return operator.process_instance(instance)
1207
 
1208
 
1209
- class FilterByCondition(SingleStreamOperator):
1210
  """Filters a stream, yielding only instances in which the values in required fields follow the required condition operator.
1211
 
1212
  Raises an error if a required field name is missing from the input instance.
@@ -1322,7 +1328,7 @@ class ComputeExpressionMixin(Artifact):
1322
  )
1323
 
1324
 
1325
- class FilterByExpression(SingleStreamOperator, ComputeExpressionMixin):
1326
  """Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated.
1327
 
1328
  Raises an error if a field participating in the specified condition is missing from the instance
@@ -1337,9 +1343,7 @@ class FilterByExpression(SingleStreamOperator, ComputeExpressionMixin):
1337
  FilterByExpression(expression = "a <= 4 and b > 5") will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
1338
  FilterByExpression(expression = "a in [4, 8]") will yield only instances where "a" is 4 or 8
1339
  FilterByExpression(expression = "a not in [4, 8]") will yield only instances where "a" is neither 4 nor 8
1340
- FilterByExpression(expression = "a['b'] not in [4, 8]") will yield only instances where "a" is a dict in
1341
- which key 'b' is mapped to a value that is neither 4 nor 8
1342
-
1343
  """
1344
 
1345
  error_on_filtered_all: bool = True
@@ -1357,7 +1361,7 @@ class FilterByExpression(SingleStreamOperator, ComputeExpressionMixin):
1357
  )
1358
 
1359
 
1360
- class ExecuteExpression(StreamInstanceOperator, ComputeExpressionMixin):
1361
  """Compute an expression, specified as a string to be eval-uated, over the instance's fields, and store the result in field to_field.
1362
 
1363
  Raises an error if a field mentioned in the query is missing from the instance.
@@ -1651,7 +1655,7 @@ class SplitByNestedGroup(MultiStreamOperator):
1651
  return MultiStream.from_iterables(result)
1652
 
1653
 
1654
- class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
1655
  """Applies stream operators to a stream based on specified fields in each instance.
1656
 
1657
  Args:
@@ -1676,14 +1680,14 @@ class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
1676
  operator = self.get_artifact(operator_name)
1677
  assert isinstance(
1678
  operator, StreamingOperator
1679
- ), f"Operator {operator_name} must be a SingleStreamOperator"
1680
 
1681
  stream = operator(MultiStream({"tmp": stream}))["tmp"]
1682
 
1683
  yield from stream
1684
 
1685
 
1686
- class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
1687
  """Applies metric operators to a stream based on a metric field specified in each instance.
1688
 
1689
  Args:
@@ -1855,7 +1859,7 @@ class FeatureGroupedShuffle(Shuffle):
1855
  return list(itertools.chain(*page_blocks))
1856
 
1857
 
1858
- class EncodeLabels(StreamInstanceOperator):
1859
  """Encode each value encountered in any field in 'fields' into the integers 0,1,...
1860
 
1861
  Encoding is determined by a str->int map that is built on the go, as different values are
@@ -1908,7 +1912,7 @@ class EncodeLabels(StreamInstanceOperator):
1908
  return instance
1909
 
1910
 
1911
- class StreamRefiner(SingleStreamOperator):
1912
  """Discard from the input stream all instances beyond the leading 'max_instances' instances.
1913
 
1914
  Thereby, if the input stream consists of no more than 'max_instances' instances, the resulting stream is the whole of the
@@ -1987,6 +1991,80 @@ class DeterministicBalancer(StreamRefiner):
1987
  yield instance
1988
 
1989
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1990
  class LengthBalancer(DeterministicBalancer):
1991
  """Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
1992
 
@@ -2071,7 +2149,7 @@ class ExtractZipFile(SideEffectOperator):
2071
  zf.extractall(self.target_dir)
2072
 
2073
 
2074
- class DuplicateInstances(SingleStreamOperator):
2075
  """Operator which duplicates each instance in stream a given number of times.
2076
 
2077
  Attributes:
 
29
 
30
  The rest of this section is dedicated for general operators.
31
 
32
+ General Operators List:
33
  ------------------------
34
  """
35
+
36
  import copy
37
  import operator
38
  import uuid
 
61
  from .dataclass import DeprecatedField, NonPositionalField, OptionalField
62
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
63
  from .operator import (
64
+ InstanceOperator,
65
  MultiStream,
66
  MultiStreamOperator,
67
  PackageRequirementsMixin,
68
  PagedStreamOperator,
69
  SequentialOperator,
70
  SideEffectOperator,
 
71
  SingleStreamReducer,
72
  SourceOperator,
73
  StreamingOperator,
74
  StreamInitializerOperator,
75
+ StreamOperator,
76
  )
77
  from .random_utils import new_random_generator
78
  from .settings_utils import get_settings
 
117
  return MultiStream.from_iterables(self.iterables)
118
 
119
 
120
+ class MapInstanceValues(InstanceOperator):
121
  """A class used to map instance values into other values.
122
 
123
+ This class is a type of InstanceOperator,
124
  it maps values of instances in a stream using predefined mappers.
125
 
126
  Attributes:
 
139
  replaces '1' with 'hi' and '2' with 'bye' in field 'a' in all instances of all streams:
140
  instance {"a":"1", "b": 2} becomes {"a":"hi", "b": 2}.
141
 
142
+ MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, process_every_value=True)
143
  Assuming field 'a' is a list of values, potentially including "1"-s and "2"-s, this replaces
144
  each such "1" with "hi" and "2" -- with "bye" in all instances of all streams:
145
  instance {"a": ["1", "2"], "b": 2} becomes {"a": ["hi", "bye"], "b": 2}.
 
205
  return val
206
 
207
 
208
+ class FlattenInstances(InstanceOperator):
209
  """Flattens each instance in a stream, making nested dictionary entries into top-level entries.
210
 
211
  Args:
 
222
  return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
223
 
224
 
225
+ class AddFields(InstanceOperator):
226
  """Adds specified fields to each instance in a given stream or all streams (default) If fields exist, updates them.
227
 
228
  Args:
 
265
  return instance
266
 
267
 
268
+ class RemoveFields(InstanceOperator):
269
  """Remove specified fields from each instance in a stream.
270
 
271
  Args:
 
282
  return instance
283
 
284
 
285
+ class InstanceFieldOperator(InstanceOperator):
286
  """A general stream instance operator that processes the values of a field (or multiple ones).
287
 
288
  Args:
 
394
  def process(
395
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
396
  ) -> Dict[str, Any]:
397
+ # Need to deep copy instance, because when assigning two dictionary fields,
398
+ # dict_set() the target field dictionary fields.
399
+ # This means that if this target field was assigned to another field before,
400
+ # the field is updated as well.
401
+ instance = deepcopy(instance)
402
  for from_field, to_field in self._field_to_field:
403
  try:
404
  old_value = dict_get(
 
491
  return self.add + value
492
 
493
 
494
+ class Augmentor(InstanceOperator):
495
  """A stream operator that augments the values of either the task input fields before rendering with the template, or the input passed to the model after rendering of the template.
496
 
497
  Args:
 
738
  return self.separator.join(str(x) for x in value)
739
 
740
 
741
+ class Apply(InstanceOperator):
742
  """A class used to apply a python function and store the result in a field.
743
 
744
  Args:
 
808
  return instance
809
 
810
 
811
+ class ListFieldValues(InstanceOperator):
812
  """Concatenates values of multiple fields into a list, and assigns it to a new field."""
813
 
814
  fields: List[str]
 
830
  return instance
831
 
832
 
833
+ class ZipFieldValues(InstanceOperator):
834
  """Zips values of multiple fields in a given instance, similar to list(zip(*fields)).
835
 
836
  The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
 
866
  return instance
867
 
868
 
869
+ class InterleaveListsToDialogOperator(InstanceOperator):
870
  """Interleaves two lists, one of user dialog turns and one of assistant dialog turns, into a single list of tuples, alternating between "user" and "assistant".
871
 
872
  The list of tuples if of format (role, turn_content), where the role label is specified by
 
911
  return instance
912
 
913
 
914
+ class IndexOf(InstanceOperator):
915
  """For a given instance, finds the offset of value of field 'index_of', within the value of field 'search_in'."""
916
 
917
  search_in: str
 
933
  return instance
934
 
935
 
936
+ class TakeByField(InstanceOperator):
937
  """From field 'field' of a given instance, select the member indexed by field 'index', and store to field 'to_field'."""
938
 
939
  field: str
 
1040
  return self.items_list[value]
1041
 
1042
 
1043
+ class AddID(InstanceOperator):
1044
  """Stores a unique id value in the designated 'id_field_name' field of the given instance."""
1045
 
1046
  id_field_name: str = "id"
 
1052
  return instance
1053
 
1054
 
1055
+ class CastFields(InstanceOperator):
1056
  """Casts specified fields to specified types.
1057
 
1058
  Args:
 
1112
  return instance
1113
 
1114
 
1115
+ class DivideAllFieldsBy(InstanceOperator):
1116
  """Recursively reach down to all fields that are float, and divide each by 'divisor'.
1117
 
1118
  The given instance is viewed as a tree whose internal nodes are dictionaries and lists, and
 
1171
  return cls.cache[artifact_identifier]
1172
 
1173
 
1174
+ class ApplyOperatorsField(InstanceOperator):
1175
  """Applies value operators to each instance in a stream based on specified fields.
1176
 
1177
  Args:
 
1212
  return operator.process_instance(instance)
1213
 
1214
 
1215
+ class FilterByCondition(StreamOperator):
1216
  """Filters a stream, yielding only instances in which the values in required fields follow the required condition operator.
1217
 
1218
  Raises an error if a required field name is missing from the input instance.
 
1328
  )
1329
 
1330
 
1331
+ class FilterByExpression(StreamOperator, ComputeExpressionMixin):
1332
  """Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated.
1333
 
1334
  Raises an error if a field participating in the specified condition is missing from the instance
 
1343
  FilterByExpression(expression = "a <= 4 and b > 5") will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
1344
  FilterByExpression(expression = "a in [4, 8]") will yield only instances where "a" is 4 or 8
1345
  FilterByExpression(expression = "a not in [4, 8]") will yield only instances where "a" is neither 4 nor 8
1346
+ FilterByExpression(expression = "a['b'] not in [4, 8]") will yield only instances where "a" is a dict in which key 'b' is mapped to a value that is neither 4 nor 8
 
 
1347
  """
1348
 
1349
  error_on_filtered_all: bool = True
 
1361
  )
1362
 
1363
 
1364
+ class ExecuteExpression(InstanceOperator, ComputeExpressionMixin):
1365
  """Compute an expression, specified as a string to be eval-uated, over the instance's fields, and store the result in field to_field.
1366
 
1367
  Raises an error if a field mentioned in the query is missing from the instance.
 
1655
  return MultiStream.from_iterables(result)
1656
 
1657
 
1658
+ class ApplyStreamOperatorsField(StreamOperator, ArtifactFetcherMixin):
1659
  """Applies stream operators to a stream based on specified fields in each instance.
1660
 
1661
  Args:
 
1680
  operator = self.get_artifact(operator_name)
1681
  assert isinstance(
1682
  operator, StreamingOperator
1683
+ ), f"Operator {operator_name} must be a StreamOperator"
1684
 
1685
  stream = operator(MultiStream({"tmp": stream}))["tmp"]
1686
 
1687
  yield from stream
1688
 
1689
 
1690
+ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1691
  """Applies metric operators to a stream based on a metric field specified in each instance.
1692
 
1693
  Args:
 
1859
  return list(itertools.chain(*page_blocks))
1860
 
1861
 
1862
+ class EncodeLabels(InstanceOperator):
1863
  """Encode each value encountered in any field in 'fields' into the integers 0,1,...
1864
 
1865
  Encoding is determined by a str->int map that is built on the go, as different values are
 
1912
  return instance
1913
 
1914
 
1915
+ class StreamRefiner(StreamOperator):
1916
  """Discard from the input stream all instances beyond the leading 'max_instances' instances.
1917
 
1918
  Thereby, if the input stream consists of no more than 'max_instances' instances, the resulting stream is the whole of the
 
1991
  yield instance
1992
 
1993
 
1994
+ class MinimumOneExamplePerLabelRefiner(StreamRefiner):
1995
+ """A class used to return a specified number instances ensuring at least one example per label.
1996
+
1997
+ For each instance, a signature value is constructed from the values of the instance in specified input 'fields'.
1998
+ MinimumOneExamplePerLabelRefiner takes first instance that appears from each label (each unique signature), and then adds more elements up to the max_instances limit. In general, the refiner takes the first elements in the stream that meet the required conditions.
1999
+ MinimumOneExamplePerLabelRefiner then shuffles the results to avoid having one instance
2000
+ from each class first and then the rest . If max instance is not set, the original stream will be used
2001
+
2002
+ Attributes:
2003
+ fields (List[str]): A list of field names to be used in producing the instance's signature.
2004
+ max_instances (Optional, int): Number of elements to select. Note that max_instances of StreamRefiners that are passed to the recipe (e.g. 'train_refiner'. `test_refiner`) are overridden by the recipe parameters ( `max_train_instances`, `max_test_instances`)
2005
+
2006
+ Usage:
2007
+ balancer = MinimumOneExamplePerLabelRefiner(fields=["field1", "field2"], max_instances=200)
2008
+ balanced_stream = balancer.process(stream)
2009
+
2010
+ Example:
2011
+ When input [{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 1, "b": 3},{"a": 1, "b": 4},{"a": 2, "b": 5}] is fed into
2012
+ MinimumOneExamplePerLabelRefiner(fields=["a"], max_instances=3)
2013
+ the resulting stream will be:
2014
+ [{'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 5}] (order may be different)
2015
+ """
2016
+
2017
+ fields: List[str]
2018
+
2019
+ def signature(self, instance):
2020
+ return str(tuple(dict_get(instance, field) for field in self.fields))
2021
+
2022
+ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
2023
+ if self.max_instances is None:
2024
+ for instance in stream:
2025
+ yield instance
2026
+
2027
+ counter = Counter()
2028
+ for instance in stream:
2029
+ counter[self.signature(instance)] += 1
2030
+ all_keys = counter.keys()
2031
+ if len(counter) == 0:
2032
+ return
2033
+
2034
+ if self.max_instances is not None and len(all_keys) > self.max_instances:
2035
+ raise Exception(
2036
+ f"Can not generate a stream with at least one example per label, because the max instances requested {self.max_instances} is smaller than the number of different labels {len(all_keys)}"
2037
+ f" ({len(all_keys)}"
2038
+ )
2039
+
2040
+ counter = Counter()
2041
+ used_indices = set()
2042
+ selected_elements = []
2043
+ # select at least one per class
2044
+ for idx, instance in enumerate(stream):
2045
+ sign = self.signature(instance)
2046
+ if counter[sign] == 0:
2047
+ counter[sign] += 1
2048
+ used_indices.add(idx)
2049
+ selected_elements.append(
2050
+ instance
2051
+ ) # collect all elements first to allow shuffling of both groups
2052
+
2053
+ # select more to reach self.max_instances examples
2054
+ for idx, instance in enumerate(stream):
2055
+ if idx not in used_indices:
2056
+ if self.max_instances is None or len(used_indices) < self.max_instances:
2057
+ used_indices.add(idx)
2058
+ selected_elements.append(
2059
+ instance
2060
+ ) # collect all elements first to allow shuffling of both groups
2061
+
2062
+ # shuffle elements to avoid having one element from each class appear first
2063
+ random_generator = new_random_generator(sub_seed=selected_elements)
2064
+ random_generator.shuffle(selected_elements)
2065
+ yield from selected_elements
2066
+
2067
+
2068
  class LengthBalancer(DeterministicBalancer):
2069
  """Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
2070
 
 
2149
  zf.extractall(self.target_dir)
2150
 
2151
 
2152
+ class DuplicateInstances(StreamOperator):
2153
  """Operator which duplicates each instance in stream a given number of times.
2154
 
2155
  Attributes:
schema.py CHANGED
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
4
 
5
  from datasets import Features, Sequence, Value
6
 
7
- from .operator import StreamInstanceOperatorValidator
8
 
9
  UNITXT_DATASET_SCHEMA = Features(
10
  {
@@ -15,20 +15,12 @@ UNITXT_DATASET_SCHEMA = Features(
15
  "group": Value("string"),
16
  "postprocessors": Sequence(Value("string")),
17
  "task_data": Value(dtype="string"),
 
18
  }
19
  )
20
 
21
- # UNITXT_METRIC_SCHEMA = Features({
22
- # "predictions": Value("string", id="sequence"),
23
- # "target": Value("string", id="sequence"),
24
- # "references": Value("string", id="sequence"),
25
- # "metrics": Value("string", id="sequence"),
26
- # 'group': Value('string'),
27
- # 'postprocessors': Value("string", id="sequence"),
28
- # })
29
 
30
-
31
- class ToUnitxtGroup(StreamInstanceOperatorValidator):
32
  group: str
33
  metrics: List[str] = None
34
  postprocessors: List[str] = field(default_factory=lambda: ["to_string_stripped"])
 
4
 
5
  from datasets import Features, Sequence, Value
6
 
7
+ from .operator import InstanceOperatorValidator
8
 
9
  UNITXT_DATASET_SCHEMA = Features(
10
  {
 
15
  "group": Value("string"),
16
  "postprocessors": Sequence(Value("string")),
17
  "task_data": Value(dtype="string"),
18
+ "data_classification_policy": Sequence(Value("string")),
19
  }
20
  )
21
 
 
 
 
 
 
 
 
 
22
 
23
+ class ToUnitxtGroup(InstanceOperatorValidator):
 
24
  group: str
25
  metrics: List[str] = None
26
  postprocessors: List[str] = field(default_factory=lambda: ["to_string_stripped"])
settings_utils.py CHANGED
@@ -128,12 +128,12 @@ if Settings.is_uninitilized():
128
  settings.default_recipe = "standard_recipe"
129
  settings.default_verbosity = "info"
130
  settings.remote_metrics = []
131
- settings.allow_passing_data_to_remote_api = (bool, False)
132
  settings.test_card_disable = (bool, False)
133
  settings.test_metric_disable = (bool, False)
134
  settings.metrics_master_key_token = None
135
  settings.seed = (int, 42)
136
  settings.skip_artifacts_prepare_and_verify = (bool, False)
 
137
 
138
  if Constants.is_uninitilized():
139
  constants = Constants()
 
128
  settings.default_recipe = "standard_recipe"
129
  settings.default_verbosity = "info"
130
  settings.remote_metrics = []
 
131
  settings.test_card_disable = (bool, False)
132
  settings.test_metric_disable = (bool, False)
133
  settings.metrics_master_key_token = None
134
  settings.seed = (int, 42)
135
  settings.skip_artifacts_prepare_and_verify = (bool, False)
136
+ settings.data_classification_policy = None
137
 
138
  if Constants.is_uninitilized():
139
  constants = Constants()
span_lableing_operators.py CHANGED
@@ -1,9 +1,9 @@
1
  from typing import Any, Dict, List, Optional
2
 
3
- from .operator import StreamInstanceOperator
4
 
5
 
6
- class IobExtractor(StreamInstanceOperator):
7
  """A class designed to extract entities from sequences of text using the Inside-Outside-Beginning (IOB) tagging convention. It identifies entities based on IOB tags and categorizes them into predefined labels such as Person, Organization, and Location.
8
 
9
  Attributes:
 
1
  from typing import Any, Dict, List, Optional
2
 
3
+ from .operator import InstanceOperator
4
 
5
 
6
+ class IobExtractor(InstanceOperator):
7
  """A class designed to extract entities from sequences of text using the Inside-Outside-Beginning (IOB) tagging convention. It identifies entities based on IOB tags and categorizes them into predefined labels such as Person, Organization, and Location.
8
 
9
  Attributes:
standard.py CHANGED
@@ -124,11 +124,23 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
124
 
125
  def set_pipelines(self):
126
  self.loading = SequentialOperator()
 
127
  self.metadata = SequentialOperator()
 
 
 
128
  self.standardization = SequentialOperator()
 
 
 
129
  self.processing = SequentialOperator()
 
 
 
130
  self.verblization = SequentialOperator()
 
131
  self.finalize = SequentialOperator()
 
132
 
133
  self.steps = [
134
  self.loading,
@@ -211,7 +223,6 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
211
  AddFields(
212
  fields={
213
  "recipe_metadata": {
214
- "card": self.card,
215
  "template": self.template,
216
  "system_prompt": self.system_prompt,
217
  "format": self.format,
@@ -228,7 +239,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
228
  self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
229
  self.processing.steps.append(self.augmentor)
230
 
231
- if self.demos_pool_size is not None:
232
  self.processing.steps.append(
233
  CreateDemosPool(
234
  from_split=self.demos_taken_from,
 
124
 
125
  def set_pipelines(self):
126
  self.loading = SequentialOperator()
127
+ self.loading.__description__ = "Loading the data from the data source."
128
  self.metadata = SequentialOperator()
129
+ self.metadata.__description__ = (
130
+ "Adding metadata (e.g. format, system prompt, template) "
131
+ )
132
  self.standardization = SequentialOperator()
133
+ self.standardization.__description__ = (
134
+ "Standardizing the raw dataset fields to task field definition."
135
+ )
136
  self.processing = SequentialOperator()
137
+ self.processing.__description__ = (
138
+ "Setting task fields (and selecting demos per sample if needed)."
139
+ )
140
  self.verblization = SequentialOperator()
141
+ self.verblization.__description__ = "Verbalizing the input to the model and gold references to the 'source', 'target' and 'references' fields."
142
  self.finalize = SequentialOperator()
143
+ self.finalize.__description__ = "Adding post processors. Removing intermediate fields. Creating the final output dataset."
144
 
145
  self.steps = [
146
  self.loading,
 
223
  AddFields(
224
  fields={
225
  "recipe_metadata": {
 
226
  "template": self.template,
227
  "system_prompt": self.system_prompt,
228
  "format": self.format,
 
239
  self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
240
  self.processing.steps.append(self.augmentor)
241
 
242
+ if self.demos_pool_size is not None and self.demos_pool_size > 0:
243
  self.processing.steps.append(
244
  CreateDemosPool(
245
  from_split=self.demos_taken_from,
string_operators.py CHANGED
@@ -6,7 +6,7 @@ from typing import (
6
  Optional,
7
  )
8
 
9
- from .operators import FieldOperator, StreamInstanceOperator
10
 
11
 
12
  class Split(FieldOperator):
@@ -44,7 +44,7 @@ class Join(FieldOperator):
44
  return self.by.join(value)
45
 
46
 
47
- class FormatText(StreamInstanceOperator):
48
  to_field: str
49
  text: str
50
 
 
6
  Optional,
7
  )
8
 
9
+ from .operators import FieldOperator, InstanceOperator
10
 
11
 
12
  class Split(FieldOperator):
 
44
  return self.by.join(value)
45
 
46
 
47
+ class FormatText(InstanceOperator):
48
  to_field: str
49
  text: str
50
 
struct_data_operators.py CHANGED
@@ -28,7 +28,7 @@ from typing import (
28
  import pandas as pd
29
 
30
  from .dict_utils import dict_get
31
- from .operators import FieldOperator, StreamInstanceOperator
32
 
33
 
34
  class SerializeTable(ABC, FieldOperator):
@@ -237,7 +237,7 @@ def truncate_cell(cell_value, max_len):
237
  return None
238
 
239
 
240
- class TruncateTableCells(StreamInstanceOperator):
241
  """Limit the maximum length of cell values in a table to reduce the overall length.
242
 
243
  Args:
@@ -318,7 +318,7 @@ class TruncateTableRows(FieldOperator):
318
  return table_content
319
 
320
 
321
- class SerializeTableRowAsText(StreamInstanceOperator):
322
  """Serializes a table row as text.
323
 
324
  Args:
@@ -348,7 +348,7 @@ class SerializeTableRowAsText(StreamInstanceOperator):
348
  return instance
349
 
350
 
351
- class SerializeTableRowAsList(StreamInstanceOperator):
352
  """Serializes a table row as list.
353
 
354
  Args:
@@ -417,7 +417,7 @@ class SerializeKeyValPairs(FieldOperator):
417
  return serialized_str[:-2]
418
 
419
 
420
- class ListToKeyValPairs(StreamInstanceOperator):
421
  """Maps list of keys and values into key:value pairs.
422
 
423
  Sample input in expected format: {"keys": ["name", "age", "sex"], "values": ["Alex", 31, "M"]}
@@ -512,16 +512,16 @@ class ShuffleTableColumns(FieldOperator):
512
  """Shuffles the table columns randomly.
513
 
514
  Sample Input:
515
- {
516
- "header": ["name", "age"],
517
- "rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
518
- }
519
 
520
  Sample Output:
521
- {
522
- "header": ["age", "name"],
523
- "rows": [[26, "Alex"], [34, "Raj"], [39, "Donald"]],
524
- }
525
  """
526
 
527
  def process_value(self, table: Any) -> Any:
 
28
  import pandas as pd
29
 
30
  from .dict_utils import dict_get
31
+ from .operators import FieldOperator, InstanceOperator
32
 
33
 
34
  class SerializeTable(ABC, FieldOperator):
 
237
  return None
238
 
239
 
240
+ class TruncateTableCells(InstanceOperator):
241
  """Limit the maximum length of cell values in a table to reduce the overall length.
242
 
243
  Args:
 
318
  return table_content
319
 
320
 
321
+ class SerializeTableRowAsText(InstanceOperator):
322
  """Serializes a table row as text.
323
 
324
  Args:
 
348
  return instance
349
 
350
 
351
+ class SerializeTableRowAsList(InstanceOperator):
352
  """Serializes a table row as list.
353
 
354
  Args:
 
417
  return serialized_str[:-2]
418
 
419
 
420
+ class ListToKeyValPairs(InstanceOperator):
421
  """Maps list of keys and values into key:value pairs.
422
 
423
  Sample input in expected format: {"keys": ["name", "age", "sex"], "values": ["Alex", 31, "M"]}
 
512
  """Shuffles the table columns randomly.
513
 
514
  Sample Input:
515
+ {
516
+ "header": ["name", "age"],
517
+ "rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
518
+ }
519
 
520
  Sample Output:
521
+ {
522
+ "header": ["age", "name"],
523
+ "rows": [[26, "Alex"], [34, "Raj"], [39, "Donald"]],
524
+ }
525
  """
526
 
527
  def process_value(self, table: Any) -> Any:
system_prompts.py CHANGED
@@ -2,10 +2,10 @@ from abc import abstractmethod
2
  from typing import Any, Dict, Optional
3
 
4
  from .dataclass import NonPositionalField
5
- from .operator import StreamInstanceOperator
6
 
7
 
8
- class SystemPrompt(StreamInstanceOperator):
9
  """The role of SystemPrompt is to add task-independent opening-text to every instance."""
10
 
11
  skip_rendered_instance: bool = NonPositionalField(default=True)
 
2
  from typing import Any, Dict, Optional
3
 
4
  from .dataclass import NonPositionalField
5
+ from .operator import InstanceOperator
6
 
7
 
8
+ class SystemPrompt(InstanceOperator):
9
  """The role of SystemPrompt is to add task-independent opening-text to every instance."""
10
 
11
  skip_rendered_instance: bool = NonPositionalField(default=True)
task.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union
3
 
4
  from .artifact import fetch_artifact
5
  from .logging_utils import get_logger
6
- from .operator import StreamInstanceOperator
7
  from .type_utils import (
8
  get_args,
9
  get_origin,
@@ -13,8 +13,8 @@ from .type_utils import (
13
  )
14
 
15
 
16
- class Task(StreamInstanceOperator):
17
- """FormTask packs the different instance fields into dictionaries by their roles in the task.
18
 
19
  Attributes:
20
  inputs (Union[Dict[str, str], List[str]]):
@@ -81,7 +81,7 @@ class Task(StreamInstanceOperator):
81
  def check_metrics_type(self) -> None:
82
  prediction_type = parse_type_string(self.prediction_type)
83
  for metric_id in self.metrics:
84
- metric_prediction_type = FormTask.get_metric_prediction_type(metric_id)
85
 
86
  if (
87
  prediction_type == metric_prediction_type
@@ -107,11 +107,13 @@ class Task(StreamInstanceOperator):
107
 
108
  inputs = {key: instance[key] for key in self.inputs.keys()}
109
  outputs = {key: instance[key] for key in self.outputs.keys()}
 
110
 
111
  return {
112
  "inputs": inputs,
113
  "outputs": outputs,
114
  "metrics": self.metrics,
 
115
  }
116
 
117
 
 
3
 
4
  from .artifact import fetch_artifact
5
  from .logging_utils import get_logger
6
+ from .operator import InstanceOperator
7
  from .type_utils import (
8
  get_args,
9
  get_origin,
 
13
  )
14
 
15
 
16
+ class Task(InstanceOperator):
17
+ """Task packs the different instance fields into dictionaries by their roles in the task.
18
 
19
  Attributes:
20
  inputs (Union[Dict[str, str], List[str]]):
 
81
  def check_metrics_type(self) -> None:
82
  prediction_type = parse_type_string(self.prediction_type)
83
  for metric_id in self.metrics:
84
+ metric_prediction_type = Task.get_metric_prediction_type(metric_id)
85
 
86
  if (
87
  prediction_type == metric_prediction_type
 
107
 
108
  inputs = {key: instance[key] for key in self.inputs.keys()}
109
  outputs = {key: instance[key] for key in self.outputs.keys()}
110
+ data_classification_policy = instance.get("data_classification_policy", [])
111
 
112
  return {
113
  "inputs": inputs,
114
  "outputs": outputs,
115
  "metrics": self.metrics,
116
+ "data_classification_policy": data_classification_policy,
117
  }
118
 
119
 
templates.py CHANGED
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
  from .artifact import Artifact
7
  from .collections import ListCollection
8
  from .dataclass import NonPositionalField
9
- from .operator import StreamInstanceOperator
10
  from .random_utils import new_random_generator
11
  from .type_utils import isoftype
12
 
@@ -20,7 +20,7 @@ class TemplateFormatKeyError(KeyError):
20
  )
21
 
22
 
23
- class Template(StreamInstanceOperator):
24
  """The role of template is to take the fields of every instance and verbalize it.
25
 
26
  Meaning the template is taking the instance and generating source, target and references.
 
6
  from .artifact import Artifact
7
  from .collections import ListCollection
8
  from .dataclass import NonPositionalField
9
+ from .operator import InstanceOperator
10
  from .random_utils import new_random_generator
11
  from .type_utils import isoftype
12
 
 
20
  )
21
 
22
 
23
+ class Template(InstanceOperator):
24
  """The role of template is to take the fields of every instance and verbalize it.
25
 
26
  Meaning the template is taking the instance and generating source, target and references.
text_utils.py CHANGED
@@ -89,6 +89,9 @@ def construct_dict_str(d, indent=0, indent_delta=4, max_chars=None):
89
  res += construct_dict_str(value, indent + indent_delta, max_chars=max_chars)
90
  else:
91
  str_value = str(value)
 
 
 
92
  line_width = max_chars - indent
93
  lines = str_value.split("\n")
94
  res += f"{indent_str}{key} ({type(value).__name__}):\n"
 
89
  res += construct_dict_str(value, indent + indent_delta, max_chars=max_chars)
90
  else:
91
  str_value = str(value)
92
+ str_value = re.sub(r"\w+=None, ", "", str_value)
93
+ str_value = re.sub(r"\w+={}, ", "", str_value)
94
+ str_value = re.sub(r"\w+=\[\], ", "", str_value)
95
  line_width = max_chars - indent
96
  lines = str_value.split("\n")
97
  res += f"{indent_str}{key} ({type(value).__name__}):\n"
type_utils.py CHANGED
@@ -13,23 +13,23 @@ def convert_union_type(type_string: str) -> str:
13
 
14
  Args:
15
  type_string (str): A string representation of a Python type hint. It can be any
16
- valid Python type, which does not contain strings (e.g. 'Literal').
17
- Examples include 'List[int|float]', 'str|float|bool' etc.
18
-
19
- Formally, the function depends on the input string adhering to the following rules.
20
- Assuming that the input is a valid type hint the function does not check that 'word' is
21
- 'str', 'bool', 'List' etc. It just depends on the following general structure (spaces ignored):
22
- type -> word OR type( | type)* OR word[type( , type)*]
23
- word is a sequence of (0 or more) chars, each being any char but: [ ] , |
24
- This implies that if any of these 4 chars shows not as a meta char of the input
25
- type_string, but inside some constant string (of Literal, for example), the scheme
26
- will not work.
27
-
28
- Cases like Literal, that might contain occurrences of the four chars above not as meta chars
29
- in the type string, must be handled as special cases by this function, as shown for Literal,
30
- as an example. Because 'format_type_string' serves as preprocessing for 'parse_type_string',
31
- which has a list of allowed types, of which Literal is not a member, Literal and such are not
32
- relevant at all now; and the case is brought here just for an example for future use.
33
 
34
 
35
  Returns:
 
13
 
14
  Args:
15
  type_string (str): A string representation of a Python type hint. It can be any
16
+ valid Python type, which does not contain strings (e.g. 'Literal').
17
+ Examples include 'List[int|float]', 'str|float|bool' etc.
18
+
19
+ Formally, the function depends on the input string adhering to the following rules.
20
+ Assuming that the input is a valid type hint the function does not check that 'word' is
21
+ 'str', 'bool', 'List' etc. It just depends on the following general structure (spaces ignored):
22
+ type -> word OR type( | type)* OR word[type( , type)*]
23
+ word is a sequence of (0 or more) chars, each being any char but: [ ] , |
24
+ This implies that if any of these 4 chars shows not as a meta char of the input
25
+ type_string, but inside some constant string (of Literal, for example), the scheme
26
+ will not work.
27
+
28
+ Cases like Literal, that might contain occurrences of the four chars above not as meta chars
29
+ in the type string, must be handled as special cases by this function, as shown for Literal,
30
+ as an example. Because 'format_type_string' serves as preprocessing for 'parse_type_string',
31
+ which has a list of allowed types, of which Literal is not a member, Literal and such are not
32
+ relevant at all now; and the case is brought here just for an example for future use.
33
 
34
 
35
  Returns:
validate.py CHANGED
@@ -4,14 +4,14 @@ from typing import Any, Dict, Optional
4
 
5
  from datasets import Features, Sequence, Value
6
 
7
- from .operator import StreamInstanceOperator
8
 
9
 
10
  class Validator(ABC):
11
  pass
12
 
13
 
14
- class ValidateSchema(Validator, StreamInstanceOperator):
15
  schema: Features = None
16
 
17
  def verify(self):
 
4
 
5
  from datasets import Features, Sequence, Value
6
 
7
+ from .operator import InstanceOperator
8
 
9
 
10
  class Validator(ABC):
11
  pass
12
 
13
 
14
+ class ValidateSchema(Validator, InstanceOperator):
15
  schema: Features = None
16
 
17
  def verify(self):
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.9.0"
 
1
+ version = "1.10.0"