Source code for datasets.iterable_dataset

import copy
from dataclasses import dataclass
from itertools import cycle, islice, repeat
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Union

import numpy as np
import pyarrow as pa

from .arrow_dataset import DatasetInfoMixin
from .features import Features
from .formatting import PythonFormatter
from .info import DatasetInfo
from .splits import NamedSplit


def _infer_features_from_batch(batch: Dict[str, list], try_features: Optional[Features] = None) -> Features:
    pa_table = pa.Table.from_pydict(batch)
    if try_features is not None:
        try:
            pa_table = pa_table.cast(pa.schema(try_features.type))
        except (pa.ArrowInvalid, pa.ArrowNotImplementedError):
            pass
    return Features.from_arrow_schema(pa_table.schema)


def _examples_to_batch(examples: List[Dict[str, Any]]) -> Dict[str, list]:
    cols = sorted(examples[0].keys())
    arrays = []
    for col in cols:
        arrays.append([example[col] for example in examples])
    return dict(zip(cols, arrays))


def _batch_to_examples(batch: Dict[str, list]) -> List[Dict[str, Any]]:
    """Convert a batch (dict of examples) to examples list"""
    n_examples = len(batch[next(iter(batch))])
    for i in range(n_examples):
        yield {col: array[i] for col, array in batch.items()}


class _BaseExamplesIterable:
    """Base class for the examples iterable used by an IterableDataset"""

    def __iter__(self):
        """An examples iterable should yield tuples (example_key, example) of type (int/str, dict)"""
        raise NotImplementedError()

    def shuffle_data_sources(self, seed: Optional[int]) -> "_BaseExamplesIterable":
        """
        Either shuffle the shards/sources of the dataset, or propagate the shuffling to the underlying iterable.
        If the order of the shards must stay fixed (when using .skip or .take for example), then this method returns self.
        """
        raise NotImplementedError()

    @property
    def n_shards(self) -> int:
        raise NotImplementedError()


def _shuffle_kwargs(rng: np.random.Generator, kwargs: dict) -> dict:
    shuffled_kwargs = {}
    for key, value in sorted(kwargs.items()):
        if isinstance(value, list):
            value = list(value)
            rng.shuffle(value)
            shuffled_kwargs[key] = value
        else:
            shuffled_kwargs[key] = value
    return shuffled_kwargs


class ExamplesIterable(_BaseExamplesIterable):
    def __init__(self, generate_examples_fn: Callable, kwargs: dict):
        self.generate_examples_fn = generate_examples_fn
        self.kwargs = kwargs

    def __iter__(self):
        for key, example in self.generate_examples_fn(**self.kwargs):
            yield key, example

    def shuffle_data_sources(self, seed: Optional[int]) -> "ExamplesIterable":
        return ShardShuffledExamplesIterable(self.generate_examples_fn, self.kwargs, seed)

    @property
    def n_shards(self) -> int:
        max_length = max([len(value) for value in self.kwargs.values() if isinstance(value, list)], default=0)
        return max(1, max_length)


class ShardShuffledExamplesIterable(ExamplesIterable):
    def __init__(self, generate_examples_fn: Callable, kwargs: dict, seed: Optional[int]):
        super().__init__(generate_examples_fn, kwargs)
        self.seed = seed

    def __iter__(self):
        """Shuffle the kwargs order to shuffle shards"""
        rng = np.random.default_rng(self.seed)
        kwargs_with_shuffled_shards = _shuffle_kwargs(rng, self.kwargs)
        for key, example in self.generate_examples_fn(**kwargs_with_shuffled_shards):
            yield key, example


class CyclingMultiSourcesExamplesIterable(_BaseExamplesIterable):
    def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
        self.ex_iterables = ex_iterables

    def __iter__(self):
        iterators = [iter(ex_iterable) for ex_iterable in self.ex_iterables]
        # this is an infinite iterator to keep track of which iterator we want to pick examples from
        indices_iterator = cycle(range(len(iterators)))
        for i in indices_iterator:
            try:  # let's pick one example from the iterator at index i
                yield next(iterators[i])
            except StopIteration:  # if we ran out of examples on this iterator, break the main for loop
                break

    def shuffle_data_sources(self, seed: Optional[int]) -> "CyclingMultiSourcesExamplesIterable":
        """Shuffle each underlying examples iterable."""
        ex_iterables = [ex_iterable.shuffle_data_sources(seed) for ex_iterable in self.ex_iterables]
        return CyclingMultiSourcesExamplesIterable(ex_iterables)

    @property
    def n_shards(self) -> int:
        return sum(ex_iterable.n_shards for ex_iterable in self.ex_iterables)


