Elron commited on
Commit
78663de
1 Parent(s): 161e5a1

Upload operators.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. operators.py +709 -162
operators.py CHANGED
@@ -1,8 +1,8 @@
1
  import collections
2
  import importlib
3
- import inspect
4
  import uuid
5
  from abc import abstractmethod
 
6
  from copy import deepcopy
7
  from dataclasses import field
8
  from itertools import zip_longest
@@ -19,7 +19,7 @@ from typing import (
19
  )
20
 
21
  from .artifact import Artifact, fetch_artifact
22
- from .dataclass import NonPositionalField, OptionalField
23
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
24
  from .operator import (
25
  MultiStream,
@@ -32,15 +32,14 @@ from .operator import (
32
  StreamInstanceOperator,
33
  StreamSource,
34
  )
35
- from .random_utils import random
36
- from .stream import MultiStream, Stream
37
  from .text_utils import nested_tuple_to_string
38
  from .utils import flatten_dict
39
 
40
 
41
  class FromIterables(StreamInitializerOperator):
42
- """
43
- Creates a MultiStream from iterables.
44
 
45
  Args:
46
  iterables (Dict[str, Iterable]): A dictionary where each key-value pair represents a stream name and its corresponding iterable.
@@ -70,35 +69,83 @@ class MapInstanceValues(StreamInstanceOperator):
70
  strict (bool): If True, the mapping is applied strictly. That means if a value
71
  does not exist in the mapper, it will raise a KeyError. If False, values
72
  that are not present in the mapper are kept as they are.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  """
74
 
75
  mappers: Dict[str, Dict[str, str]]
76
  strict: bool = True
77
- use_query = False
 
78
 
79
  def verify(self):
80
  # make sure the mappers are valid
81
  for key, mapper in self.mappers.items():
82
- assert isinstance(mapper, dict), f"Mapper for given field {key} should be a dict, got {type(mapper)}"
83
- for k, v in mapper.items():
84
- assert isinstance(k, str), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
85
-
86
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
 
 
 
 
87
  for key, mapper in self.mappers.items():
88
  value = dict_get(instance, key, use_dpath=self.use_query)
89
  if value is not None:
90
- value = str(value) # make sure the value is a string
91
- if self.strict:
92
- dict_set(instance, key, mapper[value], use_dpath=self.use_query)
93
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  if value in mapper:
95
  dict_set(instance, key, mapper[value], use_dpath=self.use_query)
 
96
  return instance
97
 
98
 
99
  class FlattenInstances(StreamInstanceOperator):
100
- """
101
- Flattens each instance in a stream, making nested dictionary entries into top-level entries.
102
 
103
  Args:
104
  parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
@@ -108,23 +155,42 @@ class FlattenInstances(StreamInstanceOperator):
108
  parent_key: str = ""
109
  sep: str = "_"
110
 
111
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
112
  return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
113
 
114
 
115
  class AddFields(StreamInstanceOperator):
116
- """
117
- Adds specified fields to each instance in a stream.
118
 
119
  Args:
120
  fields (Dict[str, object]): The fields to add to each instance.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  """
122
 
123
  fields: Dict[str, object]
124
  use_query: bool = False
125
  use_deepcopy: bool = False
126
 
127
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
128
  if self.use_query:
129
  for key, value in self.fields.items():
130
  if self.use_deepcopy:
@@ -138,30 +204,31 @@ class AddFields(StreamInstanceOperator):
138
 
139
 
140
  class RemoveFields(StreamInstanceOperator):
141
- """
142
- Adds specified fields to each instance in a stream.
143
 
144
  Args:
145
- fields (Dict[str, object]): The fields to add to each instance.
146
  """
147
 
148
  fields: List[str]
149
 
150
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
151
- for field in self.fields:
152
- del instance[field]
 
 
153
  return instance
154
 
155
 
156
  class FieldOperator(StreamInstanceOperator):
157
- """
158
- A general stream that processes the values of a field (or multiple ones
159
  Args:
160
  field (Optional[str]): The field to process, if only a single one is passed Defaults to None
161
  to_field (Optional[str]): Field name to save, if only one field is to be saved, if None is passed the operator would happen in-place and replace "field" Defaults to None
162
  field_to_field (Optional[Union[List[Tuple[str, str]], Dict[str, str]]]): Mapping from fields to process to their names after this process, duplicates are allowed. Defaults to None
163
  process_every_value (bool): Processes the values in a list instead of the list as a value, similar to *var. Defaults to False
164
- use_query (bool): Whether to use dpath style queries. Defaults to False
165
  """
166
 
167
  field: Optional[str] = None
@@ -175,14 +242,18 @@ class FieldOperator(StreamInstanceOperator):
175
  def verify(self):
176
  super().verify()
177
 
178
- assert self.field is not None or self.field_to_field is not None, "Must supply a field to work on"
 
 
179
  assert (
180
  self.to_field is None or self.field_to_field is None
181
  ), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
182
  assert (
183
  self.field is None or self.field_to_field is None
184
  ), f"Can not apply operator both on {self.field} and on the mapping from fields to fields {self.field_to_field}"
185
- assert self._field_to_field, f"the from and to fields must be defined got: {self._field_to_field}"
 
 
186
 
187
  @abstractmethod
188
  def process_value(self, value: Any) -> Any:
@@ -195,11 +266,13 @@ class FieldOperator(StreamInstanceOperator):
195
  self._field_to_field = [(self.field, self.to_field)]
196
  else:
197
  try:
198
- self._field_to_field = [(k, v) for k, v in self.field_to_field.items()]
199
  except AttributeError:
200
  self._field_to_field = self.field_to_field
201
 
202
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
203
  for from_field, to_field in self._field_to_field:
204
  try:
205
  old_value = dict_get(
@@ -209,27 +282,40 @@ class FieldOperator(StreamInstanceOperator):
209
  default=self.get_default,
210
  not_exist_ok=self.not_exist_ok,
211
  )
212
- except TypeError as e:
213
- raise TypeError(f"Failed to get {from_field} from {instance}")
214
- if self.process_every_value:
215
- new_value = [self.process_value(value) for value in old_value]
216
- else:
217
- new_value = self.process_value(old_value)
 
 
 
 
 
 
 
218
  if self.use_query and is_subpath(from_field, to_field):
219
  dict_delete(instance, from_field)
220
- dict_set(instance, to_field, new_value, use_dpath=self.use_query, not_exist_ok=True)
 
 
 
 
 
 
221
  return instance
222
 
223
 
224
  class RenameFields(FieldOperator):
225
- """
226
- Renames fields
227
- """
228
 
229
  def process_value(self, value: Any) -> Any:
230
  return value
231
 
232
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
233
  res = super().process(instance=instance, stream_name=stream_name)
234
  vals = [x[1] for x in self._field_to_field]
235
  for key, _ in self._field_to_field:
@@ -241,32 +327,202 @@ class RenameFields(FieldOperator):
241
 
242
 
243
  class AddConstant(FieldOperator):
 
 
 
 
244
  """
245
- Adds a number, similar to field + add
 
 
 
 
 
 
 
 
 
246
  Args:
247
- add (float): sum to add
 
 
248
  """
249
 
250
- add: float
 
 
 
 
 
 
 
 
 
 
 
251
 
 
252
  def process_value(self, value: Any) -> Any:
253
- return value + self.add
 
 
 
254
 
 
 
 
 
255
 
256
- class ShuffleFieldValues(FieldOperator):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  """
258
- Shuffles an iterable value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  """
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  def process_value(self, value: Any) -> Any:
262
  res = list(value)
263
- random.shuffle(res)
264
  return res
265
 
266
 
267
  class JoinStr(FieldOperator):
268
- """
269
- Joins a list of strings (contents of a field), similar to str.join()
270
  Args:
271
  separator (str): text to put between values
272
  """
@@ -278,6 +534,25 @@ class JoinStr(FieldOperator):
278
 
279
 
280
  class Apply(StreamInstanceOperator):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  __allow_unexpected_arguments__ = True
282
  function: Callable = NonPositionalField(required=True)
283
  to_field: str = NonPositionalField(required=True)
@@ -292,25 +567,23 @@ class Apply(StreamInstanceOperator):
292
  else:
293
  parts.append(function.__name__)
294
 
295
- result = ".".join(parts)
296
-
297
- return result
298
 
299
  def str_to_function(self, function_str: str) -> Callable:
300
  splitted = function_str.split(".", 1)
301
  if len(splitted) == 1:
302
- return __builtins__[module_name]
 
 
 
 
 
 
303
  else:
304
- module_name, function_name = splitted
305
- if module_name in __builtins__:
306
- obj = __builtins__[module_name]
307
- elif module_name in globals():
308
- obj = globals()[module_name]
309
- else:
310
- obj = importlib.import_module(module_name)
311
- for part in function_name.split("."):
312
- obj = getattr(obj, part)
313
- return obj
314
 
315
  def prepare(self):
316
  super().prepare()
@@ -318,7 +591,9 @@ class Apply(StreamInstanceOperator):
318
  self.function = self.str_to_function(self.function)
319
  self._init_dict["function"] = self.function_to_str(self.function)
320
 
321
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
322
  argv = [instance[arg] for arg in self._argv]
323
  kwargs = {key: instance[val] for key, val in self._kwargs}
324
 
@@ -329,36 +604,36 @@ class Apply(StreamInstanceOperator):
329
 
330
 
331
  class ListFieldValues(StreamInstanceOperator):
332
- """
333
- Concatanates values of multiple fields into a list to list(fields)
334
- """
335
 
336
- fields: str
337
  to_field: str
338
  use_query: bool = False
339
 
340
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
341
  values = []
342
- for field in self.fields:
343
- values.append(dict_get(instance, field, use_dpath=self.use_query))
344
  instance[self.to_field] = values
345
  return instance
346
 
347
 
348
  class ZipFieldValues(StreamInstanceOperator):
349
- """
350
- Zips values of multiple fields similar to list(zip(*fields))
351
- """
352
 
353
  fields: str
354
  to_field: str
355
  longest: bool = False
356
  use_query: bool = False
357
 
358
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
359
  values = []
360
- for field in self.fields:
361
- values.append(dict_get(instance, field, use_dpath=self.use_query))
362
  if self.longest:
363
  zipped = zip_longest(*values)
364
  else:
@@ -368,16 +643,16 @@ class ZipFieldValues(StreamInstanceOperator):
368
 
369
 
370
  class IndexOf(StreamInstanceOperator):
371
- """
372
- Finds the location of one value in another (iterable) value similar to to_field=search_in.index(index_of)
373
- """
374
 
375
  search_in: str
376
  index_of: str
377
  to_field: str
378
  use_query: bool = False
379
 
380
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
381
  lst = dict_get(instance, self.search_in, use_dpath=self.use_query)
382
  item = dict_get(instance, self.index_of, use_dpath=self.use_query)
383
  instance[self.to_field] = lst.index(item)
@@ -385,9 +660,7 @@ class IndexOf(StreamInstanceOperator):
385
 
386
 
387
  class TakeByField(StreamInstanceOperator):
388
- """
389
- Takes value from one field based on another field similar to field[index]
390
- """
391
 
392
  field: str
393
  index: str
@@ -398,7 +671,9 @@ class TakeByField(StreamInstanceOperator):
398
  if self.to_field is None:
399
  self.to_field = self.field
400
 
401
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
402
  value = dict_get(instance, self.field, use_dpath=self.use_query)
403
  index_value = dict_get(instance, self.index, use_dpath=self.use_query)
404
  instance[self.to_field] = value[index_value]
@@ -406,8 +681,7 @@ class TakeByField(StreamInstanceOperator):
406
 
407
 
408
  class CopyFields(FieldOperator):
409
- """
410
- Copies specified fields from one field to another.
411
 
412
  Args:
413
  field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
@@ -421,14 +695,15 @@ class CopyFields(FieldOperator):
421
  class AddID(StreamInstanceOperator):
422
  id_field_name: str = "id"
423
 
424
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
425
  instance[self.id_field_name] = str(uuid.uuid4()).replace("-", "")
426
  return instance
427
 
428
 
429
  class CastFields(StreamInstanceOperator):
430
- """
431
- Casts specified fields to specified types.
432
 
433
  Args:
434
  types (Dict[str, str]): A dictionary mapping fields to their new types.
@@ -451,24 +726,28 @@ class CastFields(StreamInstanceOperator):
451
  def _cast_single(self, value, type, field):
452
  try:
453
  return self.types[type](value)
454
- except:
455
  if field not in self.failure_defaults:
456
  raise ValueError(
457
  f'Failed to cast field "{field}" with value {value} to type "{type}", and no default value is provided.'
458
- )
459
  return self.failure_defaults[field]
460
 
461
  def _cast_multiple(self, values, type, field):
462
  values = [self._cast_single(value, type, field) for value in values]
463
 
464
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
465
- for field, type in self.fields.items():
466
- value = dict_get(instance, field, use_dpath=self.use_nested_query)
 
 
467
  if self.cast_multiple:
468
- casted_value = self._cast_multiple(value, type, field)
469
  else:
470
- casted_value = self._cast_single(value, type, field)
471
- dict_set(instance, field, casted_value, use_dpath=self.use_nested_query)
 
 
472
  return instance
473
 
474
 
@@ -491,13 +770,14 @@ class DivideAllFieldsBy(StreamInstanceOperator):
491
  strict: bool = False
492
  recursive: bool = True
493
 
494
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
495
  return recursive_divide(instance, self.divisor, strict=self.strict)
496
 
497
 
498
  class ArtifactFetcherMixin:
499
- """
500
- Provides a way to fetch and cache artifacts in the system.
501
 
502
  Args:
503
  cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
@@ -514,8 +794,7 @@ class ArtifactFetcherMixin:
514
 
515
 
516
  class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
517
- """
518
- Applies value operators to each instance in a stream based on specified fields.
519
 
520
  Args:
521
  value_field (str): The field containing the value to be operated on.
@@ -529,7 +808,9 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
529
  default_operators: List[str] = None
530
  fields_to_treat_as_list: List[str] = NonPositionalField(default_factory=list)
531
 
532
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
 
 
533
  operator_names = instance.get(self.operators_field)
534
  if operator_names is None:
535
  assert (
@@ -542,35 +823,228 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
542
 
543
  for name in operator_names:
544
  operator = self.get_artifact(name)
545
- for field in self.inputs_fields:
546
- value = instance[field]
547
- if field in self.fields_to_treat_as_list:
548
- instance[field] = [operator.process(v) for v in value]
549
  else:
550
- instance[field] = operator.process(instance[field])
551
 
552
  return instance
553
 
554
 
555
  class FilterByValues(SingleStreamOperator):
 
 
 
 
556
  """
557
- Filters a stream, yielding only instances that match specified values.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
 
559
  Args:
560
- values (Dict[str, Any]): The values that instances should match to be included in the output.
561
  """
562
 
563
- values: Dict[str, Any]
 
 
 
 
 
 
 
 
564
 
565
- def process(self, stream: Stream, stream_name: str = None) -> Generator:
566
  for instance in stream:
567
- if all(instance[key] == value for key, value in self.values.items()):
 
 
 
 
 
 
 
 
568
  yield instance
569
 
570
 
571
- class Unique(SingleStreamReducer):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
  """
573
- Reduces a stream to unique instances based on specified fields.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
  Args:
576
  fields (List[str]): The fields that should be unique in each instance.
@@ -581,8 +1055,8 @@ class Unique(SingleStreamReducer):
581
  @staticmethod
582
  def to_tuple(instance: dict, fields: List[str]) -> tuple:
583
  result = []
584
- for field in fields:
585
- value = instance[field]
586
  if isinstance(value, list):
587
  value = tuple(value)
588
  result.append(value)
@@ -598,8 +1072,7 @@ class Unique(SingleStreamReducer):
598
 
599
 
600
  class SplitByValue(MultiStreamOperator):
601
- """
602
- Splits a MultiStream into multiple streams based on unique values in specified fields.
603
 
604
  Args:
605
  fields (List[str]): The fields to use when splitting the MultiStream.
@@ -615,17 +1088,20 @@ class SplitByValue(MultiStreamOperator):
615
  for stream_name, stream in multi_stream.items():
616
  stream_unique_values = uniques[stream_name]
617
  for unique_values in stream_unique_values:
618
- filtering_values = {field: value for field, value in zip(self.fields, unique_values)}
619
- filtered_streams = FilterByValues(values=filtering_values)._process_single_stream(stream)
620
- filtered_stream_name = stream_name + "_" + nested_tuple_to_string(unique_values)
 
 
 
 
621
  result[filtered_stream_name] = filtered_streams
622
 
623
  return MultiStream(result)
624
 
625
 
626
  class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
627
- """
628
- Applies stream operators to a stream based on specified fields in each instance.
629
 
630
  Args:
631
  field (str): The field containing the operators to be applied.
@@ -635,7 +1111,7 @@ class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
635
  field: str
636
  reversed: bool = False
637
 
638
- def process(self, stream: Stream, stream_name: str = None) -> Generator:
639
  first_instance = stream.peak()
640
 
641
  operators = first_instance.get(self.field, [])
@@ -647,16 +1123,67 @@ class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
647
 
648
  for operator_name in operators:
649
  operator = self.get_artifact(operator_name)
650
- assert isinstance(operator, StreamingOperator), f"Operator {operator_name} must be a SingleStreamOperator"
 
 
651
 
652
  stream = operator(MultiStream({"tmp": stream}))["tmp"]
653
 
654
  yield from stream
655
 
656
 
657
- class AddFieldNamePrefix(StreamInstanceOperator):
 
 
 
 
 
658
  """
659
- Adds a prefix to each field name in each instance of a stream.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
 
661
  Args:
662
  prefix_dict (Dict[str, str]): A dictionary mapping stream names to prefixes.
@@ -667,13 +1194,17 @@ class AddFieldNamePrefix(StreamInstanceOperator):
667
  def prepare(self):
668
  return super().prepare()
669
 
670
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
671
- return {self.prefix_dict[stream_name] + key: value for key, value in instance.items()}
 
 
 
 
 
672
 
673
 
674
  class MergeStreams(MultiStreamOperator):
675
- """
676
- Merges multiple streams into a single stream.
677
 
678
  Args:
679
  new_stream_name (str): The name of the new stream resulting from the merge.
@@ -681,37 +1212,43 @@ class MergeStreams(MultiStreamOperator):
681
  origin_stream_name_field_name (str): The field name for the origin stream name.
682
  """
683
 
 
684
  new_stream_name: str = "all"
685
  add_origin_stream_name: bool = True
686
  origin_stream_name_field_name: str = "origin"
687
 
688
  def merge(self, multi_stream):
689
  for stream_name, stream in multi_stream.items():
690
- for instance in stream:
691
- if self.add_origin_stream_name:
692
- instance[self.origin_stream_name_field_name] = stream_name
693
- yield instance
 
694
 
695
  def process(self, multi_stream: MultiStream) -> MultiStream:
696
- return MultiStream({self.new_stream_name: Stream(self.merge, gen_kwargs={"multi_stream": multi_stream})})
 
 
 
 
 
 
697
 
698
 
699
  class Shuffle(PagedStreamOperator):
700
- """
701
- Shuffles the order of instances in each page of a stream.
702
 
703
  Args:
704
  page_size (int): The size of each page in the stream. Defaults to 1000.
705
  """
706
 
707
- def process(self, page: List[Dict], stream_name: str = None) -> Generator:
708
- random.shuffle(page)
709
  yield from page
710
 
711
 
712
  class EncodeLabels(StreamInstanceOperator):
713
- """
714
- Encode labels of specified fields together a into integers.
715
 
716
  Args:
717
  fields (List[str]): The fields to encode together.
@@ -723,16 +1260,20 @@ class EncodeLabels(StreamInstanceOperator):
723
  self.encoder = {}
724
  return super()._process_multi_stream(multi_stream)
725
 
726
- def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
727
- for field in self.fields:
728
- values = dict_get(instance, field, use_dpath=True)
 
 
729
  if not isinstance(values, list):
730
  values = [values]
731
  for value in values:
732
  if value not in self.encoder:
733
  self.encoder[value] = len(self.encoder)
734
  new_values = [self.encoder[value] for value in values]
735
- dict_set(instance, field, new_values, use_dpath=True, set_multiple=True)
 
 
736
 
737
  return instance
738
 
@@ -740,7 +1281,7 @@ class EncodeLabels(StreamInstanceOperator):
740
  class StreamRefiner(SingleStreamOperator):
741
  max_instances: int = None
742
 
743
- def process(self, stream: Stream, stream_name: str = None) -> Generator:
744
  if self.max_instances is not None:
745
  yield from stream.take(self.max_instances)
746
  else:
@@ -748,8 +1289,7 @@ class StreamRefiner(SingleStreamOperator):
748
 
749
 
750
  class DeterministicBalancer(StreamRefiner):
751
- """
752
- A class used to balance streams deterministically.
753
 
754
  Attributes:
755
  fields (List[str]): A list of field names to be used in determining the signature of an instance.
@@ -763,19 +1303,26 @@ class DeterministicBalancer(StreamRefiner):
763
  fields: List[str]
764
 
765
  def signature(self, instance):
766
- return str(tuple(dict_get(instance, field, use_dpath=True) for field in self.fields))
 
 
767
 
768
- def process(self, stream: Stream, stream_name: str = None) -> Generator:
769
  counter = collections.Counter()
770
 
771
  for instance in stream:
772
  counter[self.signature(instance)] += 1
773
 
 
 
 
774
  lowest_count = counter.most_common()[-1][-1]
775
 
776
  max_total_instances_per_sign = lowest_count
777
  if self.max_instances is not None:
778
- max_total_instances_per_sign = min(lowest_count, self.max_instances // len(counter))
 
 
779
 
780
  counter = collections.Counter()
781
 
@@ -791,8 +1338,8 @@ class LengthBalancer(DeterministicBalancer):
791
 
792
  def signature(self, instance):
793
  total_len = 0
794
- for field in self.fields:
795
- total_len += len(dict_get(instance, field, use_dpath=True))
796
  for i, val in enumerate(self.segments_boundaries):
797
  if total_len < val:
798
  return i
 
1
  import collections
2
  import importlib
 
3
  import uuid
4
  from abc import abstractmethod
5
+ from collections import Counter
6
  from copy import deepcopy
7
  from dataclasses import field
8
  from itertools import zip_longest
 
19
  )
20
 
21
  from .artifact import Artifact, fetch_artifact
22
+ from .dataclass import NonPositionalField
23
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
24
  from .operator import (
25
  MultiStream,
 
32
  StreamInstanceOperator,
33
  StreamSource,
34
  )
35
+ from .random_utils import get_random, nested_seed
36
+ from .stream import Stream
37
  from .text_utils import nested_tuple_to_string
38
  from .utils import flatten_dict
39
 
40
 
41
  class FromIterables(StreamInitializerOperator):
42
+ """Creates a MultiStream from iterables.
 
43
 
44
  Args:
45
  iterables (Dict[str, Iterable]): A dictionary where each key-value pair represents a stream name and its corresponding iterable.
 
69
  strict (bool): If True, the mapping is applied strictly. That means if a value
70
  does not exist in the mapper, it will raise a KeyError. If False, values
71
  that are not present in the mapper are kept as they are.
72
+ process_every_value (bool): If True, all fields to be mapped should be lists, and the mapping
73
+ is to be applied to their individual elements. If False, mapping is only applied to a field
74
+ containing a single value.
75
+
76
+ Examples:
77
+ MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})
78
+ replaces '1' with 'hi' and '2' with 'bye' in field 'a' in all instances of all streams:
79
+ instance {"a":"1", "b": 2} becomes {"a":"hi", "b": 2}.
80
+
81
+ MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, process_every_element=True)
82
+ Assuming field 'a' is a list of values, potentially including "1"-s and "2"-s, this replaces
83
+ each such "1" with "hi" and "2" -- with "bye" in all instances of all streams:
84
+ instance {"a": ["1", "2"], "b": 2} becomes {"a": ["hi", "bye"], "b": 2}.
85
+
86
+ MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, strict=True)
87
+ To ensure that all values of field 'a' are mapped in every instance, use strict=True.
88
+ Input instance {"a":"3", "b": 2} will raise an exception per the above call,
89
+ because "3" is not a key in the mapper of "a".
90
  """
91
 
92
  mappers: Dict[str, Dict[str, str]]
93
  strict: bool = True
94
+ use_query: bool = False
95
+ process_every_value: bool = False
96
 
97
  def verify(self):
98
  # make sure the mappers are valid
99
  for key, mapper in self.mappers.items():
100
+ assert isinstance(
101
+ mapper, dict
102
+ ), f"Mapper for given field {key} should be a dict, got {type(mapper)}"
103
+ for k in mapper.keys():
104
+ assert isinstance(
105
+ k, str
106
+ ), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
107
+
108
+ def process(
109
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
110
+ ) -> Dict[str, Any]:
111
  for key, mapper in self.mappers.items():
112
  value = dict_get(instance, key, use_dpath=self.use_query)
113
  if value is not None:
114
+ if (self.process_every_value is True) and (not isinstance(value, list)):
115
+ raise ValueError(
116
+ f"'process_every_field' == True is allowed only when all fields which have mappers, i.e., {list(self.mappers.keys())} are lists. Instace = {instance}"
117
+ )
118
+ if isinstance(value, list):
119
+ if self.process_every_value:
120
+ for i, val in enumerate(value):
121
+ val = str(val) # make sure the value is a string
122
+ if self.strict and (val not in mapper):
123
+ raise KeyError(
124
+ f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
125
+ )
126
+ if val in mapper:
127
+ # replace just that member of value (value is a list)
128
+ value[i] = mapper[val]
129
+ dict_set(instance, key, value, use_dpath=self.use_query)
130
+ else: # field is a list, and process_every_value == False
131
+ if self.strict: # whole lists can not be mapped by a string-to-something mapper
132
+ raise KeyError(
133
+ f"A whole list ({value}) in the instance can not be mapped by a field mapper."
134
+ )
135
+ else: # value is not a list, implying process_every_value == False
136
+ value = str(value) # make sure the value is a string
137
+ if self.strict and (value not in mapper):
138
+ raise KeyError(
139
+ f"value '{value}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
140
+ )
141
  if value in mapper:
142
  dict_set(instance, key, mapper[value], use_dpath=self.use_query)
143
+
144
  return instance
145
 
146
 
147
  class FlattenInstances(StreamInstanceOperator):
148
+ """Flattens each instance in a stream, making nested dictionary entries into top-level entries.
 
