gpt2news / flax_to_pt.py
imthanhlv's picture
Saving weights and logs of step 37500
1f656b0
raw
history blame
147 Bytes
from transformers import AutoModelForCausalLM
pt_model = AutoModelForCausalLM.from_pretrained('.', from_flax=True)
pt_model.save_pretrained(".")