from transformers import AutoModelForCausalLM, FlaxAutoModelForCausalLM, AutoTokenizer import torch import numpy as np import jax import jax.numpy as jnp def to_f32(t): return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) jax.config.update('jax_platform_name', 'cpu') MODEL_PATH = "./" model = FlaxAutoModelForCausalLM.from_pretrained(MODEL_PATH) model.params = to_f32(model.params) model.save_pretrained(MODEL_PATH) pt_model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, from_flax=True).to('cpu') input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32) input_ids_pt = torch.tensor(input_ids) logits_pt = pt_model(input_ids_pt).logits print(logits_pt) logits_fx = model(input_ids).logits print(logits_fx) pt_model.save_pretrained(MODEL_PATH)