Elron commited on
Commit
dcd3b86
·
verified ·
1 Parent(s): a66b8be

Upload operators.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. operators.py +625 -235
operators.py CHANGED
@@ -1,11 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import collections
2
  import importlib
 
 
3
  import uuid
4
  from abc import abstractmethod
5
  from collections import Counter
6
  from copy import deepcopy
7
  from dataclasses import field
8
  from itertools import zip_longest
 
9
  from typing import (
10
  Any,
11
  Callable,
@@ -32,17 +68,20 @@ from .operator import (
32
  StreamInstanceOperator,
33
  StreamSource,
34
  )
35
- from .random_utils import get_random, nested_seed
36
  from .stream import Stream
37
  from .text_utils import nested_tuple_to_string
 
38
  from .utils import flatten_dict
39
 
40
 
41
  class FromIterables(StreamInitializerOperator):
42
- """Creates a MultiStream from iterables.
 
 
 
 
43
 
44
- Args:
45
- iterables (Dict[str, Iterable]): A dictionary where each key-value pair represents a stream name and its corresponding iterable.
46
  """
47
 
48
  def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
@@ -50,6 +89,19 @@ class FromIterables(StreamInitializerOperator):
50
 
51
 
52
  class IterableSource(StreamSource):
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  iterables: Dict[str, Iterable]
54
 
55
  def __call__(self) -> MultiStream:
@@ -57,7 +109,7 @@ class IterableSource(StreamSource):
57
 
58
 
59
  class MapInstanceValues(StreamInstanceOperator):
60
- """A class used to map instance values into a stream.
61
 
62
  This class is a type of StreamInstanceOperator,
63
  it maps values of instances in a stream using predefined mappers.
@@ -87,6 +139,11 @@ class MapInstanceValues(StreamInstanceOperator):
87
  To ensure that all values of field 'a' are mapped in every instance, use strict=True.
88
  Input instance {"a":"3", "b": 2} will raise an exception per the above call,
89
  because "3" is not a key in the mapper of "a".
 
 
 
 
 
90
  """
91
 
92
  mappers: Dict[str, Dict[str, str]]
@@ -115,34 +172,31 @@ class MapInstanceValues(StreamInstanceOperator):
115
  raise ValueError(
116
  f"'process_every_field' == True is allowed only when all fields which have mappers, i.e., {list(self.mappers.keys())} are lists. Instace = {instance}"
117
  )
118
- if isinstance(value, list):
119
- if self.process_every_value:
120
- for i, val in enumerate(value):
121
- val = str(val) # make sure the value is a string
122
- if self.strict and (val not in mapper):
123
- raise KeyError(
124
- f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
125
- )
126
- if val in mapper:
127
- # replace just that member of value (value is a list)
128
- value[i] = mapper[val]
129
- dict_set(instance, key, value, use_dpath=self.use_query)
130
- else: # field is a list, and process_every_value == False
131
- if self.strict: # whole lists can not be mapped by a string-to-something mapper
132
- raise KeyError(
133
- f"A whole list ({value}) in the instance can not be mapped by a field mapper."
134
- )
135
- else: # value is not a list, implying process_every_value == False
136
- value = str(value) # make sure the value is a string
137
- if self.strict and (value not in mapper):
138
- raise KeyError(
139
- f"value '{value}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
140
- )
141
- if value in mapper:
142
- dict_set(instance, key, mapper[value], use_dpath=self.use_query)
143
 
144
  return instance
145
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  class FlattenInstances(StreamInstanceOperator):
148
  """Flattens each instance in a stream, making nested dictionary entries into top-level entries.
@@ -182,6 +236,7 @@ class AddFields(StreamInstanceOperator):
182
  # Add a 'classes' field on a given list, prevent modification of original list
183
  # from changing the instance.
184
  AddFields(fields={"classes": alist}), use_deepcopy=True)
 
185
  """
186
 
187
  fields: Dict[str, object]
@@ -204,7 +259,7 @@ class AddFields(StreamInstanceOperator):
204
 
205
 
206
  class RemoveFields(StreamInstanceOperator):
207
- """Remove specified fields to each instance in a stream.
208
 
209
  Args:
210
  fields (List[str]): The fields to remove from each instance.
@@ -221,19 +276,32 @@ class RemoveFields(StreamInstanceOperator):
221
 
222
 
223
  class FieldOperator(StreamInstanceOperator):
224
- """A general stream that processes the values of a field (or multiple ones.
225
 
226
  Args:
227
- field (Optional[str]): The field to process, if only a single one is passed Defaults to None
228
- to_field (Optional[str]): Field name to save, if only one field is to be saved, if None is passed the operator would happen in-place and replace "field" Defaults to None
229
- field_to_field (Optional[Union[List[Tuple[str, str]], Dict[str, str]]]): Mapping from fields to process to their names after this process, duplicates are allowed. Defaults to None
 
 
 
 
 
 
 
 
 
 
230
  process_every_value (bool): Processes the values in a list instead of the list as a value, similar to *var. Defaults to False
231
  use_query (bool): Whether to use dpath style queries. Defaults to False.
 
 
 
232
  """
233
 
234
  field: Optional[str] = None
235
  to_field: Optional[str] = None
236
- field_to_field: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None
237
  process_every_value: bool = False
238
  use_query: bool = False
239
  get_default: Any = None
@@ -250,25 +318,67 @@ class FieldOperator(StreamInstanceOperator):
250
  ), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
251
  assert (
252
  self.field is None or self.field_to_field is None
253
- ), f"Can not apply operator both on {self.field} and on the mapping from fields to fields {self.field_to_field}"
 
254
  assert (
255
- self._field_to_field
256
- ), f"the from and to fields must be defined got: {self._field_to_field}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  @abstractmethod
259
  def process_value(self, value: Any) -> Any:
260
  pass
261
 
262
  def prepare(self):
263
- if self.to_field is None:
264
- self.to_field = self.field
 
 
 
 
 
 
 
 
265
  if self.field_to_field is None:
266
- self._field_to_field = [(self.field, self.to_field)]
 
 
267
  else:
268
- try:
269
- self._field_to_field = list(self.field_to_field.items())
270
- except AttributeError:
271
- self._field_to_field = self.field_to_field
 
272
 
273
  def process(
274
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
@@ -295,7 +405,7 @@ class FieldOperator(StreamInstanceOperator):
295
  raise ValueError(
296
  f"Failed to process '{from_field}' from {instance} due to : {e}"
297
  ) from e
298
- if self.use_query and is_subpath(from_field, to_field):
299
  dict_delete(instance, from_field)
300
  dict_set(
301
  instance,
@@ -308,7 +418,25 @@ class FieldOperator(StreamInstanceOperator):
308
 
309
 
310
  class RenameFields(FieldOperator):
311
- """Renames fields."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  def process_value(self, value: Any) -> Any:
314
  return value
@@ -317,20 +445,31 @@ class RenameFields(FieldOperator):
317
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
318
  ) -> Dict[str, Any]:
319
  res = super().process(instance=instance, stream_name=stream_name)
320
- vals = [x[1] for x in self._field_to_field]
321
- for key, _ in self._field_to_field:
322
- if self.use_query and "/" in key:
323
- continue
324
- if key not in vals:
325
- res.pop(key)
 
 
 
 
 
 
 
 
 
 
 
326
  return res
327
 
328
 
329
  class AddConstant(FieldOperator):
330
- """Adds a value, similar to add + field.
331
 
332
  Args:
333
- add: sum to add.
334
  """
335
 
336
  add: Any
@@ -396,19 +535,15 @@ class Augmentor(StreamInstanceOperator):
396
  default="",
397
  not_exist_ok=False,
398
  )
399
- except TypeError as e:
400
  raise TypeError(f"Failed to get {field_name} from {instance}") from e
401
 
402
- # We are setting a nested seed based on the value processed, to ensure that
403
- # the augmentation randomizations do not effect other randomization choices and
404
- # to make the augmentation randomization choices different for each text.
405
- with nested_seed(str(hash(old_value))):
406
- try:
407
- new_value = self.process_value(old_value)
408
- except Exception as e:
409
- raise RuntimeError(
410
- f"Error augmenting value '{old_value}' from '{field_name}' in instance: {instance}"
411
- ) from e
412
  dict_set(instance, field_name, new_value, use_dpath=True, not_exist_ok=True)
413
  return instance
414
 
@@ -433,90 +568,146 @@ class AugmentWhitespace(Augmentor):
433
  words = re.split(r"(\s+)", value)
