File size: 20,375 Bytes
8437114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
from collections import OrderedDict

from fairseq import utils
from fairseq.data import (
    BacktranslationDataset,
    IndexedCachedDataset,
    IndexedDataset,
    IndexedRawTextDataset,
    LanguagePairDataset,
    NoisingDataset,
    RoundRobinZipDatasets,
    data_utils,
    indexed_dataset,
)
from fairseq.models import FairseqMultiModel
from fairseq.sequence_generator import SequenceGenerator

from . import register_task
from .multilingual_translation import MultilingualTranslationTask


logger = logging.getLogger(__name__)


def _get_bt_dataset_key(lang_pair):
    return "bt:" + lang_pair


def _get_denoising_dataset_key(lang_pair):
    return "denoising:" + lang_pair


# ported from UnsupervisedMT
def parse_lambda_config(x):
    """
    Parse the configuration of lambda coefficient (for scheduling).
    x = "3"                  # lambda will be a constant equal to x
    x = "0:1,1000:0"         # lambda will start from 1 and linearly decrease
                             # to 0 during the first 1000 iterations
    x = "0:0,1000:0,2000:1"  # lambda will be equal to 0 for the first 1000
                             # iterations, then will linearly increase to 1 until iteration 2000
    """
    split = x.split(",")
    if len(split) == 1:
        return float(x), None
    else:
        split = [s.split(os.pathsep) for s in split]
        assert all(len(s) == 2 for s in split)
        assert all(k.isdigit() for k, _ in split)
        assert all(
            int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1)
        )
        return float(split[0][1]), [(int(k), float(v)) for k, v in split]


