gpt2-bengali / jax2tensor.py
khalidsaifullaah's picture
pytorch weights added
9b64d7f
raw history blame
No virus
555 Bytes
import torch
import numpy as np
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
from transformers import FlaxGPT2LMHeadModel
from transformers import GPT2LMHeadModel
model_fx = FlaxGPT2LMHeadModel.from_pretrained("./")
model_pt = GPT2LMHeadModel.from_pretrained("./", from_flax=True)
model_pt.save_pretrained("./")
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)
logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_fx = model_fx(input_ids).logits
print(logits_fx)