# Copyright 2022 The T5X Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Trainer with Mixture of Experts support.""" from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING import cached_property from t5x import models from t5x import train_state as train_state_lib from t5x import trainer from t5x.contrib.moe import partitioning from t5x.contrib.moe import training_utils BatchType = trainer.BatchType LearningRateCallable = trainer.LearningRateCallable MetricMapType = trainer.MetricMapType PartitionSpec = partitioning.PartitionSpec PartitionedTrainCallable = trainer.PartitionedTrainCallable Rng = trainer.Rng if TYPE_CHECKING: # See b/163639353 cached_property = property # pylint: disable=invalid-name else: cached_property = cached_property.cached_property class MoeTrainer(trainer.Trainer): """T5X trainer with overrides for Mixture of Experts support.""" def __init__( self, model: models.BaseModel, train_state: train_state_lib.TrainState, partitioner: partitioning.MoePjitPartitioner, eval_names: Sequence[str], summary_dir: Optional[str], train_state_axes: Any, rng: Rng, learning_rate_fn: LearningRateCallable, num_microbatches: Optional[int], num_experts: int, sharded_match_fn: Optional[Callable[ [str], bool]] = training_utils.match_fn(r'.*expert.*'), weight_metrics_computer: Optional[trainer.WeightMetricsComputer] = None): """Trainer constructor. Args: model: the instantiation of `BaseModel` to train. train_state: a train state with parameters and optimizer state. partitioner: the partitioner to use. eval_names: names of evaluation datasets, which must match the keys of the mapping passed to `eval`. summary_dir: optional directory to write TensorBoard metrics to. train_state_axes: partitioning info for the optimizer to be used. rng: jax PRNGKey seed for random operations, to be combined with step number for a deterministic RNG. learning_rate_fn: returns the learning rate given the current step. num_microbatches: the number of microbatches to use, or None for direct training. num_experts: Global number of experts. Used to scale sharded parameter gradients. sharded_match_fn: Filter function for distinguishing sharded (MoE) parameters from replicated parameters. Used to identify the sharded parameter gradients that need to be rescaled under pjit training. weight_metrics_computer: A WeightMetricsComputer instance, or None, to decide what metrics, if any, to log about weights and weight updates during training. """ super().__init__( model=model, train_state=train_state, partitioner=partitioner, eval_names=eval_names, summary_dir=summary_dir, train_state_axes=train_state_axes, rng=rng, learning_rate_fn=learning_rate_fn, num_microbatches=num_microbatches, weight_metrics_computer=weight_metrics_computer) self._num_experts = num_experts self._sharded_match_fn = sharded_match_fn self.data_partition_spec = partitioning.data_partition_spec( partitioner.two_data_axes) @cached_property def _partitioned_train_step(self) -> PartitionedTrainCallable: """Same as a regular T5X train step, but scales expert parameter gradients. We must scale expert parameter gradients by the number of experts to account for pjit's implicit averaging over partitioned parameter gradients. Returns: Partitioned train step function. """ def train_with_lr(train_state: train_state_lib.TrainState, batch: BatchType): grad_accum, metrics, flax_mutables = ( trainer.accumulate_grads_microbatched( self._model, train_state, batch, self._get_step_rng(train_state.step), self._num_microbatches, data_partition_spec=self.data_partition_spec)) # Only difference between this train step and regular T5X train step: scaled_grads = training_utils.scale_sharded_grads( grad_accum, self._sharded_match_fn, scale_factor=self._num_experts) new_train_state, metrics = trainer.apply_grads( train_state, scaled_grads, metrics, self._learning_rate_fn(train_state.step), self._weight_metrics_computer, other_state_variables={'flax_mutables': flax_mutables} if flax_mutables else None) return new_train_state, metrics return self._partitioner.partition( train_with_lr, in_axis_resources=(self._train_state_axes, self.data_partition_spec), out_axis_resources=(self._train_state_axes, None), donate_argnums=(0,))