# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from functools import lru_cache import numpy as np import torch from fairseq.data import Dictionary, data_utils from . import BaseWrapperDataset, LRUCacheDataset class MaskTokensDataset(BaseWrapperDataset): """ A wrapper Dataset for masked language modeling. Input items are masked according to the specified masking probability. Args: dataset: Dataset to wrap. sizes: Sentence lengths vocab: Dictionary with the vocabulary and special tokens. pad_idx: Id of pad token in vocab mask_idx: Id of mask token in vocab return_masked_tokens: controls whether to return the non-masked tokens (the default) or to return a tensor with the original masked token IDs (and *pad_idx* elsewhere). The latter is useful as targets for masked LM training. seed: Seed for random number generator for reproducibility. mask_prob: probability of replacing a token with *mask_idx*. leave_unmasked_prob: probability that a masked token is unmasked. random_token_prob: probability of replacing a masked token with a random token from the vocabulary. freq_weighted_replacement: sample random replacement words based on word frequencies in the vocab. mask_whole_words: only mask whole words. This should be a byte mask over vocab indices, indicating whether it is the beginning of a word. We will extend any mask to encompass the whole word. bpe: BPE to use for whole-word masking. mask_multiple_length : repeat each mask index multiple times. Default value is 1. mask_stdev : standard deviation of masks distribution in case of multiple masking. Default value is 0. """ @classmethod def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs): """Return the source and target datasets for masked LM training.""" dataset = LRUCacheDataset(dataset) return ( LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)), LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)), ) def __init__( self, dataset: torch.utils.data.Dataset, vocab: Dictionary, pad_idx: int, mask_idx: int, return_masked_tokens: bool = False, seed: int = 1, mask_prob: float = 0.15, leave_unmasked_prob: float = 0.1, random_token_prob: float = 0.1, freq_weighted_replacement: bool = False, mask_whole_words: torch.Tensor = None, mask_multiple_length: int = 1, mask_stdev: float = 0.0, ): assert 0.0 < mask_prob < 1.0 assert 0.0 <= random_token_prob <= 1.0 assert 0.0 <= leave_unmasked_prob <= 1.0 assert random_token_prob + leave_unmasked_prob <= 1.0 assert mask_multiple_length >= 1 assert mask_stdev >= 0.0 self.dataset = dataset self.vocab = vocab self.pad_idx = pad_idx self.mask_idx = mask_idx self.return_masked_tokens = return_masked_tokens self.seed = seed self.mask_prob = mask_prob self.leave_unmasked_prob = leave_unmasked_prob self.random_token_prob = random_token_prob self.mask_whole_words = mask_whole_words self.mask_multiple_length = mask_multiple_length self.mask_stdev = mask_stdev if random_token_prob > 0.0: if freq_weighted_replacement: weights = np.array(self.vocab.count) else: weights = np.ones(len(self.vocab)) weights[: self.vocab.nspecial] = 0 self.weights = weights / weights.sum() self.epoch = 0 @property def can_reuse_epoch_itr_across_epochs(self): return True # only the noise changes, not item sizes def set_epoch(self, epoch, **unused): super().set_epoch(epoch) self.epoch = epoch def __getitem__(self, index: int): return self.__getitem_cached__(self.seed, self.epoch, index) @lru_cache(maxsize=8) def __getitem_cached__(self, seed: int, epoch: int, index: int): with data_utils.numpy_seed(self.seed, self.epoch, index): item = self.dataset[index] sz = len(item) assert ( self.mask_idx not in item ), "Dataset contains mask_idx (={}), this is not expected!".format( self.mask_idx, ) if self.mask_whole_words is not None: word_begins_mask = self.mask_whole_words.gather(0, item) word_begins_idx = word_begins_mask.nonzero().view(-1) sz = len(word_begins_idx) words = np.split(word_begins_mask, word_begins_idx)[1:] assert len(words) == sz word_lens = list(map(len, words)) # decide elements to mask mask = np.full(sz, False) num_mask = int( # add a random number for probabilistic rounding self.mask_prob * sz / float(self.mask_multiple_length) + np.random.rand() ) # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453) mask_idc = np.random.choice(sz, num_mask, replace=False) if self.mask_stdev > 0.0: lengths = np.random.normal( self.mask_multiple_length, self.mask_stdev, size=num_mask ) lengths = [max(0, int(round(x))) for x in lengths] mask_idc = np.asarray( [ mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j]) ], dtype=np.int64, ) else: mask_idc = np.concatenate( [mask_idc + i for i in range(self.mask_multiple_length)] ) mask_idc = mask_idc[mask_idc < len(mask)] try: mask[mask_idc] = True except: # something wrong print( "Assigning mask indexes {} to mask {} failed!".format( mask_idc, mask ) ) raise if self.return_masked_tokens: # exit early if we're just returning the masked tokens # (i.e., the targets for masked LM training) if self.mask_whole_words is not None: mask = np.repeat(mask, word_lens) new_item = np.full(len(mask), self.pad_idx) new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1] return torch.from_numpy(new_item) # decide unmasking and random replacement rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob if rand_or_unmask_prob > 0.0: rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) if self.random_token_prob == 0.0: unmask = rand_or_unmask rand_mask = None elif self.leave_unmasked_prob == 0.0: unmask = None rand_mask = rand_or_unmask else: unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob decision = np.random.rand(sz) < unmask_prob unmask = rand_or_unmask & decision rand_mask = rand_or_unmask & (~decision) else: unmask = rand_mask = None if unmask is not None: mask = mask ^ unmask if self.mask_whole_words is not None: mask = np.repeat(mask, word_lens) new_item = np.copy(item) new_item[mask] = self.mask_idx if rand_mask is not None: num_rand = rand_mask.sum() if num_rand > 0: if self.mask_whole_words is not None: rand_mask = np.repeat(rand_mask, word_lens) num_rand = rand_mask.sum() new_item[rand_mask] = np.random.choice( len(self.vocab), num_rand, p=self.weights, ) return torch.from_numpy(new_item)