hassiahk commited on
Commit
06f30b0
1 Parent(s): c70f803

Add PyTorch model

Browse files
Files changed (2) hide show
  1. flax_to_torch.py +22 -0
  2. pytorch_model.bin +3 -0
flax_to_torch.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from transformers import RobertaForMaskedLM
4
+
5
+
6
+ def convert_flax_model_to_torch(flax_model_path: str, torch_model_path: str = "./"):
7
+ """
8
+ Converts Flax model weights to PyTorch weights.
9
+ """
10
+ model = RobertaForMaskedLM.from_pretrained(flax_model_path, from_flax=True)
11
+ model.save_pretrained(torch_model_path)
12
+
13
+
14
+ if __name__ == "__main__":
15
+ parser = argparse.ArgumentParser(
16
+ description="Flax to Pytorch model coversion")
17
+ parser.add_argument(
18
+ "--flax_model_path", type=str, default="flax-community/roberta-base-mr", help="Flax model path"
19
+ )
20
+ parser.add_argument("--torch_model_path", type=str, default="./", help="PyTorch model path")
21
+ args = parser.parse_args()
22
+ convert_flax_model_to_torch(args.flax_model_path, args.torch_model_path)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b5417d369e470bb94982e578c90cc7d0ad1ec943577b8f92f9918fbe3a2ca2f
3
+ size 498872747