File size: 802 Bytes
06f30b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import argparse

from transformers import RobertaForMaskedLM


def convert_flax_model_to_torch(flax_model_path: str, torch_model_path: str = "./"):
    """
    Converts Flax model weights to PyTorch weights.
    """
    model = RobertaForMaskedLM.from_pretrained(flax_model_path, from_flax=True)
    model.save_pretrained(torch_model_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Flax to Pytorch model coversion")
    parser.add_argument(
        "--flax_model_path", type=str, default="flax-community/roberta-base-mr", help="Flax model path"
    )
    parser.add_argument("--torch_model_path", type=str, default="./", help="PyTorch model path")
    args = parser.parse_args()
    convert_flax_model_to_torch(args.flax_model_path, args.torch_model_path)