Elron commited on
Commit
8fecbbd
1 Parent(s): c8da176

Upload operators.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. operators.py +321 -34
operators.py CHANGED
@@ -1,22 +1,26 @@
 
 
 
1
  from dataclasses import field
2
- from typing import Any, Dict, Generator, Iterable, List, Optional, Union
 
3
 
4
- from .text_utils import nested_tuple_to_string
5
  from .artifact import Artifact, fetch_artifact
 
6
  from .operator import (
7
  MultiStream,
8
  MultiStreamOperator,
 
9
  SingleStreamOperator,
10
  SingleStreamReducer,
11
- Stream,
12
  StreamInitializerOperator,
13
  StreamInstanceOperator,
14
- PagedStreamOperator,
15
  )
 
16
  from .stream import MultiStream, Stream
 
17
  from .utils import flatten_dict
18
- import random
19
- from .utils import dict_query
20
 
21
 
22
  class FromIterables(StreamInitializerOperator):
@@ -26,20 +30,47 @@ class FromIterables(StreamInitializerOperator):
26
  Args:
27
  iterables (Dict[str, Iterable]): A dictionary where each key-value pair represents a stream name and its corresponding iterable.
28
  """
 
29
  def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
30
  return MultiStream.from_iterables(iterables)
31
 
32
 
33
- class MapInstanceValues(StreamInstanceOperator):
 
 
 
 
34
  """
35
- Maps values in each instance of a stream based on the provided mappers.
36
 
37
- Args:
38
- mappers (Dict[str, Dict[str, str]]): A dictionary where each key-value pair represents a field in the instance and a mapper for that field.
39
- strict (bool): If True, the operator will raise a KeyError if a value is not in its corresponding mapper. If False, unmapped values will be left unchanged. Defaults to True.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  """
 
41
  mappers: Dict[str, Dict[str, str]]
42
  strict: bool = True
 
43
 
44
  def verify(self):
45
  # make sure the mappers are valid
@@ -49,18 +80,16 @@ class MapInstanceValues(StreamInstanceOperator):
49
  assert isinstance(k, str), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
50
 
51
  def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
52
- result = {}
53
- for key, value in instance.items():
54
- str_value = str(value)
55
- if key in self.mappers:
56
- mapper = self.mappers[key]
57
  if self.strict:
58
- value = mapper[str_value]
59
  else:
60
- if str_value in mapper:
61
- value = mapper[str_value]
62
- result[key] = value
63
- return result
64
 
65
 
66
  class FlattenInstances(StreamInstanceOperator):
@@ -71,9 +100,10 @@ class FlattenInstances(StreamInstanceOperator):
71
  parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
72
  sep (str): The separator to use when concatenating nested keys. Defaults to "_".
73
  """
 
74
  parent_key: str = ""
75
  sep: str = "_"
76
-
77
  def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
78
  return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
79
 
@@ -85,24 +115,244 @@ class AddFields(StreamInstanceOperator):
85
  Args:
86
  fields (Dict[str, object]): The fields to add to each instance.
87
  """
 
88
  fields: Dict[str, object]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
91
- instance.update(self.fields)
 
 
 
 
 
 
 
92
  return instance
93
 
94
 
95
- class MapNestedDictValuesByQueries(StreamInstanceOperator):
96
- field_to_query: Dict[str, str]
 
 
 
 
 
 
 
97
 
98
  def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
99
- updates = {}
100
- for field, query in self.field_to_query.items():
101
- updates[field] = dict_query(instance, query)
102
- instance.update(updates)
103
  return instance
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  class ArtifactFetcherMixin:
107
  """
108
  Provides a way to fetch and cache artifacts in the system.
@@ -110,6 +360,7 @@ class ArtifactFetcherMixin:
110
  Args:
111
  cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
