Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Trainer and MetricsManager classes for use in train loop. | |
To create a custom trainer, subclass `BaseTrainer` and implement | |
`_partitioned_train_step` and `_partitioned_eval_step` methods, | |
possibly by re-using the utility functions provided in this module. | |
""" | |
import abc | |
import enum | |
import os | |
import threading | |
import time | |
from typing import Any, Dict, Iterator, Mapping, MutableMapping, Optional, Sequence, TYPE_CHECKING, Tuple, Union | |
from absl import logging | |
import cached_property | |
from clu import asynclib | |
from clu import metric_writers | |
import clu.data | |
import clu.metrics | |
import clu.values | |
from flax.core import FrozenDict | |
from jax.experimental import multihost_utils | |
import jax.lax | |
import jax.numpy as jnp | |
import jax.random | |
import numpy as np | |
from t5x import metrics as metrics_lib | |
from t5x import models | |
from t5x import partitioning | |
from t5x import train_state as train_state_lib | |
from t5x import utils | |
import typing_extensions | |
Array = Union[np.ndarray, jnp.ndarray] | |
BatchSpec = Mapping[str, jax.ShapeDtypeStruct] | |
BatchType = Mapping[str, np.ndarray] | |
FlaxMutables = FrozenDict | |
Rng = jnp.ndarray | |
MetricMapType = MutableMapping[str, clu.metrics.Metric] | |
MetricMapSpec = Mapping[str, jax.ShapeDtypeStruct] | |
MetricValueMapType = Mapping[str, clu.values.Value] | |
ModelWeights = Any | |
MutableMetricMapType = Dict[str, clu.metrics.Metric] | |
PyTreeDef = type(jax.tree_structure(None)) | |
PartitionSpec = partitioning.PartitionSpec | |
if TYPE_CHECKING: # See b/163639353 | |
cached_property = property # pylint: disable=invalid-name | |
else: | |
cached_property = cached_property.cached_property | |
def _merge_metrics(a, b): | |
return jax.tree_multimap( | |
lambda a, b: a.merge(b), a, b, is_leaf=metrics_lib.is_metric_obj) | |
# Merges two metrics pytrees (mapping of metric_name (str) to clu.Metric object) | |
def merge_metrics(a, b): | |
a, b = jax.tree_map(utils.get_local_data, (a, b)) | |
return _merge_metrics(a, b) | |
class ArrayMapFuture(typing_extensions.Protocol): | |
def result(self) -> Mapping[str, Array]: | |
... | |
class MetricValueMapFuture(typing_extensions.Protocol): | |
def result(self) -> Mapping[str, clu.values.Value]: | |
... | |
class TimeFuture(typing_extensions.Protocol): | |
def result(self) -> float: | |
... | |
class LearningRateCallable(typing_extensions.Protocol): | |
def __call__( | |
self, | |
step: jnp.ndarray, | |
) -> jnp.ndarray: | |
... | |
class SummarizeMetricsCallable(typing_extensions.Protocol): | |
"""PyType template for a metrics summary function.""" | |
def __call__(self, metrics: MetricMapType, duration: float, | |
num_steps: int) -> Mapping[str, jnp.ndarray]: | |
"""Summarizes metrics accumulated across multiple steps. | |
Args: | |
metrics: Metrics accumulated across multiple steps. | |
duration: The duration of the run being summarized. | |
num_steps: The number of steps the metrics are accumulated across. | |
Returns: | |
Summarized metrics. | |
""" | |
... | |
class PartitionedTrainCallable(typing_extensions.Protocol): | |
"""Protocol for a partitioned train step.""" | |
def __call__( | |
self, train_state: train_state_lib.TrainState, | |
batch: BatchType) -> Tuple[train_state_lib.TrainState, MetricMapType]: | |
... | |
class PartitionedEvalCallable(typing_extensions.Protocol): | |
"""Protocol for a partitioned eval step.""" | |
def __call__(self, train_state: train_state_lib.TrainState, | |
batch: jnp.ndarray) -> MetricMapType: | |
... | |
class WeightMetricsComputer(object): | |
"""Decides which weight metrics to compute during training.""" | |
_WEIGHT_METRICS = [ | |
"weight_rms", "weight_gradient_rms", "weight_update_rms", "weight_max" | |
] | |
def _make_rms_metrics(name, tree): | |
"""Calculates the root-mean-square metric for a pytree.""" | |
return { | |
f"{name}/{k}": metrics_lib.AveragePerStep.from_model_output( | |
jnp.sqrt(jnp.mean(jnp.square(v)))) | |
for k, v in utils.flatten_dict_string_keys(tree).items() | |
} | |
def _make_max_metrics(name, tree): | |
"""Calculates the L-inf norm for a pytree.""" | |
return { | |
f"{name}/{k}": | |
metrics_lib.AveragePerStep.from_model_output(jnp.max(jnp.abs(v))) | |
for k, v in utils.flatten_dict_string_keys(tree).items() | |
} | |
def compute_metrics( | |
self, gradients: ModelWeights, | |
old_train_state: train_state_lib.TrainState, | |
new_train_state: train_state_lib.TrainState) -> MutableMetricMapType: | |
"""Compute some metrics about weights after having updating these weights. | |
Args: | |
gradients: The gradients of all weights. | |
old_train_state: The training state before applying the gradients. | |
new_train_state: The training state after applying the gradients. | |
Returns: | |
A dictionary of Metrics, where the keys are either metric names, or of the | |
form metric_name/parameter_name, depending on whether or not they are | |
global to the model, or specific to each model parameter. | |
""" | |
# TODO(reinerp): Extend weight stats logging with support for non-reduced | |
# axes of tensors. For example, for stacked layers (QKV stacking or layer | |
# stacking), we might not want to reduce over the stacking dimension, in | |
# order to provide more localization in the logged stats. | |
metrics = {} | |
metrics.update(self._make_rms_metrics("weight_rms", new_train_state.params)) | |
metrics.update(self._make_rms_metrics("weight_gradient_rms", gradients)) | |
grad_norm = jnp.sqrt( | |
jnp.sum( | |
jnp.array([jnp.vdot(x, x) for x in jax.tree_leaves(gradients)]))) | |
metrics.update({ | |
"weight_gradient_norm": | |
metrics_lib.AveragePerStep.from_model_output(grad_norm) | |
}) | |
metrics.update( | |
self._make_rms_metrics( | |
"weight_update_rms", | |
jax.tree_multimap(jnp.subtract, new_train_state.params, | |
old_train_state.params))) | |
metrics.update(self._make_max_metrics("weight_max", new_train_state.params)) | |
return metrics | |
class _AsyncTimer(object): | |
"""A timer that computes computes durations between async jax operations. | |
You should call close() to wait for threads started by this class to finish. | |
""" | |
def __init__(self): | |
# We use a thread pool with a single worker to ensure that calls to the | |
# function are run in order (but in a background thread). | |
self._pool = asynclib.Pool(thread_name_prefix="AsyncTimer", max_workers=1) | |
self._start_future = None | |
def close(self): | |
self._pool.close() | |
def __del__(self): | |
self.close() | |
def _get_completion_future(self, block_on: PyTreeDef = ()) -> TimeFuture: | |
"""Returns Future containing time when `block_on` is ready.""" | |
def _get_completion_time(): | |
try: | |
jax.block_until_ready(block_on) | |
except RuntimeError as e: | |
# If the buffer no longer exists, we assume it was completed. | |
if (str(e) != | |
"INVALID_ARGUMENT: BlockHostUntilReady() called on deleted or " | |
"donated buffer"): | |
raise | |
return time.time() | |
return self._pool(_get_completion_time)() | |
def start(self, block_on: PyTreeDef = ()): | |
"""Starts timer after `block_on` is ready.""" | |
self._start_future = self._get_completion_future(block_on) | |
def stop(self, block_on: PyTreeDef = ()) -> TimeFuture: | |
"""Stops timer after `block_on` is ready, returning the duration.""" | |
if not self._start_future: | |
raise ValueError("The timer hasn't been started.") | |
start_future = self._start_future | |
self._start_future = None | |
stop_future = self._get_completion_future(block_on) | |
return self._pool(lambda: stop_future.result() - start_future.result())() | |
class MetricsManager(object): | |
"""Manages a set of distributed metrics and their logging. | |
Logging is disabled on all but host 0. | |
Logs to: | |
* TensorBoard | |
* ABSL | |
You should call close() to wait for threads started by this class to finish. | |
""" | |
def __init__(self, name: str, summary_dir: Optional[str] = None): | |
"""MetricsManager constructor. | |
Constructs an empty MetricWriter on all but host 0. | |
Args: | |
name: an identifier of the metrics to use when logging (e.g., 'train'). | |
summary_dir: the summary directory. If provided, TensorBoard summaries | |
will be written to a `name` subdirectory. | |
""" | |
self._name = name | |
if jax.process_index() == 0: | |
self._writer = metric_writers.create_default_writer( | |
summary_dir, | |
collection=name, | |
asynchronous=True) | |
else: | |
self._writer = metric_writers.MultiWriter([]) | |
self.summary_dir = os.path.join(summary_dir, name) if summary_dir else None | |
self._writer_lock = threading.Lock() | |
# We use a thread pool with a single worker to ensure that calls to the | |
# function are run in order (but in a background thread). | |
self._summary_pool = asynclib.Pool( | |
thread_name_prefix="MetricsManager", max_workers=1) | |
# Times the duration between steps. | |
self._duration_timer = _AsyncTimer() | |
def __del__(self): | |
self.close() | |
def close(self): | |
try: | |
self._summary_pool.close() | |
finally: | |
try: | |
self._duration_timer.close() | |
finally: | |
if self._writer: | |
self._writer.close() | |
self._writer = None | |
def summary_writer(self) -> metric_writers.MetricWriter: | |
"""Returns the MetricWriter used by this class.""" | |
# TODO(adarob): Make returned writer threadsafe. | |
return self._writer | |
def write_scalar(self, key: str, val: metric_writers.interface.Scalar, | |
step: int): | |
"""Writes scalar value to metric writers in a threadsafe manner.""" | |
step = int(utils.get_local_data(step)) | |
self.write_scalars(step, {key: val}) | |
def write_scalars(self, step: int, | |
scalars: Mapping[str, metric_writers.interface.Scalar]): | |
"""Writes scalar value to metric writers in a threadsafe manner.""" | |
step = utils.get_local_data(step) | |
with self._writer_lock: | |
self._writer.write_scalars(step, scalars) | |
def start_duration_timer(self, block_on: PyTreeDef = ()): | |
"""Starts the duration timer.""" | |
self._duration_timer.start(block_on=block_on) | |
def write_metrics_summary(self, metrics: MetricMapType, step: int, | |
num_steps: int) -> MetricValueMapFuture: | |
"""Writes summary based on accumulated metrics in a background thread. | |
Duration is automatically computed as the interval between completion of | |
metrics fetching. This closely approximates the duration of `num_steps`, | |
as the steps must be computes sequentually, and it is more accurate than | |
computing the time since the call to the step function since its actual | |
execution occurs asynchronously on the TPU/GPU device. | |
Args: | |
metrics: acculumated metric values. | |
step: the current train step. | |
num_steps: the number of steps the metrics are accumulated across. | |
Returns: | |
A mapping of name -> scalar value of the written summary. Only return the | |
real scalar value on host 0. For other hosts, return None. | |
""" | |
step = utils.get_local_data(step) | |
# Must be called in the main thread to avoid race condition. | |
duration_future = self._duration_timer.stop(block_on=metrics) | |
def _summarize_and_write(): | |
# For thread safety we first copy the metrics to host. | |
fetched_metrics = jax.tree_map(jax.device_get, metrics) | |
duration = duration_future.result() | |
# We set the duration on time-related metrics. | |
final_metrics = metrics_lib.set_time_metrics_duration( | |
fetched_metrics, duration) | |
# Set num_steps for Step metrics (AveragePerStep, StepsPerTime, ...) | |
final_metrics = metrics_lib.set_step_metrics_num_steps( | |
final_metrics, num_steps) | |
# Ensure the metrics are not on device, which could lead to a deadlock. | |
def _ensure_not_on_device(x): | |
assert not isinstance(x, jax.numpy.DeviceArray) | |
jax.tree_map(_ensure_not_on_device, final_metrics) | |
final_metrics = jax.tree_map(utils.get_local_data, final_metrics) | |
summary = {k: v.compute_value() for k, v in final_metrics.items()} | |
with self._writer_lock: | |
metric_writers.write_values(self._writer, int(step), summary) | |
return summary | |
return self._summary_pool(_summarize_and_write)() | |
def flush(self): | |
try: | |
self._summary_pool.join() | |
finally: | |
self._writer.flush() | |
class PreemptionError(Exception): | |
"""Training has been interrupted and needs an emergency checkpoint.""" | |
class BaseTrainer(abc.ABC): | |
"""Abstract base trainer class. | |
Internally this uses MetricsManagers that start threads. You should | |
use the trainer as a context manager, or call close() directly in | |
order to wait for these threads to finish after training is done. | |
""" | |
def __init__(self, model: models.BaseModel, | |
train_state: train_state_lib.TrainState, | |
partitioner: partitioning.BasePartitioner, | |
eval_names: Sequence[str], summary_dir: Optional[str], | |
train_state_axes: Any, rng: Rng): | |
"""Trainer constructor. | |
Args: | |
model: the instantiation of `BaseModel` to train. | |
train_state: A train state with model parameters and optimizer state. | |
partitioner: the partitioner to use. | |
eval_names: names of evaluation datasets, which must match the keys of the | |
mapping passed to `eval`. | |
summary_dir: optional directory to write TensorBoard metrics to. | |
train_state_axes: partitioning info for the train state to be used. | |
rng: jax PRNGKey seed for random operations, to be combined with step | |
number for a deterministic RNG. | |
""" | |
self._model = model | |
self._train_state_axes = train_state_axes | |
self._base_rng = rng | |
self._partitioner = partitioner | |
self._compiled_train_step: Optional[PartitionedTrainCallable] = None | |
self._compiled_eval_steps: MutableMapping[str, PartitionedEvalCallable] = {} | |
self._compiled_eval_step_cache: MutableMapping[ | |
BatchSpec, PartitionedEvalCallable] = {} | |
self._train_state_mutex = threading.RLock() | |
self._train_state = train_state | |
self.stop_training = False | |
# The training metrics combine metrics added by the Model (e.g., loss and | |
# accuracy) and Trainer (e.g., learning rate). | |
self.train_metrics_manager = MetricsManager( | |
"train", summary_dir=summary_dir) | |
# The eval metrics only include metrics added by the Model. | |
self.eval_metrics_managers = { # pylint:disable=g-complex-comprehension | |
n: MetricsManager(f"training_eval/{n}", summary_dir=summary_dir) | |
for n in eval_names | |
} | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.close() | |
def close(self): | |
"""Stops all train metric managers threads.""" | |
self.train_metrics_manager.close() | |
for mm in self.eval_metrics_managers.values(): | |
mm.close() | |
def _get_step_rng(self, step: int) -> Rng: | |
return jax.random.fold_in(self._base_rng, step) | |
def train_state(self): | |
with self._train_state_mutex: | |
return self._train_state | |
def train_state(self, train_state: PyTreeDef): | |
with self._train_state_mutex: | |
self._train_state = train_state | |
def train(self, | |
batch_iter: Union[Iterator[BatchType], clu.data.DatasetIterator], | |
num_steps: int, | |
start_step: Optional[int] = None) -> ArrayMapFuture: | |
"""Runs the train loop for the given number of steps.""" | |
metrics = None | |
# Use pre-compiled step, if available. | |
train_step_fn = self._compiled_train_step or self._partitioned_train_step | |
# We lock `train_state` access during the loop to avoid race conditions. | |
with self._train_state_mutex: | |
train_state = self.train_state | |
# Compute step number on host to avoid communication overhead. | |
start_step = int( | |
start_step if start_step is not None else train_state.step) | |
self.train_metrics_manager.start_duration_timer(block_on=train_state) | |
for step_num in range(start_step, start_step + num_steps): | |
logging.log_every_n_seconds(logging.INFO, "Training: step %d", 10, | |
step_num) | |
with jax.profiler.StepTraceAnnotation("train", step_num=step_num): | |
batch = next(batch_iter) | |
train_state, metrics_update = train_step_fn(train_state, batch) | |
if metrics: | |
metrics = merge_metrics(metrics, metrics_update) | |
else: | |
metrics = metrics_update | |
self.train_state = train_state | |
return self.train_metrics_manager.write_metrics_summary( | |
metrics, start_step + num_steps, num_steps) | |
def compile_train(self, batch: BatchType) -> None: | |
"""Pre-compiles train step (if not yet compiled). | |
Not required. | |
If not called before `train`, compilation will occur automatically on the | |
first step and JAX's "jit cache" will be used to avoid recompilation for | |
future steps. | |
Args: | |
batch: A sample batch that may contain dummy values, but with correct | |
shapes and dtypes. | |
""" | |
tick = time.time() | |
self._compiled_train_step = self._partitioner.compile( | |
self._partitioned_train_step, self.train_state, batch) | |
tock = time.time() | |
self.train_metrics_manager.write_scalar("timing/compilation_seconds", | |
tock - tick, self.train_state.step) | |
def eval( | |
self, batch_iters: Mapping[str, | |
Iterator[BatchType]]) -> Mapping[str, Array]: | |
"""Runs evaluation loop over the iterator and writes summary.""" | |
eval_summaries = {} | |
train_state = self.train_state | |
for iter_name, batch_iter in batch_iters.items(): | |
logging.info("Evaluating: %s.", iter_name) | |
metrics = None | |
# Use a pre-compiled step function, if available. | |
eval_step_fn = self._compiled_eval_steps.get(iter_name, | |
self._partitioned_eval_step) | |
mm = self.eval_metrics_managers[iter_name] | |
num_steps = 0 | |
mm.start_duration_timer(block_on=train_state) | |
for batch in batch_iter: | |
num_steps += 1 | |
multihost_utils.assert_equal( | |
jnp.array(num_steps), | |
"Eval step mismatch across hosts. Check for empty dataset shard.") | |
metrics_update = eval_step_fn(train_state, batch) | |
if metrics: | |
metrics = merge_metrics(metrics, metrics_update) | |
else: | |
metrics = metrics_update | |
multihost_utils.assert_equal( | |
jnp.array(-1), | |
"Eval step mismatch across hosts. Check for empty dataset shard.") | |
eval_summaries[iter_name] = mm.write_metrics_summary( | |
metrics, train_state.step, num_steps) | |
# TODO(adarob): Return futures. | |
return {k: v.result() for k, v in eval_summaries.items()} | |
def compile_eval(self, batches: Mapping[str, BatchType]) -> None: | |
"""Pre-compiles eval step (if not yet compiled). | |
Not required. | |
Pre-compiles the evaluation step for each evaluation dataset, reusing cached | |
compilations where possible. In other words, if multiple evaluation datasets | |
have equivalent shapes/dtypes for the batch and initial metrics, | |
recompilation will be avoided. | |
If not called before `eval`, compilation will occur automatically on the | |
first step and JAX's "jit cache" will be used to avoid recompilation for | |
future steps. | |
Args: | |
batches: a mapping from evaluation dataset name to a sample batch. The | |
batch may contain dummy values, but the shapes and dtypes must be | |
correct. | |
""" | |
for eval_name, batch in batches.items(): | |
tick = time.time() | |
cache_key: BatchSpec = FrozenDict(jax.eval_shape(lambda: batch)) # pylint:disable=cell-var-from-loop | |
if cache_key not in self._compiled_eval_step_cache: | |
self._compiled_eval_step_cache[cache_key] = self._partitioner.compile( | |
self._partitioned_eval_step, self.train_state, batch) | |
self._compiled_eval_steps[eval_name] = self._compiled_eval_step_cache[ | |
cache_key] | |
tock = time.time() | |
self.eval_metrics_managers[eval_name].write_scalar( | |
"timing/compilation_seconds", tock - tick, self.train_state.step) | |
def _partitioned_train_step(self) -> PartitionedTrainCallable: | |
"""Partitioned train step.""" | |
raise NotImplementedError | |
def _partitioned_eval_step(self) -> PartitionedEvalCallable: | |
"""Partitioned eval step.""" | |
raise NotImplementedError | |
def accumulate_grads_microbatched( | |
model: models.BaseModel, | |
train_state: train_state_lib.TrainState, | |
batch: BatchType, | |
dropout_rng: Rng, | |
num_microbatches: Optional[int], | |
data_partition_spec: PartitionSpec = PartitionSpec("data"), | |
) -> Tuple[train_state_lib.TrainState, MutableMetricMapType, | |
Optional[FlaxMutables]]: | |
"""Implements optional microbatched gradient accumulation. | |
Args: | |
model: the instantiation of `BaseModel` to train. | |
train_state: A train state with model parameters and optimizer state. | |
batch: input batch consisting of either - simply-padded batched features | |
'encoder_input_tokens', 'decoder_input_tokens' 'decoder_target_tokens' | |
'decoder_loss_weights'- packed, batched features with additional | |
"(encoder|decoder)_segment_id", "(encoder|decoder)_position" | |
dropout_rng: jax PRNGKey for dropout. | |
num_microbatches: the number of microbatches to use, or None for direct | |
training. | |
data_partition_spec: the PartitionSpec to use for partitioning annotations | |
on the batch. | |
Returns: | |
Accumulated gradients and incremental metrics. | |
""" | |
batch_size = next(iter(batch.values())).shape[0] | |
grad_fn = jax.value_and_grad(model.loss_fn, has_aux=True) | |
# We assume that the model loss_fn supports flax mutables if and only if | |
# the train state includes non-empty flax mutables. | |
# Note: Default t5x models don't support flax_mutables. One needs to subclass | |
# them and return flax_mutables from `get_initial_variables` and `loss_fn`. | |
initial_flax_mutables = train_state.flax_mutables if train_state.flax_mutables else None | |
if num_microbatches is None or num_microbatches <= 1: | |
if initial_flax_mutables is None: | |
(_, metrics), grad_accum = grad_fn(train_state.params, batch, dropout_rng) | |
flax_mutables = None | |
else: | |
(_, metrics, flax_mutables), grad_accum = grad_fn(train_state.params, | |
batch, dropout_rng, | |
initial_flax_mutables) | |
else: | |
assert batch_size % num_microbatches == 0, ( | |
"Batch size isn't divided evenly by num_microbatches.") | |
microbatch_size = batch_size // num_microbatches | |
logging.info("using microbatches: %d microbatches, %d size", | |
num_microbatches, microbatch_size) | |
def get_microbatch(batch: BatchType, idx: int) -> Mapping[str, jnp.ndarray]: | |
"""Fetch microbatch slice from possibly-packed input data.""" | |
offset = idx * microbatch_size | |
length = microbatch_size | |
starts = {k: [offset] + [0] * (b.ndim - 1) for k, b in batch.items()} | |
limits = {k: [length] + list(b.shape[1:]) for k, b in batch.items()} | |
return { | |
k: jax.lax.dynamic_slice(b, starts[k], limits[k]) | |
for k, b in batch.items() | |
} | |
def metrics_and_grad(loop_cnt, dropout_rng, flax_mutables=None): | |
dropout_rng, sub_dropout_rng = jax.random.split(dropout_rng) | |
mbatch = get_microbatch(batch, loop_cnt) | |
# We need to annotate the microbatch sharding as we would a batch. | |
mbatch = jax.tree_map( | |
lambda x: partitioning.with_sharding_constraint( # pylint: disable=g-long-lambda | |
x, data_partition_spec), | |
mbatch) | |
if flax_mutables is None: | |
(_, metrics), grad = grad_fn(train_state.params, mbatch, | |
sub_dropout_rng) | |
else: | |
(_, metrics, flax_mutables), grad = grad_fn(train_state.params, mbatch, | |
sub_dropout_rng, | |
flax_mutables) | |
return metrics, grad, flax_mutables | |
def per_microbatch_train_step( | |
loop_cnt: int, state: Tuple[jnp.ndarray, jnp.ndarray, | |
Mapping[str, jnp.ndarray], | |
Optional[FlaxMutables]] | |
) -> Tuple[jnp.ndarray, jnp.ndarray, Mapping[str, jnp.ndarray], | |
Optional[FlaxMutables]]: | |
(dropout_rng, grad_accum, prev_metrics, flax_mutables) = state | |
metrics, grad, flax_mutables = metrics_and_grad(loop_cnt, dropout_rng, | |
flax_mutables) | |
grad_accum = jax.tree_multimap(jnp.add, grad_accum, grad) | |
metrics = jax.lax.cond(loop_cnt == 0, lambda _: metrics, | |
lambda _: merge_metrics(prev_metrics, metrics), | |
None) | |
return dropout_rng, grad_accum, metrics, flax_mutables | |
# Initialize gradient accumulation loop state. | |
accum_dtype = jnp.float32 | |
grad_accum_init = jax.tree_map(lambda x: jnp.zeros(x.shape, accum_dtype), | |
train_state.params) | |
initial_metrics_shape, _, _ = jax.eval_shape( | |
metrics_and_grad, loop_cnt=0, dropout_rng=dropout_rng) | |
initial_metrics = { | |
k: metrics_lib.shape_obj_to_defined_obj(v) | |
for k, v in initial_metrics_shape.items() | |
} | |
loop_init = (dropout_rng, grad_accum_init, initial_metrics, | |
initial_flax_mutables) | |
new_dropout_rng, grad_accum, metrics, flax_mutables = jax.lax.fori_loop( | |
0, num_microbatches, per_microbatch_train_step, loop_init) | |
del new_dropout_rng | |
return grad_accum, metrics, flax_mutables | |
def apply_grads( | |
train_state: train_state_lib.TrainState, | |
grad_accum: ModelWeights, | |
metrics: MutableMetricMapType, | |
learning_rate: jnp.ndarray, | |
weight_metrics_computer: Optional[WeightMetricsComputer], | |
other_state_variables: Optional[Mapping[str, Any]] = None | |
) -> Tuple[train_state_lib.TrainState, MetricMapType]: | |
"""Applies gradients to the optimizer. | |
Args: | |
train_state: A train state that contains model and optimizer params. | |
grad_accum: results of `accumulate_grads` call. | |
metrics: incremental metrics from `accumulate_grads` call. | |
learning_rate: the learning rate to use for this step. | |
weight_metrics_computer: A WeightMetricsComputer instance, or None, to | |
decide what metrics, if any, to log about weights and weight updates | |
during training. | |
other_state_variables: other variables to update the state with. | |
Returns: | |
The updated train state, metrics. | |
""" | |
if other_state_variables is None: | |
other_state_variables = {} | |
# Update optimizer using accumulated gradient. | |
new_train_state = train_state.apply_gradient( | |
grad_accum, learning_rate=learning_rate, **other_state_variables) | |
metrics["learning_rate"] = clu.metrics.Average.from_model_output( | |
jnp.asarray([learning_rate])) | |
metrics["learning_rate/current"] = clu.metrics.LastValue.from_model_output( | |
jnp.asarray([learning_rate])) | |
if weight_metrics_computer is not None: | |
metrics.update( | |
weight_metrics_computer.compute_metrics(grad_accum, train_state, | |
new_train_state)) | |
return new_train_state, metrics | |
def eval_step(model: models.BaseModel, train_state: train_state_lib.TrainState, | |
batch: jnp.ndarray) -> MetricMapType: | |
"""Default evaluation step.""" | |
_, metrics = model.eval_fn(train_state.params, batch) | |
return metrics | |
def train_with_lr( | |
train_state: train_state_lib.TrainState, | |
batch: BatchType, | |
learning_rate: jnp.ndarray, | |
dropout_rng: Rng, | |
model: models.BaseModel, | |
num_microbatches: Optional[int], | |
weight_metrics_computer: Optional[WeightMetricsComputer] = None, | |
data_partition_spec: PartitionSpec = PartitionSpec("data")): | |
"""Main training function with LR schedule.""" | |
grad_accum, metrics, flax_mutables = ( | |
accumulate_grads_microbatched(model, train_state, batch, dropout_rng, | |
num_microbatches, data_partition_spec)) | |
new_train_state, metrics = apply_grads( | |
train_state, | |
grad_accum, | |
metrics, | |
learning_rate, | |
weight_metrics_computer, | |
other_state_variables={"flax_mutables": flax_mutables} | |
if flax_mutables else None) | |
return new_train_state, metrics | |
class Trainer(BaseTrainer): | |
"""Training loop with optional microbatches.""" | |
def __init__(self, | |
model: models.BaseModel, | |
train_state: train_state_lib.TrainState, | |
partitioner: partitioning.BasePartitioner, | |
eval_names: Sequence[str], | |
summary_dir: Optional[str], | |
train_state_axes: Any, | |
rng: Rng, | |
learning_rate_fn: LearningRateCallable, | |
num_microbatches: Optional[int], | |
weight_metrics_computer: Optional[WeightMetricsComputer] = None): | |
"""Trainer constructor. | |
Args: | |
model: the instantiation of `BaseModel` to train. | |
train_state: a train state with parameters and optimizer state. | |
partitioner: the partitioner to use. | |
eval_names: names of evaluation datasets, which must match the keys of the | |
mapping passed to `eval`. | |
summary_dir: optional directory to write TensorBoard metrics to. | |
train_state_axes: partitioning info for the optimizer to be used. | |
rng: jax PRNGKey seed for random operations, to be combined with step | |
number for a deterministic RNG. | |
learning_rate_fn: returns the learning rate given the current step. | |
num_microbatches: the number of microbatches to use, or None for direct | |
training. | |
weight_metrics_computer: A WeightMetricsComputer instance, or None, to | |
decide what metrics, if any, to log about weights and weight updates | |
during training. | |
""" | |
self._learning_rate_fn = learning_rate_fn | |
self._num_microbatches = num_microbatches | |
self._weight_metrics_computer = weight_metrics_computer | |
super().__init__( | |
model=model, | |
train_state=train_state, | |
partitioner=partitioner, | |
eval_names=eval_names, | |
summary_dir=summary_dir, | |
train_state_axes=train_state_axes, | |
rng=rng) | |
def _partitioned_train_step(self) -> PartitionedTrainCallable: | |
def train_step(train_state: train_state_lib.TrainState, batch: BatchType): | |
return train_with_lr( | |
train_state, | |
batch, | |
learning_rate=self._learning_rate_fn(train_state.step), | |
dropout_rng=self._get_step_rng(train_state.step), | |
model=self._model, | |
num_microbatches=self._num_microbatches, | |
weight_metrics_computer=self._weight_metrics_computer, | |
data_partition_spec=self._partitioner.data_partition_spec) | |
return self._partitioner.partition( | |
train_step, | |
in_axis_resources=(self._train_state_axes, | |
self._partitioner.data_partition_spec), | |
out_axis_resources=(self._train_state_axes, None), | |
donate_argnums=(0,)) | |
def _partitioned_eval_step(self) -> PartitionedEvalCallable: | |
return self._partitioner.partition( | |
lambda *args, **kwargs: eval_step(self._model, *args, **kwargs), | |
in_axis_resources=(self._train_state_axes, | |
self._partitioner.data_partition_spec), | |
out_axis_resources=None) | |
def _warn_action_not_run(action, task, metric): | |
logging.warning( | |
"The action: %s that tracks metric: %s for task: %s is not run", action, | |
metric, task) | |
# TODO(b/200701930): Support dynamic registration for enum. | |
class ActionMode(enum.Enum): | |
"""Defines when to run a action. | |
For example, TRAIN means to run an action after a TRAIN loop is done. | |
""" | |
TRAIN = 1 | |
TRAIN_EVAL = 2 | |
INFER_EVAL = 3 | |
class BaseAction(abc.ABC): | |
"""Base Action class for override. The action itself does nothing.""" | |
def run(self, train_state: train_state_lib.TrainState, | |
metrics_by_task: Mapping[str, MetricValueMapType]) -> bool: | |
"""Runs an action for the given train_state and metrics. | |
Args: | |
train_state: The current train_state in the training loop. | |
metrics_by_task: A map of metrics that is grouped by each task. | |
Returns: | |
A bool indicating whether training should be halted. | |
""" | |
raise NotImplementedError("Action must define its run method.") | |
ActionMapType = Mapping[ActionMode, Sequence[BaseAction]] | |
class EarlyStoppingAction(BaseAction): | |
"""Terminates training when the specified metric is not improving. | |
Checks whether the monitored metrics are decreasing after every `train` or | |
`eval`, or `both`. If the loss is no longer decreasing for `patience` times, | |
terminate the training process. | |
""" | |
def __init__(self, | |
metric: Tuple[str, str], | |
mode: str, | |
patience: int = 3, | |
atol: float = 0., | |
rtol: float = 0.): | |
"""Constructs the EarlyStoppingAction. | |
Args: | |
metric: A metric to monitor when invoking the action. When the metric does | |
not improve for a number of times (specified in patience), stop the | |
training. The tuple takes 2 strings, whereas the first string defines | |
the task to track, and the second defines the metric of the task to | |
track. e.g.,: ('mt5_xnli_dev_test.all_langs', 'accuracy') would monitor | |
the 'accuracy' for `mt5_xnli_dev_test.all_langs`. | |
mode: One of `{"min", "max"}`. In `min` mode, training will stop when the | |
quantity monitored has stopped decreasing; in `"max"` mode it will stop | |
when the quantity monitored has stopped increasing; | |
patience: The threshold of stopping criteria. Usually this is measured by | |
number of steps. | |
atol: Absolute tolerance in the monitored quantity to qualify as an | |
improvement, i.e. a change of less than `atol`, will count as no | |
improvement. | |
rtol: Relative tolerance in the monitoried quantity to qualify as an | |
improvement. This combined with `atol` defines whether a change is | |
considered improvement. The total change is calculated as following: | |
`delta = (atol + rtol * previous)` See `numpy.allclose` for detailed | |
information. | |
""" | |
self._task, self._metric = metric | |
if mode not in ["min", "max"]: | |
raise ValueError('mode must be in ["min", "max"]') | |
self._mode = mode | |
if atol < 0: | |
raise ValueError("atol must be greater equal than 0") | |
self._atol = atol | |
if rtol < 0: | |
raise ValueError("rtol must be greater equal than 0") | |
self._rtol = rtol | |
self._patience = patience | |
self._metric_history = [] | |
def _compare_fn(self, current, previous): | |
compare_fn = jnp.greater_equal if self._mode == "min" else jnp.less_equal | |
delta = self._atol + self._rtol * abs(previous) | |
if self._mode == "max": | |
delta *= -1 | |
return compare_fn(current, previous - delta) | |
def run(self, train_state: train_state_lib.TrainState, | |
metrics_by_task: Mapping[str, MetricValueMapType]) -> bool: | |
if self._task not in metrics_by_task.keys(): | |
logging.warning( | |
"Monitoring task: %s does not exist in all task metrics. " | |
"Available tasks are : %s", self._task, metrics_by_task.keys()) | |
_warn_action_not_run(type(self), self._task, self._metric) | |
return False | |
if self._metric not in metrics_by_task[self._task].keys(): | |
logging.warning("Metric : %s does not exist in metrics for task : %s", | |
self._metric, self._task) | |
_warn_action_not_run(type(self), self._task, self._metric) | |
return False | |
m = metrics_by_task[self._task][self._metric] | |
if not isinstance(m, clu.values.Scalar): | |
logging.warning("Metric %s does not have Scalar type. Found %s.", | |
self._metric, type(m)) | |
_warn_action_not_run(type(self), self._task, self._metric) | |
return False | |
self._metric_history.append(m.value) | |
# Not enough history. | |
if len(self._metric_history) < self._patience: | |
return False | |
if all( | |
self._compare_fn(self._metric_history[i + 1], self._metric_history[i]) | |
for i in range(self._patience - 1)): | |
logging.warning( | |
"Requested `stop_training` in training loop (Details below).\n " | |
"Metric: %s for Task: %s has not improved for %s iterations, detail " | |
"history of the metric: %s", self._metric, self._task, self._patience, | |
self._metric_history) | |
return True | |
# Remove extra histories that we don't need to keep. | |
self._metric_history.pop(0) | |
return False | |
class TerminateOnNanAction(BaseAction): | |
"""Terminates training when NaN loss is detected. | |
Checks whether the loss metric for the given task is NaN or Inf and terminates | |
training if so. | |
""" | |
def __init__(self, task: str, metric: str = "loss"): | |
"""Constructs the TerminateOnNanAction. | |
Args: | |
task: Defines the task from which to track the given metric. | |
metric: Defines a metric to track for NaN values (defaults to "loss"). | |
""" | |
self._task = task | |
self._metric = metric | |
def run(self, train_state: train_state_lib.TrainState, | |
metrics_by_task: Mapping[str, MetricValueMapType]) -> bool: | |
if self._task not in metrics_by_task.keys(): | |
logging.warning( | |
"Monitoring task: %s does not exist in all task metrics. " | |
"Available tasks are : %s", self._task, metrics_by_task.keys()) | |
_warn_action_not_run(type(self), self._task, self._metric) | |
return False | |
if self._metric not in metrics_by_task[self._task].keys(): | |
logging.warning("Metric : %s does not exist in metrics for task : %s", | |
self._metric, self._task) | |
_warn_action_not_run(type(self), self._task, self._metric) | |
return False | |
metric = metrics_by_task[self._task][self._metric] | |
if not isinstance(metric, clu.values.Scalar): | |
logging.warning("Metric %s does not have Scalar type. Found %s.", | |
self._metric, type(metric)) | |
_warn_action_not_run(type(self), self._task, self._metric) | |
return False | |
value = metric.value | |
if np.isnan(value) or np.isinf(value): | |
logging.warning( | |
"Requested `stop_training` in training loop (Details below).\n " | |
"NaN encountered in metric for task : %s", self._task) | |
return True | |
return False | |