149
 
150
  Args:
151
  parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
 
155
  parent_key: str = ""
156
  sep: str = "_"
157
 
158
+ def process(
159
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
160
+ ) -> Dict[str, Any]:
161
  return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
162
 
163
 
164
  class AddFields(StreamInstanceOperator):
165
+ """Adds specified fields to each instance in a given stream or all streams (default) If fields exist, updates them.
 
166
 
167
  Args:
168
  fields (Dict[str, object]): The fields to add to each instance.
169
+ use_query (bool) : Use '/' to access inner fields
170
+ use_deepcopy (bool) : Deep copy the input value to avoid later modifications
171
+
172
+ Examples:
173
+ # Add a 'classes' field with a value of a list "positive" and "negative" to all streams
174
+ AddFields(fields={"classes": ["positive","negatives"]})
175
+
176
+ # Add a 'start' field under the 'span' field with a value of 0 to all streams
177
+ AddFields(fields={"span/start": 0}
178
+
179
+ # Add a 'classes' field with a value of a list "positive" and "negative" to 'train' stream
180
+ AddFields(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})
181
+
182
+ # Add a 'classes' field on a given list, prevent modification of original list
183
+ # from changing the instance.
184
+ AddFields(fields={"classes": alist}), use_deepcopy=True)
185
  """
