#!/usr/bin/env python | |
import tempfile | |
import jax | |
from jax import numpy as jnp | |
from transformers import AutoTokenizer, FlaxRobertaForMaskedLM, RobertaForMaskedLM, FlaxBertForMaskedLM, BertForMaskedLM | |
def main(): | |
# Saving extra files from config.json and tokenizer.json files | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased") | |
tokenizer.save_pretrained("./") | |
# Temporary saving bfloat16 Flax model into float32 | |
tmp = tempfile.mkdtemp() | |
#flax_model = FlaxRobertaForMaskedLM.from_pretrained("./") | |
flax_model = FlaxBertForMaskedLM.from_pretrained("./") | |
flax_model.save_pretrained(tmp) | |
# Converting float32 Flax to PyTorch | |
#model = RobertaForMaskedLM.from_pretrained(tmp, from_flax=True) | |
model = BertForMaskedLM.from_pretrained(tmp, from_flax=True) | |
model.save_pretrained("./", save_config=False) | |
if __name__ == "__main__": | |
main() | |