Elron commited on
Commit
e7c76e5
1 Parent(s): 3dfd311

Upload splitters.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. splitters.py +123 -0
splitters.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .stream import MultiStream
2
+ from .operator import MultiStreamOperator, InstanceOperatorWithGlobalAccess
3
+ from .generator_utils import ReusableGenerator
4
+ from .artifact import Artifact
5
+
6
+
7
+ from typing import Optional, Dict, List
8
+ from dataclasses import field
9
+
10
+
11
+ class Splitter(MultiStreamOperator):
12
+ pass
13
+
14
+
15
+ import random
16
+
17
+ from .split_utils import (
18
+ parse_random_mix_string,
19
+ random_mix_streams,
20
+ parse_slices_string,
21
+ slice_streams,
22
+ )
23
+
24
+
25
+ class SplitRandomMix(Splitter):
26
+ mix: Dict[str, str]
27
+
28
+ def process(self, multi_stream: MultiStream) -> MultiStream:
29
+ mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()}
30
+ generators = random_mix_streams(multi_stream, mapping)
31
+ return MultiStream.from_generators(generators, streaming=True)
32
+
33
+
34
+ class SliceSplit(Splitter):
35
+ slices: Dict[str, str]
36
+
37
+ def process(self, multi_stream: MultiStream) -> MultiStream:
38
+ mapping = {k: parse_slices_string(v) for k, v in self.slices.items()}
39
+ generators = slice_streams(multi_stream, mapping)
40
+ return MultiStream.from_generators(generators, streaming=True)
41
+
42
+
43
+ class Sampler(Artifact):
44
+ sample_size: int
45
+
46
+
47
+ class RandomSampler(Sampler):
48
+ def sample(self, instances_pool: List[Dict[str, object]]) -> List[Dict[str, object]]:
49
+ instances_pool = list(instances_pool)
50
+ return random.sample(instances_pool, self.sample_size)
51
+
52
+
53
+ class SpreadSplit(InstanceOperatorWithGlobalAccess):
54
+ source_stream: str = None
55
+ target_field: str = None
56
+ sampler: Sampler = None
57
+
58
+ def prepare(self):
59
+ self.accessible_streams = [self.source_stream]
60
+ self.cache_accessible_streams = True
61
+ self.local_cache = None
62
+
63
+ def verify(self):
64
+ assert self.source_stream is not None, "Source stream must be specified"
65
+ assert self.target_field is not None, "Target field must be specified"
66
+ assert self.sampler is not None, "Sampler must be specified"
67
+ return super().verify()
68
+
69
+ def process(self, instance: Dict[str, object], multi_stream: MultiStream) -> Dict[str, object]:
70
+ if self.local_cache is None:
71
+ self.local_cache = list(multi_stream[self.source_stream])
72
+
73
+ source_stream = self.local_cache
74
+
75
+ sampled_instances = self.sampler.sample(source_stream)
76
+ instance[self.target_field] = sampled_instances
77
+ return instance
78
+
79
+
80
+ if __name__ == "__main__":
81
+ # some tests
82
+ import random
83
+
84
+ random.seed(0)
85
+ splitter = SplitRandomMix(
86
+ mix={
87
+ "train": "train[90%]+validation[50%]",
88
+ "validation": "train[10%]+validation[50%]",
89
+ "test": "test",
90
+ }
91
+ )
92
+
93
+ def generator(name, size):
94
+ for i in range(size):
95
+ yield {"text": f"{name}_{i}"}
96
+
97
+ stream = MultiStream.from_generators(
98
+ {
99
+ "train": ReusableGenerator(generator, gen_kwargs={"name": "train", "size": 10}),
100
+ "validation": ReusableGenerator(generator, gen_kwargs={"name": "validation", "size": 10}),
101
+ "test": ReusableGenerator(generator, gen_kwargs={"name": "test", "size": 10}),
102
+ }
103
+ )
104
+
105
+ ds = splitter(stream)
106
+ for key, value in ds.items():
107
+ print(key)
108
+ for item in value:
109
+ print(item)
110
+
111
+ splitter = SliceSplit(
112
+ slices={
113
+ "train": "train[:2]+train[2:4]",
114
+ "validation": "train[4:6]",
115
+ "test": "train[6:]+test",
116
+ }
117
+ )
118
+
119
+ ds = splitter(stream)
120
+ for key, value in ds.items():
121
+ print(key)
122
+ for item in value:
123
+ print(item)