TomatoCocotree
上传
6a62ffb
raw
history blame
1.56 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import FairseqDataset
class ConcatSentencesDataset(FairseqDataset):
def __init__(self, *datasets):
super().__init__()
self.datasets = datasets
assert all(
len(ds) == len(datasets[0]) for ds in datasets
), "datasets must have the same length"
def __getitem__(self, index):
return torch.cat([ds[index] for ds in self.datasets])
def __len__(self):
return len(self.datasets[0])
def collater(self, samples):
return self.datasets[0].collater(samples)
@property
def sizes(self):
return sum(ds.sizes for ds in self.datasets)
def num_tokens(self, index):
return sum(ds.num_tokens(index) for ds in self.datasets)
def size(self, index):
return sum(ds.size(index) for ds in self.datasets)
def ordered_indices(self):
return self.datasets[0].ordered_indices()
@property
def supports_prefetch(self):
return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
def prefetch(self, indices):
for ds in self.datasets:
if getattr(ds, "supports_prefetch", False):
ds.prefetch(indices)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)