Elron commited on
Commit
cf45ebb
1 Parent(s): 2e870c5

Upload operators.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. operators.py +70 -54
operators.py CHANGED
@@ -32,7 +32,6 @@ The rest of this section is dedicated for general operators.
32
  General Operaotrs List:
33
  ------------------------
34
  """
35
- import collections
36
  import copy
37
  import operator
38
  import uuid
@@ -58,7 +57,7 @@ from typing import (
58
  import requests
59
 
60
  from .artifact import Artifact, fetch_artifact
61
- from .dataclass import NonPositionalField, OptionalField
62
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
63
  from .operator import (
64
  MultiStream,
@@ -157,7 +156,6 @@ class MapInstanceValues(StreamInstanceOperator):
157
 
158
  mappers: Dict[str, Dict[str, str]]
159
  strict: bool = True
160
- use_query: bool = False
161
  process_every_value: bool = False
162
 
163
  def verify(self):
@@ -175,7 +173,7 @@ class MapInstanceValues(StreamInstanceOperator):
175
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
176
  ) -> Dict[str, Any]:
177
  for key, mapper in self.mappers.items():
178
- value = dict_get(instance, key, use_dpath=self.use_query)
179
  if value is not None:
180
  if (self.process_every_value is True) and (not isinstance(value, list)):
181
  raise ValueError(
@@ -190,7 +188,6 @@ class MapInstanceValues(StreamInstanceOperator):
190
  instance,
191
  key,
192
  value,
193
- use_dpath=self.use_query,
194
  )
195
 
196
  return instance
@@ -229,7 +226,7 @@ class AddFields(StreamInstanceOperator):
229
 
230
  Args:
231
  fields (Dict[str, object]): The fields to add to each instance.
232
- use_query (bool) : Use '/' to access inner fields
233
  use_deepcopy (bool) : Deep copy the input value to avoid later modifications
234
 
235
  Examples:
@@ -249,21 +246,21 @@ class AddFields(StreamInstanceOperator):
249
  """
250
 
251
  fields: Dict[str, object]
252
- use_query: bool = False
 
 
 
 
 
253
  use_deepcopy: bool = False
254
 
255
  def process(
256
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
257
  ) -> Dict[str, Any]:
258
- if self.use_query:
259
- for key, value in self.fields.items():
260
- if self.use_deepcopy:
261
- value = deepcopy(value)
262
- dict_set(instance, key, value, use_dpath=self.use_query)
263
- else:
264
  if self.use_deepcopy:
265
- self.fields = deepcopy(self.fields)
266
- instance.update(self.fields)
267
  return instance
268
 
269
 
@@ -302,17 +299,21 @@ class InstanceFieldOperator(StreamInstanceOperator):
302
  The operator throws an AssertionError in either of these cases.
303
  field_to_field defaults to None
304
  process_every_value (bool): Processes the values in a list instead of the list as a value, similar to *var. Defaults to False
305
- use_query (bool): Whether to use dpath style queries. Defaults to False.
306
 
307
  Note: if 'field' and 'to_field' (or both members of a pair in 'field_to_field') are equal (or share a common
308
- prefix if 'use_query'=True), then the result of the operation is saved within 'field'
309
  """
310
 
311
  field: Optional[str] = None
312
  to_field: Optional[str] = None
313
  field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
 
 
 
 
 
 
314
  process_every_value: bool = False
315
- use_query: bool = False
316
  get_default: Any = None
317
  not_exist_ok: bool = False
318
 
@@ -397,7 +398,6 @@ class InstanceFieldOperator(StreamInstanceOperator):
397
  old_value = dict_get(
398
  instance,
399
  from_field,
400
- use_dpath=self.use_query,
401
  default=self.get_default,
402
  not_exist_ok=self.not_exist_ok,
403
  )
@@ -421,7 +421,6 @@ class InstanceFieldOperator(StreamInstanceOperator):
421
  instance,
422
  to_field,
423
  new_value,
424
- use_dpath=self.use_query,
425
  not_exist_ok=True,
426
  )
427
  return instance
@@ -439,20 +438,20 @@ class FieldOperator(InstanceFieldOperator):
439
  class RenameFields(FieldOperator):
440
  """Renames fields.
