File size: 651 Bytes
d74ce70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from transformers import FlaxRobertaForMaskedLM, RobertaForMaskedLM, AutoTokenizer
import jax
import jax.numpy as jnp
def to_f32(t):
return jax.tree_map(
lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t
)
# load flax fp16 model
model = FlaxRobertaForMaskedLM.from_pretrained("./")
# convert to fp32 model
model.params = to_f32(model.params)
# save flax fp32 model
model.save_pretrained("./")
# convert flax fp32 model to pytorch
model_pt = RobertaForMaskedLM.from_pretrained("./", from_flax=True)
model_pt.save_pretrained("./")
tokenizer = AutoTokenizer.from_pretrained("./")
tokenizer.save_pretrained("./")
|