class RandomlyCyclingMultiSourcesExamplesIterable(CyclingMultiSourcesExamplesIterable):
    def __init__(self, ex_iterables, seed: Optional[int] = None, probabilities: Optional[List[float]] = None):
        super().__init__(ex_iterables)
        self.seed = seed
        self.probabilities = probabilities

    @staticmethod
    def _iter_random_indices(
        rng: np.random.Generator,
        num_sources: int,
        random_batch_size=1000,
        p: Optional[List[float]] = None,
    ) -> Iterator[int]:
        """Get an infinite iterator that randomly samples the index of the source to pick examples from."""
        if p is None:
            while True:
                yield from (int(i) for i in rng.integers(0, num_sources, size=random_batch_size))
        else:
            while True:
                yield from (int(i) for i in rng.choice(num_sources, size=random_batch_size, p=p))

    def __iter__(self):
        rng = np.random.default_rng(self.seed)
        iterators = [iter(ex_iterable) for ex_iterable in self.ex_iterables]
        # this is an infinite iterator that randomly samples the index of the source to pick examples from
        indices_iterator = self._iter_random_indices(rng, len(iterators), p=self.probabilities)
        for i in indices_iterator:
            try:  # let's pick one example from the iterator at index i
                yield next(iterators[i])
            except StopIteration:  # if we ran out of examples on this iterator, break the main for loop
                break

    def shuffle_data_sources(self, seed: Optional[int]) -> "RandomlyCyclingMultiSourcesExamplesIterable":
        """Shuffle the data sources of each wrapped examples iterable."""
        ex_iterables = [ex_iterable.shuffle_data_sources(seed) for ex_iterable in self.ex_iterables]
        return RandomlyCyclingMultiSourcesExamplesIterable(ex_iterables, seed=seed, probabilities=self.probabilities)


class MappedExamplesIterable(_BaseExamplesIterable):
    def __init__(
        self, ex_iterable: _BaseExamplesIterable, function: Callable, batched: bool = False, batch_size: int = 1000
    ):
        self.ex_iterable = ex_iterable
        self.function = function
        self.batched = batched
        self.batch_size = batch_size

    def __iter__(self):
        iterator = iter(self.ex_iterable)
        for key, example in iterator:
            if self.batched:
                # If batched, first build the batch
                key_examples_list = [(key, example)] + [
                    (key, example) for key, example in islice(iterator, self.batch_size - 1)
                ]
                keys, examples = zip(*key_examples_list)
                batch = _examples_to_batch(examples)
                # then apply the transform
                transformed_batch = self.function(batch)
                # the new key is the concatenation of the examples keys from the batch
                new_key = "_".join(str(key) for key in keys)
                # yield one example at a time from the transformed batch
                yield from zip(repeat(new_key), _batch_to_examples(transformed_batch))
            else:
                # If not batched, apply the transform and yield the example directly
                yield key, self.function(example)

    def shuffle_data_sources(self, seed: Optional[int]) -> "MappedExamplesIterable":
        """Shuffle the wrapped examples iterable."""
        return MappedExamplesIterable(
            self.ex_iterable.shuffle_data_sources(seed),
            function=self.function,
            batched=self.batched,
            batch_size=self.batch_size,
        )

    @property
    def n_shards(self) -> int:
        return self.ex_iterable.n_shards


class BufferShuffledExamplesIterable(_BaseExamplesIterable):
    def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, seed: Optional[int]):
        self.ex_iterable = ex_iterable
        self.buffer_size = buffer_size
        self.seed = seed

    @staticmethod
    def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batch_size=1000) -> Iterator[int]:
        while True:
            yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))

    def __iter__(self):
        buffer_size = self.buffer_size
        rng = np.random.default_rng(self.seed)
        indices_iterator = self._iter_random_indices(rng, buffer_size)
        # this is the shuffle buffer that we keep in memory
        mem_buffer = []
        for x in self.ex_iterable:
            if len(mem_buffer) == buffer_size:  # if the buffer is full, pick and example from it
                i = next(indices_iterator)
                yield mem_buffer[i]
                mem_buffer[i] = x  # replace the picked example by a new one
            else:  # otherwise, keep filling the buffer
                mem_buffer.append(x)
        # when we run out of examples, we shuffle the remaining examples in the buffer and yield them
        rng.shuffle(mem_buffer)
        yield from mem_buffer

    def shuffle_data_sources(self, seed: Optional[int]) -> "BufferShuffledExamplesIterable":
        """Shuffle the wrapped examples iterable as well as the shuffling buffer."""
        return BufferShuffledExamplesIterable(
            self.ex_iterable.shuffle_data_sources(seed), buffer_size=self.buffer_size, seed=seed
        )

    @property
    def n_shards(self) -> int:
        return self.ex_iterable.n_shards


