File size: 1,494 Bytes
92ec8d3
 
 
 
 
 
 
6fa3e0e
92ec8d3
6fa3e0e
92ec8d3
 
 
 
 
 
 
 
6fa3e0e
 
92ec8d3
6fa3e0e
 
92ec8d3
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter


class SharedWeightsHypernet(nn.Module):

    def __init__(self, f_size=3, z_dim=512, out_size=512, in_size=512, mode=None, device='cuda'):
        super(SharedWeightsHypernet, self).__init__()
        self.device=device
        self.mode = mode
        self.z_dim = z_dim
        self.f_size = f_size
        if self.mode == 'delta_per_channel':
            self.f_size = 1
        self.out_size = out_size
        self.in_size = in_size

        self.w1 = Parameter(torch.fmod(torch.randn((self.z_dim, self.out_size * self.f_size * self.f_size)).to(self.device) / 40, 2))
        self.b1 = Parameter(torch.fmod(torch.randn((self.out_size * self.f_size * self.f_size)).to(self.device) / 40, 2))

        self.w2 = Parameter(torch.fmod(torch.randn((self.z_dim, self.in_size * self.z_dim)).to(self.device) / 40, 2))
        self.b2 = Parameter(torch.fmod(torch.randn((self.in_size * self.z_dim)).to(self.device) / 40, 2))

    def forward(self, z):
        batch_size = z.shape[0]
        h_in = torch.matmul(z, self.w2) + self.b2
        h_in = h_in.view(batch_size, self.in_size, self.z_dim)

        h_final = torch.matmul(h_in, self.w1) + self.b1
        kernel = h_final.view(batch_size, self.out_size, self.in_size, self.f_size, self.f_size)
        if self.mode == 'delta_per_channel':  # repeat per channel values to the 3x3 conv kernels
            kernel = kernel.repeat(1, 1, 1, 3, 3)
        return kernel