Spaces:
Build error
Build error
# Scalable T5 | |
NB: This particular example is still WIP. We're investigating a slight training | |
regression compared to the "vanilla" T5 example. | |
This directory is very similar to the vanilla T5X "T5" example, but demonstrates | |
a host of techniques needed to scale model training to giant models run on | |
large TPU or GPU cluster environments using XLA's SPMD capabilities. See the | |
notes for the main "t5" example for general details on setup and execution. | |
__Note__: many of the APIs built on top of `pjit` by Flax and T5X for easier | |
model parallel programming are still experimental, and may change. | |
## Intermediate variable annotations | |
In larger models, with multi-axis model parallelism, it is typically necessary | |
to provide additional constraint annotations beyond those for the input and | |
output parameters for a function. We do this using a special version of the | |
`pjit` annotation function `with_sharding_constraint` that uses _logical_ axis | |
names instead of raw mesh axes. This allows us to avoid tightly coupling a | |
specific partitioning plan to the model code itself. Instead, we merely need | |
to annotate the axis names used in the model in a coherent scheme, and later | |
map these logical axes to the physical mesh axes using a small set of rules. | |
Example usage can be seen in `network.py`. | |
## Scan over layers | |
One challenge with giant models is the increasing amount of compilation time | |
required to handle extremely large layer stacks in XLA. At the size of a full | |
TPU pod this compile time cost can become quite extreme. To remedy this, | |
instead of handing the compiler a huge stack of unrolled layers, we can use | |
native XLA control flow constructs to simplify the computational graph given | |
from JAX. For giant models this can drop the compile time from hour(s) to | |
minutes, and even at base-scale can be roughly 5x faster. | |
In this case, we want to use the [XLA While Op](xla-while) via JAX's | |
[scan](jax-scan) control flow construct to express the idea that we're looping | |
over identically-defined layers when using a deep transformer network. We do | |
this via a custom Flax version of scan called `scan_with_axes` that also handles | |
the parameter logical axis name metadata needed for partitioning. | |
## Rematerialization / Checkpointing | |
"Rematerialization" or "checkpointing" is a technique for trading off compute | |
time for lower peak memory utilization when performing reverse-mode automatic | |
differentiation. JAX offers several different default rematerialization | |
"policies" that dictate which kinds of intermediate values are preserved from | |
the forward-pass to the backwards-pass calculation, and which are discarded to | |
be recomputed anew in the backwards-pass. | |
[xla-while]: https://www.tensorflow.org/xla/operation_semantics#while | |
[jax-scan]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html | |