186
 
187
  fields: Dict[str, object]
188
  use_query: bool = False
189
  use_deepcopy: bool = False
190
 
191
+ def process(
192
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
193
+ ) -> Dict[str, Any]:
194
  if self.use_query:
195
  for key, value in self.fields.items():
196
  if self.use_deepcopy:
 
204
 
205
 
206
  class RemoveFields(StreamInstanceOperator):
207
+ """Remove specified fields to each instance in a stream.
 
208
 
209
  Args:
210
+ fields (List[str]): The fields to remove from each instance.
211
  """
212
 
213
  fields: List[str]
214
 
215
+ def process(
216
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
217
+ ) -> Dict[str, Any]:
218
+ for field_name in self.fields:
219
+ del instance[field_name]
220
  return instance
221
 
222
 
223
  class FieldOperator(StreamInstanceOperator):
224
+ """A general stream that processes the values of a field (or multiple ones.
225
+
226
  Args:
227
  field (Optional[str]): The field to process, if only a single one is passed Defaults to None
228
  to_field (Optional[str]): Field name to save, if only one field is to be saved, if None is passed the operator would happen in-place and replace "field" Defaults to None
229
  field_to_field (Optional[Union[List[Tuple[str, str]], Dict[str, str]]]): Mapping from fields to process to their names after this process, duplicates are allowed. Defaults to None
230
  process_every_value (bool): Processes the values in a list instead of the list as a value, similar to *var. Defaults to False
231
+ use_query (bool): Whether to use dpath style queries. Defaults to False.
232
  """
