# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import numpy as np import torch from fairseq.data import FairseqDataset, data_utils logger = logging.getLogger(__name__) def collate( samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, input_feeding=True, pad_to_length=None, pad_to_multiple=1, ): if len(samples) == 0: return {} def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): return data_utils.collate_tokens( [s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning, pad_to_length=pad_to_length, pad_to_multiple=pad_to_multiple, ) def check_alignment(alignment, src_len, tgt_len): if alignment is None or len(alignment) == 0: return False if ( alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1 ): logger.warning("alignment size mismatch found, skipping alignment!") return False return True def compute_alignment_weights(alignments): """ Given a tensor of shape [:, 2] containing the source-target indices corresponding to the alignments, a weight vector containing the inverse frequency of each target index is computed. For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then a tensor containing [1., 0.5, 0.5, 1] should be returned (since target index 3 is repeated twice) """ align_tgt = alignments[:, 1] _, align_tgt_i, align_tgt_c = torch.unique( align_tgt, return_inverse=True, return_counts=True ) align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] return 1.0 / align_weights.float() id = torch.LongTensor([s["id"] for s in samples]) src_tokens = merge( "source", left_pad=left_pad_source, pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, ) # sort by descending source length src_lengths = torch.LongTensor( [s["source"].ne(pad_idx).long().sum() for s in samples] ) src_lengths, sort_order = src_lengths.sort(descending=True) id = id.index_select(0, sort_order) src_tokens = src_tokens.index_select(0, sort_order) prev_output_tokens = None target = None if samples[0].get("target", None) is not None: target = merge( "target", left_pad=left_pad_target, pad_to_length=pad_to_length["target"] if pad_to_length is not None else None, ) target = target.index_select(0, sort_order) tgt_lengths = torch.LongTensor( [s["target"].ne(pad_idx).long().sum() for s in samples] ).index_select(0, sort_order) ntokens = tgt_lengths.sum().item() if samples[0].get("prev_output_tokens", None) is not None: prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target) elif input_feeding: # we create a shifted version of targets for feeding the # previous output token(s) into the next decoder step prev_output_tokens = merge( "target", left_pad=left_pad_target, move_eos_to_beginning=True, pad_to_length=pad_to_length["target"] if pad_to_length is not None else None, ) else: ntokens = src_lengths.sum().item() batch = { "id": id, "nsentences": len(samples), "ntokens": ntokens, "net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths,}, "target": target, } if prev_output_tokens is not None: batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select( 0, sort_order ) if samples[0].get("alignment", None) is not None: bsz, tgt_sz = batch["target"].shape src_sz = batch["net_input"]["src_tokens"].shape[1] offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz if left_pad_source: offsets[:, 0] += src_sz - src_lengths if left_pad_target: offsets[:, 1] += tgt_sz - tgt_lengths alignments = [ alignment + offset for align_idx, offset, src_len, tgt_len in zip( sort_order, offsets, src_lengths, tgt_lengths ) for alignment in [samples[align_idx]["alignment"].view(-1, 2)] if check_alignment(alignment, src_len, tgt_len) ] if len(alignments) > 0: alignments = torch.cat(alignments, dim=0) align_weights = compute_alignment_weights(alignments) batch["alignments"] = alignments batch["align_weights"] = align_weights if samples[0].get("constraints", None) is not None: # Collate the packed constraints across the samples, padding to # the length of the longest sample. lens = [sample.get("constraints").size(0) for sample in samples] max_len = max(lens) constraints = torch.zeros((len(samples), max(lens))).long() for i, sample in enumerate(samples): constraints[i, 0 : lens[i]] = samples[i].get("constraints") batch["constraints"] = constraints.index_select(0, sort_order) return batch class LanguagePairDataset(FairseqDataset): """ A pair of torch.utils.data.Datasets. Args: src (torch.utils.data.Dataset): source dataset to wrap src_sizes (List[int]): source sentence lengths src_dict (~fairseq.data.Dictionary): source vocabulary tgt (torch.utils.data.Dataset, optional): target dataset to wrap tgt_sizes (List[int], optional): target sentence lengths tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary left_pad_source (bool, optional): pad source tensors on the left side (default: True). left_pad_target (bool, optional): pad target tensors on the left side (default: False). shuffle (bool, optional): shuffle dataset elements before batching (default: True). input_feeding (bool, optional): create a shifted version of the targets to be passed into the model for teacher forcing (default: True). remove_eos_from_source (bool, optional): if set, removes eos from end of source if it's present (default: False). append_eos_to_target (bool, optional): if set, appends eos to end of target if it's absent (default: False). align_dataset (torch.utils.data.Dataset, optional): dataset containing alignments. constraints (Tensor, optional): 2d tensor with a concatenated, zero- delimited list of constraints for each sentence. append_bos (bool, optional): if set, appends bos to the beginning of source/target sentence. num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. src_lang_id (int, optional): source language ID, if set, the collated batch will contain a field 'src_lang_id' in 'net_input' which indicates the source language of the samples. tgt_lang_id (int, optional): target language ID, if set, the collated batch will contain a field 'tgt_lang_id' which indicates the target language of the samples. """ def __init__( self, src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, left_pad_source=True, left_pad_target=False, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, align_dataset=None, constraints=None, append_bos=False, eos=None, num_buckets=0, src_lang_id=None, tgt_lang_id=None, pad_to_multiple=1, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() assert src_dict.eos() == tgt_dict.eos() assert src_dict.unk() == tgt_dict.unk() if tgt is not None: assert len(src) == len( tgt ), "Source and target must contain the same number of examples" self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.sizes = ( np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes ) self.src_dict = src_dict self.tgt_dict = tgt_dict self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target self.shuffle = shuffle self.input_feeding = input_feeding self.remove_eos_from_source = remove_eos_from_source self.append_eos_to_target = append_eos_to_target self.align_dataset = align_dataset if self.align_dataset is not None: assert ( self.tgt_sizes is not None ), "Both source and target needed when alignments are provided" self.constraints = constraints self.append_bos = append_bos self.eos = eos if eos is not None else src_dict.eos() self.src_lang_id = src_lang_id self.tgt_lang_id = tgt_lang_id if num_buckets > 0: from fairseq.data import BucketPadLengthDataset self.src = BucketPadLengthDataset( self.src, sizes=self.src_sizes, num_buckets=num_buckets, pad_idx=self.src_dict.pad(), left_pad=self.left_pad_source, ) self.src_sizes = self.src.sizes logger.info("bucketing source lengths: {}".format(list(self.src.buckets))) if self.tgt is not None: self.tgt = BucketPadLengthDataset( self.tgt, sizes=self.tgt_sizes, num_buckets=num_buckets, pad_idx=self.tgt_dict.pad(), left_pad=self.left_pad_target, ) self.tgt_sizes = self.tgt.sizes logger.info( "bucketing target lengths: {}".format(list(self.tgt.buckets)) ) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to BucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None self.pad_to_multiple = pad_to_multiple def get_batch_shapes(self): return self.buckets def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None src_item = self.src[index] # Append EOS to end of tgt sentence if it does not have an EOS and remove # EOS from end of src sentence if it exists. This is useful when we use # use existing datasets for opposite directions i.e., when we want to # use tgt_dataset as src_dataset and vice versa if self.append_eos_to_target: eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() if self.tgt and self.tgt[index][-1] != eos: tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) if self.append_bos: bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() if self.tgt and self.tgt[index][0] != bos: tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) bos = self.src_dict.bos() if self.src[index][0] != bos: src_item = torch.cat([torch.LongTensor([bos]), self.src[index]]) if self.remove_eos_from_source: eos = self.src_dict.eos() if self.src[index][-1] == eos: src_item = self.src[index][:-1] example = { "id": index, "source": src_item, "target": tgt_item, } if self.align_dataset is not None: example["alignment"] = self.align_dataset[index] if self.constraints is not None: example["constraints"] = self.constraints[index] return example def __len__(self): return len(self.src) def collater(self, samples, pad_to_length=None): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of {'source': source_pad_to_length, 'target': target_pad_to_length} to indicate the max length to pad to in source and target respectively. Returns: dict: a mini-batch with the following keys: - `id` (LongTensor): example IDs in the original input order - `ntokens` (int): total number of tokens in the batch - `net_input` (dict): the input to the Model, containing keys: - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in the source sentence of shape `(bsz, src_len)`. Padding will appear on the left if *left_pad_source* is ``True``. - `src_lengths` (LongTensor): 1D Tensor of the unpadded lengths of each source sentence of shape `(bsz)` - `prev_output_tokens` (LongTensor): a padded 2D Tensor of tokens in the target sentence, shifted right by one position for teacher forcing, of shape `(bsz, tgt_len)`. This key will not be present if *input_feeding* is ``False``. Padding will appear on the left if *left_pad_target* is ``True``. - `src_lang_id` (LongTensor): a long Tensor which contains source language IDs of each sample in the batch - `target` (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape `(bsz, tgt_len)`. Padding will appear on the left if *left_pad_target* is ``True``. - `tgt_lang_id` (LongTensor): a long Tensor which contains target language IDs of each sample in the batch """ res = collate( samples, pad_idx=self.src_dict.pad(), eos_idx=self.eos, left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, pad_to_length=pad_to_length, pad_to_multiple=self.pad_to_multiple, ) if self.src_lang_id is not None or self.tgt_lang_id is not None: src_tokens = res["net_input"]["src_tokens"] bsz = src_tokens.size(0) if self.src_lang_id is not None: res["net_input"]["src_lang_id"] = ( torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens) ) if self.tgt_lang_id is not None: res["tgt_lang_id"] = ( torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens) ) return res def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" return max( self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0, ) def num_tokens_vec(self, indices): """Return the number of tokens for a set of positions defined by indices. This value is used to enforce ``--max-tokens`` during batching.""" sizes = self.src_sizes[indices] if self.tgt_sizes is not None: sizes = np.maximum(sizes, self.tgt_sizes[indices]) return sizes def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" return ( self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0, ) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: indices = np.random.permutation(len(self)).astype(np.int64) else: indices = np.arange(len(self), dtype=np.int64) if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is: # max(padded_src_len, padded_tgt_len) return indices[ np.argsort(self.bucketed_num_tokens[indices], kind="mergesort") ] @property def supports_prefetch(self): return getattr(self.src, "supports_prefetch", False) and ( getattr(self.tgt, "supports_prefetch", False) or self.tgt is None ) def prefetch(self, indices): self.src.prefetch(indices) if self.tgt is not None: self.tgt.prefetch(indices) if self.align_dataset is not None: self.align_dataset.prefetch(indices) def filter_indices_by_size(self, indices, max_sizes): """Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: indices (np.array): original array of sample indices max_sizes (int or list[int] or tuple[int]): max sample size, can be defined separately for src and tgt (then list or tuple) Returns: np.array: filtered sample array list: list of removed indices """ return data_utils.filter_paired_dataset_indices_by_size( self.src_sizes, self.tgt_sizes, indices, max_sizes, )