Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Trainer with Mixture of Experts support.""" | |
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING | |
import cached_property | |
from t5x import models | |
from t5x import train_state as train_state_lib | |
from t5x import trainer | |
from t5x.contrib.moe import partitioning | |
from t5x.contrib.moe import training_utils | |
BatchType = trainer.BatchType | |
LearningRateCallable = trainer.LearningRateCallable | |
MetricMapType = trainer.MetricMapType | |
PartitionSpec = partitioning.PartitionSpec | |
PartitionedTrainCallable = trainer.PartitionedTrainCallable | |
Rng = trainer.Rng | |
if TYPE_CHECKING: # See b/163639353 | |
cached_property = property # pylint: disable=invalid-name | |
else: | |
cached_property = cached_property.cached_property | |
class MoeTrainer(trainer.Trainer): | |
"""T5X trainer with overrides for Mixture of Experts support.""" | |
def __init__( | |
self, | |
model: models.BaseModel, | |
train_state: train_state_lib.TrainState, | |
partitioner: partitioning.MoePjitPartitioner, | |
eval_names: Sequence[str], | |
summary_dir: Optional[str], | |
train_state_axes: Any, | |
rng: Rng, | |
learning_rate_fn: LearningRateCallable, | |
num_microbatches: Optional[int], | |
num_experts: int, | |
sharded_match_fn: Optional[Callable[ | |
[str], bool]] = training_utils.match_fn(r'.*expert.*'), | |
weight_metrics_computer: Optional[trainer.WeightMetricsComputer] = None): | |
"""Trainer constructor. | |
Args: | |
model: the instantiation of `BaseModel` to train. | |
train_state: a train state with parameters and optimizer state. | |
partitioner: the partitioner to use. | |
eval_names: names of evaluation datasets, which must match the keys of the | |
mapping passed to `eval`. | |
summary_dir: optional directory to write TensorBoard metrics to. | |
train_state_axes: partitioning info for the optimizer to be used. | |
rng: jax PRNGKey seed for random operations, to be combined with step | |
number for a deterministic RNG. | |
learning_rate_fn: returns the learning rate given the current step. | |
num_microbatches: the number of microbatches to use, or None for direct | |
training. | |
num_experts: Global number of experts. Used to scale sharded parameter | |
gradients. | |
sharded_match_fn: Filter function for distinguishing sharded (MoE) | |
parameters from replicated parameters. Used to identify the sharded | |
parameter gradients that need to be rescaled under pjit training. | |
weight_metrics_computer: A WeightMetricsComputer instance, or None, to | |
decide what metrics, if any, to log about weights and weight updates | |
during training. | |
""" | |
super().__init__( | |
model=model, | |
train_state=train_state, | |
partitioner=partitioner, | |
eval_names=eval_names, | |
summary_dir=summary_dir, | |
train_state_axes=train_state_axes, | |
rng=rng, | |
learning_rate_fn=learning_rate_fn, | |
num_microbatches=num_microbatches, | |
weight_metrics_computer=weight_metrics_computer) | |
self._num_experts = num_experts | |
self._sharded_match_fn = sharded_match_fn | |
self.data_partition_spec = partitioning.data_partition_spec( | |
partitioner.two_data_axes) | |
def _partitioned_train_step(self) -> PartitionedTrainCallable: | |
"""Same as a regular T5X train step, but scales expert parameter gradients. | |
We must scale expert parameter gradients by the number of experts to account | |
for pjit's implicit averaging over partitioned parameter gradients. | |
Returns: | |
Partitioned train step function. | |
""" | |
def train_with_lr(train_state: train_state_lib.TrainState, | |
batch: BatchType): | |
grad_accum, metrics, flax_mutables = ( | |
trainer.accumulate_grads_microbatched( | |
self._model, | |
train_state, | |
batch, | |
self._get_step_rng(train_state.step), | |
self._num_microbatches, | |
data_partition_spec=self.data_partition_spec)) | |
# Only difference between this train step and regular T5X train step: | |
scaled_grads = training_utils.scale_sharded_grads( | |
grad_accum, self._sharded_match_fn, scale_factor=self._num_experts) | |
new_train_state, metrics = trainer.apply_grads( | |
train_state, | |
scaled_grads, | |
metrics, | |
self._learning_rate_fn(train_state.step), | |
self._weight_metrics_computer, | |
other_state_variables={'flax_mutables': flax_mutables} | |
if flax_mutables else None) | |
return new_train_state, metrics | |
return self._partitioner.partition( | |
train_with_lr, | |
in_axis_resources=(self._train_state_axes, self.data_partition_spec), | |
out_axis_resources=(self._train_state_axes, None), | |
donate_argnums=(0,)) | |