Ahma-7B / EasyLM /models /roberta /roberta_train.py
aapot
Add training codes
a85f909
import dataclasses
import pprint
from functools import partial
import re
from tqdm import tqdm, trange
import numpy as np
import mlxu
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.sharding import PartitionSpec as PS
from flax.training.train_state import TrainState
from EasyLM.data import DatasetFactory
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.optimizers import OptimizerFactory
from EasyLM.jax_utils import (
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, get_float_dtype_by_name,
cross_entropy_loss_and_accuracy, named_tree_map, global_norm,
set_random_seed, average_metrics, get_weight_decay_mask,
make_shard_and_gather_fns, tree_apply
)
from EasyLM.models.roberta.roberta_model import (
RobertaConfig, FlaxRobertaForMaskedLMModule
)
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
seed=42,
mesh_dim='-1,1,1',
dtype='fp32',
mask_token_probability=0.15,
total_steps=10000,
load_roberta_config='',
update_roberta_config='',
load_checkpoint='',
load_dataset_state='',
log_freq=50,
save_model_freq=0,
save_milestone_freq=0,
eval_steps=0,
tokenizer=RobertaConfig.get_tokenizer_config(),
train_dataset=DatasetFactory.get_default_config(),
eval_dataset=DatasetFactory.get_default_config(),
optimizer=OptimizerFactory.get_default_config(),
checkpointer=StreamingCheckpointer.get_default_config(),
roberta=RobertaConfig.get_default_config(),
logger=mlxu.WandBLogger.get_default_config(),
log_all_worker=False,
jax_distributed=JaxDistributedConfig.get_default_config(),
)
def main(argv):
JaxDistributedConfig.initialize(FLAGS.jax_distributed)
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
logger = mlxu.WandBLogger(
config=FLAGS.logger,
variant=variant,
enable=FLAGS.log_all_worker or (jax.process_index() == 0),
)
set_random_seed(FLAGS.seed)
tokenizer = RobertaConfig.get_tokenizer(FLAGS.tokenizer)
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
if FLAGS.load_dataset_state != '':
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
if FLAGS.eval_steps > 0:
eval_dataset = DatasetFactory.load_dataset(
FLAGS.eval_dataset, dataset.tokenizer
)
eval_iterator = iter(eval_dataset)
seq_length = dataset.seq_length
if FLAGS.load_roberta_config != '':
roberta_config = RobertaConfig.load_config(FLAGS.load_roberta_config)
else:
roberta_config = RobertaConfig(**FLAGS.roberta)
if FLAGS.update_roberta_config != '':
roberta_config.update(dict(eval(FLAGS.update_roberta_config)))
roberta_config.update(dict(
bos_token_id=dataset.tokenizer.bos_token_id,
eos_token_id=dataset.tokenizer.eos_token_id,
pad_token_id=dataset.tokenizer.pad_token_id,
vocab_size=dataset.vocab_size,
))
model = FlaxRobertaForMaskedLMModule(
roberta_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
)
optimizer, optimizer_info = OptimizerFactory.get_optimizer(
FLAGS.optimizer,
get_weight_decay_mask(RobertaConfig.get_weight_decay_exclusions()),
)
def create_trainstate_from_params(params):
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def init_fn(rng):
rng_generator = JaxRNG(rng)
params = model.init(
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
token_type_ids=None,
head_mask=None,
rngs=rng_generator(roberta_config.rng_keys()),
)
return TrainState.create(params=params, tx=optimizer, apply_fn=None)
def train_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
def loss_and_accuracy(params):
altered_tokens = jax.random.uniform(
rng_generator(), shape=tokens.shape
) < FLAGS.mask_token_probability
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
altered_by_mask = altered_tokens & (random_uniform < 0.8)
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
random_tokens = jax.random.randint(
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
)
inputs = jnp.where(altered_by_random, random_tokens, inputs)
logits = model.apply(
params, inputs,
attention_mask=jnp.ones_like(inputs),
token_type_ids=None,
position_ids=None,
head_mask=None,
deterministic=False,
rngs=rng_generator(roberta_config.rng_keys()),
).logits
return cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
(loss, accuracy), grads = grad_fn(train_state.params)
train_state = train_state.apply_gradients(grads=grads)
metrics = dict(
loss=loss,
accuracy=accuracy,
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
gradient_norm=global_norm(grads),
param_norm=global_norm(train_state.params),
)
return train_state, rng_generator(), metrics
def eval_step(train_state, rng, batch):
rng_generator = JaxRNG(rng)
tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
altered_tokens = jax.random.uniform(
rng_generator(), shape=tokens.shape
) < FLAGS.mask_token_probability
random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
altered_by_mask = altered_tokens & (random_uniform < 0.8)
altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
random_tokens = jax.random.randint(
rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
)
inputs = jnp.where(altered_by_random, random_tokens, inputs)
logits = model.apply(
train_state.params, inputs,
attention_mask=jnp.ones_like(inputs),
token_type_ids=None,
position_ids=None,
head_mask=None,
deterministic=False,
rngs=rng_generator(roberta_config.rng_keys()),
).logits
loss, accuracy = cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
metrics = dict(
eval_loss=loss,
eval_accuracy=accuracy,
)
return rng_generator(), metrics
train_state_shapes = jax.eval_shape(init_fn, next_rng())
train_state_partition = match_partition_rules(
RobertaConfig.get_partition_rules(), train_state_shapes
)
shard_fns, gather_fns = make_shard_and_gather_fns(
train_state_partition, train_state_shapes
)
checkpointer = StreamingCheckpointer(
FLAGS.checkpointer, logger.output_dir,
enable=jax.process_index() == 0
)
sharded_init_fn = pjit(
init_fn,
in_shardings=PS(),
out_shardings=train_state_partition
)
sharded_create_trainstate_from_params = pjit(
create_trainstate_from_params,
in_shardings=(train_state_partition.params, ),
out_shardings=train_state_partition,
donate_argnums=(0, ),
)
sharded_train_step = pjit(
train_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(train_state_partition, PS(), PS()),
donate_argnums=(0, 1),
)
sharded_eval_step = pjit(
eval_step,
in_shardings=(train_state_partition, PS(), PS()),
out_shardings=(PS(), PS()),
donate_argnums=(1,),
)
def save_checkpoint(train_state, milestone=False):
step = int(jax.device_get(train_state.step))
metadata = dict(
step=step,
variant=variant,
flags=flags_config_dict,
roberta_config=roberta_config.to_dict(),
)
checkpointer.save_all(
train_state=train_state,
gather_fns=gather_fns,
metadata=metadata,
dataset=dataset.get_state_dict(),
milestone=milestone,
)
mesh = RobertaConfig.get_jax_mesh(FLAGS.mesh_dim)
with mesh:
train_state, restored_params = None, None
if FLAGS.load_checkpoint != '':
load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
if load_type == 'huggingface':
restored_params = tree_apply(
shard_fns.params, roberta_config.load_pretrained(load_path)
)
train_state = None
else:
train_state, restored_params = checkpointer.load_trainstate_checkpoint(
FLAGS.load_checkpoint, train_state_shapes, shard_fns
)
if train_state is None and restored_params is None:
# Initialize from scratch
train_state = sharded_init_fn(next_rng())
elif train_state is None and restored_params is not None:
# Restore from params but initialize train_state
train_state = sharded_create_trainstate_from_params(restored_params)
del restored_params
start_step = int(jax.device_get(train_state.step))
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
sharded_rng = next_rng()
step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
for step, (batch, dataset_metrics) in zip(step_counter, dataset):
train_state, sharded_rng, metrics = sharded_train_step(
train_state, sharded_rng, batch
)
if step % FLAGS.log_freq == 0:
if FLAGS.eval_steps > 0:
eval_metric_list = []
for _ in range(FLAGS.eval_steps):
eval_batch, _ = next(eval_iterator)
sharded_rng, eval_metrics = sharded_eval_step(
train_state, sharded_rng, eval_batch
)
eval_metric_list.append(eval_metrics)
metrics.update(average_metrics(eval_metric_list))
log_metrics = {"step": step}
log_metrics.update(metrics)
log_metrics.update(dataset_metrics)
log_metrics = jax.device_get(log_metrics)
logger.log(log_metrics)
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
save_checkpoint(train_state, milestone=True)
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
save_checkpoint(train_state)
if FLAGS.save_model_freq > 0:
save_checkpoint(train_state)
if __name__ == "__main__":
mlxu.run(main)