|
|
|
|
|
|
|
|
|
|
|
import pprint |
|
from functools import partial |
|
import os |
|
import numpy as np |
|
import jax |
|
import jax.numpy as jnp |
|
import flax.serialization |
|
import mlxu |
|
from EasyLM.checkpoint import StreamingCheckpointer |
|
from EasyLM.jax_utils import float_to_dtype |
|
|
|
|
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( |
|
recover_diff=False, |
|
load_base_checkpoint='', |
|
load_target_checkpoint='', |
|
output_file='', |
|
streaming=True, |
|
float_dtype='bf16', |
|
) |
|
|
|
|
|
def main(argv): |
|
assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != '' |
|
assert FLAGS.output_file != '' |
|
base_params = StreamingCheckpointer.load_trainstate_checkpoint( |
|
FLAGS.load_base_checkpoint, disallow_trainstate=True |
|
)[1]['params'] |
|
|
|
target_params = StreamingCheckpointer.load_trainstate_checkpoint( |
|
FLAGS.load_target_checkpoint, disallow_trainstate=True |
|
)[1]['params'] |
|
|
|
if FLAGS.recover_diff: |
|
params = jax.tree_util.tree_map( |
|
lambda b, t: b + t, base_params, target_params |
|
) |
|
else: |
|
params = jax.tree_util.tree_map( |
|
lambda b, t: t - b, base_params, target_params |
|
) |
|
|
|
if FLAGS.streaming: |
|
StreamingCheckpointer.save_train_state_to_file( |
|
params, FLAGS.output_file, float_dtype=FLAGS.float_dtype |
|
) |
|
else: |
|
params = float_to_dtype(params, FLAGS.float_dtype) |
|
with mlxu.open_file(FLAGS.output, 'wb') as fout: |
|
fout.write(flax.serialization.msgpack_serialize(params, in_place=True)) |
|
|
|
|
|
if __name__ == "__main__": |
|
mlxu.run(main) |
|
|