import dataclasses import pprint from functools import partial import re from tqdm import tqdm, trange import numpy as np import mlxu import jax import jax.numpy as jnp from jax.experimental.pjit import pjit, with_sharding_constraint from jax.sharding import PartitionSpec as PS from flax.training.train_state import TrainState from EasyLM.data import DatasetFactory from EasyLM.checkpoint import StreamingCheckpointer from EasyLM.optimizers import OptimizerFactory from EasyLM.jax_utils import ( JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, get_float_dtype_by_name, cross_entropy_loss_and_accuracy, named_tree_map, global_norm, set_random_seed, average_metrics, get_weight_decay_mask, make_shard_and_gather_fns, tree_apply ) from EasyLM.models.roberta.roberta_model import ( RobertaConfig, FlaxRobertaForMaskedLMModule ) FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( seed=42, mesh_dim='-1,1,1', dtype='fp32', mask_token_probability=0.15, total_steps=10000, load_roberta_config='', update_roberta_config='', load_checkpoint='', load_dataset_state='', log_freq=50, save_model_freq=0, save_milestone_freq=0, eval_steps=0, tokenizer=RobertaConfig.get_tokenizer_config(), train_dataset=DatasetFactory.get_default_config(), eval_dataset=DatasetFactory.get_default_config(), optimizer=OptimizerFactory.get_default_config(), checkpointer=StreamingCheckpointer.get_default_config(), roberta=RobertaConfig.get_default_config(), logger=mlxu.WandBLogger.get_default_config(), log_all_worker=False, jax_distributed=JaxDistributedConfig.get_default_config(), ) def main(argv): JaxDistributedConfig.initialize(FLAGS.jax_distributed) variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF) flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF) logger = mlxu.WandBLogger( config=FLAGS.logger, variant=variant, enable=FLAGS.log_all_worker or (jax.process_index() == 0), ) set_random_seed(FLAGS.seed) tokenizer = RobertaConfig.get_tokenizer(FLAGS.tokenizer) dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer) if FLAGS.load_dataset_state != '': dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state)) if FLAGS.eval_steps > 0: eval_dataset = DatasetFactory.load_dataset( FLAGS.eval_dataset, dataset.tokenizer ) eval_iterator = iter(eval_dataset) seq_length = dataset.seq_length if FLAGS.load_roberta_config != '': roberta_config = RobertaConfig.load_config(FLAGS.load_roberta_config) else: roberta_config = RobertaConfig(**FLAGS.roberta) if FLAGS.update_roberta_config != '': roberta_config.update(dict(eval(FLAGS.update_roberta_config))) roberta_config.update(dict( bos_token_id=dataset.tokenizer.bos_token_id, eos_token_id=dataset.tokenizer.eos_token_id, pad_token_id=dataset.tokenizer.pad_token_id, vocab_size=dataset.vocab_size, )) model = FlaxRobertaForMaskedLMModule( roberta_config, dtype=get_float_dtype_by_name(FLAGS.dtype) ) optimizer, optimizer_info = OptimizerFactory.get_optimizer( FLAGS.optimizer, get_weight_decay_mask(RobertaConfig.get_weight_decay_exclusions()), ) def create_trainstate_from_params(params): return TrainState.create(params=params, tx=optimizer, apply_fn=None) def init_fn(rng): rng_generator = JaxRNG(rng) params = model.init( input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32), position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32), attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32), token_type_ids=None, head_mask=None, rngs=rng_generator(roberta_config.rng_keys()), ) return TrainState.create(params=params, tx=optimizer, apply_fn=None) def train_step(train_state, rng, batch): rng_generator = JaxRNG(rng) tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp'))) def loss_and_accuracy(params): altered_tokens = jax.random.uniform( rng_generator(), shape=tokens.shape ) < FLAGS.mask_token_probability random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape) altered_by_mask = altered_tokens & (random_uniform < 0.8) altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9) inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens) random_tokens = jax.random.randint( rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size ) inputs = jnp.where(altered_by_random, random_tokens, inputs) logits = model.apply( params, inputs, attention_mask=jnp.ones_like(inputs), token_type_ids=None, position_ids=None, head_mask=None, deterministic=False, rngs=rng_generator(roberta_config.rng_keys()), ).logits return cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens) grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True) (loss, accuracy), grads = grad_fn(train_state.params) train_state = train_state.apply_gradients(grads=grads) metrics = dict( loss=loss, accuracy=accuracy, learning_rate=optimizer_info['learning_rate_schedule'](train_state.step), gradient_norm=global_norm(grads), param_norm=global_norm(train_state.params), ) return train_state, rng_generator(), metrics def eval_step(train_state, rng, batch): rng_generator = JaxRNG(rng) tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp'))) altered_tokens = jax.random.uniform( rng_generator(), shape=tokens.shape ) < FLAGS.mask_token_probability random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape) altered_by_mask = altered_tokens & (random_uniform < 0.8) altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9) inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens) random_tokens = jax.random.randint( rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size ) inputs = jnp.where(altered_by_random, random_tokens, inputs) logits = model.apply( train_state.params, inputs, attention_mask=jnp.ones_like(inputs), token_type_ids=None, position_ids=None, head_mask=None, deterministic=False, rngs=rng_generator(roberta_config.rng_keys()), ).logits loss, accuracy = cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens) metrics = dict( eval_loss=loss, eval_accuracy=accuracy, ) return rng_generator(), metrics train_state_shapes = jax.eval_shape(init_fn, next_rng()) train_state_partition = match_partition_rules( RobertaConfig.get_partition_rules(), train_state_shapes ) shard_fns, gather_fns = make_shard_and_gather_fns( train_state_partition, train_state_shapes ) checkpointer = StreamingCheckpointer( FLAGS.checkpointer, logger.output_dir, enable=jax.process_index() == 0 ) sharded_init_fn = pjit( init_fn, in_shardings=PS(), out_shardings=train_state_partition ) sharded_create_trainstate_from_params = pjit( create_trainstate_from_params, in_shardings=(train_state_partition.params, ), out_shardings=train_state_partition, donate_argnums=(0, ), ) sharded_train_step = pjit( train_step, in_shardings=(train_state_partition, PS(), PS()), out_shardings=(train_state_partition, PS(), PS()), donate_argnums=(0, 1), ) sharded_eval_step = pjit( eval_step, in_shardings=(train_state_partition, PS(), PS()), out_shardings=(PS(), PS()), donate_argnums=(1,), ) def save_checkpoint(train_state, milestone=False): step = int(jax.device_get(train_state.step)) metadata = dict( step=step, variant=variant, flags=flags_config_dict, roberta_config=roberta_config.to_dict(), ) checkpointer.save_all( train_state=train_state, gather_fns=gather_fns, metadata=metadata, dataset=dataset.get_state_dict(), milestone=milestone, ) mesh = RobertaConfig.get_jax_mesh(FLAGS.mesh_dim) with mesh: train_state, restored_params = None, None if FLAGS.load_checkpoint != '': load_type, load_path = FLAGS.load_checkpoint.split('::', 1) if load_type == 'huggingface': restored_params = tree_apply( shard_fns.params, roberta_config.load_pretrained(load_path) ) train_state = None else: train_state, restored_params = checkpointer.load_trainstate_checkpoint( FLAGS.load_checkpoint, train_state_shapes, shard_fns ) if train_state is None and restored_params is None: # Initialize from scratch train_state = sharded_init_fn(next_rng()) elif train_state is None and restored_params is not None: # Restore from params but initialize train_state train_state = sharded_create_trainstate_from_params(restored_params) del restored_params start_step = int(jax.device_get(train_state.step)) if FLAGS.save_model_freq > 0: save_checkpoint(train_state) sharded_rng = next_rng() step_counter = trange(start_step, FLAGS.total_steps, ncols=0) for step, (batch, dataset_metrics) in zip(step_counter, dataset): train_state, sharded_rng, metrics = sharded_train_step( train_state, sharded_rng, batch ) if step % FLAGS.log_freq == 0: if FLAGS.eval_steps > 0: eval_metric_list = [] for _ in range(FLAGS.eval_steps): eval_batch, _ = next(eval_iterator) sharded_rng, eval_metrics = sharded_eval_step( train_state, sharded_rng, eval_batch ) eval_metric_list.append(eval_metrics) metrics.update(average_metrics(eval_metric_list)) log_metrics = {"step": step} log_metrics.update(metrics) log_metrics.update(dataset_metrics) log_metrics = jax.device_get(log_metrics) logger.log(log_metrics) tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0: save_checkpoint(train_state, milestone=True) elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0: save_checkpoint(train_state) if FLAGS.save_model_freq > 0: save_checkpoint(train_state) if __name__ == "__main__": mlxu.run(main)