434
  new_value = ""
435
 
 
436
  for word in words:
437
  if word.isspace():
438
- new_value += get_random().choice(
439
  ["\n", "\t", " "]
440
- ) * get_random().randint(1, 3)
441
  else:
442
  new_value += word
443
  return new_value
444
 
445
 
446
- class AugmentSuffix(Augmentor):
447
- r"""Augments the input by appending to it a randomly selected (typically, whitespace) pattern.
448
 
449
  Args:
450
- suffixes : the potential (typically, whitespace) patterns to select from.
451
  The dictionary version allows to specify relative weights of the different patterns.
452
- remove_existing_trailing_whitespaces : allows to first clean existing trailing whitespaces.
453
- The selected pattern is then appended to the potentially trimmed at its end input.
454
-
 
 
 
455
 
456
  Examples:
457
- to append a '\n' or a '\t' to the end of the input, employ
458
- AugmentSuffix(augment_model_input=True, suffixes=['\n','\t'])
459
- If '\n' is preferred over '\t', at 2:1 ratio, employ
460
- AugmentSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1})
461
- which will append '\n' twice as often as '\t'.
 
462
 
463
  """
464
 
465
- suffixes: Optional[Union[List[str], Dict[str, int]]] = [" ", "\n", "\t"]
466
- remove_existing_trailing_whitespaces: Optional[bool] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
  def verify(self):
469
  assert (
470
- isinstance(self.suffixes, list) or isinstance(self.suffixes, dict)
471
- ), f"Argument 'suffixes' should be either a list or a dictionary, whereas it is of type {type(self.suffixes)}"
472
-
473
- if isinstance(self.suffixes, dict):
474
- for k, v in self.suffixes.items():
475
- assert isinstance(
476
- k, str
477
- ), f"suffixes should map strings, whereas key {k!s} is of type {type(k)}"
478
- assert isinstance(
479
- v, int
480
- ), f"suffixes should map to ints, whereas value {v!s} is of type {type(v)}"
481
- else:
482
- for k in self.suffixes:
483
- assert isinstance(
484
- k, str
485
- ), f"suffixes should be a list of strings, whereas member {k!s} is of type {type(k)}"
486
 
487
- self.pats = (
488
- self.suffixes
489
- if isinstance(self.suffixes, list)
490
- else [k for k, v in self.suffixes.items()]
 
 
 
491
  )
492
  total_weight = (
493
- len(self.pats)
494
- if isinstance(self.suffixes, list)
495
- else sum([v for k, v in self.suffixes.items()])
496
  )
497
- self.weights = (
498
- [1.0 / total_weight] * len(self.pats)
499
- if isinstance(self.suffixes, list)
500
- else [float(self.suffixes[p]) / total_weight for p in self.pats]
501
  )
502
- super().verify()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
  def process_value(self, value: Any) -> Any:
505
  assert value is not None, "input value should not be None"
506
  new_value = str(value)
507
- if self.remove_existing_trailing_whitespaces:
508
- new_value = new_value.rstrip()
509
- new_value += get_random().choices(self.pats, self.weights, k=1)[0]
510
-
511
- return new_value
 
 
 
 
 
512
 
513
 
514
  class ShuffleFieldValues(FieldOperator):
515
- """Shuffles an iterable value."""
516
 
517
  def process_value(self, value: Any) -> Any:
518
  res = list(value)
519
- get_random().shuffle(res)
 
520
  return res
521
 
522
 
@@ -621,9 +812,18 @@ class ListFieldValues(StreamInstanceOperator):
621
 
622
 
623
  class ZipFieldValues(StreamInstanceOperator):
624
- """Zips values of multiple fields similar to list(zip(*fields))."""
 
 
 
625
 
626
- fields: str
 
 
 
 
 
 
627
  to_field: str
628
  longest: bool = False
629
  use_query: bool = False
@@ -643,7 +843,7 @@ class ZipFieldValues(StreamInstanceOperator):
643
 
644
 
645
  class IndexOf(StreamInstanceOperator):
646
- """Finds the location of one value in another (iterable) value similar to to_field=search_in.index(index_of)."""
647
 
648
  search_in: str
649
  index_of: str
@@ -660,7 +860,7 @@ class IndexOf(StreamInstanceOperator):
660
 
661
 
662
  class TakeByField(StreamInstanceOperator):
663
- """Takes value from one field based on another field similar to field[index]."""
664
 
665
  field: str
666
  index: str
@@ -681,11 +881,24 @@ class TakeByField(StreamInstanceOperator):
681
 
682
 
683
  class CopyFields(FieldOperator):
684
- """Copies specified fields from one field to another.
685
 
686
- Args:
687
  field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
688
- use_dpath (bool): Whether to use dpath for accessing fields. Defaults to False.
 
 
 
 
 
 
 
 
 
 
 
 
 
689
  """
690
 
691
  def process_value(self, value: Any) -> Any:
@@ -693,6 +906,8 @@ class CopyFields(FieldOperator):
693
 
694
 
695
  class AddID(StreamInstanceOperator):
 
 
696
  id_field_name: str = "id"
697
 
