ae_gen / convert.py
mehdidc's picture
add app and generation / model code
fa128ec
import numpy as np
import torch, h5py
from model import *
w, h, c = 28, 28, 1
model_new = DeepConvAE(
w=w, h=h, c=c,
nb_filters=128,
spatial=True,
channel=True,
channel_stride=4,
# total layers = nb_layers*2, where we have nb_layers for encoder and nb_layers for decoder
nb_layers=3,
)
# model_old = h5py.File("mnist_deepconvae/model.h5")
model_old = h5py.File("/home/mehdi/work/code/out_of_class/ae/mnist/model.h5")
print(model_new)
print(model_old["model_weights"].keys())
for name, param in model_new.named_parameters():
enc_or_decode, layer_id, bias_or_kernel = name.split(".")
if enc_or_decode == "encode":
layer_name = "conv2d"
else:
layer_name = "up_conv2d"
layer_id = (int(layer_id)//2) + 1
full_layer_name = f"{layer_name}_{layer_id}"
print(full_layer_name)
k = "kernel" if bias_or_kernel == "weight" else "bias"
weights = model_old["model_weights"][full_layer_name][full_layer_name][k][()]
weights = np.array(weights)
weights = torch.from_numpy(weights)
print(name, layer_id, param.shape, weights.shape)
inds = [4,3,2,1,0]
if k == "kernel":
if layer_name == "conv2d":
weights = weights.permute((3,2,0,1))
weights = weights[:,:,inds]
weights = weights[:,:,:, inds]
print("W", weights.shape)
elif layer_name == "up_conv2d":
weights = weights.permute((2,3,0,1))
print(param.shape, weights.shape)
param.data.copy_(weights)
print((param-weights).sum())
torch.save(model_new, "mnist_deepconvae/model.th")