File size: 555 Bytes
9b64d7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import numpy as np
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
from transformers import FlaxGPT2LMHeadModel
from transformers import GPT2LMHeadModel


model_fx = FlaxGPT2LMHeadModel.from_pretrained("./")
model_pt = GPT2LMHeadModel.from_pretrained("./", from_flax=True)

model_pt.save_pretrained("./")


input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_fx = model_fx(input_ids).logits
print(logits_fx)