import argparse from transformers import RobertaForMaskedLM def convert_flax_model_to_torch(flax_model_path: str, torch_model_path: str = "./"): """ Converts Flax model weights to PyTorch weights. """ model = RobertaForMaskedLM.from_pretrained(flax_model_path, from_flax=True) model.save_pretrained(torch_model_path) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Flax to Pytorch model coversion") parser.add_argument( "--flax_model_path", type=str, default="flax-community/roberta-base-mr", help="Flax model path" ) parser.add_argument("--torch_model_path", type=str, default="./", help="PyTorch model path") args = parser.parse_args() convert_flax_model_to_torch(args.flax_model_path, args.torch_model_path)