@register_task("semisupervised_translation")
class SemisupervisedTranslationTask(MultilingualTranslationTask):
    """A task for training multiple translation models simultaneously.

    We iterate round-robin over batches from multiple language pairs, ordered
    according to the `--lang-pairs` argument.

    The training loop is roughly:

        for i in range(len(epoch)):
            for lang_pair in args.lang_pairs:
                batch = next_batch_for_lang_pair(lang_pair)
                loss = criterion(model_for_lang_pair(lang_pair), batch)
                loss.backward()
            optimizer.step()

    In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset
    (e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that
    implements the `FairseqMultiModel` interface.

    During inference it is required to specify a single `--source-lang` and
    `--target-lang`, instead of `--lang-pairs`.
    """

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        # fmt: off
        MultilingualTranslationTask.add_args(parser)
        parser.add_argument('--lambda-parallel-config', default="1.0", type=str, metavar='CONFIG',
                            help='cross-entropy reconstruction coefficient (parallel data). '
                                 'use fixed weight during training if set to floating point number. '
                                 'use piecewise linear function over number of updates to schedule the '
                                 'weight with the format: w0:step0,w1:step1,...')
        parser.add_argument('--lambda-denoising-config', default="0.0", type=str, metavar='CONFIG',
                            help='Cross-entropy reconstruction coefficient (denoising autoencoding)'
                                 'use fixed weight during training if set to floating point number. '
                                 'use piecewise linear function over number of updates to schedule the '
                                 'weight with the format: w0:step0,w1:step1,...')
        parser.add_argument('--lambda-otf-bt-config', default="0.0", type=str, metavar='CONFIG',
                            help='cross-entropy reconstruction coefficient (on-the-fly back-translation parallel data)'
                                 'use fixed weight during training if set to floating point number. '
                                 'use piecewise linear function over number of updates to schedule the '
                                 'weight with the format: w0:step0,w1:step1,...')
        parser.add_argument('--bt-max-len-a', default=1.1, type=float, metavar='N',
                            help='generate back-translated sequences of maximum length ax + b, where x is the '
                                 'source length')
        parser.add_argument('--bt-max-len-b', default=10.0, type=float, metavar='N',
                            help='generate back-translated sequences of maximum length ax + b, where x is the '
                                 'source length')
        parser.add_argument('--bt-beam-size', default=1, type=int, metavar='N',
                            help='beam size used in beam search of online back-translation')
        parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N',
                            help='maximum word shuffle distance for denoising autoencoding data generation')
        parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N',
                            help='word dropout probability for denoising autoencoding data generation')
        parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N',
                            help='word blanking probability for denoising autoencoding data generation')
        # fmt: on

    def __init__(self, args, dicts, training):
        super().__init__(args, dicts, training)
        self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config(
            args.lambda_parallel_config
        )
        self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config(
            args.lambda_otf_bt_config
        )
        self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config(
            args.lambda_denoising_config
        )
        if self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None:
            denoising_lang_pairs = [
                "%s-%s" % (tgt, tgt)
                for tgt in {lang_pair.split("-")[1] for lang_pair in args.lang_pairs}
            ]
            self.model_lang_pairs = self.model_lang_pairs + denoising_lang_pairs
        self.backtranslate_datasets = {}
        self.backtranslators = {}

    @classmethod
    def setup_task(cls, args, **kwargs):
        dicts, training = MultilingualTranslationTask.prepare(args, **kwargs)
        return cls(args, dicts, training)

    def load_dataset(self, split, epoch=1, **kwargs):
        """Load a dataset split."""
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        def split_exists(split, src, tgt, lang):
            if src is not None:
                filename = os.path.join(
                    data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)
                )
            else:
                filename = os.path.join(
                    data_path, "{}.{}-None.{}".format(split, src, tgt)
                )
            return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)

        def load_indexed_dataset(path, dictionary):
            return data_utils.load_indexed_dataset(
                path, dictionary, self.args.dataset_impl
            )

        # load parallel datasets
        src_datasets, tgt_datasets = {}, {}
        if (
            self.lambda_parallel > 0.0
            or self.lambda_parallel_steps is not None
            or not split.startswith("train")
        ):
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split("-")
                if split_exists(split, src, tgt, src):
                    prefix = os.path.join(
                        data_path, "{}.{}-{}.".format(split, src, tgt)
                    )
                elif split_exists(split, tgt, src, src):
                    prefix = os.path.join(
                        data_path, "{}.{}-{}.".format(split, tgt, src)
                    )
                else:
                    continue
                src_datasets[lang_pair] = load_indexed_dataset(
                    prefix + src, self.dicts[src]
                )
                tgt_datasets[lang_pair] = load_indexed_dataset(
                    prefix + tgt, self.dicts[tgt]
                )
                logger.info(
                    "parallel-{} {} {} examples".format(
                        data_path, split, len(src_datasets[lang_pair])
                    )
                )
            if len(src_datasets) == 0:
                raise FileNotFoundError(
                    "Dataset not found: {} ({})".format(split, data_path)
                )

        # back translation datasets
        backtranslate_datasets = {}
        if (
            self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None
        ) and split.startswith("train"):
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split("-")
                if not split_exists(split, tgt, None, tgt):
                    raise FileNotFoundError(
                        "Dataset not found: backtranslation {} ({})".format(
                            split, data_path
                        )
                    )
                filename = os.path.join(
                    data_path, "{}.{}-None.{}".format(split, tgt, tgt)
                )
                dataset = load_indexed_dataset(filename, self.dicts[tgt])
                lang_pair_dataset_tgt = LanguagePairDataset(
                    dataset,
                    dataset.sizes,
                    self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                )
                lang_pair_dataset = LanguagePairDataset(
                    dataset,
                    dataset.sizes,
                    src_dict=self.dicts[src],
                    tgt=dataset,
                    tgt_sizes=dataset.sizes,
                    tgt_dict=self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                )
                backtranslate_datasets[lang_pair] = BacktranslationDataset(
                    tgt_dataset=self.alter_dataset_langtok(
                        lang_pair_dataset_tgt,
                        src_eos=self.dicts[tgt].eos(),
                        src_lang=tgt,
                        tgt_lang=src,
                    ),
                    backtranslation_fn=self.backtranslators[lang_pair],
                    src_dict=self.dicts[src],
                    tgt_dict=self.dicts[tgt],
                    output_collater=self.alter_dataset_langtok(
                        lang_pair_dataset=lang_pair_dataset,
                        src_eos=self.dicts[src].eos(),
                        src_lang=src,
                        tgt_eos=self.dicts[tgt].eos(),
                        tgt_lang=tgt,
                    ).collater,
                )
                logger.info(
                    "backtranslate-{}: {} {} {} examples".format(
                        tgt,
                        data_path,
                        split,
                        len(backtranslate_datasets[lang_pair]),
                    )
                )
                self.backtranslate_datasets[lang_pair] = backtranslate_datasets[
                    lang_pair
                ]

        # denoising autoencoder
        noising_datasets = {}
        if (
            self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None
        ) and split.startswith("train"):
            for lang_pair in self.lang_pairs:
                _, tgt = lang_pair.split("-")
                if not split_exists(split, tgt, None, tgt):
                    continue
                filename = os.path.join(
                    data_path, "{}.{}-None.{}".format(split, tgt, tgt)
                )
                tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt])
                tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt])
                noising_dataset = NoisingDataset(
                    tgt_dataset1,
                    self.dicts[tgt],
                    seed=1,
                    max_word_shuffle_distance=self.args.max_word_shuffle_distance,
                    word_dropout_prob=self.args.word_dropout_prob,
                    word_blanking_prob=self.args.word_blanking_prob,
                )
                noising_datasets[lang_pair] = self.alter_dataset_langtok(
                    LanguagePairDataset(
                        noising_dataset,
                        tgt_dataset1.sizes,
                        self.dicts[tgt],
                        tgt_dataset2,
                        tgt_dataset2.sizes,
                        self.dicts[tgt],
                        left_pad_source=self.args.left_pad_source,
                        left_pad_target=self.args.left_pad_target,
                    ),
                    src_eos=self.dicts[tgt].eos(),
                    src_lang=tgt,
                    tgt_eos=self.dicts[tgt].eos(),
                    tgt_lang=tgt,
                )
                logger.info(
                    "denoising-{}: {} {} {} examples".format(
                        tgt,
                        data_path,
                        split,
                        len(noising_datasets[lang_pair]),
                    )
                )

        def language_pair_dataset(lang_pair):
            src, tgt = lang_pair.split("-")
            src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
            return self.alter_dataset_langtok(
                LanguagePairDataset(
                    src_dataset,
                    src_dataset.sizes,
                    self.dicts[src],
                    tgt_dataset,
                    tgt_dataset.sizes,
                    self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                ),
                self.dicts[src].eos(),
                src,
                self.dicts[tgt].eos(),
                tgt,
            )

        self.datasets[split] = RoundRobinZipDatasets(
            OrderedDict(
                [
                    (lang_pair, language_pair_dataset(lang_pair))
                    for lang_pair in src_datasets.keys()
                ]
                + [
                    (_get_bt_dataset_key(lang_pair), dataset)
                    for lang_pair, dataset in backtranslate_datasets.items()
                ]
                + [
                    (_get_denoising_dataset_key(lang_pair), dataset)
                    for lang_pair, dataset in noising_datasets.items()
                ]
            ),
            eval_key=None
            if self.training
            else "%s-%s" % (self.args.source_lang, self.args.target_lang),
        )

    def build_model(self, args):
        from fairseq import models

        model = models.build_model(args, self)
        if not isinstance(model, FairseqMultiModel):
            raise ValueError(
                "SemisupervisedTranslationTask requires a FairseqMultiModel architecture"
            )

        # create SequenceGenerator for each model that has backtranslation dependency on it
        self.sequence_generators = {}
        if (
            self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None
        ) and self.training:
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split("-")
                key = "{}-{}".format(tgt, src)
                self.sequence_generators[key] = SequenceGenerator(
                    [model.models[key]],
                    tgt_dict=self.dicts[src],
                    beam_size=args.bt_beam_size,
                    max_len_a=args.bt_max_len_a,
                    max_len_b=args.bt_max_len_b,
                )
                decoder_lang_tok_idx = self.get_decoder_langtok(src)

                def backtranslate_fn(
                    sample,
                    model=model.models[key],
                    bos_token=decoder_lang_tok_idx,
                    sequence_generator=self.sequence_generators[key],
                ):
                    return sequence_generator.generate(
                        [model],
                        sample,
                        bos_token=bos_token,
                    )

                self.backtranslators[lang_pair] = backtranslate_fn

        return model

    def train_step(
        self, sample, model, criterion, optimizer, update_num, ignore_grad=False
    ):
        model.train()

        if update_num > 0:
            self.update_step(update_num)

        agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, {}

        def forward_backward(model, samples, logging_output_key, weight):
            nonlocal agg_loss, agg_sample_size, agg_logging_output
            if samples is None or len(samples) == 0:
                return
            loss, sample_size, logging_output = criterion(model, samples)
            if ignore_grad:
                loss *= 0
            else:
                loss *= weight
            optimizer.backward(loss)
            agg_loss += loss.detach().item()
            # TODO make summing of the sample sizes configurable
            agg_sample_size += sample_size
            for k in logging_output:
                agg_logging_output[k] += logging_output[k]
                agg_logging_output[logging_output_key] += logging_output[k]

        if self.lambda_parallel > 0.0:
            for lang_pair in self.lang_pairs:
                forward_backward(
                    model.models[lang_pair],
                    sample[lang_pair],
                    lang_pair,
                    self.lambda_parallel,
                )

        if self.lambda_otf_bt > 0.0:
            for lang_pair in self.lang_pairs:
                sample_key = _get_bt_dataset_key(lang_pair)
                forward_backward(
                    model.models[lang_pair],
                    sample[sample_key],
                    sample_key,
                    self.lambda_otf_bt,
                )

        if self.lambda_denoising > 0.0:
            for lang_pair in self.lang_pairs:
                _, tgt = lang_pair.split("-")
                sample_key = _get_denoising_dataset_key(lang_pair)
                forward_backward(
                    model.models["{0}-{0}".format(tgt)],
                    sample[sample_key],
                    sample_key,
                    self.lambda_denoising,
                )

        return agg_loss, agg_sample_size, agg_logging_output

    def update_step(self, num_updates):
        def lambda_step_func(config, n_iter):
            """
            Update a lambda value according to its schedule configuration.
            """
            ranges = [
                i
                for i in range(len(config) - 1)
                if config[i][0] <= n_iter < config[i + 1][0]
            ]
            if len(ranges) == 0:
                assert n_iter >= config[-1][0]
                return config[-1][1]
            assert len(ranges) == 1
            i = ranges[0]
            x_a, y_a = config[i]
            x_b, y_b = config[i + 1]
            return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a)

        if self.lambda_parallel_steps is not None:
            self.lambda_parallel = lambda_step_func(
                self.lambda_parallel_steps, num_updates
            )
        if self.lambda_denoising_steps is not None:
            self.lambda_denoising = lambda_step_func(
                self.lambda_denoising_steps, num_updates
            )
        if self.lambda_otf_bt_steps is not None:
            self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates)