File size: 18,489 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pjit partitioner with Mixture of Experts overrides."""

from typing import Any, Callable, Optional, Sequence, Union

from absl import logging
from flax import core as flax_core
import jax
import numpy as np

from t5x import adafactor
from t5x import partitioning as t5x_partitioning
from t5x import train_state as train_state_lib

from t5x.contrib.moe import training_utils

DataLayout = t5x_partitioning.DataLayout
FlaxOptimTrainState = train_state_lib.FlaxOptimTrainState
HardwareMesh = t5x_partitioning.HardwareMesh
InferenceState = train_state_lib.InferenceState
LogicalAxisRules = t5x_partitioning.LogicalAxisRules
PartitionSpec = t5x_partitioning.PartitionSpec
Pytree = Any
TrainState = train_state_lib.TrainState


class MoePjitPartitioner(t5x_partitioning.PjitPartitioner):
  """Pjit partitioner with overrides for Mixture of Experts support.

  This MoE partitioner has two overrides relative to the default partitioner:
  (1) It prepends an 'expert' axis to all MoE optimizer state terms, so that
      they are sharded along the 'expert' axis; see get_logical_axes().
  (2) In cases where model parallelism is used and the number of experts is less
      than the number of devices, we treat the 'model' axis as a secondary data
      axis. This allows us to decouple expert parallelism ('data' mesh axis)
      from data parallelism ('data' and 'model' axes).
  """

  def __init__(self,
               num_experts: int,
               num_partitions: Optional[int] = None,
               model_parallel_submesh: Optional[HardwareMesh] = None,
               params_on_devices: bool = True,
               logical_axis_rules: Optional[LogicalAxisRules] = None,
               state_filter_fn: Optional[Callable[[str], bool]] = None):
    """Configures the partitioner.

    Args:
      num_experts: Total number of experts across all devices.
      num_partitions: Specifies the size of the model parallel submesh to be
        automatically selected for the current topology. See
        `model_parallel_submesh` for details on how this submesh is used.
        Mutually exclusive with `model_parallel_submesh`.
      model_parallel_submesh: 4-tuple that specifies the `(x, y, z, c)` submesh
        model-parallel device tile -- an axis of accelerator parallelism
        orthogonal to data parallelism. See t5x/partitioning.py for details.
        This argument is mutually exclusive with `num_partitions`.
      params_on_devices: Whether to keep the params on devices. If False, params
        stay in the host memory.
      logical_axis_rules: A priority-ordered sequence of KV tuples that maps
        logical axis names to either `None` (not sharded), 'model' (to shard
        across the model-parallel submesh), or 'data' (to shard across the
        data-parallel submesh).
      state_filter_fn: Function to identify which optimizer state axis rules
        should be overridden to be sharded along the 'expert' axis. If None
        (default), Adafactor expert sharding overrides are used.
    """
    # If True, treat 'model' axis as secondary data axis.
    self.two_data_axes = _override_model_axis(num_experts, num_partitions,
                                              model_parallel_submesh)
    if self.two_data_axes:
      # Override num_partitions to repurpose the 'model' axis as a secondary
      # data axis, along which only the batch is sharded. Experts will be
      # replicated along this secondary data axis.
      num_partitions = jax.device_count() // num_experts

      # Override user specified model parallel submesh. Rely on T5X partitioning
      # to determine new submesh from updated `num_partitions`.
      logging.info(
          'Overriding user specified `model_parallel_submesh`=%s to support '
          'expert parallelism for updated `num_partitions`=%d',
          model_parallel_submesh, num_partitions)
      model_parallel_submesh = None

    super().__init__(
        num_partitions=num_partitions,
        model_parallel_submesh=model_parallel_submesh,
        params_on_devices=params_on_devices,
        logical_axis_rules=logical_axis_rules)

    self._state_filter_fn = state_filter_fn

  def get_data_layout(self,
                      batch_size: Optional[int] = None,
                      host_index: Optional[int] = None) -> DataLayout:
    """Returns filled `DataLayout` based on the partitioned model layout.

    Overrides default data layout in case were both mesh axes ('data' and
    'model') are treated as data axes.

    Args:
      batch_size: If set, indicates the requested batch size. If not set, the
        batch size is inferred from the layout.
      host_index: Indicates the host index to use for the calculations, if not
        set - use JAX-provided one. Should be in [0, num_hosts) interval and the
        order should match the order of corresponding CPU devices in
        `jax.devices()`.

    Returns:
      Filled `DataLayout` structure.
    """
    if self.two_data_axes:
      if host_index is not None:
        raise NotImplementedError('Explicit host_index is not yet implemented.')
      mesh_size = self._local_chunker.global_mesh.shape[
          'data'] * self._local_chunker.global_mesh.shape['model']
      batch_size = batch_size or mesh_size
      if batch_size % mesh_size:
        raise ValueError(
            f'Batch size ({batch_size}) must be divisible by corresponding '
            f'mesh size ({mesh_size}).')
      num_shards = self._local_chunker.num_chunks['data']
      if batch_size % num_shards:
        raise ValueError(
            f'Batch size ({batch_size}) must be divisible by number of '
            f'replicas ({num_shards}).')
      replica_id = self._local_chunker.get_local_chunk_info(
          (batch_size,), ('data', 'model')).replica_id
      return DataLayout(
          batch_size=batch_size,
          shard_id=self._local_chunker.chunk_ids['data'],
          num_shards=num_shards,
          is_first_host_in_replica_set=(replica_id == 0))
    else:
      return super().get_data_layout(batch_size, host_index)

  def get_logical_axes(
      self, train_state: Union[FlaxOptimTrainState, InferenceState]
  ) -> Union[FlaxOptimTrainState, InferenceState]:
    """Returns a copy of TrainState with Optional[AxisNames] as leaves.

    Overrides the default logical axes by prepending the 'expert' axis to any
    MoE optimizer state terms (identified by self._state_filter_fn) so they are
    correctly sharded along the 'expert' axis.

    Args:
      train_state: Object holding all relevant training of inference state.

    Returns:
      State object matching structure of input train_state but with axis names
      as leaves.
    """
    logical_axes = train_state.as_logical_axes()

    if isinstance(logical_axes, InferenceState):
      # InferenceState does not contain any optimizer state, so we skip all
      # expert partitioning overrides.
      return logical_axes
    else:
      train_state: FlaxOptimTrainState

    state_filter_fn = (
        self._state_filter_fn or _infer_state_filter_fn(train_state))
    if state_filter_fn is None:
      # No state updates required.
      return logical_axes

    prepend_expert = lambda x: PartitionSpec(  # pylint: disable=g-long-lambda
        'expert',) + x if x else PartitionSpec('expert',)
    optimizer_axes = logical_axes._optimizer  # pylint: disable=protected-access
    state_dict = flax_core.unfreeze(optimizer_axes.state_dict())
    state_dict['state']['param_states'] = training_utils.tree_map_with_names(
        prepend_expert, state_dict['state']['param_states'], state_filter_fn)

    return train_state.restore_state(state_dict)

  def partition(
      self,
      fn: Callable,  # pylint: disable=g-bare-generic
      in_axis_resources: Pytree,
      out_axis_resources: Pytree,
      static_argnums: Union[int, Sequence[int]] = (),
      donate_argnums: Union[int, Sequence[int]] = ()
  ) -> t5x_partitioning.PjittedFnWithContext:
    """Partitions the computation using pjit.

    Overrides the default pjit partitioning in cases where expert and data axes
    are decoupled -- wherein we treat the 'model' axis as a secondary data axis.

    Args:
      fn: Function to partition.
      in_axis_resources: Pytree of structure matching that of arguments to `fn`,
        with all actual arguments replaced by resource assignment
        specifications.
      out_axis_resources: Like `in_axis_resources`, but specifies resource
        assignment for function outputs.
      static_argnums: Specifies which positional arguments to treat as static
        (compile-time constant) in the partitioned function.
      donate_argnums: Specifies which argument buffers are "donated" to the
        computation.

    Returns:
      A partitioned version of the input function.
    """
    if self.two_data_axes:
      # Both axes are used for data parallelism in this case, so we override the
      # partition specs.
      in_axis_resources = _override_partition_specs(in_axis_resources)
      out_axis_resources = _override_partition_specs(out_axis_resources)

    pjitted = t5x_partitioning.pjit(
        fn,
        in_axis_resources=in_axis_resources,
        out_axis_resources=out_axis_resources,
        static_argnums=static_argnums,
        donate_argnums=donate_argnums,
        backend=self._backend)

    return t5x_partitioning.PjittedFnWithContext(pjitted, self.mesh,
                                                 self._logical_axis_rules)


def standard_logical_axis_rules(
    num_experts: int,
    num_partitions: Optional[int] = None,
    model_parallel_submesh: Optional[HardwareMesh] = None,
    activation_partitioning_dims: int = 1,
    parameter_partitioning_dims: int = 1,
    additional_rules: Optional[LogicalAxisRules] = None):
  """Returns partitioning rules for MoE models.

  The partitioning rules vary based on whether the expert and data axes need to
  be decoupled; see also MoePjitPartitioner for details of when expert and data
  axes need to be decouple.

  2D parameter sharding (`parameter_partitioning_dims=2`) is not supported.
  Sharding parameters along the 'data' axis will interfere with expert
  parallelism, because experts are also partitioned along the 'data' axis.

  Args:
    num_experts: Total number of experts across all devices.
    num_partitions: Size of the model parallel submesh. Model parallelism is
      only used if num_model_partitions > 1. Ignored if model_parallel_submesh
      is specified.
    model_parallel_submesh: 4-tuple that specifies the `(x, y, z, c)` submesh
      model-parallel device tile -- an axis of accelerator parallelism
      orthogonal to data parallelism. Model parallelism is only used if
      np.prod(model_parallel_submesh) > 1. Mutually exclusive with
      `num_partitions`.
    activation_partitioning_dims: Enables 2-D activation sharding when set to 2.
    parameter_partitioning_dims: Enables 2-D parameter sharding when set to 2.
    additional_rules: Additional rules (a sequence of tuples) that will be
      appended to the standard rules.

  Returns:
    Sequence of logical axis rules.

  Raises:
    ValueError if parameter_partitioning_dims=2.
  """
  if parameter_partitioning_dims == 2:
    raise ValueError('2D parameter sharding (`parameter_partitioning_dims=2`) '
                     'is not supported for MoE.')

  default_rules = t5x_partitioning.standard_logical_axis_rules(
      activation_partitioning_dims, parameter_partitioning_dims)
  moe_rules = [
      ('expert', 'data'),  # Shard experts along the data axis
      ('expert_mlp', 'model'),  # Expert MLPs partitioned along model axis
      ('expert_group', None),  # Replicated axis for all-to-all constraints
      ('expert_replicas', None),  # Experts replicated along this axis
      ('unmodeled', None),  # Replicated weights
  ]
  standard_rules = list(default_rules) + moe_rules
  if additional_rules:
    standard_rules.extend(additional_rules)

  if _override_model_axis(num_experts, num_partitions, model_parallel_submesh):
    overridden_rules = []
    for logical_axis, mesh_axis in standard_rules:
      if logical_axis == 'batch':
        # Because we now treat the 'model' axis as a second data axis, we want
        # to shard batches across both axes.
        overridden_mesh_axis = ('data', 'model')
      elif logical_axis == 'expert_replicas':
        # "model" axis is repurposed as a second data axis, along which experts
        # are replicated.
        overridden_mesh_axis = 'model'
      elif mesh_axis == 'model':
        # Any weights ordinarily partitioned along the model axis, should be
        # explicitly replicated.
        overridden_mesh_axis = None
      else:
        overridden_mesh_axis = mesh_axis
      overridden_rules.append((logical_axis, overridden_mesh_axis))

    return overridden_rules

  else:
    return standard_rules


def data_partition_spec(two_data_axes: bool) -> PartitionSpec:
  """Returns data partitioning spec.

  Args:
    two_data_axes: If True, use 'model' axis as secondary data axis. Otherwise,
      only use 'data' axis for data sharding.

  Returns:
    Mesh dependent partition spec.
  """
  if two_data_axes:
    # Use 'model' axis as secondary data axis. Shard batches across both axes.
    return PartitionSpec(('data', 'model'),)
  else:
    return PartitionSpec('data',)


def _override_model_axis(
    num_experts: int, num_partitions: Optional[int],
    model_parallel_submesh: Optional[HardwareMesh]) -> bool:
  """Returns true iff there is no model parallelism & num experts < num devices.

  Args:
    num_experts: Total number of experts across all devices.
    num_partitions: Size of the model parallel submesh. Model parallelism is
      only used if num_model_partitions > 1. Mutually exclusive with
      `model_parallel_submesh`.
    model_parallel_submesh: 4-tuple that specifies the `(x, y, z, c)` submesh
      model-parallel device tile -- an axis of accelerator parallelism
      orthogonal to data parallelism. Model parallelism is only used if
      np.prod(model_parallel_submesh) > 1. Mutually exclusive with
      `num_partitions`.

  Returns:
    True if there is no model parallelism & num experts < num devices; False
    otherwise.
  """
  if (num_partitions is None) == (model_parallel_submesh is None):
    raise ValueError(
        'One, and only one, of {num_partitions, model_parallel_submesh} must '
        'be specified. Received: %s and %s' %
        (num_partitions, model_parallel_submesh))

  if num_experts == 0 or jax.device_count() <= num_experts:
    # No expert replication required. No need to override model mesh axis.
    return False

  return ((num_partitions is not None and num_partitions <= 1) or
          (model_parallel_submesh is not None and
           np.prod(model_parallel_submesh) <= 1))


def _override_partition_specs(resources: Pytree):
  """Override axis resources for two data axes setup.

  In the two data axes setup, we treat the 'model' axis as a secondary data
  axis. To this end, we override any hardcoded, raw partition specs:
  - PartitionSpec('data',) -> PartitionSpec(('data', 'model'),)
  - PartitionSpec('model',) -> None
  There is no need to override any params or optimizer state as these will
  inherit the correct specs from the logical axis rules; see
  standard_logical_axis_rules().

  Args:
    resources: Axis resource assignment specifications.

  Returns:
    Axis resources with partition specs overridden to use 'model' as secondary
    'data' axis.
  """

  def _maybe_overridde_spec(axis_resource: Pytree):
    """Overrides "data" and "model" partition specs; leaves others unchanged."""
    if axis_resource == PartitionSpec('data',):
      # Shard all batches across both axes.
      return PartitionSpec(('data', 'model'),)
    elif axis_resource == PartitionSpec('model',):
      # No model parallelism.
      return None
    else:
      return axis_resource

  if resources is None:
    return resources
  elif not isinstance(resources, Sequence):
    return _maybe_overridde_spec(resources)
  else:
    overridden_resources = []
    for resource in resources:
      overridden_resources.append(_maybe_overridde_spec(resource))
  return tuple(overridden_resources)


def _infer_state_filter_fn(
    train_state: FlaxOptimTrainState) -> Optional[Callable[[str], bool]]:
  """Infers relevant regex matching sharded expert model state for optimizer.

  Only the Adafactor optimizer is currently supported.

  The model state generally inherits the correct partitioning specs from the
  model parameters, except in cases where the kernel is factored (`v_col` and
  `v_row` terms); see derive_logical_axes():
  https://github.com/google-research/t5x/blob/main/t5x/adafactor.py#L591. For
  those cases, we use the state_filter_fn to identify the factored kernel terms
  that need to be partitioned along the expert axis.

  Args:
    train_state: Object holding optimizer and optimizer state (parameters).

  Returns:
    Function to identify which model state is sharded along 'expert' axis.

  Raises:
    ValueError if optimizer (on train state) is not an Adafactor optimizer.
  """
  optimizer = train_state._optimizer  # pylint: disable=protected-access
  optimizer_def = optimizer.optimizer_def

  # TODO(jamesleethorp): Revisit once other T5X optimizers are available.
  if not isinstance(optimizer_def, adafactor.Adafactor):
    raise ValueError('Inferred MoE overrides are currently only available for '
                     f'the Adafactor optimizer. Received: {optimizer_def}')

  if optimizer_def.hyper_params.factored:
    # Factored kernel terms (`v_col` and `v_row`) need to be identified for
    # expert sharding.
    return training_utils.match_fn(r'.*expert.*/kernel/v_.*')
  else:
    # Non-factored kernel terms (`v`) inherit the correct specs, so no state
    # updates will be required.
    return None