|
import argparse |
|
import hashlib |
|
import os |
|
|
|
import mxnet as mx |
|
import gluoncv |
|
import torch |
|
from timm import create_model |
|
|
|
parser = argparse.ArgumentParser(description='Convert from MXNet') |
|
parser.add_argument('--model', default='all', type=str, metavar='MODEL', |
|
help='Name of model to train (default: "all"') |
|
|
|
|
|
def convert(mxnet_name, torch_name): |
|
|
|
net = gluoncv.model_zoo.get_model(mxnet_name, pretrained=True) |
|
|
|
|
|
torch_net = create_model(torch_name) |
|
|
|
mxp = [(k, v) for k, v in net.collect_params().items() if 'running' not in k] |
|
torchp = list(torch_net.named_parameters()) |
|
torch_params = {} |
|
|
|
|
|
|
|
|
|
|
|
for (tn, tv), (mn, mv) in zip(torchp, mxp): |
|
m_split = mn.split('_') |
|
t_split = tn.split('.') |
|
print(t_split, m_split) |
|
print(tv.shape, mv.shape) |
|
|
|
|
|
if m_split[-1] == 'gamma': |
|
assert t_split[-1] == 'weight' |
|
if m_split[-1] == 'beta': |
|
assert t_split[-1] == 'bias' |
|
|
|
|
|
assert all(t == m for t, m in zip(tv.shape, mv.shape)) |
|
|
|
torch_tensor = torch.from_numpy(mv.data().asnumpy()) |
|
torch_params[tn] = torch_tensor |
|
|
|
|
|
mxb = [(k, v) for k, v in net.collect_params().items() if any(x in k for x in ['running_mean', 'running_var'])] |
|
torchb = [(k, v) for k, v in torch_net.named_buffers() if 'num_batches' not in k] |
|
for (tn, tv), (mn, mv) in zip(torchb, mxb): |
|
print(tn, mn) |
|
print(tv.shape, mv.shape) |
|
|
|
|
|
if 'running_var' in tn: |
|
assert 'running_var' in mn |
|
if 'running_mean' in tn: |
|
assert 'running_mean' in mn |
|
|
|
torch_tensor = torch.from_numpy(mv.data().asnumpy()) |
|
torch_params[tn] = torch_tensor |
|
|
|
torch_net.load_state_dict(torch_params) |
|
torch_filename = './%s.pth' % torch_name |
|
torch.save(torch_net.state_dict(), torch_filename) |
|
with open(torch_filename, 'rb') as f: |
|
sha_hash = hashlib.sha256(f.read()).hexdigest() |
|
final_filename = os.path.splitext(torch_filename)[0] + '-' + sha_hash[:8] + '.pth' |
|
os.rename(torch_filename, final_filename) |
|
print("=> Saved converted model to '{}, SHA256: {}'".format(final_filename, sha_hash)) |
|
|
|
|
|
def map_mx_to_torch_model(mx_name): |
|
torch_name = mx_name.lower() |
|
if torch_name.startswith('se_'): |
|
torch_name = torch_name.replace('se_', 'se') |
|
elif torch_name.startswith('senet_'): |
|
torch_name = torch_name.replace('senet_', 'senet') |
|
elif torch_name.startswith('inceptionv3'): |
|
torch_name = torch_name.replace('inceptionv3', 'inception_v3') |
|
torch_name = 'gluon_' + torch_name |
|
return torch_name |
|
|
|
|
|
ALL = ['resnet18_v1b', 'resnet34_v1b', 'resnet50_v1b', 'resnet101_v1b', 'resnet152_v1b', |
|
'resnet50_v1c', 'resnet101_v1c', 'resnet152_v1c', 'resnet50_v1d', 'resnet101_v1d', 'resnet152_v1d', |
|
|
|
'resnet50_v1s', 'resnet101_v1s', 'resnet152_v1s', 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', |
|
'se_resnext50_32x4d', 'se_resnext101_32x4d', 'se_resnext101_64x4d', 'senet_154', 'inceptionv3'] |
|
|
|
|
|
def main(): |
|
args = parser.parse_args() |
|
|
|
if not args.model or args.model == 'all': |
|
for mx_model in ALL: |
|
torch_model = map_mx_to_torch_model(mx_model) |
|
convert(mx_model, torch_model) |
|
else: |
|
mx_model = args.model |
|
torch_model = map_mx_to_torch_model(mx_model) |
|
convert(mx_model, torch_model) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|