|
|
|
|
|
import argparse |
|
from collections import OrderedDict |
|
|
|
import megengine as mge |
|
import torch |
|
|
|
|
|
def make_parser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("-w", "--weights", type=str, help="path of weight file") |
|
parser.add_argument( |
|
"-o", |
|
"--output", |
|
default="weight_mge.pkl", |
|
type=str, |
|
help="path of weight file", |
|
) |
|
return parser |
|
|
|
|
|
def numpy_weights(weight_file): |
|
torch_weights = torch.load(weight_file, map_location="cpu") |
|
if "model" in torch_weights: |
|
torch_weights = torch_weights["model"] |
|
new_dict = OrderedDict() |
|
for k, v in torch_weights.items(): |
|
new_dict[k] = v.cpu().numpy() |
|
return new_dict |
|
|
|
|
|
def map_weights(weight_file, output_file): |
|
torch_weights = numpy_weights(weight_file) |
|
|
|
new_dict = OrderedDict() |
|
for k, v in torch_weights.items(): |
|
if "num_batches_tracked" in k: |
|
print("drop: {}".format(k)) |
|
continue |
|
if k.endswith("bias"): |
|
print("bias key: {}".format(k)) |
|
v = v.reshape(1, -1, 1, 1) |
|
new_dict[k] = v |
|
elif "dconv" in k and "conv.weight" in k: |
|
print("depthwise conv key: {}".format(k)) |
|
cout, cin, k1, k2 = v.shape |
|
v = v.reshape(cout, 1, cin, k1, k2) |
|
new_dict[k] = v |
|
else: |
|
new_dict[k] = v |
|
|
|
mge.save(new_dict, output_file) |
|
print("save weights to {}".format(output_file)) |
|
|
|
|
|
def main(): |
|
parser = make_parser() |
|
args = parser.parse_args() |
|
map_weights(args.weights, args.output) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|