Elron commited on
Commit
341b917
1 Parent(s): 70d2374

Upload operators.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. operators.py +64 -1
operators.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import importlib
2
  import inspect
3
  import uuid
@@ -18,7 +19,7 @@ from typing import (
18
  )
19
 
20
  from .artifact import Artifact, fetch_artifact
21
- from .dataclass import NonPositionalField
22
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
23
  from .operator import (
24
  MultiStream,
@@ -734,3 +735,65 @@ class EncodeLabels(StreamInstanceOperator):
734
  dict_set(instance, field, new_values, use_dpath=True, set_multiple=True)
735
 
736
  return instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
  import importlib
3
  import inspect
4
  import uuid
 
19
  )
20
 
21
  from .artifact import Artifact, fetch_artifact
22
+ from .dataclass import NonPositionalField, OptionalField
23
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
24
  from .operator import (
25
  MultiStream,
 
735
  dict_set(instance, field, new_values, use_dpath=True, set_multiple=True)
736
 
737
  return instance
738
+
739
+
740
+ class StreamRefiner(SingleStreamOperator):
741
+ max_instances: int = None
742
+
743
+ def process(self, stream: Stream, stream_name: str = None) -> Generator:
744
+ if self.max_instances is not None:
745
+ yield from stream.take(self.max_instances)
746
+ else:
747
+ yield from stream
748
+
749
+
750
+ class DeterministicBalancer(StreamRefiner):
751
+ """
752
+ A class used to balance streams deterministically.
753
+
754
+ Attributes:
755
+ fields (List[str]): A list of field names to be used in determining the signature of an instance.
756
+ streams (List[str]): A list of stream names to be processed by the balancer.
757
+
758
+ Usage:
759
+ balancer = DeterministicBalancer(fields=["field1", "field2"], streams=["stream1", "stream2"])
760
+ balanced_stream = balancer.process(stream)
761
+ """
762
+
763
+ fields: List[str]
764
+
765
+ def signature(self, instance):
766
+ return str(tuple(dict_get(instance, field, use_dpath=True) for field in self.fields))
767
+
768
+ def process(self, stream: Stream, stream_name: str = None) -> Generator:
769
+ counter = collections.Counter()
770
+
771
+ for instance in stream:
772
+ counter[self.signature(instance)] += 1
773
+
774
+ lowest_count = counter.most_common()[-1][-1]
775
+
776
+ max_total_instances_per_sign = lowest_count
777
+ if self.max_instances is not None:
778
+ max_total_instances_per_sign = min(lowest_count, self.max_instances // len(counter))
779
+
780
+ counter = collections.Counter()
781
+
782
+ for instance in stream:
783
+ sign = self.signature(instance)
784
+ if counter[sign] < max_total_instances_per_sign:
785
+ counter[sign] += 1
786
+ yield instance
787
+
788
+
789
+ class LengthBalancer(DeterministicBalancer):
790
+ segments_boundaries: List[int]
791
+
792
+ def signature(self, instance):
793
+ total_len = 0
794
+ for field in self.fields:
795
+ total_len += len(dict_get(instance, field, use_dpath=True))
796
+ for i, val in enumerate(self.segments_boundaries):
797
+ if total_len < val:
798
+ return i
799
+ return i + 1