import torch import numpy as np import jax import jax.numpy as jnp from transformers import AutoTokenizer from transformers import GPT2LMHeadModel from transformers import TFGPT2LMHeadModel tokenizer = AutoTokenizer.from_pretrained("../") tokenizer.pad_token = tokenizer.eos_token model_pt = GPT2LMHeadModel.from_pretrained("./pt") model_tf = TFGPT2LMHeadModel.from_pretrained("./pt", from_pt=True) model_tf.save_pretrained("./tf") input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32) input_ids_pt = torch.tensor(input_ids) logits_pt = model_pt(input_ids_pt).logits print(logits_pt) logits_tf = model_tf(input_ids).logits print(logits_tf)