from transformers import ResNetConfig, FlaxResNetForImageClassification, ResNetForImageClassification, FlaxResNetModel, ResNetModel from flax.traverse_util import flatten_dict, unflatten_dict from flax.core.frozen_dict import unfreeze import re import jax.numpy as jnp import torch pt_resnet = ResNetForImageClassification.from_pretrained("microsoft/resnet-50") flax_resnet = FlaxResNetForImageClassification(pt_resnet.config) pt_state = pt_resnet.state_dict() flax_state = flatten_dict(unfreeze(flax_resnet.params)) new_pt_state = {} for key, tensor in pt_state.items(): key_parts = set(key.split(".")) tensor = tensor.numpy() if "convolution.weight" in key: key = key.replace("weight", "kernel") tensor = tensor.transpose((2, 3, 1, 0)) key = "params."+key new_pt_state[key] = tensor elif "normalization.weight" in key: key = key.replace("weight", "scale") key = "params."+key new_pt_state[key] = tensor elif "normalization.bias" in key: key = key.replace("bias", "bias") key = "params."+key new_pt_state[key] = tensor elif "classifier.1.weight" in key: key = "params.classifier.1.kernel" new_pt_state[key] = tensor.transpose() elif "classifier.1.bias" in key: key = "params.classifier.1.bias" new_pt_state[key] = tensor elif "normalization.running_mean" in key: key = key.replace("running_mean", "mean") key = "batch_stats."+key new_pt_state[key] = tensor elif "normalization.running_var" in key: key = key.replace("running_var", "var") key = "batch_stats."+key new_pt_state[key] = tensor else: continue for total_updated, (new_key, new_tensor) in enumerate(new_pt_state.items()): orig_flax_tensor = flax_state.get(tuple(new_key.split("."))) assert orig_flax_tensor is not None assert orig_flax_tensor.shape == new_tensor.shape flax_state[tuple(new_key.split("."))] = new_tensor flax_state = unflatten_dict(flax_state) flax_resnet.save_pretrained("resnet_50_flax", params=flax_state)