u-μP: Stable training in low precision for a significant speed-up and memory reduction during training
This Repository holds the model weights for the u-μP models trained at Aleph Alpha Research, in collaboration with Graphcore, for 72k steps (300B tokens). Please note, that the released checkpoints are not fully converged models and are intended for research use only.
You can find all model weights at the following links:
- umup-research-7b-bf16
- umup-research-7b-fp8
- sp-baseline-research-7b-bf16
- umup-research-3b-bf16
- umup-research-3b-fp8
- sp-baseline-research-3b-bf16
- umup-research-1b-bf16
- umup-research-1b-fp8
- sp-baseline-research-1b-bf16
The Maximal Update Parametrization (μP) aims to make the optimal hyperparameters (HPs) of a model-independent of its size, allowing them to be swept using a cheap proxy model rather than the full-size target model. We present a new scheme, u-μP, which improves upon μP by combining it with Unit Scaling, a method for designing models that makes them easy to train in low precision. The two techniques have a natural affinity: μP ensures that the scale of activations is independent of model size, and Unit Scaling ensures that activations, weights, and gradients begin training with a scale of one. This synthesis opens the door to a simpler scheme, whose default values are near-optimal. This in turn facilitates a more efficient sweeping strategy, with u-μP models reaching a lower loss than comparable μP models and working out-of-the-box in FP8.
If you want to learn more details about u-μP, check out our blog post and our paper.
Unit-Scaled Maximal Update Parametrization (u-μP) is available in Scaling, our official large-scale training codebase. Please note, that FP8-trained checkpoints only work on chips with FP8 support, like the Hopper architecture.
Usage
You can generate tokens with the Scaling inference implementation:
from scaling.transformer.inference import TransformerInferenceModule
from pathlib import Path
ckpt_path = Path("<path_to_repo>/7B_umup_fp8")
model = TransformerInferenceModule.from_checkpoint(ckpt_path)
prompt = "Yesterday I dreamt of"
output = model.generate(max_tokens=100, input_text=prompt)
print(output.completion_text)