Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import numpy as np | |
from fairseq.data import data_utils | |
from . import BaseWrapperDataset | |
class TruncateDataset(BaseWrapperDataset): | |
"""Truncate a sequence by returning the first truncation_length tokens""" | |
def __init__(self, dataset, truncation_length): | |
super().__init__(dataset) | |
assert truncation_length is not None | |
self.truncation_length = truncation_length | |
self.dataset = dataset | |
def __getitem__(self, index): | |
item = self.dataset[index] | |
item_len = item.size(0) | |
if item_len > self.truncation_length: | |
item = item[: self.truncation_length] | |
return item | |
def sizes(self): | |
return np.minimum(self.dataset.sizes, self.truncation_length) | |
def __len__(self): | |
return len(self.dataset) | |
class RandomCropDataset(TruncateDataset): | |
"""Truncate a sequence by returning a random crop of truncation_length tokens""" | |
def __init__(self, dataset, truncation_length, seed=1): | |
super().__init__(dataset, truncation_length) | |
self.seed = seed | |
self.epoch = 0 | |
def can_reuse_epoch_itr_across_epochs(self): | |
return True # only the crop changes, not item sizes | |
def set_epoch(self, epoch, **unused): | |
super().set_epoch(epoch) | |
self.epoch = epoch | |
def __getitem__(self, index): | |
with data_utils.numpy_seed(self.seed, self.epoch, index): | |
item = self.dataset[index] | |
item_len = item.size(0) | |
excess = item_len - self.truncation_length | |
if excess > 0: | |
start_idx = np.random.randint(0, excess) | |
item = item[start_idx : start_idx + self.truncation_length] | |
return item | |
def maybe_shorten_dataset( | |
dataset, | |
split, | |
shorten_data_split_list, | |
shorten_method, | |
tokens_per_sample, | |
seed, | |
): | |
truncate_split = ( | |
split in shorten_data_split_list.split(",") or len(shorten_data_split_list) == 0 | |
) | |
if shorten_method == "truncate" and truncate_split: | |
dataset = TruncateDataset(dataset, tokens_per_sample) | |
elif shorten_method == "random_crop" and truncate_split: | |
dataset = RandomCropDataset(dataset, tokens_per_sample, seed) | |
return dataset | |