698
  def process(
@@ -706,22 +921,31 @@ class CastFields(StreamInstanceOperator):
706
  """Casts specified fields to specified types.
707
 
708
  Args:
709
- types (Dict[str, str]): A dictionary mapping fields to their new types.
710
- nested (bool): Whether to cast nested fields. Defaults to False.
711
- fields (Dict[str, str]): A dictionary mapping fields to their new types.
712
- defaults (Dict[str, object]): A dictionary mapping types to their default values for cases of casting failure.
 
 
 
 
 
 
 
 
 
 
 
 
713
  """
714
 
715
- types = {
716
- "int": int,
717
- "float": float,
718
- "str": str,
719
- "bool": bool,
720
- }
721
  fields: Dict[str, str] = field(default_factory=dict)
722
  failure_defaults: Dict[str, object] = field(default_factory=dict)
723
  use_nested_query: bool = False
724
- cast_multiple: bool = False
 
 
 
725
 
726
  def _cast_single(self, value, type, field):
727
  try:
@@ -734,14 +958,17 @@ class CastFields(StreamInstanceOperator):
734
  return self.failure_defaults[field]
735
 
736
  def _cast_multiple(self, values, type, field):
737
- values = [self._cast_single(value, type, field) for value in values]
738
 
739
  def process(
740
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
741
  ) -> Dict[str, Any]:
742
  for field_name, type in self.fields.items():
743
  value = dict_get(instance, field_name, use_dpath=self.use_nested_query)
744
- if self.cast_multiple:
 
 
 
745
  casted_value = self._cast_multiple(value, type, field_name)
746
  else:
747
  casted_value = self._cast_single(value, type, field_name)
@@ -751,29 +978,46 @@ class CastFields(StreamInstanceOperator):
751
  return instance
752
 
753
 
754
- def recursive_divide(instance, divisor, strict=False):
755
- if isinstance(instance, dict):
756
- for key, value in instance.items():
757
- instance[key] = recursive_divide(value, divisor, strict=strict)
758
- elif isinstance(instance, list):
759
- for i, value in enumerate(instance):
760
- instance[i] = recursive_divide(value, divisor, strict=strict)
761
- elif isinstance(instance, float):
762
- instance /= divisor
763
- elif strict:
764
- raise ValueError(f"Cannot divide instance of type {type(instance)}")
765
- return instance
766
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767
 
768
- class DivideAllFieldsBy(StreamInstanceOperator):
769
  divisor: float = 1.0
770
  strict: bool = False
771
- recursive: bool = True
 
 
 
 
 
 
 
 
 
 
 
 
772
 
773
  def process(
774
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
775
  ) -> Dict[str, Any]:
776
- return recursive_divide(instance, self.divisor, strict=self.strict)
777
 
778
 
779
  class ArtifactFetcherMixin:
@@ -797,13 +1041,21 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
797
  """Applies value operators to each instance in a stream based on specified fields.
798
 
799
  Args:
800
- value_field (str): The field containing the value to be operated on.
801
- operators_field (str): The field containing the operators to be applied.
 
 
 
802
  default_operators (List[str]): A list of default operators to be used if no operators are found in the instance.
803
- """
804
 
805
- inputs_fields: str
 
 
 
806
 
 
 
 
807
  operators_field: str
808
  default_operators: List[str] = None
809
  fields_to_treat_as_list: List[str] = NonPositionalField(default_factory=list)
@@ -815,7 +1067,7 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
815
  if operator_names is None:
816
  assert (
817
  self.default_operators is not None
818
- ), f"No operators found in {self.field} field and no default operators provided"
819
  operator_names = self.default_operators
820
 
821
  if isinstance(operator_names, str):
@@ -828,35 +1080,155 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
828
  if field_name in self.fields_to_treat_as_list:
829
  instance[field_name] = [operator.process(v) for v in value]
830
  else:
831
- instance[field_name] = operator.process(instance[field_name])
832
 
833
  return instance
834
 
835
 
836
- class FilterByValues(SingleStreamOperator):
837
- """Filters a stream, yielding only instances that match specified values in the provided fields.
 
 
838
 
839
  Args:
840
- values (Dict[str, Any]): For each field, the values that instances should match to be included in the output.
 
 
 
 
 
 
 
 
 
841
  """
842
 
843
- required_values: Dict[str, Any]
 
 
 
 
 
 
 
 
 
 
 
 
844
 
845
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
 
846
  for instance in stream:
847
- filter = False
848
- for key, value in self.required_values.items():
849
- if key not in instance:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850
  raise ValueError(
851
- f"Required filter field ('{key}') in FilterByValues is not found in {instance}"
852
  )
853
- if instance[key] != value:
854
- filter = True
855
- if not filter:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
856
  yield instance
857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858
 
859
- class ExtractFieldValues(MultiStreamOperator):
860
  field: str
861
  stream_name: str
862
  overall_top_frequency_percent: Optional[int] = 100
@@ -877,21 +1249,21 @@ class ExtractFieldValues(MultiStreamOperator):
877
 
878
  Examples:
879
 
880
- ExtractFieldValues(stream_name="train", field="label", to_field="classes") - extracts all the unique values of
881
  field 'label', sorts them by decreasing frequency, and stores the resulting list in field 'classes' of each and
882
  every instance in all streams.
883
 
884
- ExtractFieldValues(stream_name="train", field="labels", to_field="classes", process_every_value=True) -
885
  in case that field 'labels' contains a list of values (and not a single value) - track the occurrences of all the possible
886
  value members in these lists, and report the most frequent values.
887
  if process_every_value=False, track the most frequent whole lists, and report those (as a list of lists) in field
888
  'to_field' of each instance of all streams.
889
 
890
- ExtractFieldValues(stream_name="train", field="label", to_field="classes",overall_top_frequency_percent=80) -
891
  extracts the most frequent possible values of field 'label' that together cover at least 80% of the instances of stream_name,
892
  and stores them in field 'classes' of each instance of all streams.
893
 
894
- ExtractFieldValues(stream_name="train", field="label", to_field="classes",min_frequency_percent=5) -
895
  extracts all possible values of field 'label' that cover, each, at least 5% of the instances.
896
  Stores these values, sorted by decreasing order of frequency, in field 'classes' of each instance in all streams.
897
  """
@@ -952,41 +1324,18 @@ class ExtractFieldValues(MultiStreamOperator):
952
  [*ele[0]] if isinstance(ele[0], tuple) else ele[0]
953
  for ele in values_and_counts
954
  ]
955
- for name in multi_stream:
956
- for instance in multi_stream[name]:
957
- instance[self.to_field] = values_to_keep
958
- return multi_stream
959
 
 
 
960
 
961
- class FilterByListsOfValues(SingleStreamOperator):
962
- """Filters a stream, yielding only instances that whose field values are included in the specified value lists.
963
-
964
- Args:
965
- required_values (Dict[str, List]): For each field, the list of values that instances should match to be included in the output.
966
- """
967
-
968
- required_values: Dict[str, List]
969
 
 
970
  def verify(self):
971
  super().verify()
972
- for key, value in self.required_values.items():
973
- if not isinstance(value, list):
974
- raise ValueError(
975
- f"The filter for key ('{key}') in FilterByListsOfValues is not a list but '{value}'"
976
- )
977
 
978
- def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
979
- for instance in stream:
980
- filter = False
981
- for key, value in self.required_values.items():
982
- if key not in instance:
983
- raise ValueError(
984
- f"Required filter field ('{key}') in FilterByListsOfValues is not found in {instance}"
985
- )
986
- if instance[key] not in value:
987
- filter = True
988
- if not filter:
989
- yield instance
990
 
991
 
992
  class Intersect(FieldOperator):
@@ -1011,6 +1360,7 @@ class Intersect(FieldOperator):
1011
  )
1012
 
1013
  def process_value(self, value: Any) -> Any:
 
1014
  if not isinstance(value, list):
1015
  raise ValueError(f"The value in field is not a list but '{value}'")
1016
  return [e for e in value if e in self.allowed_values]
@@ -1020,7 +1370,7 @@ class RemoveValues(FieldOperator):
1020
  """Removes elements in a field, which must be a list, using a given list of unallowed.
1021
 
1022
  Args:
1023
- unallowed_values (list) - removed_values.
1024
  """
1025
 
1026
  unallowed_values: List[Any]
@@ -1089,8 +1439,8 @@ class SplitByValue(MultiStreamOperator):
1089
  stream_unique_values = uniques[stream_name]
1090
  for unique_values in stream_unique_values:
1091
  filtering_values = dict(zip(self.fields, unique_values))
1092
- filtered_streams = FilterByValues(
1093
- required_values=filtering_values
1094
  )._process_single_stream(stream)
1095
  filtered_stream_name = (
1096
  stream_name + "_" + nested_tuple_to_string(unique_values)
@@ -1112,7 +1462,7 @@ class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
1112
  reversed: bool = False
1113
 
1114
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1115
- first_instance = stream.peak()
1116
 
1117
  operators = first_instance.get(self.field, [])
1118
  if isinstance(operators, str):
@@ -1146,7 +1496,7 @@ class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
1146
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1147
  from .metrics import Metric, MetricPipeline, MetricWithConfidenceInterval
1148
 
1149
- first_instance = stream.peak()
1150
 
1151
  metric_names = first_instance.get(self.metric_field, [])
1152
  if not metric_names:
@@ -1182,27 +1532,6 @@ class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
1182
  yield from stream
1183
 
1184
 
1185
- class AddFieldNamePrefix(StreamInstanceOperator):
1186
- """Adds a prefix to each field name in each instance of a stream.
1187
-
1188
- Args:
1189
- prefix_dict (Dict[str, str]): A dictionary mapping stream names to prefixes.
1190
- """
1191
-
1192
- prefix_dict: Dict[str, str]
1193
-
1194
- def prepare(self):
1195
- return super().prepare()
1196
-
1197
- def process(
1198
- self, instance: Dict[str, Any], stream_name: Optional[str] = None
1199
- ) -> Dict[str, Any]:
1200
- return {
1201
- self.prefix_dict[stream_name] + key: value
1202
- for key, value in instance.items()
1203
- }
1204
-
1205
-
1206
  class MergeStreams(MultiStreamOperator):
1207
  """Merges multiple streams into a single stream.
1208
 
@@ -1238,20 +1567,39 @@ class MergeStreams(MultiStreamOperator):
1238
  class Shuffle(PagedStreamOperator):
1239
  """Shuffles the order of instances in each page of a stream.
1240
 
1241
- Args:
1242
  page_size (int): The size of each page in the stream. Defaults to 1000.
1243
  """
1244
 
 
 
 
 
 
 
1245
  def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1246
- get_random().shuffle(page)
1247
  yield from page
1248
 
1249
 
1250
  class EncodeLabels(StreamInstanceOperator):
1251
- """Encode labels of specified fields together a into integers.
 
 
 
1252
 
1253
  Args:
1254
  fields (List[str]): The fields to encode together.
 
 
 
 
 
 
 
 
 
 
1255
  """
1256
 
1257
  fields: List[str]
@@ -1279,7 +1627,23 @@ class EncodeLabels(StreamInstanceOperator):
1279
 
1280
 
1281
  class StreamRefiner(SingleStreamOperator):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1282
  max_instances: int = None
 
1283
 
1284
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1285
  if self.max_instances is not None:
@@ -1291,13 +1655,23 @@ class StreamRefiner(SingleStreamOperator):
1291
  class DeterministicBalancer(StreamRefiner):
1292
  """A class used to balance streams deterministically.
1293
 
 
 
 
 
 
1294
  Attributes:
1295
- fields (List[str]): A list of field names to be used in determining the signature of an instance.
1296
- streams (List[str]): A list of stream names to be processed by the balancer.
1297
 
1298
  Usage:
1299
- balancer = DeterministicBalancer(fields=["field1", "field2"], streams=["stream1", "stream2"])
1300
  balanced_stream = balancer.process(stream)
 
 
 
 
 
1301
  """
1302
 
1303
  fields: List[str]
@@ -1334,7 +1708,23 @@ class DeterministicBalancer(StreamRefiner):
1334
 
1335
 
1336
  class LengthBalancer(DeterministicBalancer):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1337
  segments_boundaries: List[int]
 
1338
 
1339
  def signature(self, instance):
1340
  total_len = 0
 
1
+ """This section describes unitxt operators.
2
+
3
+ Operators: Building Blocks of Unitxt Processing Pipelines
4
+ ==============================================================
5
+
6
+ Within the Unitxt framework, operators serve as the foundational elements used to assemble processing pipelines.
7
+ Each operator is designed to perform specific manipulations on dictionary structures within a stream.
8
+ These operators are callable entities that receive a MultiStream as input.
9
+ The output is a MultiStream, augmented with the operator's manipulations, which are then systematically applied to each instance in the stream when pulled.
10
+
11
+ Creating Custom Operators
12
+ -------------------------------
13
+ To enhance the functionality of Unitxt, users are encouraged to develop custom operators.
14
+ This can be achieved by inheriting from any of the existing operators listed below or from one of the fundamental :class:`base operators<unitxt.operator>`.
15
+ The primary task in any operator development is to implement the `process` function, which defines the unique manipulations the operator will perform.
16
+
17
+ General or Specelized Operators
18
+ --------------------------------
19
+ Some operators are specielized in specific task such as:
20
+
21
+ - :class:`loaders<unitxt.loaders>` for loading data.
22
+ - :class:`splitters<unitxt.splitters>` for fixing data splits.
23
+
24
+ Other specelized operators are used by unitxt internally:
25
+
26
+ - :class:`templates<unitxt.templates>` for verbalizing data examples.
27
+ - :class:`formats<unitxt.formats>` for preparing data for models.
28
+
29
+ The rest of this section is dedicated for general operators.
30
+
31
+ General Operaotrs List:
32
+ ------------------------
33
+ """
34
  import collections
35
  import importlib
36
+ import operator
37
+ import os
38
  import uuid
39
  from abc import abstractmethod
40
  from collections import Counter
41
  from copy import deepcopy
42
  from dataclasses import field
43
  from itertools import zip_longest
44
+ from random import Random
45
  from typing import (
46
  Any,
47
  Callable,
 
68
  StreamInstanceOperator,
69
  StreamSource,
70
  )
71
+ from .random_utils import new_random_generator
72
  from .stream import Stream
73
  from .text_utils import nested_tuple_to_string
74
+ from .type_utils import isoftype
75
  from .utils import flatten_dict
76
 
77
 
78
  class FromIterables(StreamInitializerOperator):
79
+ """Creates a MultiStream from a dict of named iterables.
80
+
81
+ Example:
82
+ operator = FromIterables()
83
+ ms = operator.process(iterables)
84
 
 
 
85
  """
86
 
87
  def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
 
89
 
90
 
91
  class IterableSource(StreamSource):
92
+ """Creates a MultiStream from a dict of named iterables.
93
+
94
+ It is a callable.
95
+
96
+ Args:
97
+ iterables (Dict[str, Iterable]): A dictionary mapping stream names to iterables.
98
+
99
+ Example:
100
+ operator = IterableSource(input_dict)
101
+ ms = operator()
102
+
103
+ """
104
+
105
  iterables: Dict[str, Iterable]
106
 
107
  def __call__(self) -> MultiStream:
 
109
 
110
 
111
  class MapInstanceValues(StreamInstanceOperator):
112
+ """A class used to map instance values into other values.
113
 
114
  This class is a type of StreamInstanceOperator,
115
  it maps values of instances in a stream using predefined mappers.
 
139
  To ensure that all values of field 'a' are mapped in every instance, use strict=True.
140
  Input instance {"a":"3", "b": 2} will raise an exception per the above call,
141
  because "3" is not a key in the mapper of "a".
142
+
143
+ MapInstanceValues(mappers={"a": {str([1,2,3,4]): 'All', str([]): 'None'}}, strict=True)
144
+ replaces a list [1,2,3,4] with the string 'All' and an empty list by string 'None'.
145
+ Note that mapped values are defined by their string representation, so mapped values
146
+ must be converted to strings.
147
  """
148
 
149
  mappers: Dict[str, Dict[str, str]]
 
172
  raise ValueError(
173
  f"'process_every_field' == True is allowed only when all fields which have mappers, i.e., {list(self.mappers.keys())} are lists. Instace = {instance}"
174
  )
175
+ if isinstance(value, list) and self.process_every_value:
176
+ for i, val in enumerate(value):
177
+ value[i] = self.get_mapped_value(instance, key, mapper, val)
178
+ else:
179
+ value = self.get_mapped_value(instance, key, mapper, value)
180
+ dict_set(
181
+ instance,
182
+ key,
183
+ value,
184
+ use_dpath=self.use_query,
185
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  return instance
188
 
189
+ def get_mapped_value(self, instance, key, mapper, val):
190
+ val_as_str = str(val) # make sure the value is a string
191
+ if self.strict and (val_as_str not in mapper):
192
+ raise KeyError(
193
+ f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
194
+ )
195
+ # By default deep copy the value in mapper to avoid shared modifications
196
+ if val_as_str in mapper:
197
+ return deepcopy(mapper[val_as_str])
198
+ return val
199
+
200
 
201
  class FlattenInstances(StreamInstanceOperator):
202
  """Flattens each instance in a stream, making nested dictionary entries into top-level entries.
 
236
  # Add a 'classes' field on a given list, prevent modification of original list
237
  # from changing the instance.
238
  AddFields(fields={"classes": alist}), use_deepcopy=True)
239
+ # if now alist is modified, still the instances remain intact.
240
  """
241
 
242
  fields: Dict[str, object]
 
259
 
260
 
261
  class RemoveFields(StreamInstanceOperator):
262
+ """Remove specified fields from each instance in a stream.
263
 
264
  Args:
265
  fields (List[str]): The fields to remove from each instance.
 
276
 
277
 
278
  class FieldOperator(StreamInstanceOperator):
279
+ """A general stream instance operator that processes the values of a field (or multiple ones).
280
 
281
  Args:
282
+ field (Optional[str]): The field to process, if only a single one is passed. Defaults to None
283
+ to_field (Optional[str]): Field name to save result into, if only one field is processed, if None is passed the
284
+ operation would happen in-place and its result would replace the value of "field". Defaults to None
285
+ field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]): Mapping from names of fields to process,
286
+ to names of fields to save the results into. Inner List, if used, should be of length 2.
287
+ A field is processed by feeding its value into method 'process_value' and storing the result in to_field that
288
+ is mapped to the field.
289
+ When the type of argument 'field_to_field' is List, the order by which the fields are processed is their order
290
+ in the (outer) List. But when the type of argument 'field_to_field' is Dict, there is no uniquely determined
291
+ order. The end result might depend on that order if either (1) two different fields are mapped to the same
292
+ to_field, or (2) a field shows both as a key and as a value in different mappings.
293
+ The operator throws an AssertionError in either of these cases.
294
+ field_to_field defaults to None
295
  process_every_value (bool): Processes the values in a list instead of the list as a value, similar to *var. Defaults to False
