mariagrandury
commited on
Commit
•
814b575
1
Parent(s):
062f3fc
Add utils functions
Browse files- bfloat16_float32.py +20 -0
- flax_to_pytorch.py +4 -0
bfloat16_float32.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import flax
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
|
7 |
+
|
8 |
+
def to_f32(t):
|
9 |
+
return jax.tree_map(
|
10 |
+
lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
data = flax.serialization.msgpack_restore(
|
15 |
+
Path("output/flax_model.msgpack").read_bytes()
|
16 |
+
)
|
17 |
+
transformed_data = to_f32(data)
|
18 |
+
Path("output/flax_model_f32.msgpack").write_bytes(
|
19 |
+
flax.serialization.msgpack_serialize(transformed_data)
|
20 |
+
)
|
flax_to_pytorch.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Wav2Vec2Model
|
2 |
+
|
3 |
+
model = Wav2Vec2Model.from_pretrained("output", from_flax=True)
|
4 |
+
model.save_pretrained("./")
|