441
 
442
- Move value from one field to another, potentially, if 'use_query'=True, from one branch into another.
443
- Remove the from field, potentially part of it in case of use_query.
444
 
445
  Examples:
446
  RenameFields(field_to_field={"b": "c"})
447
  will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
448
 
449
- RenameFields(field_to_field={"b": "c/d"}, use_query=True)
450
  will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
451
 
452
- RenameFields(field_to_field={"b": "b/d"}, use_query=True)
453
  will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
454
 
455
- RenameFields(field_to_field={"b/c/e": "b/d"}, use_query=True)
456
  will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
457
 
458
  """
@@ -539,7 +538,6 @@ class Augmentor(StreamInstanceOperator):
539
  old_value = dict_get(
540
  instance,
541
  field_name,
542
- use_dpath=True,
543
  default="",
544
  not_exist_ok=False,
545
  )
@@ -552,7 +550,7 @@ class Augmentor(StreamInstanceOperator):
552
  raise RuntimeError(
553
  f"Error augmenting value '{old_value}' from '{field_name}' in instance: {instance}"
554
  ) from e
555
- dict_set(instance, field_name, new_value, use_dpath=True, not_exist_ok=True)
556
  return instance
557
 
558
 
@@ -809,14 +807,19 @@ class ListFieldValues(StreamInstanceOperator):
809
 
810
  fields: List[str]
811
  to_field: str
812
- use_query: bool = False
 
 
 
 
 
813
 
814
  def process(
815
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
816
  ) -> Dict[str, Any]:
817
  values = []
818
  for field_name in self.fields:
819
- values.append(dict_get(instance, field_name, use_dpath=self.use_query))
820
  instance[self.to_field] = values
821
  return instance
822
 
@@ -836,14 +839,19 @@ class ZipFieldValues(StreamInstanceOperator):
836
  fields: List[str]
837
  to_field: str
838
  longest: bool = False
839
- use_query: bool = False
 
 
 
 
 
840
 
841
  def process(
842
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
843
  ) -> Dict[str, Any]:
844
  values = []
845
  for field_name in self.fields:
846
- values.append(dict_get(instance, field_name, use_dpath=self.use_query))
847
  if self.longest:
848
  zipped = zip_longest(*values)
849
  else:
@@ -858,13 +866,18 @@ class IndexOf(StreamInstanceOperator):
858
  search_in: str
859
  index_of: str
860
  to_field: str
861
- use_query: bool = False
 
 
 
 
 
862
 
863
  def process(
864
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
865
  ) -> Dict[str, Any]:
866
- lst = dict_get(instance, self.search_in, use_dpath=self.use_query)
867
- item = dict_get(instance, self.index_of, use_dpath=self.use_query)
868
  instance[self.to_field] = lst.index(item)
869
  return instance
870
 
@@ -875,7 +888,12 @@ class TakeByField(StreamInstanceOperator):
875
  field: str
876
  index: str
877
  to_field: str = None
878
- use_query: bool = False
 
 
 
 
 
879
 
880
  def prepare(self):
881
  if self.to_field is None:
@@ -884,8 +902,8 @@ class TakeByField(StreamInstanceOperator):
884
  def process(
885
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
886
  ) -> Dict[str, Any]:
887
- value = dict_get(instance, self.field, use_dpath=self.use_query)
888
- index_value = dict_get(instance, self.index, use_dpath=self.use_query)
889
  instance[self.to_field] = value[index_value]
890
  return instance
891
 
@@ -943,7 +961,6 @@ class CopyFields(FieldOperator):
943
 
944
  Args (of parent class):
945
  field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
946
- use_query (bool): Whether to use dpath for accessing fields. Defaults to False.
947
 
948
  Examples:
949
  An input instance {"a": 2, "b": 3}, when processed by
@@ -952,8 +969,8 @@ class CopyFields(FieldOperator):
952
  CopyField(field_to_field={"a": "c"} would yield
953
  {"a": 2, "b": 3, "c": 2}
954
 
955
- with use_query=True, we can also copy inside the field:
956
- CopyFields(field_to_field={"a/0": "a"}, use_query=True)
957
  would process instance {"a": [1, 3]} into {"a": 1}
958
 
959
 
@@ -1031,7 +1048,7 @@ class CastFields(StreamInstanceOperator):
1031
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
1032
  ) -> Dict[str, Any]:
1033
  for field_name, type in self.fields.items():
1034
- value = dict_get(instance, field_name, use_dpath=self.use_nested_query)
1035
  if self.process_every_value:
1036
  assert isinstance(
1037
  value, list
@@ -1039,9 +1056,8 @@ class CastFields(StreamInstanceOperator):
1039
  casted_value = self._cast_multiple(value, type, field_name)
1040
  else:
1041
  casted_value = self._cast_single(value, type, field_name)
1042
- dict_set(
1043
- instance, field_name, casted_value, use_dpath=self.use_nested_query
1044
- )
1045
  return instance
1046
 
1047
 
@@ -1709,7 +1725,7 @@ class EncodeLabels(StreamInstanceOperator):
1709
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
1710
  ) -> Dict[str, Any]:
1711
  for field_name in self.fields:
1712
- values = dict_get(instance, field_name, use_dpath=True)
1713
  values_was_a_list = isinstance(values, list)
1714
  if not isinstance(values, list):
1715
  values = [values]
@@ -1723,8 +1739,10 @@ class EncodeLabels(StreamInstanceOperator):
1723
  instance,
1724
  field_name,
1725
  new_values,
1726
- use_dpath=True,
1727
- set_multiple="*" in field_name,
 
 
1728
  )
1729
 
1730
  return instance
@@ -1781,12 +1799,10 @@ class DeterministicBalancer(StreamRefiner):
1781
  fields: List[str]
1782
 
1783
  def signature(self, instance):
1784
- return str(
1785
- tuple(dict_get(instance, field, use_dpath=True) for field in self.fields)
1786
- )
1787
 
1788
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1789
- counter = collections.Counter()
1790
 
1791
  for instance in stream:
1792
  counter[self.signature(instance)] += 1
@@ -1802,7 +1818,7 @@ class DeterministicBalancer(StreamRefiner):
1802
  lowest_count, self.max_instances // len(counter)
1803
  )
1804
 
1805
- counter = collections.Counter()
1806
 
1807
  for instance in stream:
1808
  sign = self.signature(instance)
@@ -1817,7 +1833,7 @@ class LengthBalancer(DeterministicBalancer):
1817
  Args:
1818
  segments_boundaries (List[int]): distinct integers sorted in increasing order, that maps a given total length
1819
  into the index of the least of them that exceeds the total length. (If none exceeds -- into one index
1820
- beyond, namely, the length of segments_boudaries)
1821
 
1822
  fields (Optional, List[str])
1823
 
@@ -1837,7 +1853,7 @@ class LengthBalancer(DeterministicBalancer):
1837
  def signature(self, instance):
1838
  total_len = 0
1839
  for field_name in self.fields:
1840
- total_len += len(dict_get(instance, field_name, use_dpath=True))
1841
  for i, val in enumerate(self.segments_boundaries):
1842
  if total_len < val:
1843
  return i
 
32
  General Operaotrs List:
33
  ------------------------
34
  """
 