296
  use_query (bool): Whether to use dpath style queries. Defaults to False.
297
+
298
+ Note: if 'field' and 'to_field' (or both members of a pair in 'field_to_field') are equal (or share a common
299
+ prefix if 'use_query'=True), then the result of the operation is saved within 'field'
300
  """
301
 
302
  field: Optional[str] = None
303
  to_field: Optional[str] = None
304
+ field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
305
  process_every_value: bool = False
306
  use_query: bool = False
307
  get_default: Any = None
 
318
  ), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
319
  assert (
320
  self.field is None or self.field_to_field is None
321
+ ), f"Can not apply operator both on {self.field} and on the from fields in the mapping {self.field_to_field}"
322
+ assert self._field_to_field, f"the from and to fields must be defined or implied from the other inputs got: {self._field_to_field}"
323
  assert (
324
+ len(self._field_to_field) > 0
325
+ ), f"'input argument 'field_to_field' should convey at least one field to process. Got {self.field_to_field}"
326
+ # self._field_to_field is built explicitly by pairs, or copied from argument 'field_to_field'
327
+ if self.field_to_field is None:
328
+ return
329
+ # for backward compatibility also allow list of tupples of two strings
330
+ if isoftype(self.field_to_field, List[List[str]]) or isoftype(
331
+ self.field_to_field, List[Tuple[str, str]]
332
+ ):
333
+ for pair in self._field_to_field:
334
+ assert (
335
+ len(pair) == 2
336
+ ), f"when 'field_to_field' is defined as a list of lists, the inner lists should all be of length 2. {self.field_to_field}"
337
+ # order of field processing is uniquely determined by the input field_to_field when a list
338
+ return
339
+ if isoftype(self.field_to_field, Dict[str, str]):
340
+ if len(self.field_to_field) < 2:
341
+ return
342
+ for ff, tt in self.field_to_field.items():
343
+ for f, t in self.field_to_field.items():
344
+ if f == ff:
345
+ continue
346
+ assert (
347
+ t != ff
348
+ ), f"In input argument 'field_to_field': {self.field_to_field}, field {f} is mapped to field {t}, while the latter is mapped to {tt}. Whether {f} or {t} is processed first might impact end result."
349
+ assert (
350
+ tt != t
351
+ ), f"In input argument 'field_to_field': {self.field_to_field}, two different fields: {ff} and {f} are mapped to field {tt}. Whether {ff} or {f} is processed last might impact end result."
352
+ return
353
+ raise ValueError(
354
+ "Input argument 'field_to_field': {self.field_to_field} is neither of type List{List[str]] nor of type Dict[str, str]."
355
+ )
356
 
357
  @abstractmethod
358
  def process_value(self, value: Any) -> Any:
359
  pass
360
 
361
  def prepare(self):
362
+ super().prepare()
363
+
364
+ # prepare is invoked before verify, hence must make some checks here, before the changes done here
365
+ assert (
366
+ (self.field is None) != (self.field_to_field is None)
367
+ ), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
368
+ assert (
369
+ self.to_field is None or self.field_to_field is None
370
+ ), f"Can not apply operator to create both {self.to_field} and the to fields in the mapping {self.field_to_field}"
371
+
372
  if self.field_to_field is None:
373
+ self._field_to_field = [
374
+ (self.field, self.to_field if self.to_field is not None else self.field)
375
+ ]
376
  else:
377
+ self._field_to_field = (
378
+ list(self.field_to_field.items())
379
+ if isinstance(self.field_to_field, dict)
380
+ else self.field_to_field
381
+ )
382
 
383
  def process(
384
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
 
405
  raise ValueError(
406
  f"Failed to process '{from_field}' from {instance} due to : {e}"
407
  ) from e
408
+ if is_subpath(from_field, to_field) or is_subpath(to_field, from_field):
409
  dict_delete(instance, from_field)
410
  dict_set(
411
  instance,
 
418
 
419
 
420
  class RenameFields(FieldOperator):
421
+ """Renames fields.
422
+
423
+ Move value from one field to another, potentially, if 'use_query'=True, from one branch into another.
424
+ Remove the from field, potentially part of it in case of use_query.
425
+
426
+ Examples:
427
+ RenameFields(field_to_field={"b": "c"})
428
+ will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
429
+
430
+ RenameFields(field_to_field={"b": "c/d"}, use_query=True)
431
+ will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
432
+
433
+ RenameFields(field_to_field={"b": "b/d"}, use_query=True)
434
+ will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
435
+
436
+ RenameFields(field_to_field={"b/c/e": "b/d"}, use_query=True)
437
+ will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
438
+
439
+ """
440
 
441
  def process_value(self, value: Any) -> Any:
442
  return value
 
445
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
446
  ) -> Dict[str, Any]:
447
  res = super().process(instance=instance, stream_name=stream_name)
448
+ for from_field, to_field in self._field_to_field:
449
+ if (not is_subpath(from_field, to_field)) and (
450
+ not is_subpath(to_field, from_field)
451
+ ):
452
+ dict_delete(res, from_field)
453
+ if self.use_query:
454
+ from_field_components = list(
455
+ os.path.normpath(from_field).split(os.path.sep)
456
+ )
457
+ while len(from_field_components) > 1:
458
+ from_field_components.pop()
459
+ parent = dict_get(res, os.path.sep.join(from_field_components))
460
+ if isinstance(parent, dict) and not parent:
461
+ dict_delete(res, os.path.sep.join(from_field_components))
462
+ else:
463
+ break
464
+
465
  return res
466
 
467
 
468
  class AddConstant(FieldOperator):
469
+ """Adds a constant, being argument 'add', to the processed value.
470
 
