File size: 9,484 Bytes
ee21b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
import torch
from omegaconf import II

from fairseq import metrics, utils
from fairseq.dataclass import ChoiceEnum
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationConfig, TranslationTask

from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork


METHOD_CHOICES = ChoiceEnum(["sMoElp", "sMoEup", "hMoElp", "hMoEup"])


@dataclass
class TranslationMoEConfig(TranslationConfig):
    method: METHOD_CHOICES = field(
        default="hMoEup",
        metadata={"help": "MoE method"},
    )
    num_experts: int = field(
        default=3,
        metadata={"help": "number of experts"},
    )
    mean_pool_gating_network: bool = field(
        default=False,
        metadata={"help": "use a simple mean-pooling gating network"},
    )
    mean_pool_gating_network_dropout: float = field(
        default=0,
        metadata={"help": "dropout for mean-pooling gating network"},
    )
    mean_pool_gating_network_encoder_dim: int = field(
        default=0,
        metadata={"help": "encoder output dim for mean-pooling gating network"},
    )
    gen_expert: int = field(
        default=0,
        metadata={"help": "which expert to use for generation"},
    )
    sentence_avg: bool = II("optimization.sentence_avg")


@register_task("translation_moe", dataclass=TranslationMoEConfig)
class TranslationMoETask(TranslationTask):
    """
    Translation task for Mixture of Experts (MoE) models.

    See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
    (Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.

    Args:
        src_dict (~fairseq.data.Dictionary): dictionary for the source language
        tgt_dict (~fairseq.data.Dictionary): dictionary for the target language

    .. note::

        The translation task is compatible with :mod:`fairseq-train`,
        :mod:`fairseq-generate` and :mod:`fairseq-interactive`.

    The translation task provides the following additional command-line
    arguments:

    .. argparse::
        :ref: fairseq.tasks.translation_parser
        :prog:
    """

    cfg: TranslationMoEConfig

    def __init__(self, cfg: TranslationMoEConfig, src_dict, tgt_dict):
        if cfg.method == "sMoElp":
            # soft MoE with learned prior
            self.uniform_prior = False
            self.hard_selection = False
        elif cfg.method == "sMoEup":
            # soft MoE with uniform prior
            self.uniform_prior = True
            self.hard_selection = False
        elif cfg.method == "hMoElp":
            # hard MoE with learned prior
            self.uniform_prior = False
            self.hard_selection = True
        elif cfg.method == "hMoEup":
            # hard MoE with uniform prior
            self.uniform_prior = True
            self.hard_selection = True

        # add indicator tokens for each expert
        for i in range(cfg.num_experts):
            # add to both dictionaries in case we're sharing embeddings
            src_dict.add_symbol("<expert_{}>".format(i))
            tgt_dict.add_symbol("<expert_{}>".format(i))

        super().__init__(cfg, src_dict, tgt_dict)

    def build_model(self, cfg):
        from fairseq import models

        model = models.build_model(cfg, self)
        if not self.uniform_prior and not hasattr(model, "gating_network"):
            if self.cfg.mean_pool_gating_network:
                if self.cfg.mean_pool_gating_network_encoder_dim > 0:
                    encoder_dim = self.cfg.mean_pool_gating_network_encoder_dim
                elif getattr(cfg, "encoder_embed_dim", None):
                    # assume that encoder_embed_dim is the encoder's output dimension
                    encoder_dim = cfg.encoder_embed_dim
                else:
                    raise ValueError(
                        "Must specify --mean-pool-gating-network-encoder-dim"
                    )

                if self.cfg.mean_pool_gating_network_dropout > 0:
                    dropout = self.cfg.mean_pool_gating_network_dropout
                elif getattr(cfg, "dropout", None):
                    dropout = cfg.dropout
                else:
                    raise ValueError("Must specify task.mean_pool_gating_network_dropout")

                model.gating_network = MeanPoolGatingNetwork(
                    encoder_dim,
                    self.cfg.num_experts,
                    dropout,
                )
            else:
                raise ValueError(
                    "translation_moe task with learned prior requires the model to "
                    "have a gating network; try using --mean-pool-gating-network"
                )
        return model

    def expert_index(self, i):
        return i + self.tgt_dict.index("<expert_0>")

    def _get_loss(self, sample, model, criterion):
        assert hasattr(
            criterion, "compute_loss"
        ), "translation_moe task requires the criterion to implement the compute_loss() method"

        k = self.cfg.num_experts
        bsz = sample["target"].size(0)

        def get_lprob_y(encoder_out, prev_output_tokens_k):
            net_output = model.decoder(
                prev_output_tokens=prev_output_tokens_k,
                encoder_out=encoder_out,
            )
            loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False)
            loss = loss.view(bsz, -1)
            return -loss.sum(dim=1, keepdim=True)  # -> B x 1

        def get_lprob_yz(winners=None):
            encoder_out = model.encoder(
                src_tokens=sample["net_input"]["src_tokens"],
                src_lengths=sample["net_input"]["src_lengths"],
            )

            if winners is None:
                lprob_y = []
                for i in range(k):
                    prev_output_tokens_k = sample["net_input"][
                        "prev_output_tokens"
                    ].clone()
                    assert not prev_output_tokens_k.requires_grad
                    prev_output_tokens_k[:, 0] = self.expert_index(i)
                    lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k))
                lprob_y = torch.cat(lprob_y, dim=1)  # -> B x K
            else:
                prev_output_tokens_k = sample["net_input"]["prev_output_tokens"].clone()
                prev_output_tokens_k[:, 0] = self.expert_index(winners)
                lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k)  # -> B

            if self.uniform_prior:
                lprob_yz = lprob_y
            else:
                lprob_z = model.gating_network(encoder_out)  # B x K
                if winners is not None:
                    lprob_z = lprob_z.gather(dim=1, index=winners.unsqueeze(-1))
                lprob_yz = lprob_y + lprob_z.type_as(lprob_y)  # B x K

            return lprob_yz

        # compute responsibilities without dropout
        with utils.model_eval(model):  # disable dropout
            with torch.no_grad():  # disable autograd
                lprob_yz = get_lprob_yz()  # B x K
                prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1)
        assert not prob_z_xy.requires_grad

        # compute loss with dropout
        if self.hard_selection:
            winners = prob_z_xy.max(dim=1)[1]
            loss = -get_lprob_yz(winners)
        else:
            lprob_yz = get_lprob_yz()  # B x K
            loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)

        loss = loss.sum()
        sample_size = (
            sample["target"].size(0) if self.cfg.sentence_avg else sample["ntokens"]
        )
        logging_output = {
            "loss": utils.item(loss.data),
            "ntokens": sample["ntokens"],
            "nsentences": bsz,
            "sample_size": sample_size,
            "posterior": prob_z_xy.float().sum(dim=0).cpu(),
        }
        return loss, sample_size, logging_output

    def train_step(
        self, sample, model, criterion, optimizer, update_num, ignore_grad=False
    ):
        model.train()
        loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
        if ignore_grad:
            loss *= 0
        optimizer.backward(loss)
        return loss, sample_size, logging_output

    def valid_step(self, sample, model, criterion):
        model.eval()
        with torch.no_grad():
            loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
        return loss, sample_size, logging_output

    def inference_step(
        self,
        generator,
        models,
        sample,
        prefix_tokens=None,
        expert=None,
        constraints=None,
    ):
        expert = expert or self.cfg.gen_expert
        with torch.no_grad():
            return generator.generate(
                models,
                sample,
                prefix_tokens=prefix_tokens,
                constraints=constraints,
                bos_token=self.expert_index(expert),
            )

    def reduce_metrics(self, logging_outputs, criterion):
        super().reduce_metrics(logging_outputs, criterion)
        metrics.log_scalar(
            "posterior",
            sum(log["posterior"] for log in logging_outputs if "posterior" in log),
        )