SupermanxKiaski's picture
Upload 356 files
16d007c
raw
history blame
No virus
5.16 kB
import numpy as np
import torch
import torch.nn as nn
class Downsampler(nn.Module):
"""
http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf
"""
def __init__(
self, n_planes, factor, kernel_type, phase=0, kernel_width=None, support=None, sigma=None, preserve_size=False
):
super(Downsampler, self).__init__()
assert phase in [0, 0.5], "phase should be 0 or 0.5"
if kernel_type == "lanczos2":
support = 2
kernel_width = 4 * factor + 1
kernel_type_ = "lanczos"
elif kernel_type == "lanczos3":
support = 3
kernel_width = 6 * factor + 1
kernel_type_ = "lanczos"
elif kernel_type == "gauss12":
kernel_width = 7
sigma = 1 / 2
kernel_type_ = "gauss"
elif kernel_type == "gauss1sq2":
kernel_width = 9
sigma = 1.0 / np.sqrt(2)
kernel_type_ = "gauss"
elif kernel_type in ["lanczos", "gauss", "box"]:
kernel_type_ = kernel_type
else:
assert False, "wrong name kernel"
# note that `kernel width` will be different to actual size for phase = 1/2
self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma)
downsampler = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0)
downsampler.weight.data[:] = 0
downsampler.bias.data[:] = 0
kernel_torch = torch.from_numpy(self.kernel)
for i in range(n_planes):
downsampler.weight.data[i, i] = kernel_torch
self.downsampler_ = downsampler
if preserve_size:
if self.kernel.shape[0] % 2 == 1:
pad = int((self.kernel.shape[0] - 1) / 2.0)
else:
pad = int((self.kernel.shape[0] - factor) / 2.0)
self.padding = nn.ReplicationPad2d(pad)
self.preserve_size = preserve_size
def forward(self, input):
if self.preserve_size:
x = self.padding(input)
else:
x = input
self.x = x
return self.downsampler_(x)
def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None):
assert kernel_type in ["lanczos", "gauss", "box"]
# factor = float(factor)
if phase == 0.5 and kernel_type != "box":
kernel = np.zeros([kernel_width - 1, kernel_width - 1])
else:
kernel = np.zeros([kernel_width, kernel_width])
if kernel_type == "box":
assert phase == 0.5, "Box filter is always half-phased"
kernel[:] = 1.0 / (kernel_width * kernel_width)
elif kernel_type == "gauss":
assert sigma, "sigma is not specified"
assert phase != 0.5, "phase 1/2 for gauss not implemented"
center = (kernel_width + 1.0) / 2.0
print(center, kernel_width)
sigma_sq = sigma * sigma
for i in range(1, kernel.shape[0] + 1):
for j in range(1, kernel.shape[1] + 1):
di = (i - center) / 2.0
dj = (j - center) / 2.0
kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj) / (2 * sigma_sq))
kernel[i - 1][j - 1] = kernel[i - 1][j - 1] / (2.0 * np.pi * sigma_sq)
elif kernel_type == "lanczos":
assert support, "support is not specified"
center = (kernel_width + 1) / 2.0
for i in range(1, kernel.shape[0] + 1):
for j in range(1, kernel.shape[1] + 1):
if phase == 0.5:
di = abs(i + 0.5 - center) / factor
dj = abs(j + 0.5 - center) / factor
else:
di = abs(i - center) / factor
dj = abs(j - center) / factor
pi_sq = np.pi * np.pi
val = 1
if di != 0:
val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support)
val = val / (np.pi * np.pi * di * di)
if dj != 0:
val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj / support)
val = val / (np.pi * np.pi * dj * dj)
kernel[i - 1][j - 1] = val
else:
assert False, "wrong method name"
kernel /= kernel.sum()
return kernel
# a = Downsampler(n_planes=3, factor=2, kernel_type='lanczos2', phase='1', preserve_size=True)
#################
# Learnable downsampler
# KS = 32
# dow = nn.Sequential(nn.ReplicationPad2d(int((KS - factor) / 2.)), nn.Conv2d(1,1,KS,factor))
# class Apply(nn.Module):
# def __init__(self, what, dim, *args):
# super(Apply, self).__init__()
# self.dim = dim
# self.what = what
# def forward(self, input):
# inputs = []
# for i in range(input.size(self.dim)):
# inputs.append(self.what(input.narrow(self.dim, i, 1)))
# return torch.cat(inputs, dim=self.dim)
# def __len__(self):
# return len(self._modules)
# downs = Apply(dow, 1)
# downs.type(dtype)(net_input.type(dtype)).size()