35
  import copy
36
  import operator
37
  import uuid
 
57
  import requests
58
 
59
  from .artifact import Artifact, fetch_artifact
60
+ from .dataclass import DeprecatedField, NonPositionalField, OptionalField
61
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
62
  from .operator import (
63
  MultiStream,
 
156
 
157
  mappers: Dict[str, Dict[str, str]]
158
  strict: bool = True
 
159
  process_every_value: bool = False
160
 
161
  def verify(self):
 
173
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
174
  ) -> Dict[str, Any]:
175
  for key, mapper in self.mappers.items():
176
+ value = dict_get(instance, key)
177
  if value is not None:
178
  if (self.process_every_value is True) and (not isinstance(value, list)):
179
  raise ValueError(
 
188
  instance,
189
  key,
190
  value,
 
191
  )
192
 
193
  return instance
 
226
 
227
  Args:
228
  fields (Dict[str, object]): The fields to add to each instance.
229
+ Use '/' to access inner fields
230
  use_deepcopy (bool) : Deep copy the input value to avoid later modifications
231
 
232
  Examples:
 
246
  """
247
 
248
  fields: Dict[str, object]
249
+ use_query: bool = DeprecatedField(
250
+ metadata={
251
+ "deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
252
+ "Please remove this field from your code."
253
+ }
254
+ )
255
  use_deepcopy: bool = False
256
 
257
  def process(
258
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
259
  ) -> Dict[str, Any]:
260
+ for key, value in self.fields.items():
 
 
 
 
 
261
  if self.use_deepcopy:
262
+ value = deepcopy(value)
263
+ dict_set(instance, key, value)
264
  return instance
265
 
266
 
 
299
  The operator throws an AssertionError in either of these cases.
300
  field_to_field defaults to None
301
  process_every_value (bool): Processes the values in a list instead of the list as a value, similar to *var. Defaults to False
 
302
 
303
  Note: if 'field' and 'to_field' (or both members of a pair in 'field_to_field') are equal (or share a common
304
+ prefix if 'field' and 'to_field' contain a /), then the result of the operation is saved within 'field'
305
  """
