|
from transformers import ResNetConfig, FlaxResNetForImageClassification, ResNetForImageClassification, FlaxResNetModel, ResNetModel |
|
from flax.traverse_util import flatten_dict, unflatten_dict |
|
from flax.core.frozen_dict import unfreeze |
|
import re |
|
import jax.numpy as jnp |
|
import torch |
|
|
|
|
|
pt_resnet = ResNetForImageClassification.from_pretrained("microsoft/resnet-50") |
|
flax_resnet = FlaxResNetForImageClassification(pt_resnet.config) |
|
|
|
pt_state = pt_resnet.state_dict() |
|
flax_state = flatten_dict(unfreeze(flax_resnet.params)) |
|
|
|
|
|
new_pt_state = {} |
|
for key, tensor in pt_state.items(): |
|
key_parts = set(key.split(".")) |
|
tensor = tensor.numpy() |
|
|
|
if "convolution.weight" in key: |
|
key = key.replace("weight", "kernel") |
|
tensor = tensor.transpose((2, 3, 1, 0)) |
|
key = "params."+key |
|
new_pt_state[key] = tensor |
|
|
|
elif "normalization.weight" in key: |
|
key = key.replace("weight", "scale") |
|
key = "params."+key |
|
new_pt_state[key] = tensor |
|
|
|
elif "normalization.bias" in key: |
|
key = key.replace("bias", "bias") |
|
key = "params."+key |
|
new_pt_state[key] = tensor |
|
|
|
elif "classifier.1.weight" in key: |
|
key = "params.classifier.1.kernel" |
|
new_pt_state[key] = tensor.transpose() |
|
|
|
elif "classifier.1.bias" in key: |
|
key = "params.classifier.1.bias" |
|
new_pt_state[key] = tensor |
|
|
|
elif "normalization.running_mean" in key: |
|
key = key.replace("running_mean", "mean") |
|
key = "batch_stats."+key |
|
new_pt_state[key] = tensor |
|
|
|
elif "normalization.running_var" in key: |
|
key = key.replace("running_var", "var") |
|
key = "batch_stats."+key |
|
new_pt_state[key] = tensor |
|
|
|
else: |
|
continue |
|
|
|
|
|
for total_updated, (new_key, new_tensor) in enumerate(new_pt_state.items()): |
|
orig_flax_tensor = flax_state.get(tuple(new_key.split("."))) |
|
assert orig_flax_tensor is not None |
|
assert orig_flax_tensor.shape == new_tensor.shape |
|
flax_state[tuple(new_key.split("."))] = new_tensor |
|
flax_state = unflatten_dict(flax_state) |
|
flax_resnet.save_pretrained("resnet_50_flax", params=flax_state) |
|
|