indonesian-roberta-large / flax_to_torch.py
w11wo's picture
pytorch model
d74ce70
raw
history blame
651 Bytes
from transformers import FlaxRobertaForMaskedLM, RobertaForMaskedLM, AutoTokenizer
import jax
import jax.numpy as jnp
def to_f32(t):
return jax.tree_map(
lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t
)
# load flax fp16 model
model = FlaxRobertaForMaskedLM.from_pretrained("./")
# convert to fp32 model
model.params = to_f32(model.params)
# save flax fp32 model
model.save_pretrained("./")
# convert flax fp32 model to pytorch
model_pt = RobertaForMaskedLM.from_pretrained("./", from_flax=True)
model_pt.save_pretrained("./")
tokenizer = AutoTokenizer.from_pretrained("./")
tokenizer.save_pretrained("./")