471
  Args:
472
+ add: the constant to add.
473
  """
474
 
475
  add: Any
 
535
  default="",
536
  not_exist_ok=False,
537
  )
538
+ except ValueError as e:
539
  raise TypeError(f"Failed to get {field_name} from {instance}") from e
540
 
541
+ try:
542
+ new_value = self.process_value(old_value)
543
+ except Exception as e:
544
+ raise RuntimeError(
545
+ f"Error augmenting value '{old_value}' from '{field_name}' in instance: {instance}"
546
+ ) from e
 
 
 
 
547
  dict_set(instance, field_name, new_value, use_dpath=True, not_exist_ok=True)
548
  return instance
549
 
 
568
  words = re.split(r"(\s+)", value)
569
  new_value = ""
570
 
571
+ random_generator = new_random_generator(sub_seed=value)
572
  for word in words:
573
  if word.isspace():
574
+ new_value += random_generator.choice(
575
  ["\n", "\t", " "]
576
+ ) * random_generator.randint(1, 3)
577
  else:
578
  new_value += word
579
  return new_value
580
 
581
 
582
+ class AugmentPrefixSuffix(Augmentor):
583
+ r"""Augments the input by prepending and appending to it a randomly selected (typically, whitespace) patterns.
584
 
585
  Args:
586
+ prefixes, suffixes (list or dict) : the potential (typically, whitespace) patterns to select from.
587
  The dictionary version allows to specify relative weights of the different patterns.
588
+ prefix_len, suffix_len (positive int) : The added prefix or suffix will be of length
589
+ prefix_len of suffix_len, respectively, repetitions of the randomly selected patterns.
590
+ remove_existing_whitespaces : allows to first clean any existing leading and trailing whitespaces.
591
+ The strings made of repetitions of the selected pattern(s) are then prepended and/or appended to the potentially
592
+ trimmed input.
593
+ If only one of prefixes/suffixes is needed, set the other to None.
594
 
595
  Examples:
596
+ To prepend the input with a prefix made of 4 '\n'-s or '\t'-s, employ
597
+ AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)
598
+ To append the input with a suffix made of 3 '\n'-s or '\t'-s, with triple '\n' suffixes
599
+ being preferred over triple '\t', at 2:1 ratio, employ
600
+ AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)
601
+ which will append '\n'-s twice as often as '\t'-s.
602
 
