Spaces:
Runtime error
Runtime error
# ---------------------------------------------------------------------------- | |
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329) | |
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM | |
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4 | |
# | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# ---------------------------------------------------------------------------- | |
import logging | |
import numpy as np | |
import torch | |
import os | |
import itertools | |
from fairseq.data import FairseqDataset, data_utils | |
from fairseq.data import ( | |
AppendTokenDataset, | |
ConcatDataset, | |
PrependTokenDataset, | |
data_utils, | |
indexed_dataset, | |
) | |
logger = logging.getLogger(__name__) | |
def load_langtriple_dataset( | |
data_path, | |
split, | |
src, | |
src_dict, | |
ref, | |
ref_dict, | |
tgt, | |
tgt_dict, | |
combine, | |
dataset_impl, | |
upsample_primary, | |
left_pad_source, | |
left_pad_target, | |
max_source_positions, | |
max_target_positions, | |
prepend_bos=False, | |
load_alignments=False, | |
truncate_source=False, | |
append_source_id=False, | |
num_buckets=0, | |
shuffle=True, | |
pad_to_multiple=1, | |
prepend_bos_src=None, | |
lang_format="[{}]", | |
): | |
assert not truncate_source | |
def split_exists(split, src, ref, tgt, lang, data_path): | |
filename = os.path.join(data_path, "{}.{}-{}-{}.{}".format(split, src, ref, tgt, lang)) | |
return indexed_dataset.dataset_exists(filename, impl=dataset_impl) | |
src_datasets = [] | |
ref_datasets = [] | |
tgt_datasets = [] | |
for k in itertools.count(): | |
split_k = split + (str(k) if k > 0 else "") | |
# infer langcode | |
if split_exists(split_k, src, ref, tgt, src, data_path): | |
prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, src, ref, tgt)) | |
elif split_exists(split_k, tgt, ref, src, src, data_path): | |
prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, tgt, ref, src)) | |
else: | |
if k > 0: | |
break | |
else: | |
raise FileNotFoundError( | |
"Dataset not found: {} ({})".format(split, data_path) | |
) | |
src_dataset = data_utils.load_indexed_dataset( | |
prefix + src, src_dict, dataset_impl | |
) | |
src_datasets.append(src_dataset) | |
ref_dataset = data_utils.load_indexed_dataset( | |
prefix + ref, ref_dict, dataset_impl | |
) | |
ref_datasets.append(ref_dataset) | |
tgt_dataset = data_utils.load_indexed_dataset( | |
prefix + tgt, tgt_dict, dataset_impl | |
) | |
if tgt_dataset is not None: | |
tgt_datasets.append(tgt_dataset) | |
logger.info( | |
"{} {} {}-{}-{} {} examples".format( | |
data_path, split_k, src, ref, tgt, len(src_datasets[-1]) | |
) | |
) | |
if not combine: | |
break | |
assert len(src_datasets) == len(ref_datasets) | |
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 | |
if len(src_datasets) == 1: | |
src_dataset = src_datasets[0] | |
ref_dataset = ref_datasets[0] | |
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None | |
else: | |
sample_ratios = [1] * len(src_datasets) | |
sample_ratios[0] = upsample_primary | |
src_dataset = ConcatDataset(src_datasets, sample_ratios) | |
ref_dataset = ConcatDataset(ref_datasets, sample_ratios) | |
if len(tgt_datasets) > 0: | |
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) | |
else: | |
tgt_dataset = None | |
if prepend_bos: | |
assert hasattr(src_dict, "bos_index") and hasattr(ref_dict, "bos_index") and hasattr(tgt_dict, "bos_index") | |
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) | |
ref_dataset = PrependTokenDataset(ref_dataset, ref_dict.bos()) | |
if tgt_dataset is not None: | |
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) | |
elif prepend_bos_src is not None: | |
logger.info(f"prepending src bos: {prepend_bos_src}") | |
src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src) | |
ref_dataset = PrependTokenDataset(ref_dataset, prepend_bos_src) | |
eos = None | |
if append_source_id: | |
src_dataset = AppendTokenDataset( | |
src_dataset, src_dict.index(lang_format.format(src)) | |
) | |
ref_dataset = AppendTokenDataset( | |
ref_dataset, ref_dict.index(lang_format.format(ref)) | |
) | |
if tgt_dataset is not None: | |
tgt_dataset = AppendTokenDataset( | |
tgt_dataset, tgt_dict.index(lang_format.format(tgt)) | |
) | |
eos = tgt_dict.index(lang_format.format(tgt)) | |
align_dataset = None | |
if load_alignments: | |
align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) | |
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): | |
align_dataset = data_utils.load_indexed_dataset( | |
align_path, None, dataset_impl | |
) | |
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None | |
return LanguageTripleDataset( | |
src_dataset, | |
src_dataset.sizes, | |
src_dict, | |
ref_dataset, | |
ref_dataset.sizes, | |
ref_dict, | |
tgt_dataset, | |
tgt_dataset_sizes, | |
tgt_dict, | |
left_pad_source=left_pad_source, | |
left_pad_target=left_pad_target, | |
align_dataset=align_dataset, | |
eos=eos, | |
num_buckets=num_buckets, | |
shuffle=shuffle, | |
pad_to_multiple=pad_to_multiple, | |
) | |
def collate( | |
samples, | |
pad_idx, | |
eos_idx, | |
left_pad_source=True, | |
left_pad_target=False, | |
input_feeding=True, | |
pad_to_length=None, | |
pad_to_multiple=1, | |
): | |
if len(samples) == 0: | |
return {} | |
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): | |
return data_utils.collate_tokens( | |
[s[key] for s in samples], | |
pad_idx, | |
None, | |
left_pad, | |
move_eos_to_beginning, | |
pad_to_length=pad_to_length, | |
pad_to_multiple=pad_to_multiple, | |
) | |
def check_alignment(alignment, src_len, tgt_len): | |
if alignment is None or len(alignment) == 0: | |
return False | |
if ( | |
alignment[:, 0].max().item() >= src_len - 1 | |
or alignment[:, 1].max().item() >= tgt_len - 1 | |
): | |
logger.warning("alignment size mismatch found, skipping alignment!") | |
return False | |
return True | |
def compute_alignment_weights(alignments): | |
""" | |
Given a tensor of shape [:, 2] containing the source-target indices | |
corresponding to the alignments, a weight vector containing the | |
inverse frequency of each target index is computed. | |
For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then | |
a tensor containing [1., 0.5, 0.5, 1] should be returned (since target | |
index 3 is repeated twice) | |
""" | |
align_tgt = alignments[:, 1] | |
_, align_tgt_i, align_tgt_c = torch.unique( | |
align_tgt, return_inverse=True, return_counts=True | |
) | |
align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] | |
return 1.0 / align_weights.float() | |
id = torch.LongTensor([s["id"] for s in samples]) | |
src_tokens = merge( | |
"source", | |
left_pad=left_pad_source, | |
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, | |
) | |
ref_tokens = merge( | |
"reference", | |
left_pad=left_pad_source, | |
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, | |
) | |
# sort by descending source length | |
src_lengths = torch.LongTensor( | |
[s["source"].ne(pad_idx).long().sum() for s in samples] | |
) | |
ref_lengths = torch.LongTensor( | |
[s["reference"].ne(pad_idx).long().sum() for s in samples] | |
) | |
src_lengths, sort_order = src_lengths.sort(descending=True) | |
id = id.index_select(0, sort_order) | |
src_tokens = src_tokens.index_select(0, sort_order) | |
ref_lengths = ref_lengths.index_select(0, sort_order) | |
ref_tokens = ref_tokens.index_select(0, sort_order) | |
prev_output_tokens = None | |
target = None | |
if samples[0].get("target", None) is not None: | |
target = merge( | |
"target", | |
left_pad=left_pad_target, | |
pad_to_length=pad_to_length["target"] | |
if pad_to_length is not None | |
else None, | |
) | |
target = target.index_select(0, sort_order) | |
tgt_lengths = torch.LongTensor( | |
[s["target"].ne(pad_idx).long().sum() for s in samples] | |
).index_select(0, sort_order) | |
ntokens = tgt_lengths.sum().item() | |
if samples[0].get("prev_output_tokens", None) is not None: | |
prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target) | |
elif input_feeding: | |
# we create a shifted version of targets for feeding the | |
# previous output token(s) into the next decoder step | |
prev_output_tokens = merge( | |
"target", | |
left_pad=left_pad_target, | |
move_eos_to_beginning=True, | |
pad_to_length=pad_to_length["target"] | |
if pad_to_length is not None | |
else None, | |
) | |
else: | |
ntokens = src_lengths.sum().item() | |
batch = { | |
"id": id, | |
"nsentences": len(samples), | |
"ntokens": ntokens, | |
"net_input": { | |
"src_tokens": src_tokens, | |
"src_lengths": src_lengths, | |
}, | |
"target": target, | |
"ref_tokens": ref_tokens, | |
"ref_lengths": ref_lengths, | |
} | |
if prev_output_tokens is not None: | |
batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select( | |
0, sort_order | |
) | |
if samples[0].get("alignment", None) is not None: | |
bsz, tgt_sz = batch["target"].shape | |
src_sz = batch["net_input"]["src_tokens"].shape[1] | |
offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) | |
offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz | |
if left_pad_source: | |
offsets[:, 0] += src_sz - src_lengths | |
if left_pad_target: | |
offsets[:, 1] += tgt_sz - tgt_lengths | |
alignments = [ | |
alignment + offset | |
for align_idx, offset, src_len, tgt_len in zip( | |
sort_order, offsets, src_lengths, tgt_lengths | |
) | |
for alignment in [samples[align_idx]["alignment"].view(-1, 2)] | |
if check_alignment(alignment, src_len, tgt_len) | |
] | |
if len(alignments) > 0: | |
alignments = torch.cat(alignments, dim=0) | |
align_weights = compute_alignment_weights(alignments) | |
batch["alignments"] = alignments | |
batch["align_weights"] = align_weights | |
if samples[0].get("constraints", None) is not None: | |
# Collate the packed constraints across the samples, padding to | |
# the length of the longest sample. | |
lens = [sample.get("constraints").size(0) for sample in samples] | |
max_len = max(lens) | |
constraints = torch.zeros((len(samples), max(lens))).long() | |
for i, sample in enumerate(samples): | |
constraints[i, 0 : lens[i]] = samples[i].get("constraints") | |
batch["constraints"] = constraints.index_select(0, sort_order) | |
return batch | |
class LanguageTripleDataset(FairseqDataset): | |
""" | |
A pair of torch.utils.data.Datasets. | |
Args: | |
src (torch.utils.data.Dataset): source dataset to wrap | |
src_sizes (List[int]): source sentence lengths | |
src_dict (~fairseq.data.Dictionary): source vocabulary | |
tgt (torch.utils.data.Dataset, optional): target dataset to wrap | |
tgt_sizes (List[int], optional): target sentence lengths | |
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary | |
left_pad_source (bool, optional): pad source tensors on the left side | |
(default: True). | |
left_pad_target (bool, optional): pad target tensors on the left side | |
(default: False). | |
shuffle (bool, optional): shuffle dataset elements before batching | |
(default: True). | |
input_feeding (bool, optional): create a shifted version of the targets | |
to be passed into the model for teacher forcing (default: True). | |
remove_eos_from_source (bool, optional): if set, removes eos from end | |
of source if it's present (default: False). | |
append_eos_to_target (bool, optional): if set, appends eos to end of | |
target if it's absent (default: False). | |
align_dataset (torch.utils.data.Dataset, optional): dataset | |
containing alignments. | |
constraints (Tensor, optional): 2d tensor with a concatenated, zero- | |
delimited list of constraints for each sentence. | |
append_bos (bool, optional): if set, appends bos to the beginning of | |
source/target sentence. | |
num_buckets (int, optional): if set to a value greater than 0, then | |
batches will be bucketed into the given number of batch shapes. | |
src_lang_id (int, optional): source language ID, if set, the collated batch | |
will contain a field 'src_lang_id' in 'net_input' which indicates the | |
source language of the samples. | |
tgt_lang_id (int, optional): target language ID, if set, the collated batch | |
will contain a field 'tgt_lang_id' which indicates the target language | |
of the samples. | |
""" | |
def __init__( | |
self, | |
src, | |
src_sizes, | |
src_dict, | |
ref, | |
ref_sizes, | |
ref_dict, | |
tgt=None, | |
tgt_sizes=None, | |
tgt_dict=None, | |
left_pad_source=True, | |
left_pad_target=False, | |
shuffle=True, | |
input_feeding=True, | |
remove_eos_from_source=False, | |
append_eos_to_target=False, | |
align_dataset=None, | |
constraints=None, | |
append_bos=False, | |
eos=None, | |
num_buckets=0, | |
src_lang_id=None, | |
tgt_lang_id=None, | |
pad_to_multiple=1, | |
): | |
if tgt_dict is not None: | |
assert src_dict.pad() == tgt_dict.pad() | |
assert src_dict.eos() == tgt_dict.eos() | |
assert src_dict.unk() == tgt_dict.unk() | |
if tgt is not None: | |
assert len(src) == len( | |
tgt | |
), "Source and target must contain the same number of examples" | |
assert len(src) == len( | |
ref | |
), "Source and reference must contain the same number of examples" | |
self.src = src | |
self.ref = ref | |
self.tgt = tgt | |
self.src_sizes = np.array(src_sizes) | |
self.ref_sizes = np.array(ref_sizes) | |
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None | |
self.sizes = ( | |
np.vstack((self.src_sizes, self.tgt_sizes)).T | |
if self.tgt_sizes is not None | |
else self.src_sizes | |
) | |
self.src_dict = src_dict | |
self.ref_dict = ref_dict | |
self.tgt_dict = tgt_dict | |
self.left_pad_source = left_pad_source | |
self.left_pad_target = left_pad_target | |
self.shuffle = shuffle | |
self.input_feeding = input_feeding | |
self.remove_eos_from_source = remove_eos_from_source | |
self.append_eos_to_target = append_eos_to_target | |
self.align_dataset = align_dataset | |
if self.align_dataset is not None: | |
assert ( | |
self.tgt_sizes is not None | |
), "Both source and target needed when alignments are provided" | |
self.constraints = constraints | |
self.append_bos = append_bos | |
self.eos = eos if eos is not None else src_dict.eos() | |
self.src_lang_id = src_lang_id | |
self.tgt_lang_id = tgt_lang_id | |
if num_buckets > 0: | |
from fairseq.data import BucketPadLengthDataset | |
self.src = BucketPadLengthDataset( | |
self.src, | |
sizes=self.src_sizes, | |
num_buckets=num_buckets, | |
pad_idx=self.src_dict.pad(), | |
left_pad=self.left_pad_source, | |
) | |
self.src_sizes = self.src.sizes | |
logger.info("bucketing source lengths: {}".format(list(self.src.buckets))) | |
self.ref = BucketPadLengthDataset( | |
self.ref, | |
sizes=self.ref_sizes, | |
num_buckets=num_buckets, | |
pad_idx=self.ref_dict.pad(), | |
left_pad=self.left_pad_source, | |
) | |
self.ref_sizes = self.ref.sizes | |
logger.info("bucketing reference lengths: {}".format(list(self.src.buckets))) | |
if self.tgt is not None: | |
self.tgt = BucketPadLengthDataset( | |
self.tgt, | |
sizes=self.tgt_sizes, | |
num_buckets=num_buckets, | |
pad_idx=self.tgt_dict.pad(), | |
left_pad=self.left_pad_target, | |
) | |
self.tgt_sizes = self.tgt.sizes | |
logger.info( | |
"bucketing target lengths: {}".format(list(self.tgt.buckets)) | |
) | |
# determine bucket sizes using self.num_tokens, which will return | |
# the padded lengths (thanks to BucketPadLengthDataset) | |
num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long]) | |
self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) | |
self.buckets = [ | |
(None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) | |
] | |
else: | |
self.buckets = None | |
self.pad_to_multiple = pad_to_multiple | |
def get_batch_shapes(self): | |
return self.buckets | |
def __getitem__(self, index): | |
tgt_item = self.tgt[index] if self.tgt is not None else None | |
src_item = self.src[index] | |
ref_item = self.ref[index] | |
# Append EOS to end of tgt sentence if it does not have an EOS and remove | |
# EOS from end of src sentence if it exists. This is useful when we use | |
# use existing datasets for opposite directions i.e., when we want to | |
# use tgt_dataset as src_dataset and vice versa | |
if self.append_eos_to_target: | |
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() | |
if self.tgt and self.tgt[index][-1] != eos: | |
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) | |
if self.append_bos: | |
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() | |
if self.tgt and self.tgt[index][0] != bos: | |
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) | |
bos = self.src_dict.bos() | |
if self.src[index][0] != bos: | |
src_item = torch.cat([torch.LongTensor([bos]), self.src[index]]) | |
if self.ref[index][0] != bos: | |
ref_item = torch.cat([torch.LongTensor([bos]), self.ref[index]]) | |
if self.remove_eos_from_source: | |
eos = self.src_dict.eos() | |
if self.src[index][-1] == eos: | |
src_item = self.src[index][:-1] | |
if self.ref[index][-1] == eos: | |
ref_item = self.ref[index][:-1] | |
example = { | |
"id": index, | |
"source": src_item, | |
"reference": ref_item, | |
"target": tgt_item, | |
} | |
if self.align_dataset is not None: | |
example["alignment"] = self.align_dataset[index] | |
if self.constraints is not None: | |
example["constraints"] = self.constraints[index] | |
return example | |
def __len__(self): | |
return len(self.src) | |
def collater(self, samples, pad_to_length=None): | |
"""Merge a list of samples to form a mini-batch. | |
Args: | |
samples (List[dict]): samples to collate | |
pad_to_length (dict, optional): a dictionary of | |
{'source': source_pad_to_length, 'target': target_pad_to_length} | |
to indicate the max length to pad to in source and target respectively. | |
Returns: | |
dict: a mini-batch with the following keys: | |
- `id` (LongTensor): example IDs in the original input order | |
- `ntokens` (int): total number of tokens in the batch | |
- `net_input` (dict): the input to the Model, containing keys: | |
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in | |
the source sentence of shape `(bsz, src_len)`. Padding will | |
appear on the left if *left_pad_source* is ``True``. | |
- `src_lengths` (LongTensor): 1D Tensor of the unpadded | |
lengths of each source sentence of shape `(bsz)` | |
- `prev_output_tokens` (LongTensor): a padded 2D Tensor of | |
tokens in the target sentence, shifted right by one | |
position for teacher forcing, of shape `(bsz, tgt_len)`. | |
This key will not be present if *input_feeding* is | |
``False``. Padding will appear on the left if | |
*left_pad_target* is ``True``. | |
- `src_lang_id` (LongTensor): a long Tensor which contains source | |
language IDs of each sample in the batch | |
- `target` (LongTensor): a padded 2D Tensor of tokens in the | |
target sentence of shape `(bsz, tgt_len)`. Padding will appear | |
on the left if *left_pad_target* is ``True``. | |
- `tgt_lang_id` (LongTensor): a long Tensor which contains target language | |
IDs of each sample in the batch | |
""" | |
res = collate( | |
samples, | |
pad_idx=self.src_dict.pad(), | |
eos_idx=self.eos, | |
left_pad_source=self.left_pad_source, | |
left_pad_target=self.left_pad_target, | |
input_feeding=self.input_feeding, | |
pad_to_length=pad_to_length, | |
pad_to_multiple=self.pad_to_multiple, | |
) | |
if self.src_lang_id is not None or self.tgt_lang_id is not None: | |
src_tokens = res["net_input"]["src_tokens"] | |
bsz = src_tokens.size(0) | |
if self.src_lang_id is not None: | |
res["net_input"]["src_lang_id"] = ( | |
torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens) | |
) | |
if self.tgt_lang_id is not None: | |
res["tgt_lang_id"] = ( | |
torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens) | |
) | |
return res | |
def num_tokens(self, index): | |
"""Return the number of tokens in a sample. This value is used to | |
enforce ``--max-tokens`` during batching.""" | |
return max( | |
self.src_sizes[index], | |
self.tgt_sizes[index] if self.tgt_sizes is not None else 0, | |
) | |
def num_tokens_vec(self, indices): | |
"""Return the number of tokens for a set of positions defined by indices. | |
This value is used to enforce ``--max-tokens`` during batching.""" | |
sizes = self.src_sizes[indices] | |
if self.tgt_sizes is not None: | |
sizes = np.maximum(sizes, self.tgt_sizes[indices]) | |
return sizes | |
def size(self, index): | |
"""Return an example's size as a float or tuple. This value is used when | |
filtering a dataset with ``--max-positions``.""" | |
return ( | |
self.src_sizes[index], | |
self.tgt_sizes[index] if self.tgt_sizes is not None else 0, | |
) | |
def ordered_indices(self): | |
"""Return an ordered list of indices. Batches will be constructed based | |
on this order.""" | |
if self.shuffle: | |
indices = np.random.permutation(len(self)).astype(np.int64) | |
else: | |
indices = np.arange(len(self), dtype=np.int64) | |
if self.buckets is None: | |
# sort by target length, then source length | |
if self.tgt_sizes is not None: | |
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] | |
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] | |
else: | |
# sort by bucketed_num_tokens, which is: | |
# max(padded_src_len, padded_tgt_len) | |
return indices[ | |
np.argsort(self.bucketed_num_tokens[indices], kind="mergesort") | |
] | |
def supports_prefetch(self): | |
return getattr(self.src, "supports_prefetch", False) and ( | |
getattr(self.tgt, "supports_prefetch", False) or self.tgt is None | |
) | |
def prefetch(self, indices): | |
self.src.prefetch(indices) | |
if self.tgt is not None: | |
self.tgt.prefetch(indices) | |
if self.align_dataset is not None: | |
self.align_dataset.prefetch(indices) | |
def filter_indices_by_size(self, indices, max_sizes): | |
"""Filter a list of sample indices. Remove those that are longer | |
than specified in max_sizes. | |
Args: | |
indices (np.array): original array of sample indices | |
max_sizes (int or list[int] or tuple[int]): max sample size, | |
can be defined separately for src and tgt (then list or tuple) | |
Returns: | |
np.array: filtered sample array | |
list: list of removed indices | |
""" | |
return data_utils.filter_paired_dataset_indices_by_size( | |
self.src_sizes, | |
self.tgt_sizes, | |
indices, | |
max_sizes, | |
) | |