233
 
234
  field: Optional[str] = None
 
242
  def verify(self):
243
  super().verify()
244
 
245
+ assert (
246
+ self.field is not None or self.field_to_field is not None
247
+ ), "Must supply a field to work on"
248
  assert (
249
  self.to_field is None or self.field_to_field is None
250
  ), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
251
  assert (
252
  self.field is None or self.field_to_field is None
253
  ), f"Can not apply operator both on {self.field} and on the mapping from fields to fields {self.field_to_field}"
254
+ assert (
255
+ self._field_to_field
256
+ ), f"the from and to fields must be defined got: {self._field_to_field}"
257
 
258
  @abstractmethod
259
  def process_value(self, value: Any) -> Any:
 
266
  self._field_to_field = [(self.field, self.to_field)]
267
  else:
268
  try:
269
+ self._field_to_field = list(self.field_to_field.items())
270
  except AttributeError:
271
  self._field_to_field = self.field_to_field
272
 
273
+ def process(
274
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
275
+ ) -> Dict[str, Any]:
276
  for from_field, to_field in self._field_to_field:
277
  try:
278
  old_value = dict_get(
 
282
  default=self.get_default,
283
  not_exist_ok=self.not_exist_ok,
284
  )
285
+ except Exception as e:
286
+ raise ValueError(
287
+ f"Failed to get '{from_field}' from {instance} due to : {e}"
288
+ ) from e
289
+ try:
290
+ if self.process_every_value:
291
+ new_value = [self.process_value(value) for value in old_value]
292
+ else:
293
+ new_value = self.process_value(old_value)
294
+ except Exception as e:
295
+ raise ValueError(
296
+ f"Failed to process '{from_field}' from {instance} due to : {e}"
297
+ ) from e
298
  if self.use_query and is_subpath(from_field, to_field):
299
  dict_delete(instance, from_field)
300
+ dict_set(
301
+ instance,
302
+ to_field,
303
+ new_value,
304
+ use_dpath=self.use_query,
305
+ not_exist_ok=True,
306
+ )
307
  return instance
308
 
309
 
310
  class RenameFields(FieldOperator):
311
+ """Renames fields."""
 
 
312
 
313
  def process_value(self, value: Any) -> Any:
314
  return value
315
 
316
+ def process(
317
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
318
+ ) -> Dict[str, Any]:
319
  res = super().process(instance=instance, stream_name=stream_name)
320
  vals = [x[1] for x in self._field_to_field]
321
  for key, _ in self._field_to_field:
 
327
 
328
 
329
  class AddConstant(FieldOperator):
330
+ """Adds a value, similar to add + field.
331
+
332
+ Args:
333
+ add: sum to add.
334
  """
