from pathlib import Path | |
import flax | |
import jax | |
import jax.numpy as jnp | |
def to_f32(t): | |
return jax.tree_map( | |
lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t | |
) | |
data = flax.serialization.msgpack_restore( | |
Path("output/flax_model.msgpack").read_bytes() | |
) | |
transformed_data = to_f32(data) | |
Path("output/flax_model_f32.msgpack").write_bytes( | |
flax.serialization.msgpack_serialize(transformed_data) | |
) | |