Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from torch.nn.parameter import Parameter | |
class SharedWeightsHypernet(nn.Module): | |
def __init__(self, f_size=3, z_dim=512, out_size=512, in_size=512, mode=None, device='cuda'): | |
super(SharedWeightsHypernet, self).__init__() | |
self.device=device | |
self.mode = mode | |
self.z_dim = z_dim | |
self.f_size = f_size | |
if self.mode == 'delta_per_channel': | |
self.f_size = 1 | |
self.out_size = out_size | |
self.in_size = in_size | |
self.w1 = Parameter(torch.fmod(torch.randn((self.z_dim, self.out_size * self.f_size * self.f_size)).to(self.device) / 40, 2)) | |
self.b1 = Parameter(torch.fmod(torch.randn((self.out_size * self.f_size * self.f_size)).to(self.device) / 40, 2)) | |
self.w2 = Parameter(torch.fmod(torch.randn((self.z_dim, self.in_size * self.z_dim)).to(self.device) / 40, 2)) | |
self.b2 = Parameter(torch.fmod(torch.randn((self.in_size * self.z_dim)).to(self.device) / 40, 2)) | |
def forward(self, z): | |
batch_size = z.shape[0] | |
h_in = torch.matmul(z, self.w2) + self.b2 | |
h_in = h_in.view(batch_size, self.in_size, self.z_dim) | |
h_final = torch.matmul(h_in, self.w1) + self.b1 | |
kernel = h_final.view(batch_size, self.out_size, self.in_size, self.f_size, self.f_size) | |
if self.mode == 'delta_per_channel': # repeat per channel values to the 3x3 conv kernels | |
kernel = kernel.repeat(1, 1, 1, 3, 3) | |
return kernel | |