Shubhamai commited on
Commit
9ed3b66
1 Parent(s): d21b292

updated per new model code

Browse files
Files changed (1) hide show
  1. pt_resnet_to_flax.py +20 -24
pt_resnet_to_flax.py CHANGED
@@ -1,23 +1,23 @@
1
- from transformers import ResNetConfig, FlaxResNetForImageClassification, ResNetForImageClassification
2
  from flax.traverse_util import flatten_dict, unflatten_dict
3
  from flax.core.frozen_dict import unfreeze
4
  import re
 
 
5
 
6
- pt_resnet = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
7
- pt_state = pt_resnet.state_dict()
8
 
 
9
  flax_resnet = FlaxResNetForImageClassification(pt_resnet.config)
 
 
10
  flax_state = flatten_dict(unfreeze(flax_resnet.params))
11
-
 
12
  new_pt_state = {}
13
- pt_batch_stats = {}
14
  for key, tensor in pt_state.items():
15
  key_parts = set(key.split("."))
16
  tensor = tensor.numpy()
17
 
18
- key = re.sub(r"(?<=[a-zA-Z]).(?=\d)", "_", key)
19
-
20
-
21
  if "convolution.weight" in key:
22
  key = key.replace("weight", "kernel")
23
  tensor = tensor.transpose((2, 3, 1, 0))
@@ -34,36 +34,32 @@ for key, tensor in pt_state.items():
34
  key = "params."+key
35
  new_pt_state[key] = tensor
36
 
37
- elif "classifier_1.weight" in key:
38
- key = "params.classifier.kernel"
39
  new_pt_state[key] = tensor.transpose()
40
 
41
- elif "classifier_1.bias" in key:
42
- key = "params.classifier.bias"
43
  new_pt_state[key] = tensor
44
 
45
  elif "normalization.running_mean" in key:
46
  key = key.replace("running_mean", "mean")
47
- pt_batch_stats[key] = tensor
 
48
 
49
  elif "normalization.running_var" in key:
50
  key = key.replace("running_var", "var")
51
- pt_batch_stats[key] = tensor
 
52
 
53
  else:
54
- pass
55
 
 
56
  for total_updated, (new_key, new_tensor) in enumerate(new_pt_state.items()):
57
  orig_flax_tensor = flax_state.get(tuple(new_key.split(".")))
58
  assert orig_flax_tensor is not None
59
-
60
- if not("classifier" in new_key):
61
- assert orig_flax_tensor.shape == new_tensor.shape
62
  flax_state[tuple(new_key.split("."))] = new_tensor
63
-
64
  flax_state = unflatten_dict(flax_state)
65
-
66
- pt_batch_stats = unflatten_dict({tuple(k.split(".")):v for k,v in pt_batch_stats.items()})
67
- flax_state["batch_stats"] = pt_batch_stats
68
-
69
- flax_resnet.save_pretrained("resnet_50_flax", params=flax_state)
 
1
+ from transformers import ResNetConfig, FlaxResNetForImageClassification, ResNetForImageClassification, FlaxResNetModel, ResNetModel
2
  from flax.traverse_util import flatten_dict, unflatten_dict
3
  from flax.core.frozen_dict import unfreeze
4
  import re
5
+ import jax.numpy as jnp
6
+ import torch
7
 
 
 
8
 
9
+ pt_resnet = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
10
  flax_resnet = FlaxResNetForImageClassification(pt_resnet.config)
11
+
12
+ pt_state = pt_resnet.state_dict()
13
  flax_state = flatten_dict(unfreeze(flax_resnet.params))
14
+
15
+
16
  new_pt_state = {}
 
17
  for key, tensor in pt_state.items():
18
  key_parts = set(key.split("."))
19
  tensor = tensor.numpy()
20
 
 
 
 
21
  if "convolution.weight" in key:
22
  key = key.replace("weight", "kernel")
23
  tensor = tensor.transpose((2, 3, 1, 0))
 
34
  key = "params."+key
35
  new_pt_state[key] = tensor
36
 
37
+ elif "classifier.1.weight" in key:
38
+ key = "params.classifier.1.kernel"
39
  new_pt_state[key] = tensor.transpose()
40
 
41
+ elif "classifier.1.bias" in key:
42
+ key = "params.classifier.1.bias"
43
  new_pt_state[key] = tensor
44
 
45
  elif "normalization.running_mean" in key:
46
  key = key.replace("running_mean", "mean")
47
+ key = "batch_stats."+key
48
+ new_pt_state[key] = tensor
49
 
50
  elif "normalization.running_var" in key:
51
  key = key.replace("running_var", "var")
52
+ key = "batch_stats."+key
53
+ new_pt_state[key] = tensor
54
 
55
  else:
56
+ continue
57
 
58
+
59
  for total_updated, (new_key, new_tensor) in enumerate(new_pt_state.items()):
60
  orig_flax_tensor = flax_state.get(tuple(new_key.split(".")))
61
  assert orig_flax_tensor is not None
62
+ assert orig_flax_tensor.shape == new_tensor.shape
 
 
63
  flax_state[tuple(new_key.split("."))] = new_tensor
 
64
  flax_state = unflatten_dict(flax_state)
65
+ flax_resnet.save_pretrained("resnet_50_flax", params=flax_state)