306
 
307
  field: Optional[str] = None
308
  to_field: Optional[str] = None
309
  field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
310
+ use_query: bool = DeprecatedField(
311
+ metadata={
312
+ "deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
313
+ "Please remove this field from your code."
314
+ }
315
+ )
316
  process_every_value: bool = False
 
317
  get_default: Any = None
318
  not_exist_ok: bool = False
319
 
 
398
  old_value = dict_get(
399
  instance,
400
  from_field,
 
401
  default=self.get_default,
402
  not_exist_ok=self.not_exist_ok,
403
  )
 
421
  instance,
422
  to_field,
423
  new_value,
 
424
  not_exist_ok=True,
425
  )
426
  return instance
 
438
  class RenameFields(FieldOperator):
439
  """Renames fields.
440
 
441
+ Move value from one field to another, potentially, if field name contains a /, from one branch into another.
442
+ Remove the from field, potentially part of it in case of / in from_field.
443
 
444
  Examples:
445
  RenameFields(field_to_field={"b": "c"})
446
  will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
447
 
448
+ RenameFields(field_to_field={"b": "c/d"})
449
  will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
450
 
451
+ RenameFields(field_to_field={"b": "b/d"})
452
  will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
453
 
454
+ RenameFields(field_to_field={"b/c/e": "b/d"})
455
  will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
456
 
457
  """
 
538
  old_value = dict_get(
539
  instance,
540
  field_name,
 
541
  default="",
542
  not_exist_ok=False,
543
  )
 
550
  raise RuntimeError(
551
  f"Error augmenting value '{old_value}' from '{field_name}' in instance: {instance}"
552
  ) from e
553
+ dict_set(instance, field_name, new_value, not_exist_ok=True)
554
  return instance
555
 
556
 
 
807
 
808
  fields: List[str]
809
  to_field: str
810
+ use_query: bool = DeprecatedField(
811
+ metadata={
812
+ "deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
813
+ "Please remove this field from your code."
814
+ }
815
+ )
816
 
817
  def process(
818
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
819
  ) -> Dict[str, Any]:
820
  values = []
821
  for field_name in self.fields:
822
+ values.append(dict_get(instance, field_name))
823
  instance[self.to_field] = values
824
  return instance
825
 
 
839
  fields: List[str]
840
  to_field: str
841
  longest: bool = False
842
+ use_query: bool = DeprecatedField(
843
+ metadata={
844
+ "deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
845
+ "Please remove this field from your code."
846
+ }
847
+ )
848
 
849
  def process(
850
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
851
  ) -> Dict[str, Any]:
