File size: 651 Bytes
d74ce70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from transformers import FlaxRobertaForMaskedLM, RobertaForMaskedLM, AutoTokenizer
import jax
import jax.numpy as jnp


def to_f32(t):
    return jax.tree_map(
        lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t
    )


# load flax fp16 model
model = FlaxRobertaForMaskedLM.from_pretrained("./")
# convert to fp32 model
model.params = to_f32(model.params)
# save flax fp32 model
model.save_pretrained("./")

# convert flax fp32 model to pytorch
model_pt = RobertaForMaskedLM.from_pretrained("./", from_flax=True)
model_pt.save_pretrained("./")

tokenizer = AutoTokenizer.from_pretrained("./")
tokenizer.save_pretrained("./")