Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import datetime | |
| import logging | |
| import time | |
| import torch | |
| from fairseq.data import ( | |
| FairseqDataset, | |
| LanguagePairDataset, | |
| ListDataset, | |
| data_utils, | |
| iterators, | |
| ) | |
| from fairseq.data.multilingual.multilingual_data_manager import ( | |
| MultilingualDatasetManager, | |
| ) | |
| from fairseq.data.multilingual.sampling_method import SamplingMethod | |
| from fairseq.tasks import LegacyFairseqTask, register_task | |
| from fairseq.utils import FileContentsAction | |
| ### | |
| def get_time_gap(s, e): | |
| return ( | |
| datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s) | |
| ).__str__() | |
| ### | |
| logger = logging.getLogger(__name__) | |
| class TranslationMultiSimpleEpochTask(LegacyFairseqTask): | |
| """ | |
| Translate from one (source) language to another (target) language. | |
| Args: | |
| langs (List[str]): a list of languages that are being supported | |
| dicts (Dict[str, fairseq.data.Dictionary]): mapping from supported languages to their dictionaries | |
| training (bool): whether the task should be configured for training or not | |
| .. note:: | |
| The translation task is compatible with :mod:`fairseq-train`, | |
| :mod:`fairseq-generate` and :mod:`fairseq-interactive`. | |
| The translation task provides the following additional command-line | |
| arguments: | |
| .. argparse:: | |
| :ref: fairseq.tasks.translation_parser | |
| :prog: | |
| """ | |
| def add_args(parser): | |
| """Add task-specific arguments to the parser.""" | |
| # fmt: off | |
| parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', | |
| help='inference source language') | |
| parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', | |
| help='inference target language') | |
| parser.add_argument('--lang-pairs', default=None, metavar='PAIRS', | |
| help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr', | |
| action=FileContentsAction) | |
| parser.add_argument('--keep-inference-langtok', action='store_true', | |
| help='keep language tokens in inference output (e.g. for analysis or debugging)') | |
| SamplingMethod.add_arguments(parser) | |
| MultilingualDatasetManager.add_args(parser) | |
| # fmt: on | |
| def __init__(self, args, langs, dicts, training): | |
| super().__init__(args) | |
| self.langs = langs | |
| self.dicts = dicts | |
| self.training = training | |
| if training: | |
| self.lang_pairs = args.lang_pairs | |
| else: | |
| self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)] | |
| # eval_lang_pairs for multilingual translation is usually all of the | |
| # lang_pairs. However for other multitask settings or when we want to | |
| # optimize for certain languages we want to use a different subset. Thus | |
| # the eval_lang_pairs class variable is provided for classes that extend | |
| # this class. | |
| self.eval_lang_pairs = self.lang_pairs | |
| # model_lang_pairs will be used to build encoder-decoder model pairs in | |
| # models.build_model(). This allows multitask type of sub-class can | |
| # build models other than the input lang_pairs | |
| self.model_lang_pairs = self.lang_pairs | |
| self.source_langs = [d.split("-")[0] for d in self.lang_pairs] | |
| self.target_langs = [d.split("-")[1] for d in self.lang_pairs] | |
| self.check_dicts(self.dicts, self.source_langs, self.target_langs) | |
| self.sampling_method = SamplingMethod.build_sampler(args, self) | |
| self.data_manager = MultilingualDatasetManager.setup_data_manager( | |
| args, self.lang_pairs, langs, dicts, self.sampling_method | |
| ) | |
| def check_dicts(self, dicts, source_langs, target_langs): | |
| if self.args.source_dict is not None or self.args.target_dict is not None: | |
| # no need to check whether the source side and target side are sharing dictionaries | |
| return | |
| src_dict = dicts[source_langs[0]] | |
| tgt_dict = dicts[target_langs[0]] | |
| for src_lang in source_langs: | |
| assert ( | |
| src_dict == dicts[src_lang] | |
| ), "Diffrent dictionary are specified for different source languages; " | |
| "TranslationMultiSimpleEpochTask only supports one shared dictionary across all source languages" | |
| for tgt_lang in target_langs: | |
| assert ( | |
| tgt_dict == dicts[tgt_lang] | |
| ), "Diffrent dictionary are specified for different target languages; " | |
| "TranslationMultiSimpleEpochTask only supports one shared dictionary across all target languages" | |
| def setup_task(cls, args, **kwargs): | |
| langs, dicts, training = MultilingualDatasetManager.prepare( | |
| cls.load_dictionary, args, **kwargs | |
| ) | |
| return cls(args, langs, dicts, training) | |
| def has_sharded_data(self, split): | |
| return self.data_manager.has_sharded_data(split) | |
| def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
| """Load a given dataset split. | |
| Args: | |
| split (str): name of the split (e.g., train, valid, test) | |
| """ | |
| if split in self.datasets: | |
| dataset = self.datasets[split] | |
| if self.has_sharded_data(split): | |
| if self.args.virtual_epoch_size is not None: | |
| if dataset.load_next_shard: | |
| shard_epoch = dataset.shard_epoch | |
| else: | |
| # no need to load next shard so skip loading | |
| # also this avoid always loading from beginning of the data | |
| return | |
| else: | |
| shard_epoch = epoch | |
| else: | |
| # estimate the shard epoch from virtual data size and virtual epoch size | |
| shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch) | |
| logger.info(f"loading data for {split} epoch={epoch}/{shard_epoch}") | |
| logger.info(f"mem usage: {data_utils.get_mem_usage()}") | |
| if split in self.datasets: | |
| del self.datasets[split] | |
| logger.info("old dataset deleted manually") | |
| logger.info(f"mem usage: {data_utils.get_mem_usage()}") | |
| self.datasets[split] = self.data_manager.load_dataset( | |
| split, | |
| self.training, | |
| epoch=epoch, | |
| combine=combine, | |
| shard_epoch=shard_epoch, | |
| **kwargs, | |
| ) | |
| def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): | |
| if constraints is not None: | |
| raise NotImplementedError( | |
| "Constrained decoding with the multilingual_translation task is not supported" | |
| ) | |
| src_data = ListDataset(src_tokens, src_lengths) | |
| dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary) | |
| src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"] | |
| if self.args.lang_tok_replacing_bos_eos: | |
| dataset = self.data_manager.alter_dataset_langtok( | |
| dataset, | |
| src_eos=self.source_dictionary.eos(), | |
| src_lang=self.args.source_lang, | |
| tgt_eos=self.target_dictionary.eos(), | |
| tgt_lang=self.args.target_lang, | |
| src_langtok_spec=src_langtok_spec, | |
| tgt_langtok_spec=tgt_langtok_spec, | |
| ) | |
| else: | |
| dataset.src = self.data_manager.src_dataset_tranform_func( | |
| self.args.source_lang, | |
| self.args.target_lang, | |
| dataset=dataset.src, | |
| spec=src_langtok_spec, | |
| ) | |
| return dataset | |
| def build_generator( | |
| self, | |
| models, | |
| args, | |
| seq_gen_cls=None, | |
| extra_gen_cls_kwargs=None, | |
| ): | |
| if not getattr(args, "keep_inference_langtok", False): | |
| _, tgt_langtok_spec = self.args.langtoks["main"] | |
| if tgt_langtok_spec: | |
| tgt_lang_tok = self.data_manager.get_decoder_langtok( | |
| self.args.target_lang, tgt_langtok_spec | |
| ) | |
| extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} | |
| extra_gen_cls_kwargs["symbols_to_strip_from_output"] = {tgt_lang_tok} | |
| return super().build_generator( | |
| models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs | |
| ) | |
| def build_model(self, args): | |
| return super().build_model(args) | |
| def valid_step(self, sample, model, criterion): | |
| loss, sample_size, logging_output = super().valid_step(sample, model, criterion) | |
| return loss, sample_size, logging_output | |
| def inference_step( | |
| self, generator, models, sample, prefix_tokens=None, constraints=None | |
| ): | |
| with torch.no_grad(): | |
| _, tgt_langtok_spec = self.args.langtoks["main"] | |
| if not self.args.lang_tok_replacing_bos_eos: | |
| if prefix_tokens is None and tgt_langtok_spec: | |
| tgt_lang_tok = self.data_manager.get_decoder_langtok( | |
| self.args.target_lang, tgt_langtok_spec | |
| ) | |
| src_tokens = sample["net_input"]["src_tokens"] | |
| bsz = src_tokens.size(0) | |
| prefix_tokens = ( | |
| torch.LongTensor([[tgt_lang_tok]]).expand(bsz, 1).to(src_tokens) | |
| ) | |
| return generator.generate( | |
| models, | |
| sample, | |
| prefix_tokens=prefix_tokens, | |
| constraints=constraints, | |
| ) | |
| else: | |
| return generator.generate( | |
| models, | |
| sample, | |
| prefix_tokens=prefix_tokens, | |
| bos_token=self.data_manager.get_decoder_langtok( | |
| self.args.target_lang, tgt_langtok_spec | |
| ) | |
| if tgt_langtok_spec | |
| else self.target_dictionary.eos(), | |
| ) | |
| def reduce_metrics(self, logging_outputs, criterion): | |
| super().reduce_metrics(logging_outputs, criterion) | |
| def max_positions(self): | |
| """Return the max sentence length allowed by the task.""" | |
| return (self.args.max_source_positions, self.args.max_target_positions) | |
| def source_dictionary(self): | |
| return self.data_manager.get_source_dictionary(self.source_langs[0]) | |
| def target_dictionary(self): | |
| return self.data_manager.get_target_dictionary(self.target_langs[0]) | |
| def create_batch_sampler_func( | |
| self, | |
| max_positions, | |
| ignore_invalid_inputs, | |
| max_tokens, | |
| max_sentences, | |
| required_batch_size_multiple=1, | |
| seed=1, | |
| ): | |
| def construct_batch_sampler(dataset, epoch): | |
| splits = [ | |
| s for s, _ in self.datasets.items() if self.datasets[s] == dataset | |
| ] | |
| split = splits[0] if len(splits) > 0 else None | |
| # NEW implementation | |
| if epoch is not None: | |
| # initialize the dataset with the correct starting epoch | |
| dataset.set_epoch(epoch) | |
| # get indices ordered by example size | |
| start_time = time.time() | |
| logger.info(f"start batch sampler: mem usage: {data_utils.get_mem_usage()}") | |
| with data_utils.numpy_seed(seed): | |
| indices = dataset.ordered_indices() | |
| logger.info( | |
| f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}" | |
| ) | |
| logger.info(f"mem usage: {data_utils.get_mem_usage()}") | |
| # filter examples that are too large | |
| if max_positions is not None: | |
| my_time = time.time() | |
| indices = self.filter_indices_by_size( | |
| indices, dataset, max_positions, ignore_invalid_inputs | |
| ) | |
| logger.info( | |
| f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}" | |
| ) | |
| logger.info(f"mem usage: {data_utils.get_mem_usage()}") | |
| # create mini-batches with given size constraints | |
| my_time = time.time() | |
| batch_sampler = dataset.batch_by_size( | |
| indices, | |
| max_tokens=max_tokens, | |
| max_sentences=max_sentences, | |
| required_batch_size_multiple=required_batch_size_multiple, | |
| ) | |
| logger.info( | |
| f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}" | |
| ) | |
| logger.info( | |
| f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}" | |
| ) | |
| logger.info(f"mem usage: {data_utils.get_mem_usage()}") | |
| return batch_sampler | |
| return construct_batch_sampler | |
| # we need to override get_batch_iterator because we want to reset the epoch iterator each time | |
| def get_batch_iterator( | |
| self, | |
| dataset, | |
| max_tokens=None, | |
| max_sentences=None, | |
| max_positions=None, | |
| ignore_invalid_inputs=False, | |
| required_batch_size_multiple=1, | |
| seed=1, | |
| num_shards=1, | |
| shard_id=0, | |
| num_workers=0, | |
| epoch=1, | |
| data_buffer_size=0, | |
| disable_iterator_cache=False, | |
| ): | |
| """ | |
| Get an iterator that yields batches of data from the given dataset. | |
| Args: | |
| dataset (~fairseq.data.FairseqDataset): dataset to batch | |
| max_tokens (int, optional): max number of tokens in each batch | |
| (default: None). | |
| max_sentences (int, optional): max number of sentences in each | |
| batch (default: None). | |
| max_positions (optional): max sentence length supported by the | |
| model (default: None). | |
| ignore_invalid_inputs (bool, optional): don't raise Exception for | |
| sentences that are too long (default: False). | |
| required_batch_size_multiple (int, optional): require batch size to | |
| be a multiple of N (default: 1). | |
| seed (int, optional): seed for random number generator for | |
| reproducibility (default: 1). | |
| num_shards (int, optional): shard the data iterator into N | |
| shards (default: 1). | |
| shard_id (int, optional): which shard of the data iterator to | |
| return (default: 0). | |
| num_workers (int, optional): how many subprocesses to use for data | |
| loading. 0 means the data will be loaded in the main process | |
| (default: 0). | |
| epoch (int, optional): the epoch to start the iterator from | |
| (default: 0). | |
| data_buffer_size (int, optional): number of batches to | |
| preload (default: 0). | |
| disable_iterator_cache (bool, optional): don't cache the | |
| EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) | |
| (default: False). | |
| Returns: | |
| ~fairseq.iterators.EpochBatchIterator: a batched iterator over the | |
| given dataset split | |
| """ | |
| # initialize the dataset with the correct starting epoch | |
| assert isinstance(dataset, FairseqDataset) | |
| if dataset in self.dataset_to_epoch_iter: | |
| return self.dataset_to_epoch_iter[dataset] | |
| if self.args.sampling_method == "RoundRobin": | |
| batch_iter = super().get_batch_iterator( | |
| dataset, | |
| max_tokens=max_tokens, | |
| max_sentences=max_sentences, | |
| max_positions=max_positions, | |
| ignore_invalid_inputs=ignore_invalid_inputs, | |
| required_batch_size_multiple=required_batch_size_multiple, | |
| seed=seed, | |
| num_shards=num_shards, | |
| shard_id=shard_id, | |
| num_workers=num_workers, | |
| epoch=epoch, | |
| data_buffer_size=data_buffer_size, | |
| disable_iterator_cache=disable_iterator_cache, | |
| ) | |
| self.dataset_to_epoch_iter[dataset] = batch_iter | |
| return batch_iter | |
| construct_batch_sampler = self.create_batch_sampler_func( | |
| max_positions, | |
| ignore_invalid_inputs, | |
| max_tokens, | |
| max_sentences, | |
| required_batch_size_multiple=required_batch_size_multiple, | |
| seed=seed, | |
| ) | |
| epoch_iter = iterators.EpochBatchIterator( | |
| dataset=dataset, | |
| collate_fn=dataset.collater, | |
| batch_sampler=construct_batch_sampler, | |
| seed=seed, | |
| num_shards=num_shards, | |
| shard_id=shard_id, | |
| num_workers=num_workers, | |
| epoch=epoch, | |
| ) | |
| return epoch_iter | |