112
  """
 
113
  cache: Dict[str, Artifact] = {}
114
 
115
  @classmethod
@@ -129,6 +380,7 @@ class ApplyValueOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
129
  operators_field (str): The field containing the operators to be applied.
130
  default_operators (List[str]): A list of default operators to be used if no operators are found in the instance.
131
  """
 
132
  value_field: str
133
  operators_field: str
134
  default_operators: List[str] = None
@@ -158,6 +410,7 @@ class FilterByValues(SingleStreamOperator):
158
  Args:
159
  values (Dict[str, Any]): The values that instances should match to be included in the output.
160
  """
 
161
  values: Dict[str, Any]
162
 
163
  def process(self, stream: Stream, stream_name: str = None) -> Generator:
@@ -173,6 +426,7 @@ class Unique(SingleStreamReducer):
173
  Args:
174
  fields (List[str]): The fields that should be unique in each instance.
175
  """
 
176
  fields: List[str] = field(default_factory=list)
177
 
178
  @staticmethod
@@ -201,6 +455,7 @@ class SplitByValue(MultiStreamOperator):
201
  Args:
202
  fields (List[str]): The fields to use when splitting the MultiStream.
203
  """
 
204
  fields: List[str] = field(default_factory=list)
205
 
206
  def process(self, multi_stream: MultiStream) -> MultiStream:
@@ -227,6 +482,7 @@ class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
227
  field (str): The field containing the operators to be applied.
228
  reversed (bool): Whether to apply the operators in reverse order.
229
  """
 
230
  field: str
231
  reversed: bool = False
232
 
@@ -242,10 +498,9 @@ class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
242
 
243
  for operator_name in operators:
244
  operator = self.get_artifact(operator_name)
245
- assert isinstance(
246
- operator, SingleStreamOperator
247
- ), f"Operator {operator_name} must be a SingleStreamOperator"
248
- stream = operator.process(stream)
249
 
250
  yield from stream
251
 
@@ -257,6 +512,7 @@ class AddFieldNamePrefix(StreamInstanceOperator):
257
  Args:
258
  prefix_dict (Dict[str, str]): A dictionary mapping stream names to prefixes.
259
  """
 
260
  prefix_dict: Dict[str, str]
261
 
262
  def prepare(self):
@@ -275,6 +531,7 @@ class MergeStreams(MultiStreamOperator):
275
  add_origin_stream_name (bool): Whether to add the origin stream name to each instance.
276
  origin_stream_name_field_name (str): The field name for the origin stream name.
277
  """
 
278
  new_stream_name: str = "all"
279
  add_origin_stream_name: bool = True
280
  origin_stream_name_field_name: str = "origin"
@@ -289,6 +546,7 @@ class MergeStreams(MultiStreamOperator):
289
  def process(self, multi_stream: MultiStream) -> MultiStream:
290
  return MultiStream({self.new_stream_name: Stream(self.merge, gen_kwargs={"multi_stream": multi_stream})})
291
 
 
292
  class Shuffle(PagedStreamOperator):
293
  """
294
  Shuffles the order of instances in each page of a stream.
@@ -296,6 +554,35 @@ class Shuffle(PagedStreamOperator):
296
  Args:
297
  page_size (int): The size of each page in the stream. Defaults to 1000.
298
  """
 
299
  def process(self, page: List[Dict], stream_name: str = None) -> Generator:
300
  random.shuffle(page)
301
- yield from page
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from abc import abstractmethod
3
+ from copy import deepcopy
4
  from dataclasses import field
5
+ from itertools import zip_longest
6
+ from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
7
 
 
8
  from .artifact import Artifact, fetch_artifact
9
+ from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
10
  from .operator import (
11
  MultiStream,
12
  MultiStreamOperator,
13
+ PagedStreamOperator,
14
  SingleStreamOperator,
15
  SingleStreamReducer,
16
+ StreamingOperator,
17
  StreamInitializerOperator,
18
  StreamInstanceOperator,
 
19
  )
20
+ from .random_utils import random
21
  from .stream import MultiStream, Stream
