inversion_testing / models /hyperstyle /hypernetworks /shared_weights_hypernet.py
ethanNeuralImage's picture
fix GPU usage to be optional
6fa3e0e
raw
history blame
No virus
1.49 kB
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
class SharedWeightsHypernet(nn.Module):
def __init__(self, f_size=3, z_dim=512, out_size=512, in_size=512, mode=None, device='cuda'):
super(SharedWeightsHypernet, self).__init__()
self.device=device
self.mode = mode
self.z_dim = z_dim
self.f_size = f_size
if self.mode == 'delta_per_channel':
self.f_size = 1
self.out_size = out_size
self.in_size = in_size
self.w1 = Parameter(torch.fmod(torch.randn((self.z_dim, self.out_size * self.f_size * self.f_size)).to(self.device) / 40, 2))
self.b1 = Parameter(torch.fmod(torch.randn((self.out_size * self.f_size * self.f_size)).to(self.device) / 40, 2))
self.w2 = Parameter(torch.fmod(torch.randn((self.z_dim, self.in_size * self.z_dim)).to(self.device) / 40, 2))
self.b2 = Parameter(torch.fmod(torch.randn((self.in_size * self.z_dim)).to(self.device) / 40, 2))
def forward(self, z):
batch_size = z.shape[0]
h_in = torch.matmul(z, self.w2) + self.b2
h_in = h_in.view(batch_size, self.in_size, self.z_dim)
h_final = torch.matmul(h_in, self.w1) + self.b1
kernel = h_final.view(batch_size, self.out_size, self.in_size, self.f_size, self.f_size)
if self.mode == 'delta_per_channel': # repeat per channel values to the 3x3 conv kernels
kernel = kernel.repeat(1, 1, 1, 3, 3)
return kernel