# from transformers import T5ForConditionalGeneration, TFT5ForConditionalGeneration, FlaxT5ForConditionalGeneration # import numpy as np # import torch # # fx_model = FlaxT5ForConditionalGeneration.from_pretrained(".") # # pt_model = T5ForConditionalGeneration.from_pretrained(".", from_flax=True) # pt_model.save_pretrained(".") # # # # tf_model = TFT5ForConditionalGeneration.from_pretrained(".", from_pt=True) # # tf_model.save_pretrained(".") # #!/usr/bin/env python import tempfile import jax import numpy as np import torch from jax import numpy as jnp from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration, T5ForConditionalGeneration def to_f32(t): return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) def main(): # Saving extra files from config.json and tokenizer.json files tokenizer = AutoTokenizer.from_pretrained("./") tokenizer.save_pretrained("./") # Temporary saving bfloat16 Flax model into float32 tmp = tempfile.mkdtemp() flax_model = FlaxT5ForConditionalGeneration.from_pretrained("./") flax_model.params = to_f32(flax_model.params) flax_model.save_pretrained(tmp) # Converting float32 Flax to PyTorch pt_model = T5ForConditionalGeneration.from_pretrained(tmp, from_flax=True) pt_model.save_pretrained("./", save_config=False) input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32) input_ids_pt = torch.tensor(input_ids) logits_pt = pt_model(input_ids_pt).logits print(logits_pt) logits_fx = flax_model(input_ids).logits print(logits_fx) if __name__ == "__main__": main()