335
+
336
+ add: Any
337
+
338
+ def process_value(self, value: Any) -> Any:
339
+ return self.add + value
340
+
341
+
342
+ class Augmentor(StreamInstanceOperator):
343
+ """A stream that augments the values of either the task input fields before rendering with the template, or the input passed to the model after rendering of the template.
344
+
345
  Args:
346
+ augment_model_input: Whether to augment the input to the model.
347
+ augment_task_input: Whether to augment the task input fields. The specific fields are defined in the FormTask operator.
348
+
349
  """
350
 
351
+ augment_task_input: bool = False
352
+ augment_model_input: bool = False
353
+
354
+ def verify(self):
355
+ assert not (
356
+ self.augment_task_input and self.augment_model_input
357
+ ), "Augmentor must set either 'augment_task_input' and 'augment_model_input' but not both"
358
+ assert (
359
+ self.augment_task_input or self.augment_model_input
360
+ ), "Augmentor must set either 'augment_task_input' or 'augment_model_input'"
361
+
362
+ super().verify()
363
 
364
+ @abstractmethod
365
  def process_value(self, value: Any) -> Any:
366
+ pass
367
+
368
+ def prepare(self):
369
+ pass
370
 
371
+ def set_task_input_fields(self, task_input_fields: List[str]):
372
+ self._task_input_fields = [
373
+ "inputs/" + task_input_field for task_input_field in task_input_fields
374
+ ]
375
 
376
+ def process(
377
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
378
+ ) -> Dict[str, Any]:
379
+ if self.augment_task_input:
380
+ assert (
381
+ len(self._task_input_fields) > 0
382
+ ), "No augmentable input fields were defined in FormTask, and augmentation was requested. Specify the fields to augment in 'argumentable_inputs' attribute of the FormTask."
383
+ fields = self._task_input_fields
384
+ assert not self.augment_model_input
385
+
386
+ if self.augment_model_input:
387
+ fields = ["source"]
388
+ assert not self.augment_task_input
389
+
390
+ for field_name in fields:
391
+ try:
392
+ old_value = dict_get(
393
+ instance,
394
+ field_name,
395
+ use_dpath=True,
396
+ default="",
397
+ not_exist_ok=False,
398
+ )
399
+ except TypeError as e:
400
+ raise TypeError(f"Failed to get {field_name} from {instance}") from e
401
+
402
+ # We are setting a nested seed based on the value processed, to ensure that
403
+ # the augmentation randomizations do not effect other randomization choices and
404
+ # to make the augmentation randomization choices different for each text.
405
+ with nested_seed(str(hash(old_value))):
406
+ try:
407
+ new_value = self.process_value(old_value)
408
+ except Exception as e:
409
+ raise RuntimeError(
410
+ f"Error augmenting value '{old_value}' from '{field_name}' in instance: {instance}"
411
+ ) from e
412
+ dict_set(instance, field_name, new_value, use_dpath=True, not_exist_ok=True)
413
+ return instance
414
+
415
+
416
+ class NullAugmentor(Augmentor):
417
+ def verify(self):
418
+ pass
419
+
420
+ def process_value(self, value: Any) -> Any:
421
+ return value
422
+
423
+
424
+ class AugmentWhitespace(Augmentor):
425
+ """Augments the inputs by replace existing whitespace with other whitespace.
426
+
427
+ Currently each whitespace is replaced by a random choice of 1-3 whitespace charaters (spcae, tab, newline).
428
  """
429
+
430
+ def process_value(self, value: Any) -> Any:
431
+ import re
432
+
433
+ words = re.split(r"(\s+)", value)
434
+ new_value = ""
435
+
436
+ for word in words:
437
+ if word.isspace():
438
+ new_value += get_random().choice(
439
+ ["\n", "\t", " "]
440
+ ) * get_random().randint(1, 3)
441
+ else:
442
+ new_value += word
443
+ return new_value
444
+
445
+
446
+ class AugmentSuffix(Augmentor):
447
+ r"""Augments the input by appending to it a randomly selected (typically, whitespace) pattern.
448
+
449
+ Args:
450
+ suffixes : the potential (typically, whitespace) patterns to select from.
451
+ The dictionary version allows to specify relative weights of the different patterns.
452
+ remove_existing_trailing_whitespaces : allows to first clean existing trailing whitespaces.
453
+ The selected pattern is then appended to the potentially trimmed at its end input.
454
+
455
+
456
+ Examples:
457
+ to append a '\n' or a '\t' to the end of the input, employ
458
+ AugmentSuffix(augment_model_input=True, suffixes=['\n','\t'])
459
+ If '\n' is preferred over '\t', at 2:1 ratio, employ
460
+ AugmentSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1})
461
+ which will append '\n' twice as often as '\t'.
462
+
463
  """
464
 
465
+ suffixes: Optional[Union[List[str], Dict[str, int]]] = [" ", "\n", "\t"]
466
+ remove_existing_trailing_whitespaces: Optional[bool] = False
467
+
468
+ def verify(self):
469
+ assert (
470
+ isinstance(self.suffixes, list) or isinstance(self.suffixes, dict)
471
+ ), f"Argument 'suffixes' should be either a list or a dictionary, whereas it is of type {type(self.suffixes)}"
472
+
473
+ if isinstance(self.suffixes, dict):
474
+ for k, v in self.suffixes.items():
475
+ assert isinstance(
476
+ k, str
477
+ ), f"suffixes should map strings, whereas key {k!s} is of type {type(k)}"
478
+ assert isinstance(
479
+ v, int
480
+ ), f"suffixes should map to ints, whereas value {v!s} is of type {type(v)}"
481
+ else:
482
+ for k in self.suffixes:
483
+ assert isinstance(
484
+ k, str
485
+ ), f"suffixes should be a list of strings, whereas member {k!s} is of type {type(k)}"
486
+
487
+ self.pats = (
488
+ self.suffixes
489
+ if isinstance(self.suffixes, list)
490
+ else [k for k, v in self.suffixes.items()]
491
+ )
492
+ total_weight = (
493
+ len(self.pats)
494
+ if isinstance(self.suffixes, list)
495
+ else sum([v for k, v in self.suffixes.items()])
496
+ )
497
+ self.weights = (
498
+ [1.0 / total_weight] * len(self.pats)
499
+ if isinstance(self.suffixes, list)
500
+ else [float(self.suffixes[p]) / total_weight for p in self.pats]
501
+ )
502
+ super().verify()
503
+
504
+ def process_value(self, value: Any) -> Any:
505
+ assert value is not None, "input value should not be None"
506
+ new_value = str(value)
507
+ if self.remove_existing_trailing_whitespaces:
508
+ new_value = new_value.rstrip()
509
+ new_value += get_random().choices(self.pats, self.weights, k=1)[0]
510
+
511
+ return new_value
512
+
513
+
514
+ class ShuffleFieldValues(FieldOperator):
515
+ """Shuffles an iterable value."""
516
+
517
  def process_value(self, value: Any) -> Any:
518
  res = list(value)
519
+ get_random().shuffle(res)
520
  return res
521
 
522
 
523
  class JoinStr(FieldOperator):
524
+ """Joins a list of strings (contents of a field), similar to str.join().
525
+
526
  Args:
527
  separator (str): text to put between values
528
  """
 
534
 
535
 
536
  class Apply(StreamInstanceOperator):
537
+ """A class used to apply a python function and store the result in a field.
538
+
539
+ Args:
540
+ function (str): name of function.
541
+ to_field (str): the field to store the result
542
+ additional arguments are field names passed to the function
543
+
544
+ Examples:
545
+ Store in field "b" the uppercase string of the value in field "a"
546
+ Apply("a", function=str.upper, to_field="b")
547
+
548
+ Dump the json representation of field "t" and store back in the same field.
549
+ Apply("t", function=json.dumps, to_field="t")
550
+
551
+ Set the time in a field 'b'.
552
+ Apply(function=time.time, to_field="b")
553
+
554
+ """
555
+
556
  __allow_unexpected_arguments__ = True
557
  function: Callable = NonPositionalField(required=True)
