File size: 6,454 Bytes
8437114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import logging
import os
from collections import OrderedDict

import numpy as np
from fairseq import tokenizer, utils
from fairseq.data import ConcatDataset, Dictionary, TokenBlockDataset, data_utils
from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset
from fairseq.tasks import LegacyFairseqTask, register_task


logger = logging.getLogger(__name__)


@register_task("cross_lingual_lm")
class CrossLingualLMTask(LegacyFairseqTask):
    """
    Task for training cross-lingual language models.

    For more details look at: https://arxiv.org/pdf/1901.07291.pdf

    Args:
        dictionary (Dictionary): the dictionary for the input of the task
    """

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        parser.add_argument(
            "data",
            help="colon separated path to data directories list, \
                            will be iterated upon during epochs in round-robin manner",
        )
        parser.add_argument(
            "--tokens-per-sample",
            default=512,
            type=int,
            help="max number of total tokens over all segments" " per sample",
        )
        parser.add_argument(
            "--monolingual-langs",
            default="en",
            type=str,
            help="comma separated list of languages for which we"
            " want to train XLM on",
        )
        parser.add_argument(
            "--shuffle",
            action="store_true",
            help="shuffle each monolingual dataset while" " training",
        )

    def __init__(self, args, dictionary):
        super().__init__(args)
        self.dictionary = dictionary
        self.seed = args.seed
        self.distributed_world_size = args.distributed_world_size
        self.langs2id = self._lang_to_id(args.monolingual_langs)

    def _lang_to_id(self, languages: str):
        """
        Build a map from languages to ids. These ids are used as segment labels
        for cross-lingual LM training.
        """
        lang2id = {}
        langs = [l.strip() for l in languages.split(",")]
        for id, lang in enumerate(langs):
            lang2id[lang] = id
        return lang2id

    @classmethod
    def load_dictionary(cls, filename):
        return MaskedLMDictionary.load(filename)

    @classmethod
    def build_dictionary(
        cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8
    ):
        d = MaskedLMDictionary()
        for filename in filenames:
            Dictionary.add_file_to_dictionary(
                filename, d, tokenizer.tokenize_line, workers
            )
        d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
        return d

    @property
    def target_dictionary(self):
        return self.dictionary

    @classmethod
    def setup_task(cls, args, **kwargs):
        """Setup the task."""
        dictionary = MaskedLMDictionary.load(os.path.join(args.data, "dict.txt"))
        logger.info("dictionary: {} types".format(len(dictionary)))
        return cls(args, dictionary)

    def _load_single_lang_dataset(self, split, epoch):
        loaded_datasets = []

        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else "")
            path = os.path.join(data_path, split_k)

            ds = data_utils.load_indexed_dataset(
                path, self.dictionary, self.args.dataset_impl
            )
            if ds is None:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        "Dataset not found: {} ({})".format(split, data_path)
                    )

            # Since we append each block with the classification_token,
            # we need to effectively create blocks of length
            # tokens_per_sample-1
            loaded_datasets.append(
                TokenBlockDataset(
                    ds,
                    ds.sizes,
                    self.args.tokens_per_sample - 1,
                    pad=self.dictionary.pad(),
                    eos=self.dictionary.eos(),
                )
            )

            logger.info(
                "{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1]))
            )

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        return dataset, sizes

    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        dataset_map = OrderedDict()

        for lang in self.langs2id.keys():
            # Datasets are expected to be in "split.lang" format (Eg: train.en)
            language_split = "{}.{}".format(split, lang)

            block_dataset, sizes = self._load_single_lang_dataset(
                split=language_split, epoch=epoch
            )

            dataset_map[lang] = MaskedLMDataset(
                dataset=block_dataset,
                sizes=sizes,
                vocab=self.dictionary,
                pad_idx=self.dictionary.pad(),
                mask_idx=self.dictionary.mask(),
                classif_token_idx=self.dictionary.eos(),
                sep_token_idx=self.dictionary.eos(),
                shuffle=getattr(self.args, "shuffle", False),
                has_pairs=False,
                segment_id=self.langs2id[lang],
                seed=self.seed,
            )

        self.datasets[split] = MultiCorpusSampledDataset(dataset_map)
        logger.info(
            "{} {} {} examples".format(
                utils.split_paths(self.args.data)[epoch - 1],
                split,
                len(self.datasets[split]),
            )
        )