from transformers import AutoModelForCausalLM pt_model = AutoModelForCausalLM.from_pretrained('.', from_flax=True) pt_model.save_pretrained(".")