wav2vec2-spanish / bfloat16_float32.py
mariagrandury's picture
Add utils functions
814b575
from pathlib import Path
import flax
import jax
import jax.numpy as jnp
def to_f32(t):
return jax.tree_map(
lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t
)
data = flax.serialization.msgpack_restore(
Path("output/flax_model.msgpack").read_bytes()
)
transformed_data = to_f32(data)
Path("output/flax_model_f32.msgpack").write_bytes(
flax.serialization.msgpack_serialize(transformed_data)
)