Elron commited on
Commit
40135fa
1 Parent(s): cb669f3

Upload split_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. split_utils.py +39 -24
split_utils.py CHANGED
@@ -1,8 +1,10 @@
1
  import itertools
2
- import random
3
  import re
 
4
 
5
  from .generator_utils import ReusableGenerator
 
 
6
 
7
 
8
  def parse_random_mix_string(input_str):
@@ -67,7 +69,7 @@ def parse_slices_string(input_str):
67
  result_dict = {}
68
 
69
  # Split the input string into a list of sources
70
- sources = re.split("\+", input_str)
71
  for source in sources:
72
  # If the source has a slice, parse it
73
  match = re.fullmatch(r"(\w+)\[(\d*):(\d*)\]", source)
@@ -119,7 +121,7 @@ def slice_streams(input_streams, mapping):
119
  the new streams, which consist of parts of the old streams chained together.
120
 
121
  Raises:
122
- ValueError: If a stream is supposed to be sliced at an index greater than its length.
123
 
124
  Example:
125
  >>> old_streams = {"train": [1, 2, 3, 4, 5, 6, 7, 8, 9], "test": [10, 11, 12, 13, 14]}
@@ -205,15 +207,30 @@ def build_stream_routing(mapping):
205
  return stream_mapping
206
 
207
 
208
- def random_mix_generator(new_stream_name, new_stream_sources, stream_routing, rand, input_streams):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  for old_stream_name in new_stream_sources:
210
  optinal_streams, weights = stream_routing[old_stream_name]
211
- rand.seed(old_stream_name)
212
-
213
- for item in input_streams[old_stream_name]:
214
- choice = rand.choices(optinal_streams, weights=weights, k=1)[0]
215
- if choice == new_stream_name:
216
- yield item
217
 
218
 
219
  def random_mix_streams(input_streams, mapping):
@@ -263,20 +280,18 @@ def random_mix_streams(input_streams, mapping):
263
  # Build stream routing
264
  stream_routing = build_stream_routing(mapping)
265
 
266
- rand = random.Random()
267
-
268
- # Create new stream generators
269
- for new_stream_name, new_stream_sources in mapping.items():
270
- new_streams[new_stream_name] = ReusableGenerator(
271
- random_mix_generator,
272
- gen_kwargs={
273
- "new_stream_name": new_stream_name,
274
- "new_stream_sources": new_stream_sources,
275
- "stream_routing": stream_routing,
276
- "rand": rand,
277
- "input_streams": input_streams,
278
- },
279
- )
280
 
281
  return new_streams
282
 
 
1
  import itertools
 
2
  import re
3
+ from typing import Dict
4
 
5
  from .generator_utils import ReusableGenerator
6
+ from .random_utils import nested_seed
7
+ from .stream import Stream
8
 
9
 
10
  def parse_random_mix_string(input_str):
 
69
  result_dict = {}
70
 
71
  # Split the input string into a list of sources
72
+ sources = re.split(r"\+", input_str)
73
  for source in sources:
74
  # If the source has a slice, parse it
75
  match = re.fullmatch(r"(\w+)\[(\d*):(\d*)\]", source)
 
121
  the new streams, which consist of parts of the old streams chained together.
122
 
123
  Raises:
124
+ ValueError: If a stream is supposed to be sliced at an index greater than its length or a negative one.
125
 
126
  Example:
127
  >>> old_streams = {"train": [1, 2, 3, 4, 5, 6, 7, 8, 9], "test": [10, 11, 12, 13, 14]}
 
207
  return stream_mapping
208
 
209
 
210
+ def rename_split(input_streams: Dict[str, Stream], mapping: Dict[str, str]):
211
+ """
212
+ Renames the streams
213
+ Args:
214
+ input_streams (dict): A dictionary containing the input streams, where each key is
215
+ the name of the stream and the value is an iterable or generator
216
+ representing the stream.
217
+
218
+ mapping (dict): A dictionary specifying the mapping of old streams to new streams.
219
+
220
+ Returns:
221
+ dict: A dictionary containing the generated new streams, where each key is the name
222
+ of the new stream and the value is a generator representing the stream."""
223
+ return {mapping.get(key, key): val for key, val in input_streams.items()}
224
+
225
+
226
+ def random_mix_generator(new_stream_name, new_stream_sources, stream_routing, input_streams):
227
  for old_stream_name in new_stream_sources:
228
  optinal_streams, weights = stream_routing[old_stream_name]
229
+ with nested_seed(old_stream_name) as rand:
230
+ for item in input_streams[old_stream_name]:
231
+ choice = rand.choices(optinal_streams, weights=weights, k=1)[0]
232
+ if choice == new_stream_name:
233
+ yield item
 
234
 
235
 
236
  def random_mix_streams(input_streams, mapping):
 
280
  # Build stream routing
281
  stream_routing = build_stream_routing(mapping)
282
 
283
+ with nested_seed():
284
+ # Create new stream generators
285
+ for new_stream_name, new_stream_sources in mapping.items():
286
+ new_streams[new_stream_name] = ReusableGenerator(
287
+ random_mix_generator,
288
+ gen_kwargs={
289
+ "new_stream_name": new_stream_name,
290
+ "new_stream_sources": new_stream_sources,
291
+ "stream_routing": stream_routing,
292
+ "input_streams": input_streams,
293
+ },
294
+ )
 
 
295
 
296
  return new_streams
297