gpt2-finnish / flax_model_to_pytorch.py
aapot's picture
Add 10k train step pytorch model
2f135b2
raw history blame
No virus
799 Bytes
from transformers import AutoModelForCausalLM, FlaxAutoModelForCausalLM, AutoTokenizer
import torch
import numpy as np
import jax
import jax.numpy as jnp
def to_f32(t):
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
jax.config.update('jax_platform_name', 'cpu')
MODEL_PATH = "./"
model = FlaxAutoModelForCausalLM.from_pretrained(MODEL_PATH)
model.params = to_f32(model.params)
model.save_pretrained(MODEL_PATH)
pt_model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH, from_flax=True).to('cpu')
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = pt_model(input_ids_pt).logits
print(logits_pt)
logits_fx = model(input_ids).logits
print(logits_fx)
pt_model.save_pretrained(MODEL_PATH)