Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from collections import OrderedDict, defaultdict | |
| import json | |
| import os | |
| import logging | |
| from argparse import ArgumentError | |
| from fairseq import options, models | |
| from fairseq.data import ( | |
| data_utils, | |
| Dictionary, | |
| LanguagePairDataset, | |
| IndexedDataset, | |
| FairseqDataset, | |
| ) | |
| from .multitask_data_utils import ( | |
| MultitaskDatasetWrapper, | |
| MultidatasetEpochBatchIterator, | |
| ) | |
| from fairseq.tasks import LegacyFairseqTask, register_task | |
| logger = logging.getLogger(__name__) | |
| class LaserTask(LegacyFairseqTask): | |
| def add_args(parser): | |
| """Add task-specific arguments to the parser.""" | |
| parser.add_argument( | |
| "configfile", metavar="PATH", help="dataset configuration file in json" | |
| ) | |
| parser.add_argument( | |
| "--weighting-alpha", | |
| type=float, | |
| default=None, | |
| help="alpha for automatic weighting", | |
| ) | |
| parser.add_argument( | |
| "--raw-text", action="store_true", help="load raw text dataset" | |
| ) | |
| parser.add_argument( | |
| "--left-pad-source", | |
| default="True", | |
| type=str, | |
| metavar="BOOL", | |
| help="pad the source on the left (default: True)", | |
| ) | |
| parser.add_argument( | |
| "--left-pad-target", | |
| default="False", | |
| type=str, | |
| metavar="BOOL", | |
| help="pad the target on the left (default: False)", | |
| ) | |
| try: | |
| parser.add_argument( | |
| "--max-source-positions", | |
| default=1024, | |
| type=int, | |
| metavar="N", | |
| help="max number of tokens in the source sequence", | |
| ) | |
| parser.add_argument( | |
| "--max-target-positions", | |
| default=1024, | |
| type=int, | |
| metavar="N", | |
| help="max number of tokens in the target sequence", | |
| ) | |
| except ArgumentError: | |
| # this might have already been defined. Once we transition this to hydra it should be fine to add it here. | |
| pass | |
| def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks): | |
| super().__init__(args) | |
| self.config = config | |
| self.src_dictionary = src_dictionary | |
| self.tgt_dictionary = tgt_dictionary | |
| self.num_tasks = num_tasks | |
| def setup_task(cls, args, **kwargs): | |
| with open(args.configfile, "r") as f: | |
| config = json.load(f) | |
| num_tasks = max(dataset["id"] for dataset in config["train"]) + 1 | |
| args.left_pad_source = options.eval_bool(args.left_pad_source) | |
| args.left_pad_target = options.eval_bool(args.left_pad_target) | |
| src_dictionary = Dictionary.load(config["src_vocab"]) | |
| tgt_dictionary = Dictionary.load(config["tgt_vocab"]) | |
| logger.info( | |
| "| src Dictionary {} : {} types".format( | |
| config["src_vocab"], len(src_dictionary) | |
| ) | |
| ) | |
| logger.info( | |
| "| tgt Dictionary {} : {} types".format( | |
| config["tgt_vocab"], len(tgt_dictionary) | |
| ) | |
| ) | |
| return cls(args, config, src_dictionary, tgt_dictionary, num_tasks) | |
| # Experimental overriding for backtranslation | |
| def build_model(self, args): | |
| model = models.build_model(args, self) | |
| return model | |
| def dataset(self, split): | |
| if split not in self.datasets: | |
| raise KeyError("Dataset not loaded: " + split) | |
| return self.datasets[split] | |
| def load_dataset(self, split, epoch=1, **kwargs): | |
| """Load a dataset split.""" | |
| def indexed_dataset(path, dictionary): | |
| if self.args.raw_text: | |
| raise Exception("Unable to handle raw text.") | |
| dataset = IndexedDataset(path, fix_lua_indexing=True) | |
| return dataset | |
| pair_datasets = OrderedDict() | |
| if split == "valid": | |
| self.datasets[split] = pair_datasets | |
| return | |
| if split not in self.config: | |
| raise FileNotFoundError( | |
| "Dataset not found in config file: {}".format(split) | |
| ) | |
| size_by_corpus = defaultdict(int) | |
| size_sum = 0 | |
| size_sum_with_subsampling = 0 | |
| init_pair_datasets = {} | |
| for dataset_config in self.config[split]: | |
| src_path = os.path.dirname(dataset_config["src"]) | |
| corpus_name = src_path.split("/")[-2] | |
| language_pair_name = src_path.split("/")[-1] | |
| pair_datasets_key = corpus_name + "-" + language_pair_name | |
| logger.info(f"loading... {pair_datasets_key}") | |
| if "src" in dataset_config: | |
| src_dataset = indexed_dataset( | |
| dataset_config["src"], self.src_dictionary | |
| ) | |
| else: | |
| src_dataset = None | |
| if "tgt" in dataset_config: | |
| tgt_dataset = indexed_dataset( | |
| dataset_config["tgt"], self.tgt_dictionary | |
| ) | |
| else: | |
| tgt_dataset = None | |
| dataset = LanguagePairDataset( | |
| src_dataset, | |
| src_dataset.sizes, | |
| self.src_dictionary, | |
| tgt_dataset, | |
| tgt_dataset.sizes, | |
| self.tgt_dictionary, | |
| left_pad_source=self.args.left_pad_source, | |
| left_pad_target=self.args.left_pad_target, | |
| ) | |
| if pair_datasets_key in init_pair_datasets: | |
| logger.warning( | |
| f"Ignoring already added {pair_datasets_key}. " | |
| f"Consider using `sample` key in order to upsample." | |
| ) | |
| else: | |
| init_pair_datasets[pair_datasets_key] = { | |
| "dataset": dataset, | |
| "sample": dataset_config.get("sample", None), | |
| "id": dataset_config.get("id", None), | |
| "len": len(dataset), | |
| } | |
| length_sum = 0 | |
| weighted_freqs_sum = 0 | |
| freq_per_dataset = {} | |
| vmax = 0 | |
| vmin = 1 | |
| weighted_freq_per_dataset = {} | |
| if self.args.weighting_alpha: | |
| for key in init_pair_datasets: | |
| if init_pair_datasets[key]["sample"] is None: | |
| length_sum += len(init_pair_datasets[key]["dataset"]) | |
| for key in init_pair_datasets: | |
| if init_pair_datasets[key]["sample"] is None: | |
| val = float(init_pair_datasets[key]["len"]) / length_sum | |
| freq_per_dataset[key] = val | |
| weighted_freqs_sum += val ** self.args.weighting_alpha | |
| for key in freq_per_dataset: | |
| val = ( | |
| freq_per_dataset[key] ** self.args.weighting_alpha | |
| / weighted_freqs_sum | |
| ) | |
| vmin = min(vmin, val) | |
| vmax = max(vmax, val) | |
| weighted_freq_per_dataset[key] = val | |
| for pair_datasets_key in init_pair_datasets: | |
| dataset_config = init_pair_datasets[pair_datasets_key] | |
| dataset = dataset_config["dataset"] | |
| sample = dataset_config["sample"] | |
| if sample is None: | |
| sample = 1.0 | |
| if pair_datasets_key in weighted_freq_per_dataset: | |
| w = vmax / weighted_freq_per_dataset[pair_datasets_key] | |
| sample = w | |
| sample = round(sample) | |
| initial_sample = sample | |
| initial_pair_datasets_key = pair_datasets_key | |
| while sample >= 1.0: | |
| assert ( | |
| pair_datasets_key not in pair_datasets | |
| ), f"{pair_datasets_key} already in" | |
| size_sum_with_subsampling += len(dataset) | |
| pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper( | |
| dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key | |
| ) | |
| size_sum += len(dataset) | |
| sample -= 1.0 | |
| pair_datasets_key += "-up" | |
| assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}" | |
| logger.info( | |
| f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}" | |
| ) | |
| size_by_corpus[corpus_name] += len(dataset) | |
| self.datasets[split] = pair_datasets | |
| logger.info( | |
| f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}" | |
| ) | |
| def source_dictionary(self): | |
| return self.src_dictionary | |
| def target_dictionary(self): | |
| return self.tgt_dictionary | |
| def get_batch_iterator( | |
| self, | |
| dataset, | |
| max_tokens=None, | |
| max_sentences=None, | |
| max_positions=None, | |
| ignore_invalid_inputs=False, | |
| required_batch_size_multiple=1, | |
| seed=1, | |
| num_shards=1, | |
| shard_id=0, | |
| num_workers=0, | |
| epoch=1, | |
| data_buffer_size=0, | |
| disable_iterator_cache=False, | |
| ): | |
| assert isinstance(dataset, OrderedDict) | |
| assert len(dataset) | |
| assert isinstance(dataset[next(iter(dataset))], FairseqDataset) | |
| # initialize the dataset with the correct starting epoch | |
| for _, dt in dataset.items(): | |
| dt.set_epoch(epoch) | |
| indices = OrderedDict() | |
| batch_sampler = OrderedDict() | |
| with data_utils.numpy_seed(seed + epoch): | |
| for key, dt in dataset.items(): | |
| logger.info(f"\t ordered_indices {key}") | |
| indices[key] = dt.ordered_indices() | |
| # filter examples that are too large | |
| if max_positions is not None: | |
| for key, dt in dataset.items(): | |
| logger.info(f"\t filter_by_size {key}") | |
| indices[key], ignored = dt.filter_indices_by_size( | |
| indices[key], max_positions | |
| ) | |
| for key, dt in dataset.items(): | |
| logger.info(f"\t batch_by_size {key}") | |
| batch_sampler[key] = data_utils.batch_by_size( | |
| indices[key], | |
| dt.num_tokens, | |
| max_tokens=max_tokens, | |
| max_sentences=max_sentences, | |
| required_batch_size_multiple=required_batch_size_multiple, | |
| ) | |
| epoch_iter = MultidatasetEpochBatchIterator( | |
| dataset=dataset, | |
| batch_sampler=batch_sampler, | |
| seed=seed, | |
| num_shards=num_shards, | |
| shard_id=shard_id, | |
| num_workers=num_workers, | |
| epoch=epoch, | |
| ) | |
| return epoch_iter | |