# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from collections import OrderedDict, defaultdict import json import os import logging from argparse import ArgumentError from fairseq import options, models from fairseq.data import ( data_utils, Dictionary, LanguagePairDataset, IndexedDataset, FairseqDataset, ) from .multitask_data_utils import ( MultitaskDatasetWrapper, MultidatasetEpochBatchIterator, ) from fairseq.tasks import LegacyFairseqTask, register_task logger = logging.getLogger(__name__) @register_task("laser") class LaserTask(LegacyFairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" parser.add_argument( "configfile", metavar="PATH", help="dataset configuration file in json" ) parser.add_argument( "--weighting-alpha", type=float, default=None, help="alpha for automatic weighting", ) parser.add_argument( "--raw-text", action="store_true", help="load raw text dataset" ) parser.add_argument( "--left-pad-source", default="True", type=str, metavar="BOOL", help="pad the source on the left (default: True)", ) parser.add_argument( "--left-pad-target", default="False", type=str, metavar="BOOL", help="pad the target on the left (default: False)", ) try: parser.add_argument( "--max-source-positions", default=1024, type=int, metavar="N", help="max number of tokens in the source sequence", ) parser.add_argument( "--max-target-positions", default=1024, type=int, metavar="N", help="max number of tokens in the target sequence", ) except ArgumentError: # this might have already been defined. Once we transition this to hydra it should be fine to add it here. pass def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks): super().__init__(args) self.config = config self.src_dictionary = src_dictionary self.tgt_dictionary = tgt_dictionary self.num_tasks = num_tasks @classmethod def setup_task(cls, args, **kwargs): with open(args.configfile, "r") as f: config = json.load(f) num_tasks = max(dataset["id"] for dataset in config["train"]) + 1 args.left_pad_source = options.eval_bool(args.left_pad_source) args.left_pad_target = options.eval_bool(args.left_pad_target) src_dictionary = Dictionary.load(config["src_vocab"]) tgt_dictionary = Dictionary.load(config["tgt_vocab"]) logger.info( "| src Dictionary {} : {} types".format( config["src_vocab"], len(src_dictionary) ) ) logger.info( "| tgt Dictionary {} : {} types".format( config["tgt_vocab"], len(tgt_dictionary) ) ) return cls(args, config, src_dictionary, tgt_dictionary, num_tasks) # Experimental overriding for backtranslation def build_model(self, args): model = models.build_model(args, self) return model def dataset(self, split): if split not in self.datasets: raise KeyError("Dataset not loaded: " + split) return self.datasets[split] def load_dataset(self, split, epoch=1, **kwargs): """Load a dataset split.""" def indexed_dataset(path, dictionary): if self.args.raw_text: raise Exception("Unable to handle raw text.") dataset = IndexedDataset(path, fix_lua_indexing=True) return dataset pair_datasets = OrderedDict() if split == "valid": self.datasets[split] = pair_datasets return if split not in self.config: raise FileNotFoundError( "Dataset not found in config file: {}".format(split) ) size_by_corpus = defaultdict(int) size_sum = 0 size_sum_with_subsampling = 0 init_pair_datasets = {} for dataset_config in self.config[split]: src_path = os.path.dirname(dataset_config["src"]) corpus_name = src_path.split("/")[-2] language_pair_name = src_path.split("/")[-1] pair_datasets_key = corpus_name + "-" + language_pair_name logger.info(f"loading... {pair_datasets_key}") if "src" in dataset_config: src_dataset = indexed_dataset( dataset_config["src"], self.src_dictionary ) else: src_dataset = None if "tgt" in dataset_config: tgt_dataset = indexed_dataset( dataset_config["tgt"], self.tgt_dictionary ) else: tgt_dataset = None dataset = LanguagePairDataset( src_dataset, src_dataset.sizes, self.src_dictionary, tgt_dataset, tgt_dataset.sizes, self.tgt_dictionary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) if pair_datasets_key in init_pair_datasets: logger.warning( f"Ignoring already added {pair_datasets_key}. " f"Consider using `sample` key in order to upsample." ) else: init_pair_datasets[pair_datasets_key] = { "dataset": dataset, "sample": dataset_config.get("sample", None), "id": dataset_config.get("id", None), "len": len(dataset), } length_sum = 0 weighted_freqs_sum = 0 freq_per_dataset = {} vmax = 0 vmin = 1 weighted_freq_per_dataset = {} if self.args.weighting_alpha: for key in init_pair_datasets: if init_pair_datasets[key]["sample"] is None: length_sum += len(init_pair_datasets[key]["dataset"]) for key in init_pair_datasets: if init_pair_datasets[key]["sample"] is None: val = float(init_pair_datasets[key]["len"]) / length_sum freq_per_dataset[key] = val weighted_freqs_sum += val ** self.args.weighting_alpha for key in freq_per_dataset: val = ( freq_per_dataset[key] ** self.args.weighting_alpha / weighted_freqs_sum ) vmin = min(vmin, val) vmax = max(vmax, val) weighted_freq_per_dataset[key] = val for pair_datasets_key in init_pair_datasets: dataset_config = init_pair_datasets[pair_datasets_key] dataset = dataset_config["dataset"] sample = dataset_config["sample"] if sample is None: sample = 1.0 if pair_datasets_key in weighted_freq_per_dataset: w = vmax / weighted_freq_per_dataset[pair_datasets_key] sample = w sample = round(sample) initial_sample = sample initial_pair_datasets_key = pair_datasets_key while sample >= 1.0: assert ( pair_datasets_key not in pair_datasets ), f"{pair_datasets_key} already in" size_sum_with_subsampling += len(dataset) pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper( dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key ) size_sum += len(dataset) sample -= 1.0 pair_datasets_key += "-up" assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}" logger.info( f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}" ) size_by_corpus[corpus_name] += len(dataset) self.datasets[split] = pair_datasets logger.info( f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}" ) @property def source_dictionary(self): return self.src_dictionary @property def target_dictionary(self): return self.tgt_dictionary def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, data_buffer_size=0, disable_iterator_cache=False, ): assert isinstance(dataset, OrderedDict) assert len(dataset) assert isinstance(dataset[next(iter(dataset))], FairseqDataset) # initialize the dataset with the correct starting epoch for _, dt in dataset.items(): dt.set_epoch(epoch) indices = OrderedDict() batch_sampler = OrderedDict() with data_utils.numpy_seed(seed + epoch): for key, dt in dataset.items(): logger.info(f"\t ordered_indices {key}") indices[key] = dt.ordered_indices() # filter examples that are too large if max_positions is not None: for key, dt in dataset.items(): logger.info(f"\t filter_by_size {key}") indices[key], ignored = dt.filter_indices_by_size( indices[key], max_positions ) for key, dt in dataset.items(): logger.info(f"\t batch_by_size {key}") batch_sampler[key] = data_utils.batch_by_size( indices[key], dt.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, ) epoch_iter = MultidatasetEpochBatchIterator( dataset=dataset, batch_sampler=batch_sampler, seed=seed, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, epoch=epoch, ) return epoch_iter