Elron commited on
Commit
b7c4274
1 Parent(s): 2636a15

Upload split_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. split_utils.py +285 -0
split_utils.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import random
3
+ import itertools
4
+
5
+ from .generator_utils import ReusableGenerator
6
+
7
+
8
+ def parse_random_mix_string(input_str):
9
+ """
10
+ Parses a string of format "source1[percentage1%]+source2[value2]+..." and returns a dictionary.
11
+
12
+ Args:
13
+ input_str (str): A string containing source names and their respective proportions. The format is
14
+ "source[proportion%]" or "source[proportion]", with multiple sources separated by "+".
15
+ The proportion can be a percentage (e.g., "90%") or a decimal number (e.g., "0.7").
16
+ If the proportion is not provided, it assumes 100%.
17
+
18
+ Returns:
19
+ dict: A dictionary where the keys are the source names and the values are the proportions converted to floats.
20
+ If the proportion was given as a percentage, the value is divided by 100.
21
+
22
+ Raises:
23
+ ValueError: If the input string is not in the correct format.
24
+
25
+ Example:
26
+ >>> parse_random_mix_string("dale[90%]+oren[0.7]+mike")
27
+ {'dale': 0.9, 'oren': 0.7, 'mike': 1.0}
28
+ """
29
+
30
+ if not re.fullmatch(r"(([a-zA-Z]+\[\d*\.?\d*%?\]|[a-zA-Z]+)\+)*([a-zA-Z]+\[\d*\.?\d*%?\]|[a-zA-Z]+)", input_str):
31
+ raise ValueError("Invalid input format")
32
+
33
+ pattern = re.compile(r"([a-zA-Z]+)(\[\d*\.?\d*%?\])?")
34
+ matches = pattern.findall(input_str)
35
+
36
+ return {
37
+ name: float(value.strip("[]%")) / 100 if "%" in value else (float(value.strip("[]")) if value else 1.0)
38
+ for name, value in matches
39
+ }
40
+
41
+
42
+ def parse_slices_string(input_str):
43
+ """
44
+ Parses a string of format "source1[value1:value2] + source2[value2:] + source3 + ..." and returns a dictionary:
45
+ {"source1": [(value1,value2)], "source2": [(value2, None)], "source3": [(None,None)]...}
46
+
47
+ If a source appears multiple times with different indices, all index pairs are included in the list.
48
+
49
+ Args:
50
+ input_str (str): A string containing source names and their respective indices. The format is
51
+ "source[:index]" or "source[index:]", with multiple sources separated by "+".
52
+ The index represents the items to be taken from the source.
53
+
54
+ Returns:
55
+ dict: A dictionary where the keys are the source names and the values are lists of indices as tuples.
56
+ If the index is before the colon, it is represented as (None, index),
57
+ if it's after the colon, it's represented as (index, None)
58
+
59
+ Raises:
60
+ ValueError: If the input string is not in the correct format.
61
+
62
+ Example:
63
+ >>> parse_slices_string("oren[:50]+jake[24:]+test+oren[5:10]")
64
+ {'oren': [(None, 50), (5, 10)], 'jake': [(24, None)], 'test': [(None, None)]}
65
+ """
66
+
67
+ result_dict = {}
68
+
69
+ # Split the input string into a list of sources
70
+ sources = re.split("\+", input_str)
71
+ for source in sources:
72
+ # If the source has a slice, parse it
73
+ match = re.fullmatch(r"(\w+)\[(\d*):(\d*)\]", source)
74
+ if match:
75
+ name, start, end = match.groups()
76
+ start = int(start) if start else None
77
+ end = int(end) if end else None
78
+ elif re.fullmatch(r"\w+", source):
79
+ # If the source has no slice, use None for both start and end
80
+ name = source
81
+ start = end = None
82
+ else:
83
+ raise ValueError(f'The input string "{input_str}" is not in the correct format.')
84
+
85
+ if name not in result_dict:
86
+ result_dict[name] = [(start, end)]
87
+ else:
88
+ result_dict[name].append((start, end))
89
+
90
+ return result_dict
91
+
92
+
93
+ def slice_stream(stream, start, end):
94
+ # If start is None, consume from the beginning
95
+ if start is not None:
96
+ stream = itertools.islice(stream, start, None)
97
+ # If end is not None, consume until end
98
+ if end is not None:
99
+ stream = itertools.islice(stream, end)
100
+
101
+ for item in stream:
102
+ yield item
103
+ # return stream
104
+
105
+
106
+ def slice_streams(input_streams, mapping):
107
+ """
108
+ Slices multiple input streams according to a mapping and chains the results together.
109
+
110
+ Args:
111
+ input_streams (dict): A dictionary where the keys are the names of the input streams
112
+ and the values are the input streams themselves.
113
+ mapping (dict): A dictionary where the keys are the names of the new streams
114
+ and the values are dictionaries mapping old stream names
115
+ to lists of tuples representing slices.
116
+
117
+ Returns:
118
+ dict: A dictionary where the keys are the names of the new streams and the values are
119
+ the new streams, which consist of parts of the old streams chained together.
120
+
121
+ Raises:
122
+ ValueError: If a stream is supposed to be sliced at an index greater than its length.
123
+
124
+ Example:
125
+ >>> old_streams = {"train": [1, 2, 3, 4, 5, 6, 7, 8, 9], "test": [10, 11, 12, 13, 14]}
126
+ >>> mapping = {"new_train": {"train": [(None, 5), (7, 9)]}, "new_test": {"test": [(2, None)]}}
127
+ >>> slice_streams(old_streams, mapping)
128
+ {"new_train": [1, 2, 3, 4, 5, 8, 9], "new_test": [12, 13, 14]}
129
+ """
130
+
131
+ new_streams = {}
132
+ for new_stream, sources in mapping.items():
133
+
134
+ def generator(new_stream, sources):
135
+ for old_stream, slices in sources.items():
136
+ old_stream_content = input_streams[old_stream]
137
+ for start, end in slices:
138
+ yield from slice_stream(old_stream_content, start, end)
139
+
140
+ new_streams[new_stream] = ReusableGenerator(
141
+ generator, gen_kwargs={"new_stream": new_stream, "sources": sources}
142
+ )
143
+
144
+ return new_streams
145
+
146
+
147
+ def build_stream_routing(mapping):
148
+ """
149
+ Builds the stream mapping dictionary based on the provided mapping.
150
+
151
+ The stream mapping dictionary represents the mapping of old streams to new streams
152
+ and their respective probabilities. It ensures that the probabilities for each old stream
153
+ do not sum up to more than one. If the sum of probabilities is less than one,
154
+ a null stream (None) is included to account for the remaining probability.
155
+
156
+ Args:
157
+ mapping (dict): A dictionary specifying the mapping of old streams to new streams
158
+ and their respective probabilities.
159
+
160
+ Returns:
161
+ dict: A dictionary representing the stream mapping, where each entry corresponds to an
162
+ old stream, and the value is a tuple containing the new streams and their respective
163
+ probabilities.
164
+
165
+ Example:
166
+ >>> mapping = {
167
+ 'my_new_stream': {
168
+ 'my_old_stream1': 0.6,
169
+ 'my_old_stream2': 0.2
170
+ },
171
+ 'my_new_stream2': {
172
+ 'my_old_stream1': 0.4,
173
+ 'my_old_stream2': 0.8
174
+ }
175
+ }
176
+ stream_mapping = build_stream_mapping(mapping)
177
+ print(stream_mapping)
178
+ # Output: {'my_old_stream1': (['my_new_stream', 'my_new_stream2'], [0.6, 0.4]),
179
+ # 'my_old_stream2': (['my_new_stream', 'my_new_stream2'], [0.2, 0.8])}
180
+ """
181
+
182
+ stream_mapping = {}
183
+
184
+ # Calculate total weight for each old stream
185
+ total_weights = {}
186
+ for new_stream, old_streams in mapping.items():
187
+ for old_stream, weight in old_streams.items():
188
+ if old_stream not in total_weights:
189
+ total_weights[old_stream] = weight
190
+ else:
191
+ total_weights[old_stream] += weight
192
+
193
+ # Build stream_mapping with null stream included
194
+ for new_stream, old_streams in mapping.items():
195
+ for old_stream, weight in old_streams.items():
196
+ if old_stream not in stream_mapping:
197
+ stream_mapping[old_stream] = {}
198
+ stream_mapping[old_stream][new_stream] = weight
199
+
200
+ # Add null stream if total weight less than 1
201
+ if total_weights[old_stream] < 1:
202
+ stream_mapping[old_stream][None] = 1 - total_weights[old_stream]
203
+
204
+ stream_mapping = {k: (list(v.keys()), list(v.values())) for k, v in stream_mapping.items()}
205
+ return stream_mapping
206
+
207
+
208
+ def random_mix_generator(new_stream_name, new_stream_sources, stream_routing, rand, input_streams):
209
+ for old_stream_name in new_stream_sources:
210
+ optinal_streams, weights = stream_routing[old_stream_name]
211
+ rand.seed(old_stream_name)
212
+
213
+ for item in input_streams[old_stream_name]:
214
+ choice = rand.choices(optinal_streams, weights=weights, k=1)[0]
215
+ if choice == new_stream_name:
216
+ yield item
217
+
218
+
219
+ def random_mix_streams(input_streams, mapping):
220
+ """
221
+ Creates new streams based on the provided input streams and mapping.
222
+
223
+ The create_streams function generates new streams by selectively including items from
224
+ the old streams based on the specified mapping. Each item will be included in at most
225
+ one new stream, as defined by the probabilities in the mapping and stream routing.
226
+
227
+ Args:
228
+ input_streams (dict): A dictionary containing the input streams, where each key is
229
+ the name of the stream and the value is an iterable or generator
230
+ representing the stream.
231
+
232
+ mapping (dict): A dictionary specifying the mapping of old streams to new streams
233
+ and their respective probabilities.
234
+
235
+ Returns:
236
+ dict: A dictionary containing the generated new streams, where each key is the name
237
+ of the new stream and the value is a generator representing the stream.
238
+
239
+ Example:
240
+ >>> input_streams = {
241
+ 'my_old_stream1': gen1(),
242
+ 'my_old_stream2': gen2(),
243
+ }
244
+ mapping = {
245
+ 'my_new_stream': {
246
+ 'my_old_stream1': 0.6,
247
+ 'my_old_stream2': 0.2
248
+ },
249
+ 'my_new_stream2': {
250
+ 'my_old_stream1': 0.4,
251
+ 'my_old_stream2': 0.8
252
+ }
253
+ }
254
+ new_streams = create_streams(input_streams, mapping)
255
+ for new_stream_name, new_stream in new_streams.items():
256
+ print(f"{new_stream_name}:")
257
+ for _, item in zip(range(10), new_stream):
258
+ print(item)
259
+ """
260
+
261
+ new_streams = {}
262
+
263
+ # Build stream routing
264
+ stream_routing = build_stream_routing(mapping)
265
+
266
+ rand = random.Random()
267
+
268
+ # Create new stream generators
269
+ for new_stream_name, new_stream_sources in mapping.items():
270
+ new_streams[new_stream_name] = ReusableGenerator(
271
+ random_mix_generator,
272
+ gen_kwargs={
273
+ "new_stream_name": new_stream_name,
274
+ "new_stream_sources": new_stream_sources,
275
+ "stream_routing": stream_routing,
276
+ "rand": rand,
277
+ "input_streams": input_streams,
278
+ },
279
+ )
280
+
281
+ return new_streams
282
+
283
+
284
+ if __name__ == "__main__":
285
+ print(parse_random_mix_string("dale[90%]+oren[0.7]+mike"))