22
+ from .text_utils import nested_tuple_to_string
23
  from .utils import flatten_dict
 
 
24
 
25
 
26
  class FromIterables(StreamInitializerOperator):
 
30
  Args:
31
  iterables (Dict[str, Iterable]): A dictionary where each key-value pair represents a stream name and its corresponding iterable.
32
  """
33
+
34
  def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
35
  return MultiStream.from_iterables(iterables)
36
 
37
 
38
+ class RenameFields(StreamInstanceOperator):
39
+ """
40
+ Renames fields
41
+ Attributes:
42
+ mapper (Dict[str, str]): old field names to new field names dict
43
  """
 
44
 
45
+ mapper: Dict[str, str]
46
+
47
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
48
+ result = {}
49
+ # passes on all values to preserve ordering
50
+ for key, value in instance.items():
51
+ result[self.mapper.get(key, key)] = value
52
+ # doesn't warn if unnecessary mapping was supplied for efficiency
53
+ return result
54
+
55
+
56
+ class MapInstanceValues(StreamInstanceOperator):
57
+ """A class used to map instance values into a stream.
58
+
59
+ This class is a type of StreamInstanceOperator,
60
+ it maps values of instances in a stream using predefined mappers.
61
+
62
+ Attributes:
63
+ mappers (Dict[str, Dict[str, str]]): The mappers to use for mapping instance values.
64
+ Keys are the names of the fields to be mapped, and values are dictionaries
65
+ that define the mapping from old values to new values.
66
+ strict (bool): If True, the mapping is applied strictly. That means if a value
67
+ does not exist in the mapper, it will raise a KeyError. If False, values
68
+ that are not present in the mapper are kept as they are.
69
  """
70
+
71
  mappers: Dict[str, Dict[str, str]]
72
  strict: bool = True
73
+ use_query = False
74
 
75
  def verify(self):
76
  # make sure the mappers are valid
 
80
  assert isinstance(k, str), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
81
 
82
  def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
83
+ for key, mapper in self.mappers.items():
84
+ value = dict_get(instance, key, use_dpath=self.use_query)
85
+ if value is not None:
86
+ value = str(value) # make sure the value is a string
 
87
  if self.strict:
88
+ dict_set(instance, key, mapper[value], use_dpath=self.use_query)
89
  else:
90
+ if value in mapper:
91
+ dict_set(instance, key, mapper[value], use_dpath=self.use_query)
92
+ return instance
 
93
 
94
 
95
  class FlattenInstances(StreamInstanceOperator):
 
100
  parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
101
  sep (str): The separator to use when concatenating nested keys. Defaults to "_".
102
  """
103
+
104
  parent_key: str = ""
105
  sep: str = "_"
106
+
107
  def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
108
  return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
109
 
 
115
  Args:
116
  fields (Dict[str, object]): The fields to add to each instance.
