Spaces:
Running
Running
#!/usr/bin/env python | |
# coding=utf-8 | |
# Copyright 2021-2022 The HuggingFace & DALL·E Mini team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Training DALL·E Mini. | |
Script adapted from run_summarization_flax.py | |
""" | |
import io | |
import logging | |
import os | |
import sys | |
import tempfile | |
import time | |
from dataclasses import asdict, dataclass, field | |
from pathlib import Path | |
from typing import Any, Callable, NamedTuple, Optional | |
import datasets | |
import flax | |
import jax | |
import jax.numpy as jnp | |
import jaxlib | |
import numpy as np | |
import optax | |
import transformers | |
import wandb | |
from datasets import Dataset | |
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze | |
from flax.serialization import from_bytes, to_bytes | |
from flax.training import train_state | |
from flax.training.common_utils import onehot | |
from google.cloud import storage | |
from jax.experimental import PartitionSpec, maps | |
from jax.experimental.compilation_cache import compilation_cache as cc | |
from jax.experimental.pjit import pjit, with_sharding_constraint | |
from scalable_shampoo.distributed_shampoo import GraftingType, distributed_shampoo | |
from tqdm import tqdm | |
from transformers import HfArgumentParser | |
from dalle_mini.data import Dataset | |
from dalle_mini.model import ( | |
DalleBart, | |
DalleBartConfig, | |
DalleBartTokenizer, | |
set_partitions, | |
) | |
cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30) | |
logger = logging.getLogger(__name__) | |
class ModelArguments: | |
""" | |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. | |
""" | |
model_name_or_path: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "The model checkpoint for weights initialization. " | |
"Don't set if you want to train a model from scratch. " | |
"W&B artifact references are supported in addition to the sources supported by `PreTrainedModel`." | |
}, | |
) | |
config_name: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "Pretrained config name or path if not the same as model_name_or_path" | |
}, | |
) | |
tokenizer_name: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "Pretrained tokenizer name or path if not the same as model_name_or_path" | |
}, | |
) | |
dtype: Optional[str] = field( | |
default="float32", | |
metadata={ | |
"help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`." | |
}, | |
) | |
restore_state: Optional[bool] = field( | |
default=False, | |
metadata={ | |
"help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path." | |
}, | |
) | |
def __post_init__(self): | |
if self.tokenizer_name is None: | |
self.tokenizer_name = self.model_name_or_path | |
assert ( | |
self.tokenizer_name is not None | |
), "Tokenizer name or model name/path needs to be specified" | |
if self.restore_state: | |
assert self.model_name_or_path is not None and ( | |
"/model-" in self.model_name_or_path | |
), "Restoring state only available with W&B artifact reference" | |
def get_metadata(self): | |
if self.restore_state: | |
if jax.process_index() == 0: | |
artifact = wandb.run.use_artifact(self.model_name_or_path) | |
else: | |
artifact = wandb.Api().artifact(self.model_name_or_path) | |
return artifact.metadata | |
else: | |
return dict() | |
def get_opt_state(self): | |
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies | |
if self.restore_state is True: | |
# wandb artifact | |
state_artifact = self.model_name_or_path.replace( | |
"/model-", "/state-", 1 | |
) | |
if jax.process_index() == 0: | |
artifact = wandb.run.use_artifact(state_artifact) | |
else: | |
artifact = wandb.Api().artifact(state_artifact) | |
if artifact.metadata.get("bucket_path"): | |
# we will read directly file contents | |
self.restore_state = artifact.metadata["bucket_path"] | |
else: | |
artifact_dir = artifact.download(tmp_dir) | |
self.restore_state = str(Path(artifact_dir) / "opt_state.msgpack") | |
if self.restore_state.startswith("gs://"): | |
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack" | |
bucket, blob_name = str(bucket_path).split("/", 1) | |
client = storage.Client() | |
bucket = client.bucket(bucket) | |
blob = bucket.blob(blob_name) | |
return blob.download_as_bytes() | |
with Path(self.restore_state).open("rb") as f: | |
return f.read() | |
class DataTrainingArguments: | |
""" | |
Arguments pertaining to what data we are going to input our model for training and eval. | |
""" | |
text_column: Optional[str] = field( | |
default="caption", | |
metadata={ | |
"help": "The name of the column in the datasets containing the full texts (for summarization)." | |
}, | |
) | |
encoding_column: Optional[str] = field( | |
default="encoding", | |
metadata={ | |
"help": "The name of the column in the datasets containing the image encodings." | |
}, | |
) | |
dataset_repo_or_path: str = field( | |
default=None, | |
metadata={"help": "The dataset repository containing encoded files."}, | |
) | |
train_file: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "The input training data file (glob & braceexpand acceptable)." | |
}, | |
) | |
validation_file: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "An optional input evaluation data file (glob & braceexpand acceptable)." | |
}, | |
) | |
# data loading should not be a bottleneck so we use "streaming" mode by default | |
streaming: Optional[bool] = field( | |
default=True, | |
metadata={"help": "Whether to stream the dataset."}, | |
) | |
use_auth_token: Optional[bool] = field( | |
default=False, | |
metadata={ | |
"help": "Whether to use the authentication token for private datasets." | |
}, | |
) | |
shard_by_host: Optional[bool] = field( | |
default=False, | |
metadata={ | |
"help": "Whether to shard data files by host in multi-host environments." | |
}, | |
) | |
blank_caption_prob: Optional[float] = field( | |
default=0.0, | |
metadata={ | |
"help": "Probability of removing some captions for classifier-free guidance." | |
}, | |
) | |
clip_score_column: Optional[str] = field( | |
default="clip_score", | |
metadata={"help": "Column that containts clip score for filtering."}, | |
) | |
min_clip_score: Optional[float] = field( | |
default=None, | |
metadata={"help": "Minimum clip score required."}, | |
) | |
max_clip_score: Optional[float] = field( | |
default=None, | |
metadata={"help": "Maximum clip score required."}, | |
) | |
filter_column: Optional[str] = field( | |
default=None, | |
metadata={"help": "Column that containts classes to be filtered."}, | |
) | |
filter_value: Optional[str] = field( | |
default=None, | |
metadata={"help": "Class value to be kept during filtering."}, | |
) | |
max_train_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "For debugging purposes or quicker training, truncate the number of training examples." | |
}, | |
) | |
max_eval_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples." | |
}, | |
) | |
preprocessing_num_workers: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "The number of processes to use for the preprocessing. Not used in streaming mode." | |
}, | |
) | |
overwrite_cache: bool = field( | |
default=False, | |
metadata={ | |
"help": "Overwrite the cached training and evaluation sets. Not used in streaming mode." | |
}, | |
) | |
# default seed of None ensures we don't repeat the same items if script was interrupted during an epoch | |
seed_dataset: int = field( | |
default=None, | |
metadata={ | |
"help": "Random seed for the dataset that will be set at the beginning of training." | |
}, | |
) | |
def __post_init__(self): | |
if self.dataset_repo_or_path is None: | |
raise ValueError("Need a dataset repository or path.") | |
class TrainingArguments: | |
""" | |
Arguments pertaining to training parameters. | |
""" | |
output_dir: str = field( | |
metadata={ | |
"help": "The output directory where the model predictions and checkpoints will be written." | |
}, | |
) | |
overwrite_output_dir: bool = field( | |
default=False, | |
metadata={ | |
"help": ( | |
"Overwrite the content of the output directory. " | |
"Use this to continue training if output_dir points to a checkpoint directory." | |
) | |
}, | |
) | |
do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) | |
do_eval: bool = field( | |
default=False, metadata={"help": "Whether to run eval on the validation set."} | |
) | |
per_device_train_batch_size: int = field( | |
default=8, | |
metadata={"help": "Batch size per data parallel device for training."}, | |
) | |
per_device_eval_batch_size: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "Batch size per data parallel device for evaluation. Same as training batch size if not set." | |
}, | |
) | |
gradient_accumulation_steps: int = field( | |
default=1, | |
metadata={ | |
"help": "Number of updates steps to accumulate before performing an update pass." | |
}, | |
) | |
gradient_checkpointing: bool = field( | |
default=False, metadata={"help": "Use gradient checkpointing."} | |
) | |
learning_rate: float = field( | |
default=5e-5, metadata={"help": "The initial learning rate."} | |
) | |
optim: str = field( | |
default="distributed_shampoo", | |
metadata={ | |
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"' | |
}, | |
) | |
beta1: float = field( | |
default=0.9, | |
metadata={"help": "Beta1 for Adam & Distributed Shampoo."}, | |
) | |
beta2: float = field( | |
default=0.999, | |
metadata={"help": "Beta2 for for Adam & Distributed Shampoo."}, | |
) | |
adam_epsilon: float = field( | |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."} | |
) | |
max_grad_norm: float = field( | |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."} | |
) | |
block_size: int = field( | |
default=1024, | |
metadata={"help": "Chunked size for large layers with Distributed Shampoo."}, | |
) | |
preconditioning_compute_steps: int = field( | |
default=10, metadata={"help": "Number of steps to update preconditioner."} | |
) | |
skip_preconditioning_dim_size_gt: int = field( | |
default=4096, | |
metadata={"help": "Max size for preconditioning with Distributed Shampoo."}, | |
) | |
graft_type: str = field( | |
default="rmsprop_normalized", | |
metadata={ | |
"help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'" | |
}, | |
) | |
optim_quantized: bool = field( | |
default=False, | |
metadata={ | |
"help": "Whether to quantize optimizer (only supported with Distributed Shampoo)." | |
}, | |
) | |
num_train_epochs: int = field( | |
default=3, metadata={"help": "Total number of training epochs to perform."} | |
) | |
warmup_steps: int = field( | |
default=0, metadata={"help": "Linear warmup over warmup_steps."} | |
) | |
lr_decay: str = field( | |
default=None, | |
metadata={ | |
"help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential." | |
}, | |
) | |
lr_transition_steps: int = field( | |
default=None, | |
metadata={ | |
"help": "Number of transition steps associated with learning rate decay when using exponential decay." | |
}, | |
) | |
lr_decay_rate: float = field( | |
default=None, | |
metadata={ | |
"help": "Decay rate associated with learning rate when using exponential decay." | |
}, | |
) | |
lr_staircase: bool = field( | |
default=False, | |
metadata={ | |
"help": "Whether to use staircase or continuous learning rate when using exponential decay." | |
}, | |
) | |
logging_steps: int = field( | |
default=40, metadata={"help": "Log every X updates steps."} | |
) | |
eval_steps: int = field( | |
default=400, metadata={"help": "Run an evaluation every X steps."} | |
) | |
save_steps: int = field( | |
default=4000, metadata={"help": "Save checkpoint every X updates steps."} | |
) | |
log_model: bool = field( | |
default=False, | |
metadata={"help": "Log model to wandb at `save_steps` frequency."}, | |
) | |
log_norm_steps: int = field( | |
default=True, | |
metadata={"help": "Log parameters and gradients norm at this frequency."}, | |
) | |
log_histogram_steps: int = field( | |
default=False, | |
metadata={ | |
"help": "Log parameters and gradients histograms at this frequency. Slows down training." | |
}, | |
) | |
seed_model: int = field( | |
default=42, | |
metadata={ | |
"help": "Random seed for the model that will be set at the beginning of training." | |
}, | |
) | |
wandb_entity: Optional[str] = field( | |
default=None, | |
metadata={"help": "The wandb entity to use (for teams)."}, | |
) | |
wandb_project: str = field( | |
default="dalle-mini", | |
metadata={"help": "The name of the wandb project."}, | |
) | |
wandb_job_type: str = field( | |
default="Seq2Seq", | |
metadata={"help": "The name of the wandb job type."}, | |
) | |
assert_TPU_available: bool = field( | |
default=False, | |
metadata={"help": "Verify that TPU is not in use."}, | |
) | |
mp_devices: Optional[int] = field( | |
default=1, | |
metadata={ | |
"help": "Number of devices required for model parallelism. The other dimension of available devices is used for data parallelism." | |
}, | |
) | |
dp_devices: int = field(init=False) | |
def __post_init__(self): | |
if self.assert_TPU_available: | |
assert ( | |
jax.local_device_count() == 8 | |
), "TPUs in use, please check running processes" | |
assert self.optim in [ | |
"distributed_shampoo", | |
"adam", | |
"adafactor", | |
], f"Selected optimizer not supported: {self.optim}" | |
assert self.graft_type in [ | |
"rmsprop_normalized", | |
"rmsprop", | |
"adagrad", | |
"adagrad_normalized", | |
"sgd", | |
"sqrt_n", | |
], f"Selected graft type not supported: {self.graft_type}" | |
assert self.lr_decay in [ | |
None, | |
"linear", | |
"exponential", | |
], f"Selected learning rate decay not supported: {self.lr_decay}" | |
if self.per_device_eval_batch_size is None: | |
self.per_device_eval_batch_size = self.per_device_train_batch_size | |
if self.log_norm_steps is True: | |
self.log_norm_steps = self.logging_steps | |
if ( | |
os.path.exists(self.output_dir) | |
and os.listdir(self.output_dir) | |
and self.do_train | |
and not self.overwrite_output_dir | |
): | |
raise ValueError( | |
f"Output directory ({self.output_dir}) already exists and is not empty." | |
"Use --overwrite_output_dir to overcome." | |
) | |
assert ( | |
self.mp_devices > 0 | |
), f"Number of devices for model parallelism must be > 0" | |
assert ( | |
jax.device_count() % self.mp_devices == 0 | |
), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})." | |
self.dp_devices = jax.device_count() // self.mp_devices | |
class TrainState(train_state.TrainState): | |
dropout_rng: jnp.ndarray = None | |
epoch: int = 0 | |
train_time: float = 0.0 # total time the model trained | |
train_samples: int = 0 # number of samples seen | |
def main(): | |
# See all possible arguments by passing the --help flag to this script. | |
parser = HfArgumentParser( | |
(ModelArguments, DataTrainingArguments, TrainingArguments) | |
) | |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): | |
# If we pass only one argument to the script and it's the path to a json file, | |
# let's parse it to get our arguments. | |
model_args, data_args, training_args = parser.parse_json_file( | |
json_file=os.path.abspath(sys.argv[1]) | |
) | |
else: | |
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
# Make one log on every process with the configuration for debugging. | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
# Setup logging, we only want one process per machine to log things on the screen. | |
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) | |
if jax.process_index() == 0: | |
datasets.utils.logging.set_verbosity_warning() | |
transformers.utils.logging.set_verbosity_info() | |
else: | |
datasets.utils.logging.set_verbosity_error() | |
transformers.utils.logging.set_verbosity_error() | |
# Set the verbosity to info of the Transformers logger (on main process only): | |
logger.info(f"Training/evaluation parameters {training_args}") | |
# Load dataset | |
dataset = Dataset( | |
**asdict(data_args), | |
do_train=training_args.do_train, | |
do_eval=training_args.do_eval, | |
) | |
logger.info(f"Local TPUs: {jax.local_device_count()}") | |
logger.info(f"Global TPUs: {jax.device_count()}") | |
# Set up wandb run | |
if jax.process_index() == 0: | |
wandb.init( | |
entity=training_args.wandb_entity, | |
project=training_args.wandb_project, | |
job_type=training_args.wandb_job_type, | |
config=parser.parse_args(), | |
) | |
# Set up our new model config | |
if model_args.config_name: | |
config = DalleBartConfig.from_pretrained(model_args.config_name) | |
config.gradient_checkpointing = training_args.gradient_checkpointing | |
else: | |
config = None | |
# Load or create new model | |
if model_args.model_name_or_path: | |
model = DalleBart.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
seed=training_args.seed_model, | |
dtype=getattr(jnp, model_args.dtype), | |
abstract_init=True, # we overwrite them with loaded checkpoint | |
gradient_checkpointing=training_args.gradient_checkpointing, | |
) | |
else: | |
model = DalleBart( | |
config, | |
seed=training_args.seed_model, | |
dtype=getattr(jnp, model_args.dtype), | |
abstract_init=True, | |
) | |
# get model metadata | |
model_metadata = model_args.get_metadata() | |
# get PartitionSpec for model params (required to be a dict) | |
param_spec = set_partitions(model.params) | |
# convert params to frozen dict | |
model._params = freeze(model.params) | |
# Load tokenizer | |
tokenizer = DalleBartTokenizer.from_pretrained( | |
model_args.tokenizer_name, use_fast=True | |
) | |
# Preprocessing the datasets. | |
# We need to normalize and tokenize inputs and targets. | |
dataset.preprocess(tokenizer=tokenizer, config=model.config) | |
# Initialize our training | |
dropout_rng = jax.random.PRNGKey(training_args.seed_model) | |
# Store some constant | |
num_epochs = training_args.num_train_epochs | |
# batch size | |
batch_size_per_node_per_grad_step = ( | |
training_args.per_device_train_batch_size | |
* jax.local_device_count() | |
// training_args.mp_devices | |
) | |
batch_size_per_node = ( | |
batch_size_per_node_per_grad_step * training_args.gradient_accumulation_steps | |
) | |
batch_size_per_step = batch_size_per_node * jax.process_count() | |
eval_batch_size_per_node = ( | |
training_args.per_device_eval_batch_size | |
* jax.local_device_count() | |
// training_args.mp_devices | |
) | |
eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count() | |
len_train_dataset, len_eval_dataset = dataset.length | |
steps_per_epoch = ( | |
len_train_dataset // batch_size_per_node | |
if len_train_dataset is not None | |
else None | |
) | |
num_train_steps = ( | |
steps_per_epoch * num_epochs if steps_per_epoch is not None else None | |
) | |
num_params = model.num_params | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len_train_dataset}") | |
logger.info(f" Num Epochs = {num_epochs}") | |
logger.info( | |
f" Batch size per dp device = {training_args.per_device_train_batch_size}" | |
) | |
logger.info(f" Number of devices = {jax.device_count()}") | |
logger.info( | |
f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}" | |
) | |
logger.info(f" Batch size per update = {batch_size_per_step}") | |
logger.info(f" Model parameters = {num_params:,}") | |
# set up wandb run | |
if jax.process_index() == 0: | |
# set default x-axis as 'train/step' | |
wandb.define_metric("*", step_metric="train/step") | |
# add interesting config parameters | |
wandb.config.update( | |
{ | |
"len_train_dataset": len_train_dataset, | |
"len_eval_dataset": len_eval_dataset, | |
"batch_size_per_step": batch_size_per_step, | |
"model": {"num_params": num_params, "config": model.config.to_dict()}, | |
"num_devices": jax.device_count(), | |
"versions": { | |
"jax": jax.__version__, | |
"jaxlib": jaxlib.__version__, | |
"flax": flax.__version__, | |
"transformers": transformers.__version__, | |
"datasets": datasets.__version__, | |
"wandb": wandb.__version__, | |
}, | |
} | |
) | |
# Create learning rate schedule | |
def create_learning_rate_fn() -> Callable[[int], jnp.array]: | |
"""Create the learning rate function.""" | |
warmup_fn = optax.linear_schedule( | |
init_value=0.0, | |
end_value=training_args.learning_rate, | |
transition_steps=training_args.warmup_steps + 1, # ensure not 0 | |
) | |
# offset step when resuming | |
if model_metadata.get("step", 0): | |
warmup_fn = optax.join_schedules( | |
schedules=[optax.constant_schedule(0.0), warmup_fn], | |
boundaries=[model_metadata["step"]], | |
) | |
if training_args.lr_decay is None: | |
return warmup_fn | |
elif training_args.lr_decay == "linear": | |
assert ( | |
num_train_steps is not None | |
), "linear decay requires knowing the dataset length" | |
decay_fn = optax.linear_schedule( | |
init_value=training_args.learning_rate, | |
end_value=0, | |
transition_steps=num_train_steps - training_args.warmup_steps, | |
) | |
elif training_args.lr_decay == "exponential": | |
decay_fn = optax.exponential_decay( | |
init_value=training_args.learning_rate, | |
transition_steps=training_args.lr_transition_steps, | |
decay_rate=training_args.lr_decay_rate, | |
staircase=training_args.lr_staircase, | |
) | |
schedule_fn = optax.join_schedules( | |
schedules=[warmup_fn, decay_fn], | |
boundaries=[model_metadata.get("step", 0) + training_args.warmup_steps], | |
) | |
return schedule_fn | |
learning_rate_fn = create_learning_rate_fn() | |
# create adam optimizer | |
if training_args.optim == "distributed_shampoo": | |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729 | |
graft_type = { | |
"sgd": GraftingType.SGD, | |
"adagrad": GraftingType.ADAGRAD, | |
"rmsprop": GraftingType.RMSPROP, | |
"rmsprop_normalized": GraftingType.RMSPROP_NORMALIZED, | |
"sqrt_n": GraftingType.SQRT_N, | |
"adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED, | |
}[training_args.graft_type] | |
optimizer = distributed_shampoo( | |
learning_rate_fn, | |
block_size=training_args.block_size, | |
beta1=training_args.beta1, | |
beta2=training_args.beta2, | |
diagonal_epsilon=1e-10, | |
matrix_epsilon=1e-6, | |
start_preconditioning_step=max( | |
training_args.preconditioning_compute_steps + 1, 101 | |
), | |
preconditioning_compute_steps=training_args.preconditioning_compute_steps, | |
statistics_compute_steps=1, | |
best_effort_shape_interpretation=True, | |
graft_type=graft_type, | |
nesterov=False, | |
exponent_override=0, | |
statistics_partition_spec=PartitionSpec(None, "dp", None), | |
preconditioner_partition_spec=PartitionSpec("dp", None, None), | |
num_devices_for_pjit=training_args.dp_devices, | |
shard_optimizer_states=True, | |
inverse_failure_threshold=0.1, | |
moving_average_for_momentum=True, | |
skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt, | |
clip_by_scaled_gradient_norm=None, | |
precision=jax.lax.Precision.HIGHEST, | |
best_effort_memory_usage_reduction=training_args.optim_quantized, | |
) | |
# get the real optimizer and helper functions | |
update_fn = optimizer.update | |
optimizer = optimizer.init(model.params) | |
opt_fn = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)( | |
optimizer.pspec_fn, optimizer.shape_and_dtype_fn | |
) | |
optimizer = optax.GradientTransformation(optimizer.init_fn, update_fn) | |
elif training_args.optim == "adam": | |
optimizer = optax.adamw( | |
learning_rate=learning_rate_fn, | |
b1=training_args.beta1, | |
b2=training_args.beta2, | |
eps=training_args.adam_epsilon, | |
) | |
elif training_args.optim == "adafactor": | |
# We use the default parameters here to initialize adafactor, | |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 | |
optimizer = optax.adafactor( | |
learning_rate=learning_rate_fn, | |
clipping_threshold=training_args.max_grad_norm, | |
) | |
# get PartitionSpec for optimizer state | |
def get_opt_state_spec_and_shape(param_spec): | |
# get opt_state shape without actual init | |
opt_state_shape = jax.eval_shape(optimizer.init, model.params) | |
if training_args.optim == "adam": | |
def _opt_state_spec_per_leaf(x): | |
if isinstance(x, FrozenDict): | |
# variables with same structure as params | |
return param_spec | |
else: | |
# other variables such as count | |
return None | |
opt_state_spec = jax.tree_map( | |
_opt_state_spec_per_leaf, | |
opt_state_shape, | |
# return None spec for empty elements | |
is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)), | |
) | |
elif training_args.optim == "adafactor": | |
# factorized state must be replicated (rank different than params) | |
opt_state_spec = None | |
elif training_args.optim == "distributed_shampoo": | |
opt_state_spec = opt_fn.pspec_fn( | |
params=model.params, | |
params_partition_spec=param_spec, | |
partition_spec_for_statistics=PartitionSpec(None, "dp", None), | |
) | |
else: | |
raise NotImplementedError | |
return opt_state_spec, opt_state_shape | |
opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape(param_spec) | |
# create a mesh | |
mesh_shape = (training_args.dp_devices, training_args.mp_devices) | |
devices = np.asarray(jax.devices()).reshape(*mesh_shape) | |
mesh = maps.Mesh(devices, ("dp", "mp")) | |
logger.info(f" Mesh shape: {mesh_shape}") | |
# define state spec | |
state_spec = TrainState( | |
params=param_spec, | |
opt_state=opt_state_spec, | |
dropout_rng=None, | |
step=None, | |
epoch=None, | |
train_time=None, | |
train_samples=None, | |
apply_fn=model.__call__, | |
tx=optimizer, | |
) | |
# init params if not available yet | |
def maybe_init_params(params): | |
if model_args.model_name_or_path: | |
# model params are correctly loaded | |
return params | |
else: | |
# params have not been initialized yet | |
return model.init_weights() | |
with mesh: | |
logger.info(" Creating state") | |
if not model_args.restore_state: | |
def init_state(params): | |
return TrainState.create( | |
apply_fn=model.__call__, | |
tx=optimizer, | |
params=maybe_init_params(params), | |
dropout_rng=dropout_rng, | |
) | |
state = pjit( | |
init_state, | |
in_axis_resources=(param_spec,) | |
if model_args.model_name_or_path | |
else None, | |
out_axis_resources=state_spec, | |
donate_argnums=(0,), | |
)(model.params if model_args.model_name_or_path else None) | |
else: | |
# load opt_state | |
opt_state = from_bytes(opt_state_shape, model_args.get_opt_state()) | |
# restore other attributes | |
attr_state = { | |
k: model_metadata[k] | |
for k in ["step", "epoch", "train_time", "train_samples"] | |
} | |
def restore_state(params, opt_state): | |
return TrainState( | |
apply_fn=model.__call__, | |
tx=optimizer, | |
params=params, | |
opt_state=opt_state, | |
dropout_rng=dropout_rng, | |
**attr_state, | |
) | |
state = pjit( | |
restore_state, | |
in_axis_resources=( | |
param_spec, | |
opt_state_spec, | |
), | |
out_axis_resources=state_spec, | |
donate_argnums=(0, 1), | |
)(model.params, opt_state) | |
# remove opt_state from CPU | |
del opt_state | |
# free CPU memory | |
del model._params, opt_state_spec, opt_state_shape | |
# define batch specs | |
batch_spec = PartitionSpec("dp") | |
grad_batch_spec = PartitionSpec(None, "dp") | |
# define loss | |
def loss_fn(logits, labels): | |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) | |
loss = loss.mean() | |
return loss | |
# "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens) | |
# lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2 | |
use_vmap_trick = True | |
# make grad_param_spec for vmap | |
if use_vmap_trick: | |
grad_param_spec = jax.tree_map( | |
lambda x: PartitionSpec(*("dp",) + (x if x is not None else (None,))), | |
param_spec, | |
) | |
# Define gradient update step fn | |
def train_step(state, batch, train_time): | |
# get a minibatch (one gradient accumulation slice) | |
def get_minibatch(batch, grad_idx): | |
return jax.tree_map( | |
lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False), | |
batch, | |
) | |
def compute_loss(params, minibatch, dropout_rng): | |
# minibatch has dim (batch_size, ...) | |
minibatch, labels = minibatch.pop("labels") | |
logits = state.apply_fn( | |
**minibatch, params=params, dropout_rng=dropout_rng, train=True | |
)[0] | |
return loss_fn(logits, labels) | |
grad_fn = jax.value_and_grad(compute_loss) | |
def loss_and_grad(grad_idx, dropout_rng): | |
# minibatch at grad_idx for gradient accumulation (None otherwise) | |
minibatch = ( | |
get_minibatch(batch, grad_idx) if grad_idx is not None else batch | |
) | |
# ensure it is sharded properly | |
minibatch = with_sharding_constraint(minibatch, batch_spec) | |
# only 1 single rng per grad step, let us handle larger batch size (not sure why) | |
dropout_rng, _ = jax.random.split(dropout_rng) | |
if use_vmap_trick: | |
# "vmap trick", calculate loss and grads independently per dp_device | |
loss, grads = jax.vmap( | |
grad_fn, in_axes=(None, 0, None), out_axes=(0, 0) | |
)(state.params, minibatch, dropout_rng) | |
# ensure they are sharded correctly | |
loss = with_sharding_constraint(loss, batch_spec) | |
grads = with_sharding_constraint(grads, grad_param_spec) | |
# average across all devices | |
# Note: we could average per device only after gradient accumulation, right before params update | |
loss, grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), (loss, grads)) | |
else: | |
# "vmap trick" does not work in multi-hosts and requires too much hbm | |
loss, grads = grad_fn(state.params, minibatch, dropout_rng) | |
# ensure grads are sharded | |
grads = with_sharding_constraint(grads, param_spec) | |
# return loss and grads | |
return loss, grads, dropout_rng | |
if training_args.gradient_accumulation_steps == 1: | |
loss, grads, dropout_rng = loss_and_grad(None, state.dropout_rng) | |
else: | |
# create initial state for cumul_minibatch_step loop | |
init_minibatch_step = ( | |
0.0, | |
with_sharding_constraint( | |
jax.tree_map(jnp.zeros_like, state.params), param_spec | |
), | |
state.dropout_rng, | |
) | |
# accumulate gradients | |
def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout): | |
cumul_loss, cumul_grads, dropout_rng = cumul_loss_grad_dropout | |
loss, grads, dropout_rng = loss_and_grad(grad_idx, dropout_rng) | |
cumul_loss, cumul_grads = jax.tree_map( | |
jnp.add, (cumul_loss, cumul_grads), (loss, grads) | |
) | |
cumul_grads = with_sharding_constraint(cumul_grads, param_spec) | |
return cumul_loss, cumul_grads, dropout_rng | |
# loop over gradients | |
loss, grads, dropout_rng = jax.lax.fori_loop( | |
0, | |
training_args.gradient_accumulation_steps, | |
cumul_minibatch_step, | |
init_minibatch_step, | |
) | |
grads = with_sharding_constraint(grads, param_spec) | |
# sum -> mean | |
loss, grads = jax.tree_map( | |
lambda x: x / training_args.gradient_accumulation_steps, (loss, grads) | |
) | |
grads = with_sharding_constraint(grads, param_spec) | |
# update state | |
state = state.apply_gradients( | |
grads=grads, | |
dropout_rng=dropout_rng, | |
train_time=train_time, | |
train_samples=state.train_samples + batch_size_per_step, | |
) | |
metrics = { | |
"loss": loss, | |
"learning_rate": learning_rate_fn(state.step), | |
} | |
def maybe_fn(fn, val, zeros, freq): | |
"""Call fn only if it is a logging step""" | |
return jax.lax.cond( | |
state.step % freq == 0, | |
fn, | |
lambda _: zeros, | |
val, | |
) | |
if training_args.log_norm_steps: | |
zeros_norm = jax.tree_map(lambda _: jnp.float32(0), state.params) | |
def norm(val): | |
return jax.tree_map(lambda x: jnp.linalg.norm(x), val) | |
gradients_norm = maybe_fn( | |
norm, grads, zeros_norm, training_args.log_norm_steps | |
) | |
params_norm = maybe_fn( | |
norm, state.params, zeros_norm, training_args.log_norm_steps | |
) | |
metrics.update( | |
{ | |
"gradients_norm": gradients_norm, | |
"params_norm": params_norm, | |
} | |
) | |
if training_args.log_histogram_steps: | |
zeros_hist = jax.tree_map( | |
lambda _: jnp.histogram(jnp.zeros(1), density=True), state.params | |
) | |
def histogram(val): | |
return jax.tree_map(lambda x: jnp.histogram(x, density=True), val) | |
gradients_hist = maybe_fn( | |
histogram, grads, zeros_hist, training_args.log_histogram_steps | |
) | |
params_hist = maybe_fn( | |
histogram, state.params, zeros_hist, training_args.log_histogram_steps | |
) | |
metrics.update( | |
{ | |
"params_hist": params_hist, | |
"gradients_hist": gradients_hist, | |
} | |
) | |
return state, metrics | |
# Define eval fn | |
def eval_step(state, batch): | |
def compute_eval_loss(batch): | |
batch, labels = batch.pop("labels") | |
logits = model(**batch, params=state.params, train=False)[0] | |
return loss_fn(logits, labels) | |
if use_vmap_trick: | |
loss = jax.vmap(compute_eval_loss)(batch) | |
# ensure they are sharded correctly | |
loss = with_sharding_constraint(loss, batch_spec) | |
# average across all devices | |
loss = jnp.mean(loss) | |
else: | |
loss = compute_eval_loss(batch) | |
return loss | |
# Create parallel version of the train and eval step | |
p_train_step = pjit( | |
train_step, | |
in_axis_resources=( | |
state_spec, | |
grad_batch_spec | |
if training_args.gradient_accumulation_steps > 1 | |
else batch_spec, | |
None, | |
), | |
out_axis_resources=(state_spec, None), | |
donate_argnums=(0,), | |
) | |
p_eval_step = pjit( | |
eval_step, | |
in_axis_resources=(state_spec, batch_spec), | |
out_axis_resources=None, | |
) | |
# define metrics logger | |
class MetricsLogger: | |
def __init__(self, step): | |
# keep state | |
self.state_dict = {} | |
# estimate speed | |
self.step = step | |
self.time = time.perf_counter() | |
self.offset_time = 0.0 | |
def update_state_metrics(self, state): | |
"""Update internal state metrics (logged at each call to be used as x-axis)""" | |
self.state_dict = { | |
f'train/{k.split("_")[-1]}': state[k] | |
for k in ["step", "epoch", "train_time", "train_samples"] | |
} | |
# timing metrics | |
new_step = int(state["step"]) | |
new_time = time.perf_counter() | |
if new_step > self.step: | |
# remove time for eval & save | |
delta_time = new_time - self.time - self.offset_time | |
self.offset_time = 0 | |
time_per_step = delta_time / (new_step - self.step) | |
self.step = new_step | |
self.time = new_time | |
self.log_time("train_per_step", time_per_step, offset=False) | |
self.log_time("train_per_log", delta_time, offset=False) | |
def log_time(self, key, duration, offset=True): | |
wandb.log({f"time/{key}": duration, **self.state_dict}) | |
if offset: | |
self.offset_time += duration | |
def log(self, metrics, prefix=None): | |
if jax.process_index() == 0: | |
log_metrics = {} | |
for k, v in metrics.items(): | |
if "_norm" in k: | |
if self.step % training_args.log_norm_steps == 0: | |
log_metrics[f"{k}/"] = unfreeze(v) | |
elif "_hist" in k: | |
if self.step % training_args.log_histogram_steps == 0: | |
v = jax.tree_map(lambda x: jax.device_get(x), unfreeze(v)) | |
v = jax.tree_map( | |
lambda x: wandb.Histogram(np_histogram=x), | |
v, | |
is_leaf=lambda x: isinstance(x, tuple), | |
) | |
log_metrics[f"{k}/"] = v | |
else: | |
if prefix is not None: | |
k = f"{prefix}/{k}" | |
log_metrics[k] = v | |
wandb.log({**log_metrics, **self.state_dict}) | |
# keep local copy of state | |
local_state = { | |
k: jax.device_get(getattr(state, k)).item() | |
for k in ["step", "epoch", "train_time", "train_samples"] | |
} | |
# init variables | |
start_time = time.perf_counter() - local_state["train_time"] | |
train_metrics = None | |
metrics_logger = MetricsLogger(local_state["step"]) | |
epochs = tqdm( | |
range(local_state["epoch"], num_epochs), | |
desc=f"Epoch ... (1/{num_epochs})", | |
position=0, | |
disable=jax.process_index() > 0, | |
) | |
def run_evaluation(): | |
# ======================== Evaluating ============================== | |
if training_args.do_eval: | |
start_eval_time = time.perf_counter() | |
eval_loader = dataset.dataloader("eval", eval_batch_size_per_step) | |
eval_steps = ( | |
len_eval_dataset // eval_batch_size_per_step | |
if len_eval_dataset is not None | |
else None | |
) | |
eval_loss = [] | |
for batch in tqdm( | |
eval_loader, | |
desc="Evaluating...", | |
position=2, | |
leave=False, | |
total=eval_steps, | |
disable=jax.process_index() > 0, | |
): | |
# need to keep only eval_batch_size_per_node items relevant to the node | |
batch = jax.tree_map( | |
lambda x: x.reshape( | |
(jax.process_count(), eval_batch_size_per_node) + x.shape[1:] | |
), | |
batch, | |
) | |
batch = jax.tree_map(lambda x: x[jax.process_index()], batch) | |
# add dp dimension when using "vmap trick" | |
if use_vmap_trick: | |
bs_shape = ( | |
jax.local_device_count() // training_args.mp_devices, | |
training_args.per_device_eval_batch_size, | |
) | |
batch = jax.tree_map( | |
lambda x: x.reshape(bs_shape + x.shape[1:]), batch | |
) | |
# freeze batch to pass safely to jax transforms | |
batch = freeze(batch) | |
# accumulate losses async | |
eval_loss.append(p_eval_step(state, batch)) | |
# get the mean of the loss | |
eval_loss = jnp.stack(eval_loss) | |
eval_loss = jnp.mean(eval_loss) | |
eval_metrics = {"loss": eval_loss} | |
# log metrics | |
metrics_logger.log(eval_metrics, prefix="eval") | |
metrics_logger.log_time("eval", time.perf_counter() - start_eval_time) | |
# Print metrics and update progress bar | |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})" | |
epochs.write(desc) | |
epochs.desc = desc | |
return eval_metrics | |
def run_save_model(state, eval_metrics=None): | |
if jax.process_index() == 0: | |
start_save_time = time.perf_counter() | |
output_dir = training_args.output_dir | |
use_bucket = output_dir.startswith("gs://") | |
if use_bucket: | |
bucket_path = Path(output_dir[5:]) / wandb.run.id / f"step_{state.step}" | |
bucket, dir_path = str(bucket_path).split("/", 1) | |
tmp_dir = tempfile.TemporaryDirectory() | |
output_dir = tmp_dir.name | |
# save model | |
params = jax.device_get(state.params) | |
model.save_pretrained( | |
output_dir, | |
params=params, | |
) | |
# save tokenizer | |
tokenizer.save_pretrained(output_dir) | |
# copy to bucket | |
if use_bucket: | |
client = storage.Client() | |
bucket = client.bucket(bucket) | |
for filename in Path(output_dir).glob("*"): | |
blob_name = str(Path(dir_path) / "model" / filename.name) | |
blob = bucket.blob(blob_name) | |
blob.upload_from_filename(str(filename)) | |
tmp_dir.cleanup() | |
# save state | |
opt_state = jax.device_get(state.opt_state) | |
if use_bucket: | |
blob_name = str(Path(dir_path) / "state" / "opt_state.msgpack") | |
blob = bucket.blob(blob_name) | |
blob.upload_from_file(io.BytesIO(to_bytes(opt_state))) | |
else: | |
with (Path(output_dir) / "opt_state.msgpack").open("wb") as f: | |
f.write(to_bytes(opt_state)) | |
# save to W&B | |
if training_args.log_model: | |
# save some space | |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache() | |
c.cleanup(wandb.util.from_human_size("20GB")) | |
metadata = { | |
k: jax.device_get(getattr(state, k)).item() | |
for k in ["step", "epoch", "train_time", "train_samples"] | |
} | |
metadata["num_params"] = num_params | |
if eval_metrics is not None: | |
metadata["eval"] = eval_metrics | |
# create model artifact | |
if use_bucket: | |
metadata["bucket_path"] = f"gs://{bucket_path}/model" | |
artifact = wandb.Artifact( | |
name=f"model-{wandb.run.id}", | |
type="DalleBart_model", | |
metadata=metadata, | |
) | |
if use_bucket: | |
artifact.add_reference(metadata["bucket_path"]) | |
else: | |
for filename in [ | |
"config.json", | |
"flax_model.msgpack", | |
"merges.txt", | |
"special_tokens_map.json", | |
"tokenizer.json", | |
"tokenizer_config.json", | |
"vocab.json", | |
]: | |
artifact.add_file( | |
f"{Path(training_args.output_dir) / filename}" | |
) | |
wandb.run.log_artifact(artifact) | |
# create state artifact | |
if use_bucket: | |
metadata["bucket_path"] = f"gs://{bucket_path}/state" | |
artifact_state = wandb.Artifact( | |
name=f"state-{wandb.run.id}", | |
type="DalleBart_state", | |
metadata=metadata, | |
) | |
if use_bucket: | |
artifact_state.add_reference(metadata["bucket_path"]) | |
else: | |
artifact_state.add_file( | |
f"{Path(training_args.output_dir) / 'opt_state.msgpack'}" | |
) | |
wandb.run.log_artifact(artifact_state) | |
metrics_logger.log_time("save_model", time.perf_counter() - start_save_time) | |
logger.info(" Ready to start training") | |
with mesh: | |
for epoch in epochs: | |
state.replace(epoch=epoch) | |
local_state["epoch"] = epoch | |
# ======================== Training ================================ | |
metrics_logger.update_state_metrics(local_state) | |
metrics_logger.log({}) | |
# Generate an epoch by shuffling sampling indices from the train dataset | |
train_loader = dataset.dataloader( | |
"train", | |
batch_size_per_node, | |
epoch, | |
) | |
# train | |
for batch in tqdm( | |
train_loader, | |
desc="Training...", | |
position=1, | |
leave=False, | |
total=steps_per_epoch, | |
disable=jax.process_index() > 0, | |
): | |
# calculate delta time (we have a lag of one step but it's ok) | |
train_time = time.perf_counter() - start_time | |
# set correct shape to batch | |
# - add grad_step dim if gradient_accumulation_steps > 1 | |
# - split per dp device if not multi-host for vmap trick (does not work in multi-host) | |
bs_shape = ( | |
(batch_size_per_node_per_grad_step,) | |
if not use_vmap_trick | |
else ( | |
jax.local_device_count() | |
// training_args.mp_devices, # local dp devices | |
training_args.per_device_train_batch_size, | |
) | |
) | |
if training_args.gradient_accumulation_steps > 1: | |
# reshape data into (gradient_accumulation_steps, batch_per_node, ...) | |
# to avoid any data redistribution when sharding | |
bs_shape = (training_args.gradient_accumulation_steps,) + bs_shape | |
# reshape batch | |
batch = jax.tree_map( | |
lambda x: x.reshape(bs_shape + x.shape[1:]), | |
batch, | |
) | |
# freeze batch to pass safely to jax transforms | |
batch = freeze(batch) | |
# train step | |
state, train_metrics = p_train_step(state, batch, train_time) | |
local_state["step"] += 1 | |
local_state["train_time"] = train_time | |
local_state["train_samples"] += batch_size_per_step | |
if ( | |
local_state["step"] % training_args.logging_steps == 0 | |
and jax.process_index() == 0 | |
): | |
metrics_logger.update_state_metrics(local_state) | |
metrics_logger.log(train_metrics, prefix="train") | |
eval_metrics = None | |
if local_state["step"] % training_args.eval_steps == 0: | |
eval_metrics = run_evaluation() | |
if local_state["step"] % training_args.save_steps == 0: | |
run_save_model(state, eval_metrics) | |
# log final train metrics | |
if train_metrics is not None: | |
metrics_logger.update_state_metrics(state) | |
metrics_logger.log(train_metrics, prefix="train") | |
epochs.write( | |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})" | |
) | |
# Final evaluation | |
eval_metrics = run_evaluation() | |
# save checkpoint after each epoch | |
run_save_model(state, eval_metrics) | |
if __name__ == "__main__": | |
main() | |