|
import itertools |
|
import re |
|
from typing import Dict |
|
|
|
from .generator_utils import ReusableGenerator |
|
from .logging_utils import get_logger |
|
from .random_utils import new_random_generator |
|
from .stream import MissingStreamError, Stream |
|
|
|
logger = get_logger() |
|
|
|
|
|
def parse_random_mix_string(input_str): |
|
"""Parses a string of format "source1[percentage1%]+source2[value2]+..." and returns a dictionary. |
|
|
|
Args: |
|
input_str (str): A string containing source names and their respective proportions. The format is |
|
"source[proportion%]" or "source[proportion]", with multiple sources separated by "+". |
|
The proportion can be a percentage (e.g., "90%") or a decimal number (e.g., "0.7"). |
|
If the proportion is not provided, it assumes 100%. |
|
|
|
Returns: |
|
dict: A dictionary where the keys are the source names and the values are the proportions converted to floats. |
|
If the proportion was given as a percentage, the value is divided by 100. |
|
|
|
Raises: |
|
ValueError: If the input string is not in the correct format. |
|
|
|
Example: |
|
>>> parse_random_mix_string("dale[90%]+oren[0.7]+mike") |
|
{'dale': 0.9, 'oren': 0.7, 'mike': 1.0} |
|
""" |
|
if not re.fullmatch( |
|
r"((\w+\[\d*\.?\d*%?\]|\w+)\+)*(\w+\[\d*\.?\d*%?\]|\w+)", |
|
input_str, |
|
): |
|
raise ValueError(f"Invalid input format for split '{input_str}'") |
|
|
|
pattern = re.compile(r"(\w+)(\[\d*\.?\d*%?\])?") |
|
matches = pattern.findall(input_str) |
|
|
|
return { |
|
name: float(value.strip("[]%")) / 100 |
|
if "%" in value |
|
else (float(value.strip("[]")) if value else 1.0) |
|
for name, value in matches |
|
} |
|
|
|
|
|
def parse_slices_string(input_str): |
|
"""Parses a string of format "source1[value1:value2] + source2[value2:] + source3 + ..." and returns a dictionary. |
|
|
|
{"source1": [(value1,value2)], "source2": [(value2, None)], "source3": [(None,None)]...}. |
|
|
|
If a source appears multiple times with different indices, all index pairs are included in the list. |
|
|
|
Args: |
|
input_str (str): A string containing source names and their respective indices. The format is |
|
"source[:index]" or "source[index:]", with multiple sources separated by "+". |
|
The index represents the items to be taken from the source. |
|
|
|
Returns: |
|
dict: A dictionary where the keys are the source names and the values are lists of indices as tuples. |
|
If the index is before the colon, it is represented as (None, index), |
|
if it's after the colon, it's represented as (index, None) |
|
|
|
Raises: |
|
ValueError: If the input string is not in the correct format. |
|
|
|
Example: |
|
>>> parse_slices_string("oren[:50]+jake[24:]+test+oren[5:10]") |
|
{'oren': [(None, 50), (5, 10)], 'jake': [(24, None)], 'test': [(None, None)]} |
|
""" |
|
result_dict = {} |
|
|
|
|
|
sources = re.split(r"\+", input_str) |
|
for source in sources: |
|
|
|
match = re.fullmatch(r"(\w+)\[(\d*):(\d*)\]", source) |
|
if match: |
|
name, start, end = match.groups() |
|
start = int(start) if start else None |
|
end = int(end) if end else None |
|
elif re.fullmatch(r"\w+", source): |
|
|
|
name = source |
|
start = end = None |
|
else: |
|
raise ValueError( |
|
f'The input string "{input_str}" is not in the correct format.' |
|
) |
|
|
|
if name not in result_dict: |
|
result_dict[name] = [(start, end)] |
|
else: |
|
result_dict[name].append((start, end)) |
|
|
|
return result_dict |
|
|
|
|
|
def slice_stream(stream, start, end): |
|
|
|
if start is not None: |
|
stream = itertools.islice(stream, start, None) |
|
|
|
if end is not None: |
|
stream = itertools.islice(stream, end) |
|
|
|
yield from stream |
|
|
|
|
|
|
|
def slice_streams(input_streams, mapping): |
|
"""Slices multiple input streams according to a mapping and chains the results together. |
|
|
|
Args: |
|
input_streams (dict): A dictionary where the keys are the names of the input streams |
|
and the values are the input streams themselves. |
|
mapping (dict): A dictionary where the keys are the names of the new streams |
|
and the values are dictionaries mapping old stream names |
|
to lists of tuples representing slices. |
|
|
|
Returns: |
|
dict: A dictionary where the keys are the names of the new streams and the values are |
|
the new streams, which consist of parts of the old streams chained together. |
|
|
|
Raises: |
|
ValueError: If a stream is supposed to be sliced at an index greater than its length or a negative one. |
|
|
|
Example: |
|
>>> old_streams = {"train": [1, 2, 3, 4, 5, 6, 7, 8, 9], "test": [10, 11, 12, 13, 14]} |
|
>>> mapping = {"new_train": {"train": [(None, 5), (7, 9)]}, "new_test": {"test": [(2, None)]}} |
|
>>> slice_streams(old_streams, mapping) |
|
{"new_train": [1, 2, 3, 4, 5, 8, 9], "new_test": [12, 13, 14]} |
|
""" |
|
new_streams = {} |
|
for new_stream, sources in mapping.items(): |
|
|
|
def generator(new_stream, sources): |
|
for old_stream, slices in sources.items(): |
|
if old_stream not in input_streams: |
|
raise MissingStreamError( |
|
f"'{old_stream}' is not available in input streams, but need to slice there from" |
|
) |
|
old_stream_content = input_streams[old_stream] |
|
for start, end in slices: |
|
yield from slice_stream(old_stream_content, start, end) |
|
|
|
new_streams[new_stream] = ReusableGenerator( |
|
generator, gen_kwargs={"new_stream": new_stream, "sources": sources} |
|
) |
|
|
|
return new_streams |
|
|
|
|
|
def build_stream_routing(mapping): |
|
"""Builds the stream mapping dictionary based on the provided mapping. |
|
|
|
The stream mapping dictionary represents the mapping of old streams to new streams |
|
and their respective probabilities. It ensures that the probabilities for each old stream |
|
do not sum up to more than one. If the sum of probabilities is less than one, |
|
a null stream (None) is included to account for the remaining probability. |
|
|
|
Args: |
|
mapping (dict): A dictionary specifying the mapping of old streams to new streams |
|
and their respective probabilities. |
|
|
|
Returns: |
|
dict: A dictionary representing the stream mapping, where each entry corresponds to an |
|
old stream, and the value is a tuple containing the new streams and their respective |
|
probabilities. |
|
|
|
Example: |
|
>>> mapping = { |
|
'my_new_stream': { |
|
'my_old_stream1': 0.6, |
|
'my_old_stream2': 0.2 |
|
}, |
|
'my_new_stream2': { |
|
'my_old_stream1': 0.4, |
|
'my_old_stream2': 0.8 |
|
} |
|
} |
|
stream_mapping = build_stream_mapping(mapping) |
|
logger.info(stream_mapping) |
|
# Output: {'my_old_stream1': (['my_new_stream', 'my_new_stream2'], [0.6, 0.4]), |
|
# 'my_old_stream2': (['my_new_stream', 'my_new_stream2'], [0.2, 0.8])} |
|
""" |
|
stream_mapping = {} |
|
|
|
|
|
total_weights = {} |
|
for _new_stream, old_streams in mapping.items(): |
|
for old_stream, weight in old_streams.items(): |
|
if old_stream not in total_weights: |
|
total_weights[old_stream] = weight |
|
else: |
|
total_weights[old_stream] += weight |
|
|
|
|
|
for new_stream, old_streams in mapping.items(): |
|
for old_stream, weight in old_streams.items(): |
|
if old_stream not in stream_mapping: |
|
stream_mapping[old_stream] = {} |
|
stream_mapping[old_stream][new_stream] = weight |
|
|
|
|
|
if total_weights[old_stream] < 1: |
|
stream_mapping[old_stream][None] = 1 - total_weights[old_stream] |
|
|
|
return {k: (list(v.keys()), list(v.values())) for k, v in stream_mapping.items()} |
|
|
|
|
|
def rename_split(input_streams: Dict[str, Stream], mapping: Dict[str, str]): |
|
"""Renames the streams. |
|
|
|
Args: |
|
input_streams (dict): A dictionary containing the input streams, where each key is |
|
the name of the stream and the value is an iterable or generator |
|
representing the stream. |
|
|
|
mapping (dict): A dictionary specifying the mapping of old streams to new streams. |
|
|
|
Returns: |
|
dict: A dictionary containing the generated new streams, where each key is the name |
|
of the new stream and the value is a generator representing the stream. |
|
""" |
|
new_streams = {} |
|
for key, val in mapping.items(): |
|
if key not in input_streams: |
|
raise ValueError( |
|
f"Stream '{key}' is not in input_streams '{input_streams.keys()}'" |
|
) |
|
new_streams[val] = input_streams.pop(key) |
|
return {**input_streams, **new_streams} |
|
|
|
|
|
def random_mix_generator( |
|
new_stream_name, new_stream_sources, stream_routing, input_streams |
|
): |
|
for old_stream_name in new_stream_sources: |
|
optional_streams, weights = stream_routing[old_stream_name] |
|
random_generator = new_random_generator(sub_seed=old_stream_name) |
|
assert ( |
|
old_stream_name in input_streams |
|
), f"'{old_stream_name}' split not found. Possibles options: {input_streams.keys()}" |
|
for item in input_streams[old_stream_name]: |
|
choice = random_generator.choices(optional_streams, weights=weights, k=1)[0] |
|
if choice == new_stream_name: |
|
yield item |
|
|
|
|
|
def random_mix_streams(input_streams, mapping): |
|
"""Creates new streams based on the provided input streams and mapping. |
|
|
|
The create_streams function generates new streams by selectively including items from |
|
the old streams based on the specified mapping. Each item will be included in at most |
|
one new stream, as defined by the probabilities in the mapping and stream routing. |
|
|
|
Args: |
|
input_streams (dict): A dictionary containing the input streams, where each key is |
|
the name of the stream and the value is an iterable or generator |
|
representing the stream. |
|
|
|
mapping (dict): A dictionary specifying the mapping of old streams to new streams |
|
and their respective probabilities. |
|
|
|
Returns: |
|
dict: A dictionary containing the generated new streams, where each key is the name |
|
of the new stream and the value is a generator representing the stream. |
|
|
|
Example: |
|
>>> input_streams = { |
|
'my_old_stream1': gen1(), |
|
'my_old_stream2': gen2(), |
|
} |
|
mapping = { |
|
'my_new_stream': { |
|
'my_old_stream1': 0.6, |
|
'my_old_stream2': 0.2 |
|
}, |
|
'my_new_stream2': { |
|
'my_old_stream1': 0.4, |
|
'my_old_stream2': 0.8 |
|
} |
|
} |
|
new_streams = create_streams(input_streams, mapping) |
|
for new_stream_name, new_stream in new_streams.items(): |
|
logger.info(f"{new_stream_name}:") |
|
for _, item in zip(range(10), new_stream): |
|
logger.info(item) |
|
""" |
|
new_streams = {} |
|
|
|
|
|
stream_routing = build_stream_routing(mapping) |
|
|
|
|
|
for new_stream_name, new_stream_sources in mapping.items(): |
|
new_streams[new_stream_name] = ReusableGenerator( |
|
random_mix_generator, |
|
gen_kwargs={ |
|
"new_stream_name": new_stream_name, |
|
"new_stream_sources": new_stream_sources, |
|
"stream_routing": stream_routing, |
|
"input_streams": input_streams, |
|
}, |
|
) |
|
|
|
return new_streams |
|
|
|
|
|
if __name__ == "__main__": |
|
logger.info(parse_random_mix_string("dale[90%]+oren[0.7]+mike")) |
|
|