SupermanxKiaski's picture
Upload 356 files
16d007c
raw
history blame
No virus
5.3 kB
import torch
import torch.nn as nn
import numpy as np
from .downsampler import Downsampler
def add_module(self, module):
self.add_module(str(len(self) + 1), module)
torch.nn.Module.add = add_module
class Concat(nn.Module):
def __init__(self, dim, *args):
super(Concat, self).__init__()
self.dim = dim
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def forward(self, input):
inputs = []
for module in self._modules.values():
inputs.append(module(input))
inputs_shapes2 = [x.shape[2] for x in inputs]
inputs_shapes3 = [x.shape[3] for x in inputs]
if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(
np.array(inputs_shapes3) == min(inputs_shapes3)
):
inputs_ = inputs
else:
target_shape2 = min(inputs_shapes2)
target_shape3 = min(inputs_shapes3)
inputs_ = []
for inp in inputs:
diff2 = (inp.size(2) - target_shape2) // 2
diff3 = (inp.size(3) - target_shape3) // 2
inputs_.append(inp[:, :, diff2 : diff2 + target_shape2, diff3 : diff3 + target_shape3])
return torch.cat(inputs_, dim=self.dim)
def __len__(self):
return len(self._modules)
class GenNoise(nn.Module):
def __init__(self, dim2):
super(GenNoise, self).__init__()
self.dim2 = dim2
def forward(self, input):
a = list(input.size())
a[1] = self.dim2
# print (input.data.type())
b = torch.zeros(a).type_as(input.data)
b.normal_()
x = torch.autograd.Variable(b)
return x
class Swish(nn.Module):
"""
https://arxiv.org/abs/1710.05941
The hype was so huge that I could not help but try it
"""
def __init__(self):
super(Swish, self).__init__()
self.s = nn.Sigmoid()
def forward(self, x):
return x * self.s(x)
def act(act_fun="LeakyReLU"):
"""
Either string defining an activation function or module (e.g. nn.ReLU)
"""
if isinstance(act_fun, str):
if act_fun == "LeakyReLU":
return nn.LeakyReLU(0.2, inplace=True)
elif act_fun == "Swish":
return Swish()
elif act_fun == "ELU":
return nn.ELU()
elif act_fun == "none":
return nn.Sequential()
else:
assert False
else:
return act_fun()
class PixelNormLayer(nn.Module):
"""
Pixelwise feature vector normalization.
"""
def __init__(self, eps=1e-8):
super(PixelNormLayer, self).__init__()
self.eps = eps
def forward(self, x):
return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8)
def __repr__(self):
return self.__class__.__name__ + "(eps = %s)" % (self.eps)
def pixelnorm(num_features):
return PixelNormLayer()
def bn(num_features):
return nn.BatchNorm2d(num_features)
def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad="zero", downsample_mode="stride"):
downsampler = None
if stride != 1 and downsample_mode != "stride":
if downsample_mode == "avg":
downsampler = nn.AvgPool2d(stride, stride)
elif downsample_mode == "max":
downsampler = nn.MaxPool2d(stride, stride)
elif downsample_mode in ["lanczos2", "lanczos3"]:
downsampler = Downsampler(
n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True
)
else:
assert False
stride = 1
padder = None
to_pad = int((kernel_size - 1) / 2)
if pad == "reflection":
padder = nn.ReflectionPad2d(to_pad)
to_pad = 0
convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias)
layers = filter(lambda x: x is not None, [padder, convolver, downsampler])
return nn.Sequential(*layers)
class DecorrelatedColorsToRGB(nn.Module):
"""Converts from a decorrelated color space to RGB. See
https://github.com/eps696/aphantasia/blob/master/aphantasia/image.py. Usually intended
to be followed by a sigmoid.
"""
def __init__(self, inv_color_scale=1.6):
super().__init__()
color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]])
color_correlation_svd_sqrt /= torch.tensor([inv_color_scale, 1.0, 1.0]) # saturate, empirical
max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max()
color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt
self.register_buffer("colcorr_t", color_correlation_normalized.T)
def inverse(self, image):
colcorr_t_inv = torch.linalg.inv(self.colcorr_t)
return torch.einsum("nchw,cd->ndhw", image, colcorr_t_inv)
def forward(self, image):
if image.dim() == 4:
image_rgb, alpha = image[:, :3], image[:, 3].unsqueeze(1)
image_rgb = torch.einsum("nchw,cd->ndhw", image_rgb, self.colcorr_t)
image = torch.cat([image_rgb, alpha], dim=1)
else:
image = torch.einsum("nchw,cd->ndhw", image, self.colcorr_t)
return image