558
  to_field: str = NonPositionalField(required=True)
 
567
  else:
568
  parts.append(function.__name__)
569
 
570
+ return ".".join(parts)
 
 
571
 
572
  def str_to_function(self, function_str: str) -> Callable:
573
  splitted = function_str.split(".", 1)
574
  if len(splitted) == 1:
575
+ return __builtins__[splitted[0]]
576
+
577
+ module_name, function_name = splitted
578
+ if module_name in __builtins__:
579
+ obj = __builtins__[module_name]
580
+ elif module_name in globals():
581
+ obj = globals()[module_name]
582
  else:
583
+ obj = importlib.import_module(module_name)
584
+ for part in function_name.split("."):
585
+ obj = getattr(obj, part)
586
+ return obj
 
 
 
 
 
 
587
 
588
  def prepare(self):
589
  super().prepare()
 
591
  self.function = self.str_to_function(self.function)
592
  self._init_dict["function"] = self.function_to_str(self.function)
593
 
594
+ def process(
595
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
596
+ ) -> Dict[str, Any]:
597
  argv = [instance[arg] for arg in self._argv]
598
  kwargs = {key: instance[val] for key, val in self._kwargs}
599
 
 
604
 
605
 
606
  class ListFieldValues(StreamInstanceOperator):
607
+ """Concatenates values of multiple fields into a list, and assigns it to a new field."""
 
 
608
 
609
+ fields: List[str]
610
  to_field: str
611
  use_query: bool = False
612
 
613
+ def process(
614
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
615
+ ) -> Dict[str, Any]:
616
  values = []
617
+ for field_name in self.fields:
618
+ values.append(dict_get(instance, field_name, use_dpath=self.use_query))
619
  instance[self.to_field] = values
620
  return instance
621
 
622
 
623
  class ZipFieldValues(StreamInstanceOperator):
624
+ """Zips values of multiple fields similar to list(zip(*fields))."""
 
 
625
 
626
  fields: str
627
  to_field: str
628
  longest: bool = False
629
  use_query: bool = False
630
 
631
+ def process(
632
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
633
+ ) -> Dict[str, Any]:
634
  values = []
635
+ for field_name in self.fields:
636
+ values.append(dict_get(instance, field_name, use_dpath=self.use_query))
637
  if self.longest:
638
  zipped = zip_longest(*values)
639
  else:
 
643
 
644
 
645
  class IndexOf(StreamInstanceOperator):
646
+ """Finds the location of one value in another (iterable) value similar to to_field=search_in.index(index_of)."""
 
 
647
 
648
  search_in: str
649
  index_of: str
650
  to_field: str
651
  use_query: bool = False
652
 
653
+ def process(
654
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
655
+ ) -> Dict[str, Any]:
656
  lst = dict_get(instance, self.search_in, use_dpath=self.use_query)
657
  item = dict_get(instance, self.index_of, use_dpath=self.use_query)
658
  instance[self.to_field] = lst.index(item)
 
660
 
661
 
662
  class TakeByField(StreamInstanceOperator):
663
+ """Takes value from one field based on another field similar to field[index]."""
 
 
664
 
665
  field: str
666
  index: str
 
671
  if self.to_field is None:
672
  self.to_field = self.field
673
 
674
+ def process(
675
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
676
+ ) -> Dict[str, Any]:
677
  value = dict_get(instance, self.field, use_dpath=self.use_query)
678
  index_value = dict_get(instance, self.index, use_dpath=self.use_query)
679
  instance[self.to_field] = value[index_value]
 
681
 
682
 
683
  class CopyFields(FieldOperator):
684
+ """Copies specified fields from one field to another.
 
685
 
686
  Args:
687
  field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
 
695
  class AddID(StreamInstanceOperator):
696
  id_field_name: str = "id"
697
 
698
+ def process(
699
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
700
+ ) -> Dict[str, Any]:
701
  instance[self.id_field_name] = str(uuid.uuid4()).replace("-", "")
702
  return instance
703
 
704
 
705
  class CastFields(StreamInstanceOperator):
706
+ """Casts specified fields to specified types.
 
707
 
708
  Args:
709
  types (Dict[str, str]): A dictionary mapping fields to their new types.
 
726
  def _cast_single(self, value, type, field):
727
  try:
728
  return self.types[type](value)
729
+ except Exception as e:
730
  if field not in self.failure_defaults:
731
  raise ValueError(
732
  f'Failed to cast field "{field}" with value {value} to type "{type}", and no default value is provided.'
733
+ ) from e
734
  return self.failure_defaults[field]
735
 
736
  def _cast_multiple(self, values, type, field):
737
  values = [self._cast_single(value, type, field) for value in values]
738
 
739
+ def process(
740
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
741
+ ) -> Dict[str, Any]:
742
+ for field_name, type in self.fields.items():
743
+ value = dict_get(instance, field_name, use_dpath=self.use_nested_query)
744
  if self.cast_multiple:
745
+ casted_value = self._cast_multiple(value, type, field_name)
746
  else:
747
+ casted_value = self._cast_single(value, type, field_name)
748
+ dict_set(
749
+ instance, field_name, casted_value, use_dpath=self.use_nested_query
750
+ )
751
  return instance
752
 
753
 
 
770
  strict: bool = False
771
  recursive: bool = True
772
 
773
+ def process(
774
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
775
+ ) -> Dict[str, Any]:
776
  return recursive_divide(instance, self.divisor, strict=self.strict)
777
 
778
 
779
  class ArtifactFetcherMixin:
780
+ """Provides a way to fetch and cache artifacts in the system.
 
781
 
782
  Args:
783
  cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
 
794
 
795
 
796
  class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
797
+ """Applies value operators to each instance in a stream based on specified fields.
 
798
 
799
  Args:
800
  value_field (str): The field containing the value to be operated on.
 
808
  default_operators: List[str] = None
809
  fields_to_treat_as_list: List[str] = NonPositionalField(default_factory=list)
810
 
811
+ def process(
812
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
813
+ ) -> Dict[str, Any]:
814
  operator_names = instance.get(self.operators_field)
815
  if operator_names is None:
