LogicGoInfotechSpaces's picture
Bundle StyleFeatureEditor code packages in Space to fix ModuleNotFoundError
95b1715
from collections import OrderedDict
import torch
import torch.nn as nn
def shape_to_num_params(shapes):
return torch.sum(torch.tensor([torch.prod(s) for s in shapes])).int().item()
class WeightRegressor(nn.Module):
"""Regressing features to convolution weight kernel"""
def __init__(self, input_dim, hidden_dim, kernel_size=3, out_channels=16, in_channels=16):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.out_channels = out_channels
self.in_channels = in_channels
# Feature Transformer
self.fusion = nn.Sequential(
nn.Conv2d(2 * self.input_dim, self.input_dim, kernel_size=1, padding=0, stride=1, bias=True),
nn.InstanceNorm2d(self.input_dim),
nn.ReLU(),
)
self.feature_extractor = nn.Sequential(
nn.Conv2d(self.input_dim, 64, kernel_size=3, padding=1, stride=1, bias=True),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),
nn.ReLU(),
nn.Conv2d(64, self.hidden_dim, kernel_size=4, stride=2, padding=1, bias=True),
nn.ReLU(),
)
# Linear Mapper
self.w1 = nn.Parameter(torch.randn((self.hidden_dim, self.in_channels * self.hidden_dim)))
self.b1 = nn.Parameter(torch.randn((self.in_channels * self.hidden_dim)))
self.w2 = nn.Parameter(torch.randn((self.hidden_dim, self.out_channels * self.kernel_size * self.kernel_size)))
self.b2 = nn.Parameter(torch.randn((self.out_channels * self.kernel_size * self.kernel_size)))
self.weight_init()
def weight_init(self):
nn.init.kaiming_normal_(self.w1)
nn.init.zeros_(self.b1)
nn.init.kaiming_normal_(self.w2)
nn.init.zeros_(self.b2)
def forward(self, w_image_codes, w_bar_codes):
bs = w_image_codes.size(0)
# Feature Transformation
out = self.fusion(torch.cat((w_image_codes, w_bar_codes), 1))
out = self.feature_extractor(out)
out = out.view(bs, -1)
# Linear map to weights
out = torch.matmul(out, self.w1) + self.b1
out = out.view(bs, self.in_channels, self.hidden_dim)
out = torch.matmul(out, self.w2) + self.b2
kernel = out.view(bs, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
return kernel
class Hypernetwork(nn.Module):
def __init__(self, input_dim=512, hidden_dim=64, target_shape=None):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.target_shape = target_shape
num_predicted_weights = 0
weight_regressors = OrderedDict()
for layer_name in target_shape:
new_layer_name = "_".join(layer_name.split("."))
shape = target_shape[layer_name]["shape"]
# without consider bias
if len(shape) == 4:
out_channels, in_channels, kernel_size = shape[:3]
else:
out_channels, in_channels = shape
kernel_size = 1
num_predicted_weights += shape_to_num_params([torch.tensor(list(shape))])
weight_regressors[new_layer_name] = WeightRegressor(
input_dim=self.input_dim,
hidden_dim=self.hidden_dim,
kernel_size=kernel_size,
out_channels=out_channels,
in_channels=in_channels,
)
self.weight_regressors = nn.ModuleDict(weight_regressors)
self.num_predicted_weights = num_predicted_weights
def forward(self, w_image_codes, w_bar_codes):
bs = w_image_codes.size(0)
out_weights = {}
for layer_name in self.weight_regressors:
ori_layer_name = ".".join(layer_name.split("_"))
w_idx = self.target_shape[ori_layer_name]["w_idx"]
weights = self.weight_regressors[layer_name](
w_image_codes[:, w_idx, :, :, :], w_bar_codes[:, w_idx, :, :, :]
)
out_weights[ori_layer_name] = weights.view(bs, *list(self.target_shape[ori_layer_name]["shape"]))
return out_weights