OFA-Visual_Grounding / fairseq /fairseq /data /multi_corpus_sampled_dataset.py
JustinLin610
update
10b0761
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from typing import Callable, Dict, List
import numpy as np
from . import FairseqDataset
def uniform_sampler(x):
# Sample from uniform distribution
return np.random.choice(x, 1).item()
class MultiCorpusSampledDataset(FairseqDataset):
"""
Stores multiple instances of FairseqDataset together and in every iteration
creates a batch by first sampling a dataset according to a specified
probability distribution and then getting instances from that dataset.
Args:
datasets: an OrderedDict of FairseqDataset instances.
sampling_func: A function for sampling over list of dataset keys.
The default strategy is to sample uniformly.
"""
def __init__(
self,
datasets: Dict[str, FairseqDataset],
sampling_func: Callable[[List], int] = None,
):
super().__init__()
assert isinstance(datasets, OrderedDict)
self.datasets = datasets
if sampling_func is None:
sampling_func = uniform_sampler
self.sampling_func = sampling_func
self.total_num_instances = 0
for _, dataset in datasets.items():
assert isinstance(dataset, FairseqDataset)
self.total_num_instances += len(dataset)
self._ordered_indices = None
def __len__(self):
"""
Length of this dataset is the sum of individual datasets
"""
return self.total_num_instances
def ordered_indices(self):
"""
Ordered indices for batching. Here we call the underlying
dataset's ordered_indices() so that we get the same random ordering
as we would have from using the underlying dataset directly.
"""
if self._ordered_indices is None:
self._ordered_indices = OrderedDict(
[
(key, dataset.ordered_indices())
for key, dataset in self.datasets.items()
]
)
return np.arange(len(self))
def _map_index_to_dataset(self, key: int, index: int):
"""
Different underlying datasets have different lengths. In order to ensure
we are not accessing an index outside the range of the current dataset
size, we wrap around. This function should be called after we have
created an ordering for this and all underlying datasets.
"""
assert (
self._ordered_indices is not None
), "Must call MultiCorpusSampledDataset.ordered_indices() first"
mapped_index = index % len(self.datasets[key])
return self._ordered_indices[key][mapped_index]
def __getitem__(self, index: int):
"""
Get the item associated with index from each underlying dataset.
Since index is in the range of [0, TotalNumInstances], we need to
map the index to the dataset before retrieving the item.
"""
return OrderedDict(
[
(key, dataset[self._map_index_to_dataset(key, index)])
for key, dataset in self.datasets.items()
]
)
def collater(self, samples: List[Dict]):
"""
Generate a mini-batch for this dataset.
To convert this into a regular mini-batch we use the following
logic:
1. Select a dataset using the specified probability distribution.
2. Call the collater function of the selected dataset.
"""
if len(samples) == 0:
return None
selected_key = self.sampling_func(list(self.datasets.keys()))
selected_samples = [sample[selected_key] for sample in samples]
return self.datasets[selected_key].collater(selected_samples)
def num_tokens(self, index: int):
"""
Return an example's length (number of tokens), used for batching. Here
we return the max across all examples at index across all underlying
datasets.
"""
return max(
dataset.num_tokens(self._map_index_to_dataset(key, index))
for key, dataset in self.datasets.items()
)
def size(self, index: int):
"""
Return an example's size as a float or tuple. Here we return the max
across all underlying datasets. This value is used when filtering a
dataset with max-positions.
"""
return max(
dataset.size(self._map_index_to_dataset(key, index))
for key, dataset in self.datasets.items()
)
@property
def supports_prefetch(self):
return all(
getattr(dataset, "supports_prefetch", False)
for dataset in self.datasets.values()
)
def prefetch(self, indices):
for key, dataset in self.datasets.items():
dataset.prefetch(
[self._map_index_to_dataset(key, index) for index in indices]
)
@property
def supports_fetch_outside_dataloader(self):
return all(
self.datasets[key].supports_fetch_outside_dataloader
for key in self.datasets
)