816
  assert (
 
823
 
824
  for name in operator_names:
825
  operator = self.get_artifact(name)
826
+ for field_name in self.inputs_fields:
827
+ value = instance[field_name]
828
+ if field_name in self.fields_to_treat_as_list:
829
+ instance[field_name] = [operator.process(v) for v in value]
830
  else:
831
+ instance[field_name] = operator.process(instance[field_name])
832
 
833
  return instance
834
 
835
 
836
  class FilterByValues(SingleStreamOperator):
837
+ """Filters a stream, yielding only instances that match specified values in the provided fields.
838
+
839
+ Args:
840
+ values (Dict[str, Any]): For each field, the values that instances should match to be included in the output.
841
  """
842
+
843
+ required_values: Dict[str, Any]
844
+
845
+ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
846
+ for instance in stream:
847
+ filter = False
848
+ for key, value in self.required_values.items():
849
+ if key not in instance:
850
+ raise ValueError(
851
+ f"Required filter field ('{key}') in FilterByValues is not found in {instance}"
852
+ )
853
+ if instance[key] != value:
854
+ filter = True
855
+ if not filter:
856
+ yield instance
857
+
858
+
859
+ class ExtractFieldValues(MultiStreamOperator):
860
+ field: str
861
+ stream_name: str
862
+ overall_top_frequency_percent: Optional[int] = 100
863
+ min_frequency_percent: Optional[int] = 0
864
+ to_field: str
865
+ process_every_value: Optional[bool] = False
866
+
867
+ """
868
+ Extract the unique values of a field ('field') of a given stream ('stream_name') and store (the most frequent of) them
869
+ as a list in a new field ('to_field') in all streams.
870
+
871
+ More specifically, sort all the unique values encountered in field 'field' by decreasing order of frequency.
872
+ When 'overall_top_frequency_percent' is smaller than 100, trim the list from bottom, so that the total frequency of
873
+ the remaining values makes 'overall_top_frequency_percent' of the total number of instances in the stream.
874
+ When 'min_frequency_percent' is larger than 0, remove from the list any value whose relative frequency makes
875
+ less than 'min_frequency_percent' of the total number of instances in the stream.
876
+ At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default values.
877
+
878
+ Examples:
879
+
880
+ ExtractFieldValues(stream_name="train", field="label", to_field="classes") - extracts all the unique values of
881
+ field 'label', sorts them by decreasing frequency, and stores the resulting list in field 'classes' of each and
882
+ every instance in all streams.
883
+
884
+ ExtractFieldValues(stream_name="train", field="labels", to_field="classes", process_every_value=True) -
885
+ in case that field 'labels' contains a list of values (and not a single value) - track the occurrences of all the possible
886
+ value members in these lists, and report the most frequent values.
887
+ if process_every_value=False, track the most frequent whole lists, and report those (as a list of lists) in field
888
+ 'to_field' of each instance of all streams.
889
+
890
+ ExtractFieldValues(stream_name="train", field="label", to_field="classes",overall_top_frequency_percent=80) -
891
+ extracts the most frequent possible values of field 'label' that together cover at least 80% of the instances of stream_name,
892
+ and stores them in field 'classes' of each instance of all streams.
893
+
894
+ ExtractFieldValues(stream_name="train", field="label", to_field="classes",min_frequency_percent=5) -
895
+ extracts all possible values of field 'label' that cover, each, at least 5% of the instances.
896
+ Stores these values, sorted by decreasing order of frequency, in field 'classes' of each instance in all streams.
897
+ """
898
+
899
+ def verify(self):
900
+ assert (
901
+ self.overall_top_frequency_percent <= 100
902
+ and self.overall_top_frequency_percent >= 0
903
+ ), "'overall_top_frequency_percent' must be between 0 and 100"
904
+ assert (
905
+ self.min_frequency_percent <= 100 and self.min_frequency_percent >= 0
906
+ ), "'min_frequency_percent' must be between 0 and 100"
907
+ assert not (
908
+ self.overall_top_frequency_percent < 100 and self.min_frequency_percent > 0
909
+ ), "At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default value"
910
+ super().verify()
911
+
912
+ def process(self, multi_stream: MultiStream) -> MultiStream:
913
+ stream = multi_stream[self.stream_name]
914
+ all_values = []
915
+ for instance in stream:
916
+ if (not isinstance(instance[self.field], list)) and (
917
+ self.process_every_value is True
918
+ ):
919
+ raise ValueError(
920
+ "'process_every_field' is allowed to change to 'True' only for fields whose contents are lists"
921
+ )
922
+ if (not isinstance(instance[self.field], list)) or (
923
+ self.process_every_value is False
924
+ ):
925
+ # either not a list, or is a list but process_every_value == False : view contetns of 'field' as one entity whose occurrences are counted.
926
+ all_values.append(
927
+ (*instance[self.field],)
928
+ if isinstance(instance[self.field], list)
929
+ else instance[self.field]
930
+ ) # convert to a tuple if list, to enable the use of Counter which would not accept
931
+ # a list as an entity to count its occurrences
932
+ else:
933
+ # content of 'field' is a list and process_every_value == True: add one occurrence on behalf of each individual value
934
+ all_values.extend(instance[self.field])
935
+ counter = Counter(
936
+ all_values
937
+ ) # here all_values is a list of individual values, or tupples. Hence, Counter is feasible
938
+ values_and_counts = counter.most_common()
939
+ if self.overall_top_frequency_percent < 100:
940
+ top_frequency = len(all_values) * self.overall_top_frequency_percent / 100.0
941
+ sum_counts = 0
942
+ for _i, p in enumerate(values_and_counts):
943
+ sum_counts += p[1]
944
+ if sum_counts >= top_frequency:
945
+ break
946
+ values_and_counts = counter.most_common(_i + 1)
947
+ if self.min_frequency_percent > 0:
948
+ min_frequency = self.min_frequency_percent * len(all_values) / 100.0
949
+ while values_and_counts[-1][1] < min_frequency:
950
+ values_and_counts.pop()
951
+ values_to_keep = [
952
+ [*ele[0]] if isinstance(ele[0], tuple) else ele[0]
953
+ for ele in values_and_counts
954
+ ]
955
+ for name in multi_stream:
956
+ for instance in multi_stream[name]:
957
+ instance[self.to_field] = values_to_keep
958
+ return multi_stream
959
+
960
+
961
+ class FilterByListsOfValues(SingleStreamOperator):
962
+ """Filters a stream, yielding only instances that whose field values are included in the specified value lists.
963
 
964
  Args:
965
+ required_values (Dict[str, List]): For each field, the list of values that instances should match to be included in the output.
966
  """
967
 
968
+ required_values: Dict[str, List]
969
+
970
+ def verify(self):
971
+ super().verify()
972
+ for key, value in self.required_values.items():
973
+ if not isinstance(value, list):
974
+ raise ValueError(
975
+ f"The filter for key ('{key}') in FilterByListsOfValues is not a list but '{value}'"
976
+ )
977
 
978
+ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
979
  for instance in stream:
980
+ filter = False
981
+ for key, value in self.required_values.items():
982
+ if key not in instance:
983
+ raise ValueError(
984
+ f"Required filter field ('{key}') in FilterByListsOfValues is not found in {instance}"
985
+ )
986
+ if instance[key] not in value:
987
+ filter = True
988
+ if not filter:
989
  yield instance
990
 
991
 
992
+ class Intersect(FieldOperator):
993
+ """Intersects the value of a field, which must be a list, with a given list.
994
+
995
+ Args:
996
+ allowed_values (list) - list to intersect.
997
+ """
998
+
999
+ allowed_values: List[Any]
1000
+
1001
+ def verify(self):
1002
+ super().verify()
1003
+ if self.process_every_value:
1004
+ raise ValueError(
1005
+ "'process_every_value=True' is not supported in Intersect operator"
1006
+ )
1007
+
1008
+ if not isinstance(self.allowed_values, list):
1009
+ raise ValueError(
1010
+ f"The allowed_values is not a list but '{self.allowed_values}'"
1011
+ )
1012
+
1013
+ def process_value(self, value: Any) -> Any:
1014
+ if not isinstance(value, list):
1015
+ raise ValueError(f"The value in field is not a list but '{value}'")
1016
+ return [e for e in value if e in self.allowed_values]
1017
+
1018
+
1019
+ class RemoveValues(FieldOperator):
1020
+ """Removes elements in a field, which must be a list, using a given list of unallowed.
1021
+
1022
+ Args:
1023
+ unallowed_values (list) - removed_values.
1024
  """
1025
+
1026
+ unallowed_values: List[Any]
1027
+
1028
+ def verify(self):
1029
+ super().verify()
1030
+ if self.process_every_value:
1031
+ raise ValueError(
1032
+ "'process_every_value=True' is not supported in RemoveValues operator"
1033
+ )
1034
+
1035
+ if not isinstance(self.unallowed_values, list):
1036
+ raise ValueError(
1037
+ f"The unallowed_values is not a list but '{self.unallowed_values}'"
1038
+ )
1039
+
1040
+ def process_value(self, value: Any) -> Any:
1041
+ if not isinstance(value, list):
1042
+ raise ValueError(f"The value in field is not a list but '{value}'")
1043
+ return [e for e in value if e not in self.unallowed_values]
1044
+
1045
+
1046
+ class Unique(SingleStreamReducer):
1047
+ """Reduces a stream to unique instances based on specified fields.
1048
 
1049
  Args:
1050
  fields (List[str]): The fields that should be unique in each instance.
 
1055
  @staticmethod
1056
  def to_tuple(instance: dict, fields: List[str]) -> tuple:
1057
  result = []
1058
+ for field_name in fields:
1059
+ value = instance[field_name]
1060
  if isinstance(value, list):
1061
  value = tuple(value)
1062
  result.append(value)
 
1072
 
1073
 
1074
  class SplitByValue(MultiStreamOperator):
1075
+ """Splits a MultiStream into multiple streams based on unique values in specified fields.
 
1076
 
1077
  Args:
1078
  fields (List[str]): The fields to use when splitting the MultiStream.
 
1088
  for stream_name, stream in multi_stream.items():
1089
  stream_unique_values = uniques[stream_name]
1090
  for unique_values in stream_unique_values:
1091
+ filtering_values = dict(zip(self.fields, unique_values))
1092
+ filtered_streams = FilterByValues(
1093
+ required_values=filtering_values
1094
+ )._process_single_stream(stream)
1095
+ filtered_stream_name = (
1096
+ stream_name + "_" + nested_tuple_to_string(unique_values)
1097
+ )
1098
  result[filtered_stream_name] = filtered_streams
1099
 
1100
  return MultiStream(result)
1101
 
1102
 
1103
  class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
1104
+ """Applies stream operators to a stream based on specified fields in each instance.
 
1105
 
1106
  Args:
1107
  field (str): The field containing the operators to be applied.
 
1111
  field: str
1112
  reversed: bool = False
1113
 
1114
+ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1115
  first_instance = stream.peak()
1116
 
1117
  operators = first_instance.get(self.field, [])
 
1123
 
1124
  for operator_name in operators:
1125
  operator = self.get_artifact(operator_name)
1126
+ assert isinstance(
1127
+ operator, StreamingOperator
1128
+ ), f"Operator {operator_name} must be a SingleStreamOperator"
1129
 
1130
  stream = operator(MultiStream({"tmp": stream}))["tmp"]
1131
 
1132
  yield from stream
1133
 
1134
 
1135
+ class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
1136
+ """Applies metric operators to a stream based on a metric field specified in each instance.
1137
+
1138
+ Args:
1139
+ metric_field (str): The field containing the metrics to be applied.
1140
+ calc_confidence_intervals (bool): Whether the applied metric should calculate confidence intervals or not.
1141
  """
1142
+
1143
+ metric_field: str
1144
+ calc_confidence_intervals: bool
1145
+
1146
+ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1147
+ from .metrics import Metric, MetricPipeline, MetricWithConfidenceInterval
1148
+
1149
+ first_instance = stream.peak()
1150
+
1151
+ metric_names = first_instance.get(self.metric_field, [])
1152
+ if not metric_names:
1153
+ raise RuntimeError(
1154
+ f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
1155
+ )
1156
+
1157
+ if isinstance(metric_names, str):
1158
+ metric_names = [metric_names]
1159
+
1160
+ # Each metric operator computes its score and then sets the main score, overwriting
1161
+ # the previous main score value (if any). So, we need to reverse the order of the listed metrics.
1162
+ # This will cause the first listed metric to run last, and the main score will be set
1163
+ # by the first listed metric (as desired).
1164
+ metric_names = list(reversed(metric_names))
1165
+
1166
+ for metric_name in metric_names:
1167
+ metric = self.get_artifact(metric_name)
1168
+ assert isinstance(
1169
+ metric, Metric
1170
+ ), f"Operator {metric_name} must be a Metric"
1171
+
1172
+ if not self.calc_confidence_intervals:
1173
+ if isinstance(metric, MetricWithConfidenceInterval):
1174
+ metric.disable_confidence_interval_calculation()
1175
+ elif isinstance(metric, MetricPipeline) and isinstance(
1176
+ metric.metric, MetricWithConfidenceInterval
1177
+ ):
1178
+ metric.metric.disable_confidence_interval_calculation()
1179
+
1180
+ stream = metric(MultiStream({"tmp": stream}))["tmp"]
1181
+
1182
+ yield from stream
1183
+
1184
+
1185
+ class AddFieldNamePrefix(StreamInstanceOperator):
1186
+ """Adds a prefix to each field name in each instance of a stream.
1187
 
1188
  Args:
1189
  prefix_dict (Dict[str, str]): A dictionary mapping stream names to prefixes.
 
1194
  def prepare(self):
1195
  return super().prepare()
1196
 
1197
+ def process(
1198
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
1199
+ ) -> Dict[str, Any]:
1200
+ return {
1201
+ self.prefix_dict[stream_name] + key: value
1202
+ for key, value in instance.items()
1203
+ }
1204
 
1205
 
1206
  class MergeStreams(MultiStreamOperator):
1207
+ """Merges multiple streams into a single stream.
 
1208
 
1209
  Args:
1210
  new_stream_name (str): The name of the new stream resulting from the merge.
 
1212
  origin_stream_name_field_name (str): The field name for the origin stream name.
1213
  """
1214
 
1215
+ streams_to_merge: List[str] = None
1216
  new_stream_name: str = "all"
1217
  add_origin_stream_name: bool = True
1218
  origin_stream_name_field_name: str = "origin"
1219
 
1220
  def merge(self, multi_stream):
1221
  for stream_name, stream in multi_stream.items():
1222
+ if self.streams_to_merge is None or stream_name in self.streams_to_merge:
1223
+ for instance in stream:
1224
+ if self.add_origin_stream_name:
1225
+ instance[self.origin_stream_name_field_name] = stream_name
1226
+ yield instance
1227
 
1228
  def process(self, multi_stream: MultiStream) -> MultiStream:
1229
+ return MultiStream(
1230
+ {
1231
+ self.new_stream_name: Stream(
1232
+ self.merge, gen_kwargs={"multi_stream": multi_stream}
1233
+ )
1234
+ }
1235
+ )
1236
 
1237
 
1238
  class Shuffle(PagedStreamOperator):
1239
+ """Shuffles the order of instances in each page of a stream.
 
1240
 
1241
  Args:
1242
  page_size (int): The size of each page in the stream. Defaults to 1000.
1243
  """
1244
 
1245
+ def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1246
+ get_random().shuffle(page)
1247
  yield from page
1248
 
1249
 
1250
  class EncodeLabels(StreamInstanceOperator):
1251
+ """Encode labels of specified fields together a into integers.
 
1252
 
1253
  Args:
1254
  fields (List[str]): The fields to encode together.
 
1260
  self.encoder = {}
1261
  return super()._process_multi_stream(multi_stream)
1262
 
1263
+ def process(
1264
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
1265
+ ) -> Dict[str, Any]:
1266
+ for field_name in self.fields:
1267
+ values = dict_get(instance, field_name, use_dpath=True)
1268
  if not isinstance(values, list):
1269
  values = [values]
1270
  for value in values:
1271
  if value not in self.encoder:
1272
  self.encoder[value] = len(self.encoder)
1273
  new_values = [self.encoder[value] for value in values]
1274
+ dict_set(
1275
+ instance, field_name, new_values, use_dpath=True, set_multiple=True
1276
+ )
1277
 
1278
  return instance
1279
 
 
1281
  class StreamRefiner(SingleStreamOperator):
1282
  max_instances: int = None
1283
 
1284
+ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1285
  if self.max_instances is not None:
1286
  yield from stream.take(self.max_instances)
1287
  else:
 
1289
 
1290
 
1291
  class DeterministicBalancer(StreamRefiner):
1292
+ """A class used to balance streams deterministically.
 
1293
 
1294
  Attributes:
1295
  fields (List[str]): A list of field names to be used in determining the signature of an instance.
 
1303
  fields: List[str]
1304
 
1305
  def signature(self, instance):
1306
+ return str(
1307
+ tuple(dict_get(instance, field, use_dpath=True) for field in self.fields)
1308
+ )
1309
 
1310
+ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1311
  counter = collections.Counter()
1312
 
1313
  for instance in stream:
1314
  counter[self.signature(instance)] += 1
1315
 
1316
+ if len(counter) == 0:
1317
+ return
1318
+
1319
  lowest_count = counter.most_common()[-1][-1]
1320
 
1321
  max_total_instances_per_sign = lowest_count
1322
  if self.max_instances is not None:
1323
+ max_total_instances_per_sign = min(
1324
+ lowest_count, self.max_instances // len(counter)
1325
+ )
1326
 
1327
  counter = collections.Counter()
1328
 
 
1338
 
1339
  def signature(self, instance):
1340
  total_len = 0
1341
+ for field_name in self.fields:
1342
+ total_len += len(dict_get(instance, field_name, use_dpath=True))
1343
  for i, val in enumerate(self.segments_boundaries):
1344
  if total_len < val:
1345
  return i