|
import pprint |
|
from functools import partial |
|
|
|
from tqdm import tqdm, trange |
|
import numpy as np |
|
import mlxu |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
from jax.experimental.pjit import pjit |
|
from jax.sharding import PartitionSpec as PS |
|
from flax.training.train_state import TrainState |
|
|
|
from EasyLM.data import DatasetFactory |
|
from EasyLM.checkpoint import StreamingCheckpointer |
|
from EasyLM.optimizers import OptimizerFactory |
|
from EasyLM.jax_utils import ( |
|
JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, |
|
cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name, |
|
set_random_seed, average_metrics, get_weight_decay_mask, |
|
make_shard_and_gather_fns, with_sharding_constraint, |
|
) |
|
from EasyLM.models.llama.llama_model import ( |
|
LLaMAConfig, FlaxLLaMAForCausalLMModule |
|
) |
|
|
|
|
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( |
|
seed=42, |
|
mesh_dim='1,-1,1', |
|
dtype='fp32', |
|
param_dtype='fp32', |
|
total_steps=10000, |
|
load_llama_config='', |
|
update_llama_config='', |
|
load_checkpoint='', |
|
load_dataset_state='', |
|
log_freq=50, |
|
save_model_freq=0, |
|
save_milestone_freq=0, |
|
eval_freq=0, |
|
tokenizer=LLaMAConfig.get_tokenizer_config(), |
|
train_dataset=DatasetFactory.get_default_config(), |
|
eval_dataset=DatasetFactory.get_default_config(), |
|
optimizer=OptimizerFactory.get_default_config(), |
|
checkpointer=StreamingCheckpointer.get_default_config(), |
|
llama=LLaMAConfig.get_default_config(), |
|
logger=mlxu.WandBLogger.get_default_config(), |
|
log_all_worker=False, |
|
jax_distributed=JaxDistributedConfig.get_default_config(), |
|
) |
|
|
|
|
|
def main(argv): |
|
JaxDistributedConfig.initialize(FLAGS.jax_distributed) |
|
variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF) |
|
flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF) |
|
logger = mlxu.WandBLogger( |
|
config=FLAGS.logger, |
|
variant=variant, |
|
enable=FLAGS.log_all_worker or (jax.process_index() == 0), |
|
) |
|
set_random_seed(FLAGS.seed) |
|
|
|
tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer) |
|
dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer) |
|
if FLAGS.load_dataset_state != '': |
|
dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state)) |
|
|
|
if FLAGS.eval_freq > 0: |
|
eval_dataset = DatasetFactory.load_dataset( |
|
FLAGS.eval_dataset, dataset.tokenizer, eval_dataset=True |
|
) |
|
|
|
seq_length = dataset.seq_length |
|
|
|
if FLAGS.load_llama_config != '': |
|
llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config) |
|
else: |
|
llama_config = LLaMAConfig(**FLAGS.llama) |
|
|
|
if FLAGS.update_llama_config != '': |
|
llama_config.update(dict(eval(FLAGS.update_llama_config))) |
|
|
|
llama_config.update(dict( |
|
bos_token_id=dataset.tokenizer.bos_token_id, |
|
eos_token_id=dataset.tokenizer.eos_token_id, |
|
)) |
|
if llama_config.vocab_size < dataset.vocab_size: |
|
print("Updating model config vocab size from", llama_config.vocab_size, "to", dataset.vocab_size) |
|
llama_config.update(dict(vocab_size=dataset.vocab_size)) |
|
|
|
model = FlaxLLaMAForCausalLMModule( |
|
llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype), param_dtype=get_float_dtype_by_name(FLAGS.param_dtype) |
|
) |
|
|
|
optimizer, optimizer_info = OptimizerFactory.get_optimizer( |
|
FLAGS.optimizer, |
|
get_weight_decay_mask(LLaMAConfig.get_weight_decay_exclusions()) |
|
) |
|
|
|
def create_trainstate_from_params(params): |
|
return TrainState.create(params=params, tx=optimizer, apply_fn=None) |
|
|
|
def init_fn(rng): |
|
rng_generator = JaxRNG(rng) |
|
params = model.init( |
|
input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32), |
|
position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32), |
|
attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32), |
|
rngs=rng_generator(llama_config.rng_keys()), |
|
) |
|
return TrainState.create(params=params, tx=optimizer, apply_fn=None) |
|
|
|
def train_step(train_state, rng, batch): |
|
rng_generator = JaxRNG(rng) |
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'))) |
|
def loss_and_accuracy(params): |
|
logits = model.apply( |
|
params, batch['input_tokens'], deterministic=False, |
|
rngs=rng_generator(llama_config.rng_keys()), |
|
).logits |
|
return cross_entropy_loss_and_accuracy( |
|
logits, batch['target_tokens'], batch['loss_masks'] |
|
) |
|
grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True) |
|
(loss, accuracy), grads = grad_fn(train_state.params) |
|
train_state = train_state.apply_gradients(grads=grads) |
|
metrics = dict( |
|
loss=loss, |
|
accuracy=accuracy, |
|
learning_rate=optimizer_info['learning_rate_schedule'](train_state.step), |
|
gradient_norm=global_norm(grads), |
|
param_norm=global_norm(train_state.params), |
|
) |
|
return train_state, rng_generator(), metrics |
|
|
|
def eval_step(train_state, rng, batch): |
|
rng_generator = JaxRNG(rng) |
|
batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'))) |
|
logits = model.apply( |
|
train_state.params, batch['input_tokens'], deterministic=True, |
|
rngs=rng_generator(llama_config.rng_keys()), |
|
).logits |
|
loss, accuracy = cross_entropy_loss_and_accuracy( |
|
logits, batch['target_tokens'], batch['loss_masks'] |
|
) |
|
metrics = dict( |
|
eval_loss=loss, |
|
eval_accuracy=accuracy, |
|
) |
|
return rng_generator(), metrics |
|
|
|
train_state_shapes = jax.eval_shape(init_fn, next_rng()) |
|
train_state_partition = match_partition_rules( |
|
LLaMAConfig.get_partition_rules(), train_state_shapes |
|
) |
|
|
|
shard_fns, gather_fns = make_shard_and_gather_fns( |
|
train_state_partition, train_state_shapes |
|
) |
|
checkpointer = StreamingCheckpointer( |
|
FLAGS.checkpointer, logger.output_dir, |
|
enable=jax.process_index() == 0, |
|
) |
|
|
|
sharded_init_fn = pjit( |
|
init_fn, |
|
in_shardings=PS(), |
|
out_shardings=train_state_partition |
|
) |
|
|
|
sharded_create_trainstate_from_params = pjit( |
|
create_trainstate_from_params, |
|
in_shardings=(train_state_partition.params, ), |
|
out_shardings=train_state_partition, |
|
donate_argnums=(0, ), |
|
) |
|
|
|
sharded_train_step = pjit( |
|
train_step, |
|
in_shardings=(train_state_partition, PS(), PS()), |
|
out_shardings=(train_state_partition, PS(), PS()), |
|
donate_argnums=(0, 1), |
|
) |
|
|
|
sharded_eval_step = pjit( |
|
eval_step, |
|
in_shardings=(train_state_partition, PS(), PS()), |
|
out_shardings=(PS(), PS()), |
|
donate_argnums=(1,), |
|
) |
|
|
|
def save_checkpoint(train_state, milestone=False): |
|
step = int(jax.device_get(train_state.step)) |
|
metadata = dict( |
|
step=step, |
|
variant=variant, |
|
flags=flags_config_dict, |
|
llama_config=llama_config.to_dict(), |
|
) |
|
checkpointer.save_all( |
|
train_state=train_state, |
|
gather_fns=gather_fns, |
|
metadata=metadata, |
|
dataset=dataset.get_state_dict(), |
|
milestone=milestone, |
|
) |
|
|
|
mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim) |
|
with mesh: |
|
train_state, restored_params = None, None |
|
if FLAGS.load_checkpoint != '': |
|
train_state, restored_params = checkpointer.load_trainstate_checkpoint( |
|
FLAGS.load_checkpoint, train_state_shapes, shard_fns |
|
) |
|
|
|
if train_state is None and restored_params is None: |
|
|
|
train_state = sharded_init_fn(next_rng()) |
|
elif train_state is None and restored_params is not None: |
|
|
|
train_state = sharded_create_trainstate_from_params(restored_params) |
|
del restored_params |
|
|
|
start_step = int(jax.device_get(train_state.step)) |
|
|
|
if FLAGS.save_model_freq > 0: |
|
save_checkpoint(train_state) |
|
|
|
sharded_rng = next_rng() |
|
|
|
step_counter = trange(start_step, FLAGS.total_steps, ncols=0) |
|
|
|
for step, (batch, dataset_metrics) in zip(step_counter, dataset): |
|
train_state, sharded_rng, metrics = sharded_train_step( |
|
train_state, sharded_rng, batch |
|
) |
|
|
|
if FLAGS.eval_freq > 0 and (step + 1) % FLAGS.eval_freq == 0: |
|
eval_metric_list = [] |
|
eval_iterator = iter(eval_dataset) |
|
for eval_batch, _ in eval_iterator: |
|
sharded_rng, eval_metrics = sharded_eval_step( |
|
train_state, sharded_rng, eval_batch |
|
) |
|
eval_metric_list.append(eval_metrics) |
|
metrics.update(average_metrics(eval_metric_list)) |
|
|
|
if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0: |
|
log_metrics = {"step": step + 1} |
|
log_metrics.update(metrics) |
|
log_metrics.update(dataset_metrics) |
|
log_metrics = jax.device_get(log_metrics) |
|
logger.log(log_metrics) |
|
tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") |
|
|
|
if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0: |
|
save_checkpoint(train_state, milestone=True) |
|
elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0: |
|
save_checkpoint(train_state) |
|
|
|
if FLAGS.save_model_freq > 0: |
|
save_checkpoint(train_state) |
|
|
|
|
|
if __name__ == "__main__": |
|
mlxu.run(main) |
|
|