from pathlib import Path import flax import jax import jax.numpy as jnp def to_f32(t): return jax.tree_map( lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t ) data = flax.serialization.msgpack_restore( Path("output/flax_model.msgpack").read_bytes() ) transformed_data = to_f32(data) Path("output/flax_model_f32.msgpack").write_bytes( flax.serialization.msgpack_serialize(transformed_data) )