603
  """
604
 
605
+ prefixes: Optional[Union[List[str], Dict[str, int]]] = {
606
+ " ": 20,
607
+ "\\t": 10,
608
+ "\\n": 40,
609
+ "": 30,
610
+ }
611
+ prefix_len: Optional[int] = 3
612
+ suffixes: Optional[Union[List[str], Dict[str, int]]] = {
613
+ " ": 20,
614
+ "\\t": 10,
615
+ "\\n": 40,
616
+ "": 30,
617
+ }
618
+ suffix_len: Optional[int] = 3
619
+ remove_existing_whitespaces: Optional[bool] = False
620
 
621
  def verify(self):
622
  assert (
623
+ self.prefixes or self.suffixes
624
+ ), "At least one of prefixes/suffixes should be not None."
625
+ for arg, arg_name in zip(
626
+ [self.prefixes, self.suffixes], ["prefixes", "suffixes"]
627
+ ):
628
+ assert (
629
+ arg is None or isoftype(arg, List[str]) or isoftype(arg, Dict[str, int])
630
+ ), f"Argument {arg_name} should be either None or a list of strings or a dictionary str->int. {arg} is none of the above."
631
+ assert (
632
+ self.prefix_len > 0
633
+ ), f"prefix_len must be positive, got {self.prefix_len}"
634
+ assert (
635
+ self.suffix_len > 0
636
+ ), f"suffix_len must be positive, got {self.suffix_len}"
637
+ super().verify()
 
638
 
639
+ def _calculate_distributions(self, prefs_or_suffs):
640
+ if prefs_or_suffs is None:
641
+ return None, None
642
+ patterns = (
643
+ prefs_or_suffs
644
+ if isinstance(prefs_or_suffs, list)
645
+ else [k for k, v in prefs_or_suffs.items()]
646
  )
647
  total_weight = (
648
+ len(patterns)
649
+ if isinstance(prefs_or_suffs, list)
650
+ else sum([v for k, v in prefs_or_suffs.items()])
651
  )
652
+ weights = (
653
+ [1.0 / total_weight] * len(patterns)
654
+ if isinstance(prefs_or_suffs, list)
655
+ else [float(prefs_or_suffs[p]) / total_weight for p in patterns]
656
  )
657
+ return patterns, weights
658
+
659
+ def prepare(self):
660
+ # Being an artifact, prepare is invoked before verify. Here we need verify before the actions
661
+ self.verify()
662
+ self._prefix_pattern_distribution = {"length": self.prefix_len}
663
+ self._suffix_pattern_distribution = {"length": self.suffix_len}
664
+
665
+ (
666
+ self._prefix_pattern_distribution["patterns"],
667
+ self._prefix_pattern_distribution["weights"],
668
+ ) = self._calculate_distributions(self.prefixes)
669
+ (
670
+ self._suffix_pattern_distribution["patterns"],
671
+ self._suffix_pattern_distribution["weights"],
672
+ ) = self._calculate_distributions(self.suffixes)
673
+ super().prepare()
674
+
675
+ def _get_random_pattern(
676
+ self, pattern_distribution, random_generator: Random
677
+ ) -> str:
678
+ string_to_add = ""
679
+ if pattern_distribution["patterns"]:
680
+ string_to_add = "".join(
681
+ random_generator.choices(
682
+ pattern_distribution["patterns"],
683
+ pattern_distribution["weights"],
684
+ k=pattern_distribution["length"],
685
+ )
686
+ )
687
+ return string_to_add
688
 
689
  def process_value(self, value: Any) -> Any:
690
  assert value is not None, "input value should not be None"
691
  new_value = str(value)
692
+ if self.remove_existing_whitespaces:
693
+ new_value = new_value.strip()
694
+ random_generator = new_random_generator(sub_seed=value)
695
+ prefix = self._get_random_pattern(
696
+ self._prefix_pattern_distribution, random_generator
697
+ )
698
+ suffix = self._get_random_pattern(
699
+ self._suffix_pattern_distribution, random_generator
700
+ )
701
+ return prefix + new_value + suffix
702
 
703
 
704
  class ShuffleFieldValues(FieldOperator):
705
+ """Shuffles a list of values found in a field."""
706
 
707
  def process_value(self, value: Any) -> Any:
708
  res = list(value)
709
+ random_generator = new_random_generator(sub_seed=res)
710
+ random_generator.shuffle(res)
711
  return res
712
 
713
 
 
812
 
813
 
814
  class ZipFieldValues(StreamInstanceOperator):
815
+ """Zips values of multiple fields in a given instance, similar to list(zip(*fields)).
816
+
817
+ The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
818
+ are zipped, and stored into 'to_field'.
819
 
820
+ If 'longest'=False, the length of the zipped result is determined by the shortest input value.
821
+ If 'longest'=False, the length of the zipped result is determined by the longest input, padding shorter
822
+ inputs with None -s.
823
+
824
+ """
825
+
826
+ fields: List[str]
827
  to_field: str
828
  longest: bool = False
829
  use_query: bool = False
 
843
 
844
 
845
  class IndexOf(StreamInstanceOperator):
846
+ """For a given instance, finds the offset of value of field 'index_of', within the value of field 'search_in'."""
847
 
848
  search_in: str
849
  index_of: str
 
860
 
861
 
862
  class TakeByField(StreamInstanceOperator):
863
+ """From field 'field' of a given instance, select the member indexed by field 'index', and store to field 'to_field'."""
864
 
865
  field: str
866
  index: str
 
881
 
882
 
883
  class CopyFields(FieldOperator):
884
+ """Copies values from specified fields to specified fields.
885
 
886
+ Args (of parent class):
887
  field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
