File size: 647 Bytes
c5a9149 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
import numpy as np
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
from transformers import GPT2LMHeadModel
from transformers import TFGPT2LMHeadModel
tokenizer = AutoTokenizer.from_pretrained("../")
tokenizer.pad_token = tokenizer.eos_token
model_pt = GPT2LMHeadModel.from_pretrained("./pt")
model_tf = TFGPT2LMHeadModel.from_pretrained("./pt", from_pt=True)
model_tf.save_pretrained("./tf")
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_tf = model_tf(input_ids).logits
print(logits_tf) |