852
  values = []
853
  for field_name in self.fields:
854
+ values.append(dict_get(instance, field_name))
855
  if self.longest:
856
  zipped = zip_longest(*values)
857
  else:
 
866
  search_in: str
867
  index_of: str
868
  to_field: str
869
+ use_query: bool = DeprecatedField(
870
+ metadata={
871
+ "deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
872
+ "Please remove this field from your code."
873
+ }
874
+ )
875
 
876
  def process(
877
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
878
  ) -> Dict[str, Any]:
879
+ lst = dict_get(instance, self.search_in)
880
+ item = dict_get(instance, self.index_of)
881
  instance[self.to_field] = lst.index(item)
882
  return instance
883
 
 
888
  field: str
889
  index: str
890
  to_field: str = None
891
+ use_query: bool = DeprecatedField(
892
+ metadata={
893
+ "deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
894
+ "Please remove this field from your code."
895
+ }
896
+ )
897
 
898
  def prepare(self):
899
  if self.to_field is None:
 
902
  def process(
903
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
904
  ) -> Dict[str, Any]:
905
+ value = dict_get(instance, self.field)
906
+ index_value = dict_get(instance, self.index)
907
  instance[self.to_field] = value[index_value]
908
  return instance
909
 
 
961
 
962
  Args (of parent class):
963
  field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
 
964
 
965
  Examples:
966
  An input instance {"a": 2, "b": 3}, when processed by
 
969
  CopyField(field_to_field={"a": "c"} would yield
970
  {"a": 2, "b": 3, "c": 2}
971
 
972
+ with field names containing / , we can also copy inside the field:
973
+ CopyFields(field_to_field={"a/0": "a"})
974
  would process instance {"a": [1, 3]} into {"a": 1}
975
 
976
 
 
1048
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
1049
  ) -> Dict[str, Any]:
1050
  for field_name, type in self.fields.items():
1051
+ value = dict_get(instance, field_name)
1052
  if self.process_every_value:
1053
  assert isinstance(
1054
  value, list
 
1056
  casted_value = self._cast_multiple(value, type, field_name)
1057
  else:
1058
  casted_value = self._cast_single(value, type, field_name)
1059
+
1060
+ dict_set(instance, field_name, casted_value)
 
1061
  return instance
1062
 
1063
 
 
1725
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
1726
  ) -> Dict[str, Any]:
1727
  for field_name in self.fields:
1728
+ values = dict_get(instance, field_name)
1729
  values_was_a_list = isinstance(values, list)
1730
  if not isinstance(values, list):
1731
  values = [values]
 
1739
  instance,
1740
  field_name,
1741
  new_values,
1742
+ not_exist_ok=False, # the values to encode where just taken from there
1743
+ set_multiple="*" in field_name
1744
+ and isinstance(new_values, list)
1745
+ and len(new_values) > 0,
1746
  )
1747
 
1748
  return instance
 
1799
  fields: List[str]
1800
 
1801
  def signature(self, instance):
1802
+ return str(tuple(dict_get(instance, field) for field in self.fields))
 
 
1803
 
1804
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1805
+ counter = Counter()
1806
 
1807
  for instance in stream:
1808
  counter[self.signature(instance)] += 1
 
1818
  lowest_count, self.max_instances // len(counter)
1819
  )
1820
 
1821
+ counter = Counter()
1822
 
1823
  for instance in stream:
1824
  sign = self.signature(instance)
 
1833
  Args:
1834
  segments_boundaries (List[int]): distinct integers sorted in increasing order, that maps a given total length
1835
  into the index of the least of them that exceeds the total length. (If none exceeds -- into one index
1836
+ beyond, namely, the length of segments_boundaries)
1837
 
1838
  fields (Optional, List[str])
1839
 
 
1853
  def signature(self, instance):
1854
  total_len = 0
1855
  for field_name in self.fields:
1856
+ total_len += len(dict_get(instance, field_name))
1857
  for i, val in enumerate(self.segments_boundaries):
1858
  if total_len < val:
1859
  return i