888
+ use_query (bool): Whether to use dpath for accessing fields. Defaults to False.
889
+
890
+ Examples:
891
+ An input instance {"a": 2, "b": 3}, when processed by
892
+ CopyField(field_to_field={"a": "b"}
893
+ would yield {"a": 2, "b": 2}, and when processed by
894
+ CopyField(field_to_field={"a": "c"} would yield
895
+ {"a": 2, "b": 3, "c": 2}
896
+
897
+ with use_query=True, we can also copy inside the field:
898
+ CopyFields(field_to_field={"a/0": "a"}, use_query=True)
899
+ would process instance {"a": [1, 3]} into {"a": 1}
900
+
901
+
902
  """
903
 
904
  def process_value(self, value: Any) -> Any:
 
906
 
907
 
908
  class AddID(StreamInstanceOperator):
909
+ """Stores a unique id value in the designated 'id_field_name' field of the given instance."""
910
+
911
  id_field_name: str = "id"
912
 
913
  def process(
 
921
  """Casts specified fields to specified types.
922
 
923
  Args:
924
+ use_nested_query (bool): Whether to cast nested fields, expressed in dpath. Defaults to False.
925
+ fields (Dict[str, str]): A dictionary mapping field names to the names of the types to cast the fields to.
926
+ e.g: "int", "str", "float", "bool". Basic names of types
927
+ defaults (Dict[str, object]): A dictionary mapping field names to default values for cases of casting failure.
928
+ process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
929
+
930
+ Examples:
931
+ CastFields(
932
+ fields={"a/d": "float", "b": "int"},
933
+ failure_defaults={"a/d": 0.0, "b": 0},
934
+ process_every_value=True,
935
+ use_nested_query=True
936
+ )
937
+ would process the input instance: {"a": {"d": ["half", "0.6", 1, 12]}, "b": ["2"]}
938
+ into {"a": {"d": [0.0, 0.6, 1.0, 12.0]}, "b": [2]}
939
+
940
  """
941
 
 
 
 
 
 
 
942
  fields: Dict[str, str] = field(default_factory=dict)
943
  failure_defaults: Dict[str, object] = field(default_factory=dict)
944
  use_nested_query: bool = False
945
+ process_every_value: bool = False
946
+
947
+ def prepare(self):
948
+ self.types = {"int": int, "float": float, "str": str, "bool": bool}
949
 
950
  def _cast_single(self, value, type, field):
951
  try:
 
958
  return self.failure_defaults[field]
959
 
960
  def _cast_multiple(self, values, type, field):
961
+ return [self._cast_single(value, type, field) for value in values]
962
 
963
  def process(
964
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
965
  ) -> Dict[str, Any]:
966
  for field_name, type in self.fields.items():
967
  value = dict_get(instance, field_name, use_dpath=self.use_nested_query)
968
+ if self.process_every_value:
969
+ assert isinstance(
970
+ value, list
971
+ ), f"'process_every_value' can be set to True only for fields that contain lists, whereas in instance {instance}, the contents of field '{field_name}' is of type '{type(value)}'"
972
  casted_value = self._cast_multiple(value, type, field_name)
973
  else:
974
  casted_value = self._cast_single(value, type, field_name)
 
978
  return instance
979
 
980
 
981
+ class DivideAllFieldsBy(StreamInstanceOperator):
982
+ """Recursively reach down to all fields that are float, and divide each by 'divisor'.
 
 
 
 
 
 
 
 
 
 
983
 
984
+ The given instance is viewed as a tree whose internal nodes are dictionaries and lists, and
985
+ the leaves are either 'float' and then divided, or other basic type, in which case, a ValueError is raised
986
+ if input flag 'strict' is True, or -- left alone, if 'strict' is False.
987
+
988
+ Args:
989
+ divisor (float) the value to divide by
990
+ strict (bool) whether to raise an error upon visiting a leaf that is not float. Defaults to False.
991
+
992
+ Example:
993
+ when instance {"a": 10.0, "b": [2.0, 4.0, 7.0], "c": 5} is processed by operator:
994
+ operator = DivideAllFieldsBy(divisor=2.0)
995
+ the output is: {"a": 5.0, "b": [1.0, 2.0, 3.5], "c": 5}
996
+ If the operator were defined with strict=True, through:
997
+ operator = DivideAllFieldsBy(divisor=2.0, strict=True),
998
+ the processing of the above instance would raise a ValueError, for the integer at "c".
999
+ """
1000
 
 
1001
  divisor: float = 1.0
1002
  strict: bool = False
1003
+
1004
+ def _recursive_divide(self, instance, divisor):
1005
+ if isinstance(instance, dict):
1006
+ for key, value in instance.items():
1007
+ instance[key] = self._recursive_divide(value, divisor)
1008
+ elif isinstance(instance, list):
1009
+ for i, value in enumerate(instance):
1010
+ instance[i] = self._recursive_divide(value, divisor)
1011
+ elif isinstance(instance, float):
1012
+ instance /= divisor
1013
+ elif self.strict:
1014
+ raise ValueError(f"Cannot divide instance of type {type(instance)}")
1015
+ return instance
1016
 
1017
  def process(
1018
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
1019
  ) -> Dict[str, Any]:
1020
+ return self._recursive_divide(instance, self.divisor)
1021
 
1022
 
1023
  class ArtifactFetcherMixin:
 
1041
  """Applies value operators to each instance in a stream based on specified fields.
1042
 
1043
  Args:
1044
+ inputs_fields (List[str]): list of field names, the values in which are to be processed
1045
+ fields_to_treat_as_list (List[str]): sublist of input_fields, each member of this sublist is supposed to contain
1046
+ a list of values, each of which is to be processed.
1047
+ operators_field (str): name of the field that contains the list of names of the operators to be applied,
1048
+ one after the other, for the processing.
1049
  default_operators (List[str]): A list of default operators to be used if no operators are found in the instance.
 
1050
 
1051
+ Example:
1052
+ when instance {"a": 111, "b": 2, "c": ["processors.to_string", "processors.first_character"]} is processed by operator:
1053
+ operator = ApplyOperatorsField(inputs_fields=["a"], operators_field="c", default_operators=["add"]),
1054
+ the resulting instance is: {"a": "1", "b": 2, "c": ["processors.to_string", "processors.first_character"]}
1055
 
1056
+ """
1057
+
1058
+ inputs_fields: List[str]
1059
  operators_field: str
1060
  default_operators: List[str] = None
1061
  fields_to_treat_as_list: List[str] = NonPositionalField(default_factory=list)
 
1067
  if operator_names is None:
1068
  assert (
1069
  self.default_operators is not None
1070
+ ), f"No operators found in field '{self.operators_field}', and no default operators provided."
1071
  operator_names = self.default_operators
1072
 
1073
  if isinstance(operator_names, str):
 
1080
  if field_name in self.fields_to_treat_as_list:
1081
  instance[field_name] = [operator.process(v) for v in value]
1082
  else:
1083
+ instance[field_name] = operator.process(value)
1084
 
1085
  return instance
1086
 
1087
 
1088
+ class FilterByCondition(SingleStreamOperator):
1089
+ """Filters a stream, yielding only instances for which the required values follows the required condition operator.
1090
+
1091
+ Raises an error if a required key is missing.
1092
 
1093
  Args:
1094
+ values (Dict[str, Any]): Values that instances must match using the condition to be included in the output.
1095
+ condition: the name of the desired condition operator between the key and the value in values ("gt", "ge", "lt", "le", "ne", "eq")
1096
+ error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1097
+
1098
+ Examples:
1099
+ FilterByCondition(values = {"a":4}, condition = "gt") will yield only instances where "a">4
1100
+ FilterByCondition(values = {"a":4}, condition = "le") will yield only instances where "a"<=4
1101
+ FilterByCondition(values = {"a":[4,8]}, condition = "in") will yield only instances where "a" is 4 or 8
1102
+ FilterByCondition(values = {"a":[4,8]}, condition = "not in") will yield only instances where "a" different from 4 or 8
1103
+
1104
  """
1105
 
1106
+ values: Dict[str, Any]
1107
+ condition: str
1108
+ condition_to_func = {
1109
+ "gt": operator.gt,
1110
+ "ge": operator.ge,
1111
+ "lt": operator.lt,
1112
+ "le": operator.le,
1113
+ "eq": operator.eq,
1114
+ "ne": operator.ne,
1115
+ "in": None, # Handled as special case
1116
+ "not in": None, # Handled as special case
1117
+ }
1118
+ error_on_filtered_all: bool = True
1119
 
1120
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1121
+ yielded = False
1122
  for instance in stream:
1123
+ if self._is_required(instance):
1124
+ yielded = True
1125
+ yield instance
1126
+
1127
+ if not yielded and self.error_on_filtered_all:
1128
+ raise RuntimeError(
1129
+ f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1130
+ )
1131
+
1132
+ def verify(self):
1133
+ if self.condition not in self.condition_to_func:
1134
+ raise ValueError(
1135
+ f"Unsupported condition operator '{self.condition}', supported {list(self.condition_to_func.keys())}"
1136
+ )
1137
+
1138
+ for key, value in self.values.items():
1139
+ if self.condition in ["in", "not it"] and not isinstance(value, list):
1140
+ raise ValueError(
1141
+ f"The filter for key ('{key}') in FilterByCondition with condition '{self.condition}' must be list but is not : '{value}'"
1142
+ )
1143
+ return super().verify()
1144
+
1145
+ def _is_required(self, instance: dict) -> bool:
1146
+ for key, value in self.values.items():
1147
+ if key not in instance:
1148
+ raise ValueError(
1149
+ f"Required filter field ('{key}') in FilterByCondition is not found in {instance}"
1150
+ )
1151
+ if self.condition == "in":
1152
+ if instance[key] not in value:
1153
+ return False
1154
+ elif self.condition == "not in":
1155
+ if instance[key] in value:
1156
+ return False
1157
+ else:
1158
+ func = self.condition_to_func[self.condition]
1159
+ if func is None:
1160
  raise ValueError(
1161
+ f"Function not defined for condition '{self.condition}'"
1162
  )
1163
+ if not func(instance[key], value):
1164
+ return False
1165
+ return True
1166
+
1167
+
1168
+ class FilterByQuery(SingleStreamOperator):
1169
+ """Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated.
1170
+
1171
+ Raises an error if a field participating in the specified condition is missing from the instance
1172
+
1173
+ Args:
1174
+ query (str): a condition over fields of the instance, to be processed by python's eval()
1175
+ error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1176
+
1177
+ Examples:
1178
+ FilterByQuery(query = "a > 4") will yield only instances where "a">4
1179
+ FilterByQuery(query = "a <= 4 and b > 5") will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
1180
+ FilterByQuery(query = "a in [4, 8]") will yield only instances where "a" is 4 or 8
1181
+ FilterByQuery(query = "a not in [4, 8]") will yield only instances where "a" is neither 4 nor 8
1182
+
1183
+ """
1184
+
1185
+ query: str
1186
+ error_on_filtered_all: bool = True
1187
+
1188
+ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1189
+ yielded = False
1190
+ for instance in stream:
1191
+ if eval(self.query, None, instance):
1192
+ yielded = True
1193
  yield instance
1194
 
1195
+ if not yielded and self.error_on_filtered_all:
1196
+ raise RuntimeError(
1197
+ f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1198
+ )
1199
+
1200
+
1201
+ class ExecuteQuery(StreamInstanceOperator):
1202
+ """Compute an expression (query), expressed as a string to be eval-uated, over the instance's fields, and store the result in field to_field.
1203
+
1204
+ Raises an error if a field mentioned in the query is missing from the instance.
1205
+
1206
+ Args:
1207
+ query (str): an expression to be evaluated over the fields of the instance
1208
+ to_field (str): the field where the result is to be stored into
1209
+
1210
+ Examples:
1211
+ When instance {"a": 2, "b": 3} is process-ed by operator
1212
+ ExecuteQuery(query="a+b", to_field = "c")
1213
+ the result is {"a": 2, "b": 3, "c": 5}
1214
+
1215
+ When instance {"a": "hello", "b": "world"} is process-ed by operator
1216
+ ExecuteQuery(query = "a+' '+b", to_field = "c")
1217
+ the result is {"a": "hello", "b": "world", "c": "hello world"}
1218
+
1219
+ """
1220
+
1221
+ query: str
1222
+ to_field: str
1223
+
1224
+ def process(
1225
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
1226
+ ) -> Dict[str, Any]:
1227
+ instance[self.to_field] = eval(self.query, None, instance)
1228
+ return instance
1229
+
1230
 
1231
+ class ExtractMostCommonFieldValues(MultiStreamOperator):
1232
  field: str
1233
  stream_name: str
1234
  overall_top_frequency_percent: Optional[int] = 100
 
1249
 
1250
  Examples:
1251
 
1252
+ ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes") - extracts all the unique values of
1253
  field 'label', sorts them by decreasing frequency, and stores the resulting list in field 'classes' of each and
1254
  every instance in all streams.
1255
 
1256
+ ExtractMostCommonFieldValues(stream_name="train", field="labels", to_field="classes", process_every_value=True) -
1257
  in case that field 'labels' contains a list of values (and not a single value) - track the occurrences of all the possible
1258
  value members in these lists, and report the most frequent values.
1259
  if process_every_value=False, track the most frequent whole lists, and report those (as a list of lists) in field
1260
  'to_field' of each instance of all streams.
1261
 
1262
+ ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",overall_top_frequency_percent=80) -
1263
  extracts the most frequent possible values of field 'label' that together cover at least 80% of the instances of stream_name,
1264
  and stores them in field 'classes' of each instance of all streams.
1265
 
1266
+ ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",min_frequency_percent=5) -
1267
  extracts all possible values of field 'label' that cover, each, at least 5% of the instances.
1268
  Stores these values, sorted by decreasing order of frequency, in field 'classes' of each instance in all streams.
1269
  """
 
1324
  [*ele[0]] if isinstance(ele[0], tuple) else ele[0]
1325
  for ele in values_and_counts
1326
  ]
 
 
 
 
1327
 
1328
+ addmostcommons = AddFields(fields={self.to_field: values_to_keep})
1329
+ return addmostcommons(multi_stream)
1330
 
 
 
 
 
 
 
 
 
1331
 
1332
+ class ExtractFieldValues(ExtractMostCommonFieldValues):
1333
  def verify(self):
1334
  super().verify()
 
 
 
 
 
1335
 
1336
+ def prepare(self):
1337
+ self.overall_top_frequency_percent = 100
1338
+ self.min_frequency_percent = 0
 
 
 
 
 
 
 
 
 
1339
 
1340
 
1341
  class Intersect(FieldOperator):
 
1360
  )
