File size: 440 Bytes
814b575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from pathlib import Path

import flax
import jax
import jax.numpy as jnp


def to_f32(t):
    return jax.tree_map(
        lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t
    )


data = flax.serialization.msgpack_restore(
    Path("output/flax_model.msgpack").read_bytes()
)
transformed_data = to_f32(data)
Path("output/flax_model_f32.msgpack").write_bytes(
    flax.serialization.msgpack_serialize(transformed_data)
)