File size: 1,607 Bytes
fa128ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
import torch, h5py
from model import *
w, h, c = 28, 28, 1
model_new = DeepConvAE(
    w=w, h=h, c=c, 
    nb_filters=128, 
    spatial=True, 
    channel=True, 
    channel_stride=4,
    # total layers = nb_layers*2, where we have nb_layers for encoder and nb_layers for decoder
    nb_layers=3, 
)
# model_old = h5py.File("mnist_deepconvae/model.h5")
model_old = h5py.File("/home/mehdi/work/code/out_of_class/ae/mnist/model.h5")


print(model_new)
print(model_old["model_weights"].keys())


for name, param in model_new.named_parameters():
    enc_or_decode, layer_id, bias_or_kernel = name.split(".")

    if enc_or_decode == "encode":
        layer_name = "conv2d"
    else:
        layer_name = "up_conv2d"

    layer_id = (int(layer_id)//2) + 1

    full_layer_name = f"{layer_name}_{layer_id}"
    print(full_layer_name)

    k = "kernel" if bias_or_kernel == "weight" else "bias"
    weights =  model_old["model_weights"][full_layer_name][full_layer_name][k][()]
    weights = np.array(weights)
    weights = torch.from_numpy(weights)
    print(name, layer_id, param.shape, weights.shape)
    inds = [4,3,2,1,0]
    if k == "kernel":
        if layer_name == "conv2d":
            weights = weights.permute((3,2,0,1))
            weights = weights[:,:,inds]
            weights = weights[:,:,:, inds]
            print("W", weights.shape)
        elif layer_name == "up_conv2d":
            weights = weights.permute((2,3,0,1))
    print(param.shape, weights.shape)
    param.data.copy_(weights)
    print((param-weights).sum())
torch.save(model_new, "mnist_deepconvae/model.th")