1361
 
1362
  def process_value(self, value: Any) -> Any:
1363
+ super().process_value(value)
1364
  if not isinstance(value, list):
1365
  raise ValueError(f"The value in field is not a list but '{value}'")
1366
  return [e for e in value if e in self.allowed_values]
 
1370
  """Removes elements in a field, which must be a list, using a given list of unallowed.
1371
 
1372
  Args:
1373
+ unallowed_values (list) - values to be removed.
1374
  """
1375
 
1376
  unallowed_values: List[Any]
 
1439
  stream_unique_values = uniques[stream_name]
1440
  for unique_values in stream_unique_values:
1441
  filtering_values = dict(zip(self.fields, unique_values))
1442
+ filtered_streams = FilterByCondition(
1443
+ values=filtering_values, condition="eq"
1444
  )._process_single_stream(stream)
1445
  filtered_stream_name = (
1446
  stream_name + "_" + nested_tuple_to_string(unique_values)
 
1462
  reversed: bool = False
1463
 
1464
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1465
+ first_instance = stream.peek()
1466
 
1467
  operators = first_instance.get(self.field, [])
1468
  if isinstance(operators, str):
 
1496
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1497
  from .metrics import Metric, MetricPipeline, MetricWithConfidenceInterval
1498
 
1499
+ first_instance = stream.peek()
1500
 
1501
  metric_names = first_instance.get(self.metric_field, [])
1502
  if not metric_names:
 
1532
  yield from stream
1533
 
1534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1535
  class MergeStreams(MultiStreamOperator):
1536
  """Merges multiple streams into a single stream.
1537
 
 
1567
  class Shuffle(PagedStreamOperator):
1568
  """Shuffles the order of instances in each page of a stream.
1569
 
1570
+ Args (of superclass):
1571
  page_size (int): The size of each page in the stream. Defaults to 1000.
1572
  """
1573
 
1574
+ random_generator: Random = None
1575
+
1576
+ def before_process_multi_stream(self):
1577
+ super().before_process_multi_stream()
1578
+ self.random_generator = new_random_generator(sub_seed="shuffle")
1579
+
1580
  def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1581
+ self.random_generator.shuffle(page)
1582
  yield from page
1583
 
1584
 
1585
  class EncodeLabels(StreamInstanceOperator):
1586
+ """Encode each value encountered in any field in 'fields' into the integers 0,1,...
1587
+
1588
+ Encoding is determined by a str->int map that is built on the go, as different values are
1589
+ first encountered in the stream, either as list members or as values in single-value fields.
1590
 
1591
  Args:
1592
  fields (List[str]): The fields to encode together.
1593
+
1594
+ Example: applying
1595
+ EncodeLabels(fields = ["a", "b/*"])
1596
+ on input stream = [{"a": "red", "b": ["red", "blue"], "c":"bread"},
1597
+ {"a": "blue", "b": ["green"], "c":"water"}] will yield the
1598
+ output stream = [{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]
1599
+
1600
+ Note: dpath is applied here, and hence, fields that are lists, should be included in
1601
+ input 'fields' with the appendix "/*" as in the above example.
1602
+
1603
  """
1604
 
1605
  fields: List[str]
 
1627
 
1628
 
1629
  class StreamRefiner(SingleStreamOperator):
1630
+ """Discard from the input stream all instances beyond the leading 'max_instances' instances.
1631
+
1632
+ Thereby, if the input stream consists of no more than 'max_instances' instances, the resulting stream is the whole of the
1633
+ input stream. And if the input stream consists of more than 'max_instances' instances, the resulting stream only consists
1634
+ of the leading 'max_instances' of the input stream.
1635
+
1636
+ Args: max_instances (int)
1637
+ apply_to_streams (optional, list(str)): names of streams to refine.
1638
+
1639
+ Examples:
1640
+ when input = [{"a": 1},{"a": 2},{"a": 3},{"a": 4},{"a": 5},{"a": 6}] is fed into
1641
+ StreamRefiner(max_instances=4)
1642
+ the resulting stream is [{"a": 1},{"a": 2},{"a": 3},{"a": 4}]
1643
+ """
1644
+
1645
  max_instances: int = None
1646
+ apply_to_streams: Optional[List[str]] = None
1647
 
1648
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1649
  if self.max_instances is not None:
 
1655
  class DeterministicBalancer(StreamRefiner):
1656
  """A class used to balance streams deterministically.
1657
 
1658
+ For each instance, a signature is constructed from the values of the instance in specified input 'fields'.
1659
+ By discarding instances from the input stream, DeterministicBalancer maintains equal number of instances for all signatures.
1660
+ When also input 'max_instances' is specified, DeterministicBalancer maintains a total instance count not exceeding
1661
+ 'max_instances'. The total number of discarded instances is as few as possible.
1662
+
1663
  Attributes:
1664
+ fields (List[str]): A list of field names to be used in producing the instance's signature.
1665
+ max_instances (Optional, int)
1666
 
1667
  Usage:
1668
+ balancer = DeterministicBalancer(fields=["field1", "field2"], max_instances=200)
1669
  balanced_stream = balancer.process(stream)
1670
+
1671
+ Example:
1672
+ When input [{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 2},{"a": 3},{"a": 4}] is fed into
1673
+ DeterministicBalancer(fields=["a"])
1674
+ the resulting stream will be: [{"a": 1, "b": 1},{"a": 2},{"a": 3},{"a": 4}]
1675
  """
1676
 
1677
  fields: List[str]
 
1708
 
1709
 
1710
  class LengthBalancer(DeterministicBalancer):
1711
+ """Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
1712
+
1713
+ Args:
1714
+ segments_boundaries (List[int]): distinct integers sorted in increasing order, that maps a given total length
1715
+ into the index of the least of them that exceeds the total length. (If none exceeds -- into one index
1716
+ beyond, namely, the length of segments_boudaries)
1717
+
1718
+ fields (Optional, List[str])
1719
+
1720
+ Example:
1721
+ when input [{"a": [1, 3], "b": 0, "id": 0}, {"a": [1, 3], "b": 0, "id": 1}, {"a": [], "b": "a", "id": 2}] is fed into
1722
+ LengthBalancer(fields=["a"], segments_boundaries=[1])
1723
+ input instances will be counted and balanced against two categories: empty total length (less than 1), and non-empty.
1724
+ """
1725
+
1726
  segments_boundaries: List[int]
1727
+ fields: Optional[List[str]]
1728
 
1729
  def signature(self, instance):
1730
  total_len = 0