File size: 793 Bytes
730de6f aa60c5c 730de6f d5cc6ba 730de6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import argparse
from transformers import RobertaForMaskedLM
def convert_flax_model_to_torch(flax_model_path: str, torch_model_path: str = "./"):
"""
Converts Flax model weights to PyTorch weights.
"""
model = RobertaForMaskedLM.from_pretrained(flax_model_path, from_flax=True)
model.save_pretrained(torch_model_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Flax to Pytorch model coversion")
parser.add_argument(
"--flax_model_path", type=str, default="flax-community/roberta-hindi", help="Flax model path"
)
parser.add_argument("--torch_model_path", type=str, default="./", help="PyTorch model path")
args = parser.parse_args()
convert_flax_model_to_torch(args.flax_model_path, args.torch_model_path)
|