gpt-neo-125M-dutch / flax_to_pytorch.py
yhavinga's picture
Saving scripts, log and checkpoint at step 70000
f1818f3
raw history blame
No virus
836 Bytes
import torch
import numpy as np
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
from transformers import FlaxGPTNeoForCausalLM
from transformers import GPTNeoForCausalLM
tokenizer = AutoTokenizer.from_pretrained(".")
tokenizer.pad_token = tokenizer.eos_token
model_fx = FlaxGPTNeoForCausalLM.from_pretrained(".")
# def to_f32(t):
# return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
# model_fx.params = to_f32(model_fx.params)
# model_fx.save_pretrained("./fx")
model_pt = GPTNeoForCausalLM.from_pretrained(".", from_flax=True)
model_pt.save_pretrained(".")
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_fx = model_fx(input_ids).logits
print(logits_fx)