metric / stream.py
Elron's picture
Upload stream.py with huggingface_hub
fa27fcf verified
raw
history blame
6.04 kB
import tempfile
from typing import Dict, Iterable
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
from .dataclass import Dataclass, OptionalField
from .generator_utils import CopyingReusableGenerator, ReusableGenerator
class Stream(Dataclass):
"""A class for handling streaming data in a customizable way.
This class provides methods for generating, caching, and manipulating streaming data.
Attributes:
generator (function): A generator function for streaming data. :no-index:
gen_kwargs (dict, optional): A dictionary of keyword arguments for the generator function. :no-index:
caching (bool): Whether the data is cached or not. :no-index:
"""
generator: callable
gen_kwargs: Dict[str, any] = OptionalField(default_factory=dict)
caching: bool = False
copying: bool = False
def _get_initiator(self):
"""Private method to get the correct initiator based on the streaming and caching attributes.
Returns:
function: The correct initiator function.
"""
if self.caching:
return Dataset.from_generator
if self.copying:
return CopyingReusableGenerator
return ReusableGenerator
def _get_stream(self):
"""Private method to get the stream based on the initiator function.
Returns:
object: The stream object.
"""
return self._get_initiator()(self.generator, gen_kwargs=self.gen_kwargs)
def __iter__(self):
return iter(self._get_stream())
def peek(self):
return next(iter(self))
def take(self, n):
for i, instance in enumerate(self):
if i >= n:
break
yield instance
class MultiStream(dict):
"""A class for handling multiple streams of data in a dictionary-like format.
This class extends dict and its values should be instances of the Stream class.
Attributes:
data (dict): A dictionary of Stream objects.
"""
def __init__(self, data=None):
"""Initializes the MultiStream with the provided data.
Args:
data (dict, optional): A dictionary of Stream objects. Defaults to None.
Raises:
AssertionError: If the values are not instances of Stream or keys are not strings.
"""
for key, value in data.items():
isinstance(value, Stream), "MultiStream values must be Stream"
isinstance(key, str), "MultiStream keys must be strings"
super().__init__(data)
def get_generator(self, key):
"""Gets a generator for a specified key.
Args:
key (str): The key for the generator.
Yields:
object: The next value in the stream.
"""
yield from self[key]
def set_caching(self, caching: bool):
for stream in self.values():
stream.caching = caching
def set_copying(self, copying: bool):
for stream in self.values():
stream.copying = copying
def to_dataset(self, disable_cache=True, cache_dir=None) -> DatasetDict:
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
cache_dir = dir_to_be_deleted if disable_cache else cache_dir
return DatasetDict(
{
key: Dataset.from_generator(
self.get_generator,
keep_in_memory=disable_cache,
cache_dir=cache_dir,
gen_kwargs={"key": key},
)
for key in self.keys()
}
)
def to_iterable_dataset(self) -> IterableDatasetDict:
return IterableDatasetDict(
{
key: IterableDataset.from_generator(
self.get_generator, gen_kwargs={"key": key}
)
for key in self.keys()
}
)
def __setitem__(self, key, value):
assert isinstance(value, Stream), "StreamDict values must be Stream"
assert isinstance(key, str), "StreamDict keys must be strings"
super().__setitem__(key, value)
@classmethod
def from_generators(
cls, generators: Dict[str, ReusableGenerator], caching=False, copying=False
):
"""Creates a MultiStream from a dictionary of ReusableGenerators.
Args:
generators (Dict[str, ReusableGenerator]): A dictionary of ReusableGenerators.
caching (bool, optional): Whether the data should be cached or not. Defaults to False.
copying (bool, optional): Whether the data should be copied or not. Defaults to False.
Returns:
MultiStream: A MultiStream object.
"""
assert all(isinstance(v, ReusableGenerator) for v in generators.values())
return cls(
{
key: Stream(
generator.generator,
gen_kwargs=generator.gen_kwargs,
caching=caching,
copying=copying,
)
for key, generator in generators.items()
}
)
@classmethod
def from_iterables(
cls, iterables: Dict[str, Iterable], caching=False, copying=False
):
"""Creates a MultiStream from a dictionary of iterables.
Args:
iterables (Dict[str, Iterable]): A dictionary of iterables.
caching (bool, optional): Whether the data should be cached or not. Defaults to False.
copying (bool, optional): Whether the data should be copied or not. Defaults to False.
Returns:
MultiStream: A MultiStream object.
"""
return cls(
{
key: Stream(
iterable.__iter__,
caching=caching,
copying=copying,
)
for key, iterable in iterables.items()
}
)