mariagrandury commited on
Commit
814b575
1 Parent(s): 062f3fc

Add utils functions

Browse files
Files changed (2) hide show
  1. bfloat16_float32.py +20 -0
  2. flax_to_pytorch.py +4 -0
bfloat16_float32.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import flax
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+
8
+ def to_f32(t):
9
+ return jax.tree_map(
10
+ lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t
11
+ )
12
+
13
+
14
+ data = flax.serialization.msgpack_restore(
15
+ Path("output/flax_model.msgpack").read_bytes()
16
+ )
17
+ transformed_data = to_f32(data)
18
+ Path("output/flax_model_f32.msgpack").write_bytes(
19
+ flax.serialization.msgpack_serialize(transformed_data)
20
+ )
flax_to_pytorch.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ from transformers import Wav2Vec2Model
2
+
3
+ model = Wav2Vec2Model.from_pretrained("output", from_flax=True)
4
+ model.save_pretrained("./")