Spaces:
Configuration error
Configuration error
import torch | |
import torch.nn as nn | |
import numpy as np | |
from .downsampler import Downsampler | |
def add_module(self, module): | |
self.add_module(str(len(self) + 1), module) | |
torch.nn.Module.add = add_module | |
class Concat(nn.Module): | |
def __init__(self, dim, *args): | |
super(Concat, self).__init__() | |
self.dim = dim | |
for idx, module in enumerate(args): | |
self.add_module(str(idx), module) | |
def forward(self, input): | |
inputs = [] | |
for module in self._modules.values(): | |
inputs.append(module(input)) | |
inputs_shapes2 = [x.shape[2] for x in inputs] | |
inputs_shapes3 = [x.shape[3] for x in inputs] | |
if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all( | |
np.array(inputs_shapes3) == min(inputs_shapes3) | |
): | |
inputs_ = inputs | |
else: | |
target_shape2 = min(inputs_shapes2) | |
target_shape3 = min(inputs_shapes3) | |
inputs_ = [] | |
for inp in inputs: | |
diff2 = (inp.size(2) - target_shape2) // 2 | |
diff3 = (inp.size(3) - target_shape3) // 2 | |
inputs_.append(inp[:, :, diff2 : diff2 + target_shape2, diff3 : diff3 + target_shape3]) | |
return torch.cat(inputs_, dim=self.dim) | |
def __len__(self): | |
return len(self._modules) | |
class GenNoise(nn.Module): | |
def __init__(self, dim2): | |
super(GenNoise, self).__init__() | |
self.dim2 = dim2 | |
def forward(self, input): | |
a = list(input.size()) | |
a[1] = self.dim2 | |
# print (input.data.type()) | |
b = torch.zeros(a).type_as(input.data) | |
b.normal_() | |
x = torch.autograd.Variable(b) | |
return x | |
class Swish(nn.Module): | |
""" | |
https://arxiv.org/abs/1710.05941 | |
The hype was so huge that I could not help but try it | |
""" | |
def __init__(self): | |
super(Swish, self).__init__() | |
self.s = nn.Sigmoid() | |
def forward(self, x): | |
return x * self.s(x) | |
def act(act_fun="LeakyReLU"): | |
""" | |
Either string defining an activation function or module (e.g. nn.ReLU) | |
""" | |
if isinstance(act_fun, str): | |
if act_fun == "LeakyReLU": | |
return nn.LeakyReLU(0.2, inplace=True) | |
elif act_fun == "Swish": | |
return Swish() | |
elif act_fun == "ELU": | |
return nn.ELU() | |
elif act_fun == "none": | |
return nn.Sequential() | |
else: | |
assert False | |
else: | |
return act_fun() | |
class PixelNormLayer(nn.Module): | |
""" | |
Pixelwise feature vector normalization. | |
""" | |
def __init__(self, eps=1e-8): | |
super(PixelNormLayer, self).__init__() | |
self.eps = eps | |
def forward(self, x): | |
return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8) | |
def __repr__(self): | |
return self.__class__.__name__ + "(eps = %s)" % (self.eps) | |
def pixelnorm(num_features): | |
return PixelNormLayer() | |
def bn(num_features): | |
return nn.BatchNorm2d(num_features) | |
def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad="zero", downsample_mode="stride"): | |
downsampler = None | |
if stride != 1 and downsample_mode != "stride": | |
if downsample_mode == "avg": | |
downsampler = nn.AvgPool2d(stride, stride) | |
elif downsample_mode == "max": | |
downsampler = nn.MaxPool2d(stride, stride) | |
elif downsample_mode in ["lanczos2", "lanczos3"]: | |
downsampler = Downsampler( | |
n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True | |
) | |
else: | |
assert False | |
stride = 1 | |
padder = None | |
to_pad = int((kernel_size - 1) / 2) | |
if pad == "reflection": | |
padder = nn.ReflectionPad2d(to_pad) | |
to_pad = 0 | |
convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) | |
layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) | |
return nn.Sequential(*layers) | |
class DecorrelatedColorsToRGB(nn.Module): | |
"""Converts from a decorrelated color space to RGB. See | |
https://github.com/eps696/aphantasia/blob/master/aphantasia/image.py. Usually intended | |
to be followed by a sigmoid. | |
""" | |
def __init__(self, inv_color_scale=1.6): | |
super().__init__() | |
color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]]) | |
color_correlation_svd_sqrt /= torch.tensor([inv_color_scale, 1.0, 1.0]) # saturate, empirical | |
max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max() | |
color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt | |
self.register_buffer("colcorr_t", color_correlation_normalized.T) | |
def inverse(self, image): | |
colcorr_t_inv = torch.linalg.inv(self.colcorr_t) | |
return torch.einsum("nchw,cd->ndhw", image, colcorr_t_inv) | |
def forward(self, image): | |
if image.dim() == 4: | |
image_rgb, alpha = image[:, :3], image[:, 3].unsqueeze(1) | |
image_rgb = torch.einsum("nchw,cd->ndhw", image_rgb, self.colcorr_t) | |
image = torch.cat([image_rgb, alpha], dim=1) | |
else: | |
image = torch.einsum("nchw,cd->ndhw", image, self.colcorr_t) | |
return image | |