# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import os from collections import OrderedDict from fairseq import utils from fairseq.data import ( BacktranslationDataset, IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, LanguagePairDataset, NoisingDataset, RoundRobinZipDatasets, data_utils, indexed_dataset, ) from fairseq.models import FairseqMultiModel from fairseq.sequence_generator import SequenceGenerator from . import register_task from .multilingual_translation import MultilingualTranslationTask logger = logging.getLogger(__name__) def _get_bt_dataset_key(lang_pair): return "bt:" + lang_pair def _get_denoising_dataset_key(lang_pair): return "denoising:" + lang_pair # ported from UnsupervisedMT def parse_lambda_config(x): """ Parse the configuration of lambda coefficient (for scheduling). x = "3" # lambda will be a constant equal to x x = "0:1,1000:0" # lambda will start from 1 and linearly decrease # to 0 during the first 1000 iterations x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 # iterations, then will linearly increase to 1 until iteration 2000 """ split = x.split(",") if len(split) == 1: return float(x), None else: split = [s.split(os.pathsep) for s in split] assert all(len(s) == 2 for s in split) assert all(k.isdigit() for k, _ in split) assert all( int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1) ) return float(split[0][1]), [(int(k), float(v)) for k, v in split] @register_task("semisupervised_translation") class SemisupervisedTranslationTask(MultilingualTranslationTask): """A task for training multiple translation models simultaneously. We iterate round-robin over batches from multiple language pairs, ordered according to the `--lang-pairs` argument. The training loop is roughly: for i in range(len(epoch)): for lang_pair in args.lang_pairs: batch = next_batch_for_lang_pair(lang_pair) loss = criterion(model_for_lang_pair(lang_pair), batch) loss.backward() optimizer.step() In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset (e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that implements the `FairseqMultiModel` interface. During inference it is required to specify a single `--source-lang` and `--target-lang`, instead of `--lang-pairs`. """ @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" # fmt: off MultilingualTranslationTask.add_args(parser) parser.add_argument('--lambda-parallel-config', default="1.0", type=str, metavar='CONFIG', help='cross-entropy reconstruction coefficient (parallel data). ' 'use fixed weight during training if set to floating point number. ' 'use piecewise linear function over number of updates to schedule the ' 'weight with the format: w0:step0,w1:step1,...') parser.add_argument('--lambda-denoising-config', default="0.0", type=str, metavar='CONFIG', help='Cross-entropy reconstruction coefficient (denoising autoencoding)' 'use fixed weight during training if set to floating point number. ' 'use piecewise linear function over number of updates to schedule the ' 'weight with the format: w0:step0,w1:step1,...') parser.add_argument('--lambda-otf-bt-config', default="0.0", type=str, metavar='CONFIG', help='cross-entropy reconstruction coefficient (on-the-fly back-translation parallel data)' 'use fixed weight during training if set to floating point number. ' 'use piecewise linear function over number of updates to schedule the ' 'weight with the format: w0:step0,w1:step1,...') parser.add_argument('--bt-max-len-a', default=1.1, type=float, metavar='N', help='generate back-translated sequences of maximum length ax + b, where x is the ' 'source length') parser.add_argument('--bt-max-len-b', default=10.0, type=float, metavar='N', help='generate back-translated sequences of maximum length ax + b, where x is the ' 'source length') parser.add_argument('--bt-beam-size', default=1, type=int, metavar='N', help='beam size used in beam search of online back-translation') parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N', help='maximum word shuffle distance for denoising autoencoding data generation') parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N', help='word dropout probability for denoising autoencoding data generation') parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N', help='word blanking probability for denoising autoencoding data generation') # fmt: on def __init__(self, args, dicts, training): super().__init__(args, dicts, training) self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config( args.lambda_parallel_config ) self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config( args.lambda_otf_bt_config ) self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config( args.lambda_denoising_config ) if self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None: denoising_lang_pairs = [ "%s-%s" % (tgt, tgt) for tgt in {lang_pair.split("-")[1] for lang_pair in args.lang_pairs} ] self.model_lang_pairs = self.model_lang_pairs + denoising_lang_pairs self.backtranslate_datasets = {} self.backtranslators = {} @classmethod def setup_task(cls, args, **kwargs): dicts, training = MultilingualTranslationTask.prepare(args, **kwargs) return cls(args, dicts, training) def load_dataset(self, split, epoch=1, **kwargs): """Load a dataset split.""" paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] def split_exists(split, src, tgt, lang): if src is not None: filename = os.path.join( data_path, "{}.{}-{}.{}".format(split, src, tgt, lang) ) else: filename = os.path.join( data_path, "{}.{}-None.{}".format(split, src, tgt) ) return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) def load_indexed_dataset(path, dictionary): return data_utils.load_indexed_dataset( path, dictionary, self.args.dataset_impl ) # load parallel datasets src_datasets, tgt_datasets = {}, {} if ( self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train") ): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split("-") if split_exists(split, src, tgt, src): prefix = os.path.join( data_path, "{}.{}-{}.".format(split, src, tgt) ) elif split_exists(split, tgt, src, src): prefix = os.path.join( data_path, "{}.{}-{}.".format(split, tgt, src) ) else: continue src_datasets[lang_pair] = load_indexed_dataset( prefix + src, self.dicts[src] ) tgt_datasets[lang_pair] = load_indexed_dataset( prefix + tgt, self.dicts[tgt] ) logger.info( "parallel-{} {} {} examples".format( data_path, split, len(src_datasets[lang_pair]) ) ) if len(src_datasets) == 0: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path) ) # back translation datasets backtranslate_datasets = {} if ( self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None ) and split.startswith("train"): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split("-") if not split_exists(split, tgt, None, tgt): raise FileNotFoundError( "Dataset not found: backtranslation {} ({})".format( split, data_path ) ) filename = os.path.join( data_path, "{}.{}-None.{}".format(split, tgt, tgt) ) dataset = load_indexed_dataset(filename, self.dicts[tgt]) lang_pair_dataset_tgt = LanguagePairDataset( dataset, dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) lang_pair_dataset = LanguagePairDataset( dataset, dataset.sizes, src_dict=self.dicts[src], tgt=dataset, tgt_sizes=dataset.sizes, tgt_dict=self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) backtranslate_datasets[lang_pair] = BacktranslationDataset( tgt_dataset=self.alter_dataset_langtok( lang_pair_dataset_tgt, src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_lang=src, ), backtranslation_fn=self.backtranslators[lang_pair], src_dict=self.dicts[src], tgt_dict=self.dicts[tgt], output_collater=self.alter_dataset_langtok( lang_pair_dataset=lang_pair_dataset, src_eos=self.dicts[src].eos(), src_lang=src, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ).collater, ) logger.info( "backtranslate-{}: {} {} {} examples".format( tgt, data_path, split, len(backtranslate_datasets[lang_pair]), ) ) self.backtranslate_datasets[lang_pair] = backtranslate_datasets[ lang_pair ] # denoising autoencoder noising_datasets = {} if ( self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None ) and split.startswith("train"): for lang_pair in self.lang_pairs: _, tgt = lang_pair.split("-") if not split_exists(split, tgt, None, tgt): continue filename = os.path.join( data_path, "{}.{}-None.{}".format(split, tgt, tgt) ) tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt]) tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt]) noising_dataset = NoisingDataset( tgt_dataset1, self.dicts[tgt], seed=1, max_word_shuffle_distance=self.args.max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) noising_datasets[lang_pair] = self.alter_dataset_langtok( LanguagePairDataset( noising_dataset, tgt_dataset1.sizes, self.dicts[tgt], tgt_dataset2, tgt_dataset2.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ) logger.info( "denoising-{}: {} {} {} examples".format( tgt, data_path, split, len(noising_datasets[lang_pair]), ) ) def language_pair_dataset(lang_pair): src, tgt = lang_pair.split("-") src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair] return self.alter_dataset_langtok( LanguagePairDataset( src_dataset, src_dataset.sizes, self.dicts[src], tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), self.dicts[src].eos(), src, self.dicts[tgt].eos(), tgt, ) self.datasets[split] = RoundRobinZipDatasets( OrderedDict( [ (lang_pair, language_pair_dataset(lang_pair)) for lang_pair in src_datasets.keys() ] + [ (_get_bt_dataset_key(lang_pair), dataset) for lang_pair, dataset in backtranslate_datasets.items() ] + [ (_get_denoising_dataset_key(lang_pair), dataset) for lang_pair, dataset in noising_datasets.items() ] ), eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), ) def build_model(self, args): from fairseq import models model = models.build_model(args, self) if not isinstance(model, FairseqMultiModel): raise ValueError( "SemisupervisedTranslationTask requires a FairseqMultiModel architecture" ) # create SequenceGenerator for each model that has backtranslation dependency on it self.sequence_generators = {} if ( self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None ) and self.training: for lang_pair in self.lang_pairs: src, tgt = lang_pair.split("-") key = "{}-{}".format(tgt, src) self.sequence_generators[key] = SequenceGenerator( [model.models[key]], tgt_dict=self.dicts[src], beam_size=args.bt_beam_size, max_len_a=args.bt_max_len_a, max_len_b=args.bt_max_len_b, ) decoder_lang_tok_idx = self.get_decoder_langtok(src) def backtranslate_fn( sample, model=model.models[key], bos_token=decoder_lang_tok_idx, sequence_generator=self.sequence_generators[key], ): return sequence_generator.generate( [model], sample, bos_token=bos_token, ) self.backtranslators[lang_pair] = backtranslate_fn return model def train_step( self, sample, model, criterion, optimizer, update_num, ignore_grad=False ): model.train() if update_num > 0: self.update_step(update_num) agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, {} def forward_backward(model, samples, logging_output_key, weight): nonlocal agg_loss, agg_sample_size, agg_logging_output if samples is None or len(samples) == 0: return loss, sample_size, logging_output = criterion(model, samples) if ignore_grad: loss *= 0 else: loss *= weight optimizer.backward(loss) agg_loss += loss.detach().item() # TODO make summing of the sample sizes configurable agg_sample_size += sample_size for k in logging_output: agg_logging_output[k] += logging_output[k] agg_logging_output[logging_output_key] += logging_output[k] if self.lambda_parallel > 0.0: for lang_pair in self.lang_pairs: forward_backward( model.models[lang_pair], sample[lang_pair], lang_pair, self.lambda_parallel, ) if self.lambda_otf_bt > 0.0: for lang_pair in self.lang_pairs: sample_key = _get_bt_dataset_key(lang_pair) forward_backward( model.models[lang_pair], sample[sample_key], sample_key, self.lambda_otf_bt, ) if self.lambda_denoising > 0.0: for lang_pair in self.lang_pairs: _, tgt = lang_pair.split("-") sample_key = _get_denoising_dataset_key(lang_pair) forward_backward( model.models["{0}-{0}".format(tgt)], sample[sample_key], sample_key, self.lambda_denoising, ) return agg_loss, agg_sample_size, agg_logging_output def update_step(self, num_updates): def lambda_step_func(config, n_iter): """ Update a lambda value according to its schedule configuration. """ ranges = [ i for i in range(len(config) - 1) if config[i][0] <= n_iter < config[i + 1][0] ] if len(ranges) == 0: assert n_iter >= config[-1][0] return config[-1][1] assert len(ranges) == 1 i = ranges[0] x_a, y_a = config[i] x_b, y_b = config[i + 1] return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a) if self.lambda_parallel_steps is not None: self.lambda_parallel = lambda_step_func( self.lambda_parallel_steps, num_updates ) if self.lambda_denoising_steps is not None: self.lambda_denoising = lambda_step_func( self.lambda_denoising_steps, num_updates ) if self.lambda_otf_bt_steps is not None: self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates)