Elron commited on
Commit
f60252a
1 Parent(s): fd97850

Upload operators.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. operators.py +11 -1
operators.py CHANGED
@@ -29,6 +29,7 @@ from .operator import (
29
  StreamingOperator,
30
  StreamInitializerOperator,
31
  StreamInstanceOperator,
 
32
  )
33
  from .random_utils import random
34
  from .stream import MultiStream, Stream
@@ -48,6 +49,13 @@ class FromIterables(StreamInitializerOperator):
48
  return MultiStream.from_iterables(iterables)
49
 
50
 
 
 
 
 
 
 
 
51
  class MapInstanceValues(StreamInstanceOperator):
52
  """A class used to map instance values into a stream.
53
 
@@ -499,8 +507,10 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
499
  """
500
 
501
  inputs_fields: str
 
502
  operators_field: str
503
  default_operators: List[str] = None
 
504
 
505
  def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
506
  operator_names = instance.get(self.operators_field)
@@ -517,7 +527,7 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
517
  operator = self.get_artifact(name)
518
  for field in self.inputs_fields:
519
  value = instance[field]
520
- if isinstance(value, list):
521
  instance[field] = [operator.process(v) for v in value]
522
  else:
523
  instance[field] = operator.process(instance[field])
 
29
  StreamingOperator,
30
  StreamInitializerOperator,
31
  StreamInstanceOperator,
32
+ StreamSource,
33
  )
34
  from .random_utils import random
35
  from .stream import MultiStream, Stream
 
49
  return MultiStream.from_iterables(iterables)
50
 
51
 
52
+ class IterableSource(StreamSource):
53
+ iterables: Dict[str, Iterable]
54
+
55
+ def __call__(self) -> MultiStream:
56
+ return MultiStream.from_iterables(self.iterables)
57
+
58
+
59
  class MapInstanceValues(StreamInstanceOperator):
60
  """A class used to map instance values into a stream.
61
 
 
507
  """
508
 
509
  inputs_fields: str
510
+
511
  operators_field: str
512
  default_operators: List[str] = None
513
+ fields_to_treat_as_list: List[str] = NonPositionalField(default_factory=list)
514
 
515
  def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
516
  operator_names = instance.get(self.operators_field)
 
527
  operator = self.get_artifact(name)
528
  for field in self.inputs_fields:
529
  value = instance[field]
530
+ if field in self.fields_to_treat_as_list:
531
  instance[field] = [operator.process(v) for v in value]
532
  else:
533
  instance[field] = operator.process(instance[field])