# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import itertools import logging import os import numpy as np from fairseq import tokenizer, utils from fairseq.data import ConcatDataset, Dictionary, data_utils, indexed_dataset from fairseq.data.legacy.block_pair_dataset import BlockPairDataset from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset from fairseq.data.legacy.masked_lm_dictionary import BertDictionary from fairseq.tasks import LegacyFairseqTask, register_task logger = logging.getLogger(__name__) @register_task("legacy_masked_lm") class LegacyMaskedLMTask(LegacyFairseqTask): """ Task for training Masked LM (BERT) model. Args: dictionary (Dictionary): the dictionary for the input of the task """ @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" parser.add_argument( "data", help="colon separated path to data directories list, \ will be iterated upon during epochs in round-robin manner", ) parser.add_argument( "--tokens-per-sample", default=512, type=int, help="max number of total tokens over all segments" " per sample for BERT dataset", ) parser.add_argument( "--break-mode", default="doc", type=str, help="mode for breaking sentence" ) parser.add_argument("--shuffle-dataset", action="store_true", default=False) def __init__(self, args, dictionary): super().__init__(args) self.dictionary = dictionary self.seed = args.seed @classmethod def load_dictionary(cls, filename): return BertDictionary.load(filename) @classmethod def build_dictionary( cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 ): d = BertDictionary() for filename in filenames: Dictionary.add_file_to_dictionary( filename, d, tokenizer.tokenize_line, workers ) d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) return d @property def target_dictionary(self): return self.dictionary @classmethod def setup_task(cls, args, **kwargs): """Setup the task.""" paths = utils.split_paths(args.data) assert len(paths) > 0 dictionary = BertDictionary.load(os.path.join(paths[0], "dict.txt")) logger.info("dictionary: {} types".format(len(dictionary))) return cls(args, dictionary) def load_dataset(self, split, epoch=1, combine=False): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ loaded_datasets = [] paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] logger.info("data_path", data_path) for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") path = os.path.join(data_path, split_k) ds = indexed_dataset.make_dataset( path, impl=self.args.dataset_impl, fix_lua_indexing=True, dictionary=self.dictionary, ) if ds is None: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path) ) with data_utils.numpy_seed(self.seed + k): loaded_datasets.append( BlockPairDataset( ds, self.dictionary, ds.sizes, self.args.tokens_per_sample, break_mode=self.args.break_mode, doc_break_size=1, ) ) logger.info( "{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1])) ) if not combine: break if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) self.datasets[split] = MaskedLMDataset( dataset=dataset, sizes=sizes, vocab=self.dictionary, pad_idx=self.dictionary.pad(), mask_idx=self.dictionary.mask(), classif_token_idx=self.dictionary.cls(), sep_token_idx=self.dictionary.sep(), shuffle=self.args.shuffle_dataset, seed=self.seed, )