roberta-base-mr / flax_to_torch.py
hassiahk's picture
Add PyTorch model
06f30b0
import argparse
from transformers import RobertaForMaskedLM
def convert_flax_model_to_torch(flax_model_path: str, torch_model_path: str = "./"):
"""
Converts Flax model weights to PyTorch weights.
"""
model = RobertaForMaskedLM.from_pretrained(flax_model_path, from_flax=True)
model.save_pretrained(torch_model_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Flax to Pytorch model coversion")
parser.add_argument(
"--flax_model_path", type=str, default="flax-community/roberta-base-mr", help="Flax model path"
)
parser.add_argument("--torch_model_path", type=str, default="./", help="PyTorch model path")
args = parser.parse_args()
convert_flax_model_to_torch(args.flax_model_path, args.torch_model_path)