import numpy as np import torch, h5py from model import * w, h, c = 28, 28, 1 model_new = DeepConvAE( w=w, h=h, c=c, nb_filters=128, spatial=True, channel=True, channel_stride=4, # total layers = nb_layers*2, where we have nb_layers for encoder and nb_layers for decoder nb_layers=3, ) # model_old = h5py.File("mnist_deepconvae/model.h5") model_old = h5py.File("/home/mehdi/work/code/out_of_class/ae/mnist/model.h5") print(model_new) print(model_old["model_weights"].keys()) for name, param in model_new.named_parameters(): enc_or_decode, layer_id, bias_or_kernel = name.split(".") if enc_or_decode == "encode": layer_name = "conv2d" else: layer_name = "up_conv2d" layer_id = (int(layer_id)//2) + 1 full_layer_name = f"{layer_name}_{layer_id}" print(full_layer_name) k = "kernel" if bias_or_kernel == "weight" else "bias" weights = model_old["model_weights"][full_layer_name][full_layer_name][k][()] weights = np.array(weights) weights = torch.from_numpy(weights) print(name, layer_id, param.shape, weights.shape) inds = [4,3,2,1,0] if k == "kernel": if layer_name == "conv2d": weights = weights.permute((3,2,0,1)) weights = weights[:,:,inds] weights = weights[:,:,:, inds] print("W", weights.shape) elif layer_name == "up_conv2d": weights = weights.permute((2,3,0,1)) print(param.shape, weights.shape) param.data.copy_(weights) print((param-weights).sum()) torch.save(model_new, "mnist_deepconvae/model.th")