File size: 4,507 Bytes
3a273df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import torch
import torch.nn as nn
import torch.nn.functional as fun
import numpy as np
class DownsamplingBlock(nn.Module):
"""
Downsamples the input to halve the dimensions while doubling the channels through two parallel conv + antialiased downsampling branches.
In: HxWxC
Out: H/2xW/2x2C
"""
def __init__(self, in_channels, bias=False):
super().__init__()
self.branch1 = (
nn.Sequential( # 1x1 conv + PReLU -> 3x3 conv + PReLU -> AD -> 1x1 conv
nn.Conv2d(
in_channels, in_channels, kernel_size=1, padding=0, bias=bias
),
nn.PReLU(),
nn.Conv2d(
in_channels, in_channels, kernel_size=3, padding=1, bias=bias
),
nn.PReLU(),
DownSample(channels=in_channels, filter_size=3, stride=2),
nn.Conv2d(
in_channels, in_channels * 2, kernel_size=1, padding=0, bias=bias
),
)
)
self.branch2 = nn.Sequential(
DownSample(channels=in_channels, filter_size=3, stride=2),
nn.Conv2d(
in_channels, in_channels * 2, kernel_size=1, padding=0, bias=bias
),
)
def forward(self, x):
return self.branch1(x) + self.branch2(x) # H/2xW/2x2C
class DownsamplingModule(nn.Module):
"""
Downsampling module of the network composed of (scaling factor) DownsamplingBlocks.
In: HxWxC
Out: H/2^(scaling factor) x W/2^(scaling factor) x C^2(scaling factor)
"""
def __init__(self, in_channels, scaling_factor, stride=2):
super().__init__()
self.scaling_factor = int(np.log2(scaling_factor))
blocks = []
for i in range(self.scaling_factor):
blocks.append(DownsamplingBlock(in_channels))
in_channels = int(in_channels * stride)
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
x = self.blocks(x)
return x # H/2^(scaling factor) x W/2^(scaling factor) x C^2(scaling factor)
class DownSample(nn.Module):
"""
Antialiased downsampling module using the blur-pooling method.
From Adobe's implementation available here: https://github.com/yilundu/improved_contrastive_divergence/blob/master/downsample.py
"""
def __init__(
self, pad_type="reflect", filter_size=3, stride=2, channels=None, pad_off=0
):
super().__init__()
self.filter_size = filter_size
self.stride = stride
self.pad_off = pad_off
self.channels = channels
self.pad_sizes = [
int(1.0 * (filter_size - 1) / 2),
int(np.ceil(1.0 * (filter_size - 1) / 2)),
int(1.0 * (filter_size - 1) / 2),
int(np.ceil(1.0 * (filter_size - 1) / 2)),
]
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
self.off = int((self.stride - 1) / 2.0)
if self.filter_size == 1:
a = np.array([1.0])
elif self.filter_size == 2:
a = np.array([1.0, 1.0])
elif self.filter_size == 3:
a = np.array([1.0, 2.0, 1.0])
elif self.filter_size == 4:
a = np.array([1.0, 3.0, 3.0, 1.0])
elif self.filter_size == 5:
a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
elif self.filter_size == 6:
a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
elif self.filter_size == 7:
a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
filt = torch.Tensor(a[:, None] * a[None, :])
filt = filt / torch.sum(filt)
self.register_buffer(
"filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
)
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
def forward(self, x):
if self.filter_size == 1:
if self.pad_off == 0:
return x[:, :, :: self.stride, :: self.stride]
else:
return self.pad(x)[:, :, :: self.stride, :: self.stride]
else:
return fun.conv2d(
self.pad(x), self.filt, stride=self.stride, groups=x.shape[1]
)
def get_pad_layer(pad_type):
if pad_type == "reflect":
pad_layer = nn.ReflectionPad2d
elif pad_type == "replication":
pad_layer = nn.ReplicationPad2d
else:
print("Pad Type [%s] not recognized" % pad_type)
return pad_layer
|