# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch from fairseq import utils from fairseq.data import LanguagePairDataset from . import register_task from .translation import TranslationTask, load_langpair_dataset @register_task("translation_from_pretrained_bart") class TranslationFromPretrainedBARTTask(TranslationTask): """ Translate from source language to target language with a model initialized with a multilingual pretrain. Args: src_dict (~fairseq.data.Dictionary): dictionary for the source language tgt_dict (~fairseq.data.Dictionary): dictionary for the target language .. note:: The translation task is compatible with :mod:`fairseq-train`, :mod:`fairseq-generate` and :mod:`fairseq-interactive`. The translation task provides the following additional command-line arguments: .. argparse:: :ref: fairseq.tasks.translation_parser :prog: """ @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" # fmt: off TranslationTask.add_args(parser) parser.add_argument('--langs', type=str, metavar='LANG', help='comma-separated list of monolingual language, ' 'for example, "en,de,fr". These should match the ' 'langs from pretraining (and be in the same order). ' 'You should always add all pretraining language idx ' 'during finetuning.') parser.add_argument('--prepend-bos', action='store_true', help='prepend bos token to each sentence, which matches ' 'mBART pretraining') # fmt: on def __init__(self, args, src_dict, tgt_dict): super().__init__(args, src_dict, tgt_dict) self.langs = args.langs.split(",") for d in [src_dict, tgt_dict]: for l in self.langs: d.add_symbol("[{}]".format(l)) d.add_symbol("") def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] # infer langcode src, tgt = self.args.source_lang, self.args.target_lang self.datasets[split] = load_langpair_dataset( data_path, split, src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=getattr(self.args, "max_source_positions", 1024), max_target_positions=getattr(self.args, "max_target_positions", 1024), load_alignments=self.args.load_alignments, prepend_bos=getattr(self.args, "prepend_bos", False), append_source_id=True, ) def build_generator(self, models, args, **unused): if getattr(args, "score_reference", False): from fairseq.sequence_scorer import SequenceScorer return SequenceScorer( self.target_dictionary, eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), ) else: from fairseq.sequence_generator import SequenceGenerator return SequenceGenerator( models, self.target_dictionary, beam_size=getattr(args, "beam", 5), max_len_a=getattr(args, "max_len_a", 0), max_len_b=getattr(args, "max_len_b", 200), min_len=getattr(args, "min_len", 1), normalize_scores=(not getattr(args, "unnormalized", False)), len_penalty=getattr(args, "lenpen", 1), unk_penalty=getattr(args, "unkpen", 0), temperature=getattr(args, "temperature", 1.0), match_source_len=getattr(args, "match_source_len", False), no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), ) def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): src_lang_id = self.source_dictionary.index("[{}]".format(self.args.source_lang)) source_tokens = [] for s_t in src_tokens: s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)]) source_tokens.append(s_t) dataset = LanguagePairDataset( source_tokens, src_lengths, self.source_dictionary, tgt_dict=self.target_dictionary, constraints=constraints, ) return dataset