roberta-large-finnish / flax_model_to_pytorch.py
aapot
Add pytorch model
2670ecc
raw history blame
No virus
779 Bytes
from transformers import RobertaForMaskedLM, FlaxRobertaForMaskedLM, AutoTokenizer
import torch
import numpy as np
import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'cpu')
MODEL_PATH = "./"
model = FlaxRobertaForMaskedLM.from_pretrained(MODEL_PATH)
def to_f32(t):
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
model.params = to_f32(model.params)
model.save_pretrained(MODEL_PATH)
pt_model = RobertaForMaskedLM.from_pretrained(MODEL_PATH, from_flax=True).to('cpu')
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = pt_model(input_ids_pt).logits
print(logits_pt)
logits_fx = model(input_ids).logits
print(logits_fx)
pt_model.save_pretrained(MODEL_PATH)