# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from collections import OrderedDict import torch from torch.utils.data.dataloader import default_collate from . import FairseqDataset def _flatten(dico, prefix=None): """Flatten a nested dictionary.""" new_dico = OrderedDict() if isinstance(dico, dict): prefix = prefix + "." if prefix is not None else "" for k, v in dico.items(): if v is None: continue new_dico.update(_flatten(v, prefix + k)) elif isinstance(dico, list): for i, v in enumerate(dico): new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]")) else: new_dico = OrderedDict({prefix: dico}) return new_dico def _unflatten(dico): """Unflatten a flattened dictionary into a nested dictionary.""" new_dico = OrderedDict() for full_k, v in dico.items(): full_k = full_k.split(".") node = new_dico for k in full_k[:-1]: if k.startswith("[") and k.endswith("]"): k = int(k[1:-1]) if k not in node: node[k] = OrderedDict() node = node[k] node[full_k[-1]] = v return new_dico class NestedDictionaryDataset(FairseqDataset): def __init__(self, defn, sizes=None): super().__init__() self.defn = _flatten(defn) self.sizes = [sizes] if not isinstance(sizes, (list, tuple)) else sizes first = None for v in self.defn.values(): if not isinstance( v, ( FairseqDataset, torch.utils.data.Dataset, ), ): raise ValueError("Expected Dataset but found: {}".format(v.__class__)) first = first or v if len(v) > 0: assert len(v) == len(first), "dataset lengths must match" self._len = len(first) def __getitem__(self, index): return OrderedDict((k, ds[index]) for k, ds in self.defn.items()) def __len__(self): return self._len def collater(self, samples): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate Returns: dict: a mini-batch suitable for forwarding with a Model """ if len(samples) == 0: return {} sample = OrderedDict() for k, ds in self.defn.items(): try: sample[k] = ds.collater([s[k] for s in samples]) except NotImplementedError: sample[k] = default_collate([s[k] for s in samples]) return _unflatten(sample) def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" return max(s[index] for s in self.sizes) def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" if len(self.sizes) == 1: return self.sizes[0][index] else: return (s[index] for s in self.sizes) @property def supports_prefetch(self): """Whether this dataset supports prefetching.""" return any(ds.supports_prefetch for ds in self.defn.values()) def prefetch(self, indices): """Prefetch the data required for this epoch.""" for ds in self.defn.values(): if getattr(ds, "supports_prefetch", False): ds.prefetch(indices) @property def can_reuse_epoch_itr_across_epochs(self): return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values()) def set_epoch(self, epoch): super().set_epoch(epoch) for ds in self.defn.values(): ds.set_epoch(epoch)