class SkipExamplesIterable(_BaseExamplesIterable):
    def __init__(self, ex_iterable: _BaseExamplesIterable, n: int):
        self.ex_iterable = ex_iterable
        self.n = n

    def __iter__(self):
        ex_iterator = iter(self.ex_iterable)
        for _ in islice(ex_iterator, self.n):
            pass
        yield from ex_iterator

    def shuffle_data_sources(self, seed: Optional[int]) -> "SkipExamplesIterable":
        """Doesn't shuffle the wrapped examples iterable since it would skip exampels from other shards instead."""
        return self

    @property
    def n_shards(self) -> int:
        return self.ex_iterable.n_shards


class TakeExamplesIterable(_BaseExamplesIterable):
    def __init__(self, ex_iterable: _BaseExamplesIterable, n: int):
        self.ex_iterable = ex_iterable
        self.n = n

    def __iter__(self):
        yield from islice(self.ex_iterable, self.n)

    def shuffle_data_sources(self, seed: Optional[int]) -> "TakeExamplesIterable":
        """Doesn't shuffle the wrapped examples iterable since it would take examples from other shards instead."""
        return self

    @property
    def n_shards(self) -> int:
        return self.ex_iterable.n_shards


def _generate_examples_from_tables_wrapper(generate_tables_fn):
    def wrapper(**kwargs):
        python_formatter = PythonFormatter()
        for key, table in generate_tables_fn(**kwargs):
            batch = python_formatter.format_batch(table)
            for i, example in enumerate(_batch_to_examples(batch)):
                yield f"{key}_{i}", example

    return wrapper


@dataclass
class ShufflingConfig:
    seed: Optional[int] = None


