Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
class Gauss_Pyramid_Conv(nn.Module): | |
""" | |
Code borrowed from: https://github.com/csjliang/LPTN | |
""" | |
def __init__(self, num_high=3): | |
super(Gauss_Pyramid_Conv, self).__init__() | |
self.num_high = num_high | |
self.kernel = self.gauss_kernel() | |
def gauss_kernel(self, device=torch.device('cuda'), channels=3): | |
kernel = torch.tensor([[1., 4., 6., 4., 1], | |
[4., 16., 24., 16., 4.], | |
[6., 24., 36., 24., 6.], | |
[4., 16., 24., 16., 4.], | |
[1., 4., 6., 4., 1.]]) | |
kernel /= 256. | |
kernel = kernel.repeat(channels, 1, 1, 1) | |
kernel = kernel.to(device) | |
return kernel | |
def downsample(self, x): | |
return x[:, :, ::2, ::2] | |
def conv_gauss(self, img, kernel): | |
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect') | |
out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) | |
return out | |
def forward(self, img): | |
current = img | |
pyr = [] | |
for _ in range(self.num_high): | |
filtered = self.conv_gauss(current, self.kernel) | |
pyr.append(filtered) | |
down = self.downsample(filtered) | |
current = down | |
pyr.append(current) | |
return pyr | |