Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from collections import OrderedDict | |
from typing import Callable, Dict, List | |
import numpy as np | |
from . import FairseqDataset | |
def uniform_sampler(x): | |
# Sample from uniform distribution | |
return np.random.choice(x, 1).item() | |
class MultiCorpusSampledDataset(FairseqDataset): | |
""" | |
Stores multiple instances of FairseqDataset together and in every iteration | |
creates a batch by first sampling a dataset according to a specified | |
probability distribution and then getting instances from that dataset. | |
Args: | |
datasets: an OrderedDict of FairseqDataset instances. | |
sampling_func: A function for sampling over list of dataset keys. | |
The default strategy is to sample uniformly. | |
""" | |
def __init__( | |
self, | |
datasets: Dict[str, FairseqDataset], | |
sampling_func: Callable[[List], int] = None, | |
): | |
super().__init__() | |
assert isinstance(datasets, OrderedDict) | |
self.datasets = datasets | |
if sampling_func is None: | |
sampling_func = uniform_sampler | |
self.sampling_func = sampling_func | |
self.total_num_instances = 0 | |
for _, dataset in datasets.items(): | |
assert isinstance(dataset, FairseqDataset) | |
self.total_num_instances += len(dataset) | |
self._ordered_indices = None | |
def __len__(self): | |
""" | |
Length of this dataset is the sum of individual datasets | |
""" | |
return self.total_num_instances | |
def ordered_indices(self): | |
""" | |
Ordered indices for batching. Here we call the underlying | |
dataset's ordered_indices() so that we get the same random ordering | |
as we would have from using the underlying dataset directly. | |
""" | |
if self._ordered_indices is None: | |
self._ordered_indices = OrderedDict( | |
[ | |
(key, dataset.ordered_indices()) | |
for key, dataset in self.datasets.items() | |
] | |
) | |
return np.arange(len(self)) | |
def _map_index_to_dataset(self, key: int, index: int): | |
""" | |
Different underlying datasets have different lengths. In order to ensure | |
we are not accessing an index outside the range of the current dataset | |
size, we wrap around. This function should be called after we have | |
created an ordering for this and all underlying datasets. | |
""" | |
assert ( | |
self._ordered_indices is not None | |
), "Must call MultiCorpusSampledDataset.ordered_indices() first" | |
mapped_index = index % len(self.datasets[key]) | |
return self._ordered_indices[key][mapped_index] | |
def __getitem__(self, index: int): | |
""" | |
Get the item associated with index from each underlying dataset. | |
Since index is in the range of [0, TotalNumInstances], we need to | |
map the index to the dataset before retrieving the item. | |
""" | |
return OrderedDict( | |
[ | |
(key, dataset[self._map_index_to_dataset(key, index)]) | |
for key, dataset in self.datasets.items() | |
] | |
) | |
def collater(self, samples: List[Dict]): | |
""" | |
Generate a mini-batch for this dataset. | |
To convert this into a regular mini-batch we use the following | |
logic: | |
1. Select a dataset using the specified probability distribution. | |
2. Call the collater function of the selected dataset. | |
""" | |
if len(samples) == 0: | |
return None | |
selected_key = self.sampling_func(list(self.datasets.keys())) | |
selected_samples = [sample[selected_key] for sample in samples] | |
return self.datasets[selected_key].collater(selected_samples) | |
def num_tokens(self, index: int): | |
""" | |
Return an example's length (number of tokens), used for batching. Here | |
we return the max across all examples at index across all underlying | |
datasets. | |
""" | |
return max( | |
dataset.num_tokens(self._map_index_to_dataset(key, index)) | |
for key, dataset in self.datasets.items() | |
) | |
def size(self, index: int): | |
""" | |
Return an example's size as a float or tuple. Here we return the max | |
across all underlying datasets. This value is used when filtering a | |
dataset with max-positions. | |
""" | |
return max( | |
dataset.size(self._map_index_to_dataset(key, index)) | |
for key, dataset in self.datasets.items() | |
) | |
def supports_prefetch(self): | |
return all( | |
getattr(dataset, "supports_prefetch", False) | |
for dataset in self.datasets.values() | |
) | |
def prefetch(self, indices): | |
for key, dataset in self.datasets.items(): | |
dataset.prefetch( | |
[self._map_index_to_dataset(key, index) for index in indices] | |
) | |
def supports_fetch_outside_dataloader(self): | |
return all( | |
self.datasets[key].supports_fetch_outside_dataloader | |
for key in self.datasets | |
) | |