import torch import numpy as np import jax import jax.numpy as jnp from transformers import AutoTokenizer from transformers import FlaxGPT2LMHeadModel from transformers import GPT2LMHeadModel model_fx = FlaxGPT2LMHeadModel.from_pretrained("./") model_pt = GPT2LMHeadModel.from_pretrained("./", from_flax=True) model_pt.save_pretrained("./") input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32) input_ids_pt = torch.tensor(input_ids) logits_pt = model_pt(input_ids_pt).logits print(logits_pt) logits_fx = model_fx(input_ids).logits print(logits_fx)