gpt2-medium-dutch-nedd / flax_to_pytorch.py
yhavinga's picture
Add scripts, model config and vocabulary
21e9f42
raw
history blame
828 Bytes
import torch
import numpy as np
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
from transformers import FlaxGPT2LMHeadModel
from transformers import GPT2LMHeadModel
tokenizer = AutoTokenizer.from_pretrained(".")
tokenizer.pad_token = tokenizer.eos_token
model_fx = FlaxGPT2LMHeadModel.from_pretrained(".")
# def to_f32(t):
# return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
# model_fx.params = to_f32(model_fx.params)
# model_fx.save_pretrained("./fx")
model_pt = GPT2LMHeadModel.from_pretrained(".", from_flax=True)
model_pt.save_pretrained(".")
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_fx = model_fx(input_ids).logits
print(logits_fx)