juancopi81's picture
Add t5x and mt3 models
b100e1c
|
raw
history blame
2.83 kB

Scalable T5

NB: This particular example is still WIP. We're investigating a slight training regression compared to the "vanilla" T5 example.

This directory is very similar to the vanilla T5X "T5" example, but demonstrates a host of techniques needed to scale model training to giant models run on large TPU or GPU cluster environments using XLA's SPMD capabilities. See the notes for the main "t5" example for general details on setup and execution.

Note: many of the APIs built on top of pjit by Flax and T5X for easier model parallel programming are still experimental, and may change.

Intermediate variable annotations

In larger models, with multi-axis model parallelism, it is typically necessary to provide additional constraint annotations beyond those for the input and output parameters for a function. We do this using a special version of the pjit annotation function with_sharding_constraint that uses logical axis names instead of raw mesh axes. This allows us to avoid tightly coupling a specific partitioning plan to the model code itself. Instead, we merely need to annotate the axis names used in the model in a coherent scheme, and later map these logical axes to the physical mesh axes using a small set of rules. Example usage can be seen in network.py.

Scan over layers

One challenge with giant models is the increasing amount of compilation time required to handle extremely large layer stacks in XLA. At the size of a full TPU pod this compile time cost can become quite extreme. To remedy this, instead of handing the compiler a huge stack of unrolled layers, we can use native XLA control flow constructs to simplify the computational graph given from JAX. For giant models this can drop the compile time from hour(s) to minutes, and even at base-scale can be roughly 5x faster.

In this case, we want to use the XLA While Op via JAX's scan control flow construct to express the idea that we're looping over identically-defined layers when using a deep transformer network. We do this via a custom Flax version of scan called scan_with_axes that also handles the parameter logical axis name metadata needed for partitioning.

Rematerialization / Checkpointing

"Rematerialization" or "checkpointing" is a technique for trading off compute time for lower peak memory utilization when performing reverse-mode automatic differentiation. JAX offers several different default rematerialization "policies" that dictate which kinds of intermediate values are preserved from the forward-pass to the backwards-pass calculation, and which are discarded to be recomputed anew in the backwards-pass.