File size: 9,293 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provides model subclasses with Mixture of Experts support."""

import dataclasses
from typing import Callable, Mapping, Optional, Sequence, Tuple, Union

import clu.metrics as clu_metrics
from flax import core as flax_core
from flax import linen as nn
from flax import optim
from flax import traverse_util
from flax.core import scope as flax_scope
import jax.numpy as jnp
import seqio
from t5x import decoding
from t5x import losses
from t5x import metrics as metrics_lib
from t5x import models

AveragePerStep = metrics_lib.AveragePerStep
DecodeFnCallable = models.DecodeFnCallable
FrozenVariableDict = flax_scope.FrozenVariableDict
MetricsMap = metrics_lib.MetricsMap
PyTreeDef = models.PyTreeDef
Sum = metrics_lib.Sum


@dataclasses.dataclass()
class ExpertMetrics:
  """Metrics for analyzing diversity among experts in mixture of experts models.

  Attributes:
    auxiliary_loss: Auxiliary load balancing loss.
    router_z_loss: Router z-loss. Encourages router logits to remain small in an
      effort to improve stability.
    fraction_tokens_left_behind: Fraction of tokens NOT processed by any expert.
    expert_usage: Fraction of total capacity, across all experts, used to
      process tokens. Larger expert capacities or non-uniform token routing will
      result in smaller expert usage values.
    router_confidence: How confident the router is about the tokens that it has
      routed.
  """
  auxiliary_loss: float
  router_z_loss: float

  fraction_tokens_left_behind: float
  expert_usage: float
  router_confidence: float


class MoeEncoderDecoderModel(models.EncoderDecoderModel):
  """Subclass which propagates MoE auxiliary loss and metrics."""

  def __init__(
      self,
      module: nn.Module,
      input_vocabulary: seqio.Vocabulary,
      output_vocabulary: seqio.Vocabulary,
      optimizer_def: optim.OptimizerDef,
      decode_fn: DecodeFnCallable = decoding.beam_search,
      feature_converter_cls: Optional[Callable[...,
                                               seqio.FeatureConverter]] = None,
      label_smoothing: float = 0.0,
      z_loss: float = 0.0,
      loss_normalizing_factor: Optional[float] = None,
      aux_loss_factor: float = 0.,
      router_z_loss_factor: float = 0.):
    super().__init__(
        module=module,
        input_vocabulary=input_vocabulary,
        output_vocabulary=output_vocabulary,
        optimizer_def=optimizer_def,
        decode_fn=decode_fn,
        feature_converter_cls=feature_converter_cls,
        label_smoothing=label_smoothing,
        z_loss=z_loss,
        loss_normalizing_factor=loss_normalizing_factor)
    self.aux_loss_factor = aux_loss_factor
    self.router_z_loss_factor = router_z_loss_factor

  def loss_fn(
      self, params: models.PyTreeDef, batch: Mapping[str, jnp.ndarray],
      dropout_rng: Optional[jnp.ndarray]) -> Tuple[jnp.ndarray, MetricsMap]:
    """Cross-entropy loss function with auxiliary MoE load balancing loss.

    Args:
      params: Model parameters.
      batch: Batch of training examples.
      dropout_rng: Random number generator key for dropout.

    Returns:
      - Model loss.
      - Metrics.
    """
    logits, state = self._compute_logits(
        params, batch, dropout_rng, mutable=['intermediates'])
    loss_normalizing_factor: Optional[Union[
        float, int, str, losses.SpecialLossNormalizingFactor]]
    (loss_normalizing_factor,
     weights) = losses.get_loss_normalizing_factor_and_weights(
         self._loss_normalizing_factor, batch)

    targets = batch['decoder_target_tokens']
    total_loss, z_loss, _ = losses.compute_weighted_cross_entropy(
        logits,
        targets=targets,
        weights=weights,
        label_smoothing=self._label_smoothing,
        z_loss=self._z_loss,
        loss_normalizing_factor=loss_normalizing_factor)

    # Extract and add MoE losses to total loss.
    diversity_metrics = _extract_diversity_metrics(state)
    aux_loss, router_z_loss = _expert_losses(diversity_metrics,
                                             self.aux_loss_factor,
                                             self.router_z_loss_factor)
    total_loss += aux_loss + router_z_loss

    metrics = self._compute_metrics(
        logits=logits,
        targets=targets,
        mask=weights,
        loss=total_loss,
        z_loss=z_loss)
    metrics.update(
        _expert_metrics(
            diversity_metrics,
            total_loss,
            z_loss,
            aux_loss,
            router_z_loss,
            num_tokens=targets.size))

    return total_loss, metrics


def _extract_diversity_metrics(
    state: flax_scope.FrozenVariableDict) -> Sequence[ExpertMetrics]:
  """Extract expert diversity metrics from sown state intermediates.

  Args:
    state: Model state holding sown intermediate metrics.

  Returns:
    Single diversity metrics instance per MoE layer.

  Raises:
    ValueError if unable to extract any diversity metrics from model state.
  """
  state_dict = traverse_util.flatten_dict(flax_core.unfreeze(state))
  diversity_metrics = [
      metric for path, metric in state_dict.items()
      if path[-1] == 'diversity_metrics'
  ]
  if not diversity_metrics:
    raise ValueError(
        'Unable to find any expert diversity metrics. Please check that MoE '
        'metrics and losses are correctly sown.')
  # Convert modeling library DiversityMetrics objects to local ExpertMetrics
  # objects to avoid modeling library dependencies.
  return [
      ExpertMetrics(metric.auxiliary_loss, metric.router_z_loss,
                    metric.fraction_tokens_left_behind, metric.expert_usage,
                    metric.router_confidence) for metric in diversity_metrics
  ]


def _expert_losses(diversity_metrics: Sequence[ExpertMetrics],
                   auxiliary_loss_factor: float,
                   router_z_loss_factor: float) -> Tuple[float, float]:
  """Summarizes per-layer MoE auxiliary losses.

  For auxiliary losses, we take the mean across MoE layers.

  Args:
    diversity_metrics: Per-layer mixture of expert metrics.
    auxiliary_loss_factor: Factor by which to scale auxiliary load balancing
      loss for mixture of experts models. The raw auxiliary losses will be
      summed and then scaled by this factor.
    router_z_loss_factor: Factor by which to scale router z-loss for mixture of
      experts models.

  Returns:
    - Load balancing loss.
    - Router z-loss.
  """
  aux_loss = auxiliary_loss_factor * jnp.array(
      [m.auxiliary_loss for m in diversity_metrics], dtype=jnp.float32).mean()
  router_z_loss = router_z_loss_factor * jnp.array(
      [m.router_z_loss for m in diversity_metrics], dtype=jnp.float32).mean()
  return aux_loss, router_z_loss


def _expert_metrics(diversity_metrics: Sequence[ExpertMetrics],
                    total_loss: float, z_loss: float, auxiliary_loss: float,
                    router_z_loss: float, num_tokens: int) -> MetricsMap:
  """Summarizes per-layer expert metrics for the entire model.

  The return metrics map will also contain overrides for the cross entropy loss
  metrics to account for the MoE losses.

  Args:
    diversity_metrics: Per-layer mixture of expert metrics.
    total_loss: Total model loss.
    z_loss: Output logits z-loss (not MoE specific).
    auxiliary_loss: Auxiliary load balancing loss for MoE models.
    router_z_loss: Router z-loss for MoE models.
    num_tokens: Total number of target tokens.

  Returns:
    Expert diversity metrics.
  """
  cross_ent_loss = total_loss - z_loss - auxiliary_loss - router_z_loss
  return {
      'experts/auxiliary_loss':
          AveragePerStep.from_model_output(auxiliary_loss),
      'experts/router_z_loss':
          AveragePerStep.from_model_output(router_z_loss),
      'experts/fraction_tokens_left_behind':
          AveragePerStep.from_model_output(
              jnp.array(
                  [m.fraction_tokens_left_behind for m in diversity_metrics],
                  dtype=jnp.float32).mean()),
      'experts/expert_usage':
          AveragePerStep.from_model_output(
              jnp.array([m.expert_usage for m in diversity_metrics],
                        dtype=jnp.float32).mean()),
      'experts/router_confidence':
          AveragePerStep.from_model_output(
              jnp.array([m.router_confidence for m in diversity_metrics],
                        dtype=jnp.float32).mean()),
      # Override vanilla T5 cross entropy loss metrics with corrected loss that
      # accounts for MoE losses.
      'cross_ent_loss':
          metrics_lib.AveragePerStep(total=cross_ent_loss),
      'cross_ent_loss_per_all_target_tokens':
          clu_metrics.Average(total=jnp.sum(cross_ent_loss), count=num_tokens)
  }