[docs]class IterableDataset(DatasetInfoMixin): """A Dataset backed by an iterable.""" def __init__( self, ex_iterable: _BaseExamplesIterable, info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, format_type: Optional[str] = None, shuffling: Optional[ShufflingConfig] = None, ): info = info.copy() if info is not None else DatasetInfo() DatasetInfoMixin.__init__(self, info=info, split=split) self._ex_iterable = ex_iterable self._format_type = format_type self._shuffling = shuffling self._epoch = 0 def _head(self, n=5): return _examples_to_batch([x for key, x in islice(self._iter(), n)]) @property def _effective_seed(self): if self._shuffling: return self._shuffling.seed + self._epoch if self._shuffling.seed is not None else None else: return None @property def n_shards(self) -> int: return self._ex_iterable.n_shards def _iter(self): if self._shuffling: ex_iterable = self._ex_iterable.shuffle_data_sources(self._effective_seed) else: ex_iterable = self._ex_iterable yield from ex_iterable def __iter__(self): for key, example in self._iter(): if self.features: # we encode the example for ClassLabel feature types for example encoded_example = self.features.encode_example(example) # Decode example for Audio feature, e.g. decoded_example = self.features.decode_example(encoded_example) yield decoded_example else: yield example def with_format( self, type: Optional[str] = None, ) -> "IterableDataset": """ Return a dataset with the specified format. This method only supports the "torch" format for now. Args: type (:obj:`str`, optional, default None): if set to "torch", the returned dataset will be a subclass of torch.utils.data.IterableDataset to be used in a DataLoader """ # TODO(QL): add examples formatting to get tensors when using the "torch" format # TODO(QL): add format_kwargs # TODO(QL): add format_columns and return_all_columns # TODO(QL): add pandas, numpy and tf formats return iterable_dataset( ex_iterable=self._ex_iterable, info=copy.deepcopy(self._info), split=self._split, format_type=type, shuffling=copy.deepcopy(self._shuffling), )
[docs] def map(self, function: Callable, batched: bool = False, batch_size: int = 1000): """ Return a dataset with the specified map function. The function is applied on-the-fly on the examples when iterating over the dataset. You can specify whether the function should be batched or not with the ``batched`` parameter: - If batched is False, then the function takes 1 example in and should return 1 example. An example is a dictionary, e.g. {"text": "Hello there !"} - If batched is True and batch_size is 1, then the function takes a batch of 1 example as input and can return a batch with 1 or more examples. A batch is a dictionary, e.g. a batch of 1 example is {"text": ["Hello there !"]} - If batched is True and batch_size is ``n`` > 1, then the function takes a batch of ``n`` examples as input and can return a batch with ``n`` examples, or with an arbitrary number of examples. Note that the last batch may have less than ``n`` examples. A batch is a dictionary, e.g. a batch of ``n`` examples is {"text": ["Hello there !"] * n} Args: function (:obj:`Callable`, optional, default None): if not None, this function is applied on-the-fly on the examples when you iterate on the dataset. batched (:obj:`bool`, default `False`): Provide batch of examples to `function`. batch_size (:obj:`int`, optional, default ``1000``): Number of examples per batch provided to `function` if `batched=True`. """ info = copy.deepcopy(self._info) info.features = None ex_iterable = MappedExamplesIterable( self._ex_iterable, function=function, batched=batched, batch_size=batch_size ) return iterable_dataset( ex_iterable=ex_iterable, info=info, split=self._split, format_type=self._format_type, shuffling=copy.deepcopy(self._shuffling), )
[docs] def shuffle(self, buffer_size, seed=None) -> "IterableDataset": """ Randomly shuffles the elements of this dataset. This dataset fills a buffer with buffer_size elements, then randomly samples elements from this buffer, replacing the selected elements with new elements. For perfect shuffling, a buffer size greater than or equal to the full size of the dataset is required. For instance, if your dataset contains 10,000 elements but ``buffer_size`` is set to 1,000, then shuffle will initially select a random element from only the first 1,000 elements in the buffer. Once an element is selected, its space in the buffer is replaced by the next (i.e. 1,001-st) element, maintaining the 1,000 element buffer. If the dataset is made of several shards, it also does shuffle the order of the shards. However if the order has been fixed by using :func:`datasets.IterableDataset.skip` or :func:`datasets.IterableDataset.take` then the order of the shards is kept unchanged. Args: buffer_size (:obj:`int`): size of the buffer. seed (:obj:`int`, optional, default None): random seed that will be used to create the distribution. """ shuffling = ShufflingConfig(seed=seed) return iterable_dataset( ex_iterable=BufferShuffledExamplesIterable(self._ex_iterable, buffer_size, seed=seed).shuffle_data_sources( seed=seed ), info=copy.deepcopy(self._info), split=self._split, format_type=self._format_type, shuffling=shuffling, )
def set_epoch(self, epoch: int): self._epoch = epoch
[docs] def skip(self, n) -> "IterableDataset": """ Create a new IterableDataset that skips the first ``n`` elements. Args: n (:obj:`int`): number of elements to skip. """ ex_iterable = SkipExamplesIterable(self._ex_iterable, n) return iterable_dataset( ex_iterable=ex_iterable, info=copy.deepcopy(self._info), split=self._split, format_type=self._format_type, shuffling=copy.deepcopy(self._shuffling), )
[docs] def take(self, n) -> "IterableDataset": """ Create a new IterableDataset with only the first ``n`` elements. Args: n (:obj:`int`): number of elements to take. """ ex_iterable = TakeExamplesIterable(self._ex_iterable, n) return iterable_dataset( ex_iterable=ex_iterable, info=copy.deepcopy(self._info), split=self._split, format_type=self._format_type, shuffling=copy.deepcopy(self._shuffling), )
[docs] def remove_columns(self, column_names: Union[str, List[str]]) -> "IterableDataset": """ Remove one or several column(s) in the dataset and the features associated to them. The removal is done on-the-fly on the examples when iterating over the dataset. Args: column_names (:obj:`Union[str, List[str]]`): Name of the column(s) to remove. Returns: :class:`IterableDataset`: A copy of the dataset object without the columns to remove. """ if isinstance(column_names, str): column_names = [column_names] def remove_fn(example): return {k: v for k, v in example.items() if k not in column_names} return self.map(remove_fn)
def iterable_dataset( ex_iterable: Iterable, info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, format_type: Optional[str] = None, shuffling: Optional[ShufflingConfig] = None, ): if format_type is not None and format_type == "torch": import torch class TorchIterableDataset(IterableDataset, torch.utils.data.IterableDataset): pass cls = TorchIterableDataset else: cls = IterableDataset return cls( ex_iterable=ex_iterable, info=info, split=split, format_type=format_type, shuffling=shuffling, )