117
  """
118
+
119
  fields: Dict[str, object]
120
+ use_query: bool = False
121
+ use_deepcopy: bool = False
122
+
123
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
124
+ if self.use_query:
125
+ for key, value in self.fields.items():
126
+ if self.use_deepcopy:
127
+ value = deepcopy(value)
128
+ dict_set(instance, key, value, use_dpath=self.use_query)
129
+ else:
130
+ if self.use_deepcopy:
131
+ self.fields = deepcopy(self.fields)
132
+ instance.update(self.fields)
133
+ return instance
134
+
135
+
136
+ class FieldOperator(StreamInstanceOperator):
137
+ """
138
+ A general stream that processes the values of a field (or multiple ones
139
+ Args:
140
+ field (Optional[str]): The field to process, if only a single one is passed Defaults to None
141
+ to_field (Optional[str]): Field name to save, if only one field is to be saved, if None is passed the operator would happen in-place and replace "field" Defaults to None
142
+ field_to_field (Optional[Union[List[Tuple[str, str]], Dict[str, str]]]): Mapping from fields to process to their names after this process, duplicates are allowed. Defaults to None
143
+ process_every_value (bool): Processes the values in a list instead of the list as a value, similar to *var. Defaults to False
144
+ use_query (bool): Whether to use dpath style queries. Defaults to False
145
+ """
146
+
147
+ field: Optional[str] = None
148
+ to_field: Optional[str] = None
149
+ field_to_field: Optional[Union[List[Tuple[str, str]], Dict[str, str]]] = None
150
+ process_every_value: bool = False
151
+ use_query: bool = False
152
+
153
+ def verify(self):
154
+ super().verify()
155
+
156
+ assert self.field is not None or self.field_to_field is not None, "Must supply a field to work on"
157
+ assert (
158
+ self.to_field is None or self.field_to_field is None
159
+ ), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
160
+ assert (
161
+ self.field is None or self.field_to_field is None
162
+ ), f"Can not apply operator both on {self.field} and on the mapping from fields to fields {self.field_to_field}"
163
+ assert self._field_to_field, f"the from and to fields must be defined got: {self._field_to_field}"
164
+
165
+ @abstractmethod
166
+ def process_value(self, value: Any) -> Any:
167
+ pass
168
+
169
+ def prepare(self):
170
+ if self.to_field is None:
171
+ self.to_field = self.field
172
+ if self.field_to_field is None:
173
+ self._field_to_field = [(self.field, self.to_field)]
174
+ else:
175
+ try:
176
+ self._field_to_field = [(k, v) for k, v in self.field_to_field.items()]
177
+ except AttributeError:
178
+ self._field_to_field = self.field_to_field
179
+
180
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
181
+ for from_field, to_field in self._field_to_field:
182
+ old_value = dict_get(instance, from_field, use_dpath=self.use_query)
183
+ if self.process_every_value:
184
+ new_value = [self.process_value(value) for value in old_value]
185
+ else:
186
+ new_value = self.process_value(old_value)
187
+ if self.use_query and is_subpath(from_field, to_field):
188
+ dict_delete(instance, from_field)
189
+ dict_set(instance, to_field, new_value, use_dpath=self.use_query, not_exist_ok=True)
190
+ return instance
191
+
192
+
193
+ class JoinStr(FieldOperator):
194
+ """
195
+ Joins a list of strings (contents of a field), similar to str.join()
196
+ Args:
197
+ separator (str): text to put between values
198
+ """
199
+
200
+ separator: str = ","
201
+
202
+ def process_value(self, value: Any) -> Any:
203
+ return self.separator.join(str(x) for x in value)
204
+
205
+
206
+ class ZipFieldValues(StreamInstanceOperator):
207
+ """
208
+ Zips values of multiple fields similar to list(zip(*fields))
209
+ """
210
+
211
+ fields: str
212
+ to_field: str
213
+ longest: bool = False
214
+ use_query: bool = False
215
 
216
  def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
217
+ values = []
218
+ for field in self.fields:
219
+ values.append(dict_get(instance, field, use_dpath=self.use_query))
220
+ if self.longest:
221
+ zipped = zip_longest(*values)
222
+ else:
223
+ zipped = zip(*values)
224
+ instance[self.to_field] = list(zipped)
225
  return instance
226
 
227
 
228
+ class IndexOf(StreamInstanceOperator):
229
+ """
230
+ Finds the location of one value in another (iterable) value similar to to_field=search_in.index(index_of)
231
+ """
232
+
233
+ search_in: str
234
+ index_of: str
235
+ to_field: str
236
+ use_query: bool = False
237
 
238
  def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
239
+ lst = dict_get(instance, self.search_in, use_dpath=self.use_query)
240
+ item = dict_get(instance, self.index_of, use_dpath=self.use_query)
241
+ instance[self.to_field] = lst.index(item)
 
242
  return instance
243
 
244
 
245
+ class TakeByField(StreamInstanceOperator):
246
+ """
247
+ Takes value from one field based on another field similar to field[index]
248
+ """
249
+
250
+ field: str
251
+ index: str
252
+ to_field: str = None
253
+ use_query: bool = False
254
+
255
+ def prepare(self):
256
+ if self.to_field is None:
257
+ self.to_field = self.field
258
+
259
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
260
+ value = dict_get(instance, self.field, use_dpath=self.use_query)
261
+ index_value = dict_get(instance, self.index, use_dpath=self.use_query)
262
+ instance[self.to_field] = value[index_value]
263
+ return instance
264
+
265
+
266
+ class CopyFields(FieldOperator):
267
+ """
268
+ Copies specified fields from one field to another.
269
+
270
+ Args:
271
+ field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
272
+ use_dpath (bool): Whether to use dpath for accessing fields. Defaults to False.
273
+ """
274
+
275
+ def process_value(self, value: Any) -> Any:
276
+ return value
277
+
278
+
279
+ class AddID(StreamInstanceOperator):
280
+ id_field_name: str = "id"
281
+
282
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
283
+ instance[self.id_field_name] = str(uuid.uuid4()).replace("-", "")
284
+ return instance
285
+
286
+
287
+ class CastFields(StreamInstanceOperator):
288
+ """
289
+ Casts specified fields to specified types.
290
+
291
+ Args:
292
+ types (Dict[str, str]): A dictionary mapping fields to their new types.
293
+ nested (bool): Whether to cast nested fields. Defaults to False.
294
+ fields (Dict[str, str]): A dictionary mapping fields to their new types.
295
+ defaults (Dict[str, object]): A dictionary mapping types to their default values for cases of casting failure.
296
+ """
297
+
298
+ types = {
299
+ "int": int,
300
+ "float": float,
301
+ "str": str,
302
+ "bool": bool,
303
+ }
304
+ fields: Dict[str, str] = field(default_factory=dict)
305
+ failure_defaults: Dict[str, object] = field(default_factory=dict)
306
+ use_nested_query: bool = False
307
+ cast_multiple: bool = False
308
+
309
+ def _cast_single(self, value, type, field):
310
+ try:
311
+ return self.types[type](value)
312
+ except:
313
+ if field not in self.failure_defaults:
314
+ raise ValueError(
315
+ f'Failed to cast field "{field}" with value {value} to type "{type}", and no default value is provided.'
316
+ )
317
+ return self.failure_defaults[field]
318
+
319
+ def _cast_multiple(self, values, type, field):
320
+ values = [self._cast_single(value, type, field) for value in values]
321
+
322
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
323
+ for field, type in self.fields.items():
324
+ value = dict_get(instance, field, use_dpath=self.use_nested_query)
325
+ if self.cast_multiple:
326
+ casted_value = self._cast_multiple(value, type, field)
327
+ else:
328
+ casted_value = self._cast_single(value, type, field)
329
+ dict_set(instance, field, casted_value, use_dpath=self.use_nested_query)
330
+ return instance
331
+
332
+
333
+ def recursive_divide(instance, divisor, strict=False):
334
+ if isinstance(instance, dict):
335
+ for key, value in instance.items():
336
+ instance[key] = recursive_divide(value, divisor, strict=strict)
337
+ elif isinstance(instance, list):
338
+ for i, value in enumerate(instance):
339
+ instance[i] = recursive_divide(value, divisor, strict=strict)
340
+ elif isinstance(instance, float):
341
+ instance /= divisor
342
+ elif strict:
343
+ raise ValueError(f"Cannot divide instance of type {type(instance)}")
344
+ return instance
345
+
346
+
347
+ class DivideAllFieldsBy(StreamInstanceOperator):
348
+ divisor: float = 1.0
349
+ strict: bool = False
350
+ recursive: bool = True
351
+
352
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
353
+ return recursive_divide(instance, self.divisor, strict=self.strict)
354
+
355
+
356
  class ArtifactFetcherMixin:
357
  """
358
  Provides a way to fetch and cache artifacts in the system.
 
360
  Args:
361
  cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
362
  """
363
+
364
  cache: Dict[str, Artifact] = {}
365
 
366
  @classmethod
 
380
  operators_field (str): The field containing the operators to be applied.
381
  default_operators (List[str]): A list of default operators to be used if no operators are found in the instance.
382
  """
383
+
384
  value_field: str
385
  operators_field: str
386
  default_operators: List[str] = None
 
410
  Args:
411
  values (Dict[str, Any]): The values that instances should match to be included in the output.
412
  """
413
+
414
  values: Dict[str, Any]
415
 
416
  def process(self, stream: Stream, stream_name: str = None) -> Generator:
 
426
  Args:
427
  fields (List[str]): The fields that should be unique in each instance.
428
  """
429
+
430
  fields: List[str] = field(default_factory=list)
431
 
432
  @staticmethod
 
455
  Args:
456
  fields (List[str]): The fields to use when splitting the MultiStream.
457
  """
458
+
459
  fields: List[str] = field(default_factory=list)
460
 
461
  def process(self, multi_stream: MultiStream) -> MultiStream:
 
482
  field (str): The field containing the operators to be applied.
483
  reversed (bool): Whether to apply the operators in reverse order.
484
  """
485
+
486
  field: str
487
  reversed: bool = False
488
 
 
498
 
499
  for operator_name in operators:
500
  operator = self.get_artifact(operator_name)
501
+ assert isinstance(operator, StreamingOperator), f"Operator {operator_name} must be a SingleStreamOperator"
502
+
503
+ stream = operator(MultiStream({"tmp": stream}))["tmp"]
 
504
 
505
  yield from stream
506
 
 
512
  Args:
513
  prefix_dict (Dict[str, str]): A dictionary mapping stream names to prefixes.
514
  """
515
+
516
  prefix_dict: Dict[str, str]
517
 
518
  def prepare(self):
 
531
  add_origin_stream_name (bool): Whether to add the origin stream name to each instance.
532
  origin_stream_name_field_name (str): The field name for the origin stream name.
533
  """
534
+
535
  new_stream_name: str = "all"
536
  add_origin_stream_name: bool = True
537
  origin_stream_name_field_name: str = "origin"
 
546
  def process(self, multi_stream: MultiStream) -> MultiStream:
547
  return MultiStream({self.new_stream_name: Stream(self.merge, gen_kwargs={"multi_stream": multi_stream})})
548
 
549
+
550
  class Shuffle(PagedStreamOperator):
551
  """
552
  Shuffles the order of instances in each page of a stream.
 
554
  Args:
555
  page_size (int): The size of each page in the stream. Defaults to 1000.
556
  """
557
+
558
  def process(self, page: List[Dict], stream_name: str = None) -> Generator:
559
  random.shuffle(page)
560
+ yield from page
561
+
562
+
563
+ class EncodeLabels(StreamInstanceOperator):
564
+ """
565
+ Encode labels of specified fields together a into integers.
566
+
567
+ Args:
568
+ fields (List[str]): The fields to encode together.
569
+ """
570
+
571
+ fields: List[str]
572
+
573
+ def _process_multi_stream(self, multi_stream: MultiStream) -> MultiStream:
574
+ self.encoder = {}
575
+ return super()._process_multi_stream(multi_stream)
576
+
577
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
578
+ for field in self.fields:
579
+ values = dict_get(instance, field, use_dpath=True)
580
+ if not isinstance(values, list):
581
+ values = [values]
582
+ for value in values:
583
+ if value not in self.encoder:
584
+ self.encoder[value] = len(self.encoder)
585
+ new_values = [self.encoder[value] for value in values]
586
+ dict_set(instance, field, new_values, use_dpath=True, set_multiple=True)
587
+
588
+ return instance