updated per new model code
Browse files- pt_resnet_to_flax.py +20 -24
pt_resnet_to_flax.py
CHANGED
@@ -1,23 +1,23 @@
|
|
1 |
-
from transformers import ResNetConfig, FlaxResNetForImageClassification, ResNetForImageClassification
|
2 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
3 |
from flax.core.frozen_dict import unfreeze
|
4 |
import re
|
|
|
|
|
5 |
|
6 |
-
pt_resnet = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
|
7 |
-
pt_state = pt_resnet.state_dict()
|
8 |
|
|
|
9 |
flax_resnet = FlaxResNetForImageClassification(pt_resnet.config)
|
|
|
|
|
10 |
flax_state = flatten_dict(unfreeze(flax_resnet.params))
|
11 |
-
|
|
|
12 |
new_pt_state = {}
|
13 |
-
pt_batch_stats = {}
|
14 |
for key, tensor in pt_state.items():
|
15 |
key_parts = set(key.split("."))
|
16 |
tensor = tensor.numpy()
|
17 |
|
18 |
-
key = re.sub(r"(?<=[a-zA-Z]).(?=\d)", "_", key)
|
19 |
-
|
20 |
-
|
21 |
if "convolution.weight" in key:
|
22 |
key = key.replace("weight", "kernel")
|
23 |
tensor = tensor.transpose((2, 3, 1, 0))
|
@@ -34,36 +34,32 @@ for key, tensor in pt_state.items():
|
|
34 |
key = "params."+key
|
35 |
new_pt_state[key] = tensor
|
36 |
|
37 |
-
elif "
|
38 |
-
key = "params.classifier.kernel"
|
39 |
new_pt_state[key] = tensor.transpose()
|
40 |
|
41 |
-
elif "
|
42 |
-
key = "params.classifier.bias"
|
43 |
new_pt_state[key] = tensor
|
44 |
|
45 |
elif "normalization.running_mean" in key:
|
46 |
key = key.replace("running_mean", "mean")
|
47 |
-
|
|
|
48 |
|
49 |
elif "normalization.running_var" in key:
|
50 |
key = key.replace("running_var", "var")
|
51 |
-
|
|
|
52 |
|
53 |
else:
|
54 |
-
|
55 |
|
|
|
56 |
for total_updated, (new_key, new_tensor) in enumerate(new_pt_state.items()):
|
57 |
orig_flax_tensor = flax_state.get(tuple(new_key.split(".")))
|
58 |
assert orig_flax_tensor is not None
|
59 |
-
|
60 |
-
if not("classifier" in new_key):
|
61 |
-
assert orig_flax_tensor.shape == new_tensor.shape
|
62 |
flax_state[tuple(new_key.split("."))] = new_tensor
|
63 |
-
|
64 |
flax_state = unflatten_dict(flax_state)
|
65 |
-
|
66 |
-
pt_batch_stats = unflatten_dict({tuple(k.split(".")):v for k,v in pt_batch_stats.items()})
|
67 |
-
flax_state["batch_stats"] = pt_batch_stats
|
68 |
-
|
69 |
-
flax_resnet.save_pretrained("resnet_50_flax", params=flax_state)
|
|
|
1 |
+
from transformers import ResNetConfig, FlaxResNetForImageClassification, ResNetForImageClassification, FlaxResNetModel, ResNetModel
|
2 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
3 |
from flax.core.frozen_dict import unfreeze
|
4 |
import re
|
5 |
+
import jax.numpy as jnp
|
6 |
+
import torch
|
7 |
|
|
|
|
|
8 |
|
9 |
+
pt_resnet = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
|
10 |
flax_resnet = FlaxResNetForImageClassification(pt_resnet.config)
|
11 |
+
|
12 |
+
pt_state = pt_resnet.state_dict()
|
13 |
flax_state = flatten_dict(unfreeze(flax_resnet.params))
|
14 |
+
|
15 |
+
|
16 |
new_pt_state = {}
|
|
|
17 |
for key, tensor in pt_state.items():
|
18 |
key_parts = set(key.split("."))
|
19 |
tensor = tensor.numpy()
|
20 |
|
|
|
|
|
|
|
21 |
if "convolution.weight" in key:
|
22 |
key = key.replace("weight", "kernel")
|
23 |
tensor = tensor.transpose((2, 3, 1, 0))
|
|
|
34 |
key = "params."+key
|
35 |
new_pt_state[key] = tensor
|
36 |
|
37 |
+
elif "classifier.1.weight" in key:
|
38 |
+
key = "params.classifier.1.kernel"
|
39 |
new_pt_state[key] = tensor.transpose()
|
40 |
|
41 |
+
elif "classifier.1.bias" in key:
|
42 |
+
key = "params.classifier.1.bias"
|
43 |
new_pt_state[key] = tensor
|
44 |
|
45 |
elif "normalization.running_mean" in key:
|
46 |
key = key.replace("running_mean", "mean")
|
47 |
+
key = "batch_stats."+key
|
48 |
+
new_pt_state[key] = tensor
|
49 |
|
50 |
elif "normalization.running_var" in key:
|
51 |
key = key.replace("running_var", "var")
|
52 |
+
key = "batch_stats."+key
|
53 |
+
new_pt_state[key] = tensor
|
54 |
|
55 |
else:
|
56 |
+
continue
|
57 |
|
58 |
+
|
59 |
for total_updated, (new_key, new_tensor) in enumerate(new_pt_state.items()):
|
60 |
orig_flax_tensor = flax_state.get(tuple(new_key.split(".")))
|
61 |
assert orig_flax_tensor is not None
|
62 |
+
assert orig_flax_tensor.shape == new_tensor.shape
|
|
|
|
|
63 |
flax_state[tuple(new_key.split("."))] = new_tensor
|
|
|
64 |
flax_state = unflatten_dict(flax_state)
|
65 |
+
flax_resnet.save_pretrained("resnet_50_flax", params=flax_state)
|
|
|
|
|
|
|
|