|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.utils import spectral_norm |
|
|
|
class DiscriminatorVGG(nn.Module): |
|
def __init__(self, in_ch=3, image_size=128, d=64): |
|
super(DiscriminatorVGG, self).__init__() |
|
self.feature_map_size = image_size // 32 |
|
self.d = d |
|
|
|
self.features = nn.Sequential( |
|
nn.Conv2d(in_ch, d, kernel_size=3, stride=1, padding=1), |
|
nn.LeakyReLU(negative_slope=0.2, inplace=True), |
|
|
|
nn.Conv2d(d, d, kernel_size=3, stride=2, padding=1, bias=False), |
|
nn.BatchNorm2d(d), |
|
nn.LeakyReLU(negative_slope=0.2, inplace=True), |
|
|
|
nn.Conv2d(d, d*2, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(d*2), |
|
nn.LeakyReLU(negative_slope=0.2, inplace=True), |
|
|
|
nn.Conv2d(d*2, d*2, kernel_size=3, stride=2, padding=1, bias=False), |
|
nn.BatchNorm2d(d*2), |
|
nn.LeakyReLU(negative_slope=0.2, inplace=True), |
|
|
|
nn.Conv2d(d*2, d*4, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(d*4), |
|
nn.LeakyReLU(negative_slope=0.2, inplace=True), |
|
|
|
nn.Conv2d(d*4, d*4, kernel_size=3, stride=2, padding=1, bias=False), |
|
nn.BatchNorm2d(d*4), |
|
nn.LeakyReLU(negative_slope=0.2, inplace=True), |
|
|
|
nn.Conv2d(d*4, d*8, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(d*8), |
|
nn.LeakyReLU(negative_slope=0.2, inplace=True), |
|
|
|
nn.Conv2d(d*8, d*8, kernel_size=3, stride=2, padding=1, bias=False), |
|
nn.BatchNorm2d(d*8), |
|
nn.LeakyReLU(negative_slope=0.2, inplace=True), |
|
|
|
nn.Conv2d(d*8, d*8, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(d*8), |
|
nn.LeakyReLU(negative_slope=0.2, inplace=True), |
|
|
|
nn.Conv2d(d*8, d*8, kernel_size=3, stride=2, padding=1, bias=False), |
|
nn.BatchNorm2d(d*8), |
|
nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
) |
|
|
|
self.classifier = nn.Sequential( |
|
nn.Linear((self.d*8) * self.feature_map_size * self.feature_map_size, 100), |
|
nn.LeakyReLU(negative_slope=0.2, inplace=True), |
|
nn.Linear(100, 1) |
|
) |
|
|
|
def forward(self, x): |
|
out = self.features(x) |
|
out = torch.flatten(out, 1) |
|
out = self.classifier(out) |
|
|
|
return out |
|
|
|
|
|
class UNetDiscriminator(nn.Module): |
|
def __init__(self, num_in_ch, num_feat=64, skip_connection=True): |
|
super(UNetDiscriminator, self).__init__() |
|
self.skip_connection = skip_connection |
|
norm = spectral_norm |
|
self.num_in_ch = num_in_ch |
|
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) |
|
|
|
self.conv1 = norm(nn.Conv2d(num_feat, num_feat*2, kernel_size=4, stride=2, padding=1, bias=False)) |
|
self.conv2 = norm(nn.Conv2d(num_feat*2, num_feat*4, kernel_size=4, stride=2, padding=1, bias=False)) |
|
self.conv3 = norm(nn.Conv2d(num_feat*4, num_feat*8, kernel_size=4, stride=2, padding=1, bias=False)) |
|
|
|
|
|
self.conv4 = norm(nn.Conv2d(num_feat*8, num_feat*4, kernel_size=3, stride=1, padding=1, bias=False)) |
|
self.conv5 = norm(nn.Conv2d(num_feat*4, num_feat*2, kernel_size=3, stride=1, padding=1, bias=False)) |
|
self.conv6 = norm(nn.Conv2d(num_feat*2, num_feat, kernel_size=3, stride=1, padding=1, bias=False)) |
|
|
|
|
|
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1, bias=False)) |
|
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1, bias=False)) |
|
|
|
self.conv9 = nn.Conv2d(num_feat, 1, kernel_size=3, stride=1, padding=1) |
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
|
|
def forward(self, x): |
|
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) |
|
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) |
|
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) |
|
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) |
|
|
|
|
|
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear') |
|
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) |
|
|
|
if self.skip_connection: |
|
x4 = x4 + x2 |
|
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear') |
|
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) |
|
|
|
if self.skip_connection: |
|
x5 = x5 + x1 |
|
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear') |
|
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) |
|
|
|
if self.skip_connection: |
|
x6 = x6 + x0 |
|
|
|
|
|
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) |
|
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) |
|
out = self.conv9(out) |
|
out = self.avg_pool(out) |
|
out = torch.flatten(out, 1) |
|
return out |
|
|