USR-DA / model /discriminator.py
DS
dump shiet
e5b70eb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
class DiscriminatorVGG(nn.Module):
def __init__(self, in_ch=3, image_size=128, d=64):
super(DiscriminatorVGG, self).__init__()
self.feature_map_size = image_size // 32
self.d = d
self.features = nn.Sequential(
nn.Conv2d(in_ch, d, kernel_size=3, stride=1, padding=1), # input is 3 x 128 x 128
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(d, d, kernel_size=3, stride=2, padding=1, bias=False), # state size. 64 x 64 x 64
nn.BatchNorm2d(d),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(d, d*2, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(d*2),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(d*2, d*2, kernel_size=3, stride=2, padding=1, bias=False), # state size. 128 x 32 x 32
nn.BatchNorm2d(d*2),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(d*2, d*4, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(d*4),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(d*4, d*4, kernel_size=3, stride=2, padding=1, bias=False), # state size. 256 x 16 x 16
nn.BatchNorm2d(d*4),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(d*4, d*8, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(d*8),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(d*8, d*8, kernel_size=3, stride=2, padding=1, bias=False), # state size. 512 x 8 x 8
nn.BatchNorm2d(d*8),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(d*8, d*8, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(d*8),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(d*8, d*8, kernel_size=3, stride=2, padding=1, bias=False), # state size. 512 x 4 x 4
nn.BatchNorm2d(d*8),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
self.classifier = nn.Sequential(
nn.Linear((self.d*8) * self.feature_map_size * self.feature_map_size, 100),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(100, 1)
)
def forward(self, x):
out = self.features(x)
out = torch.flatten(out, 1)
out = self.classifier(out)
return out
class UNetDiscriminator(nn.Module):
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
super(UNetDiscriminator, self).__init__()
self.skip_connection = skip_connection
norm = spectral_norm
self.num_in_ch = num_in_ch
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
self.conv1 = norm(nn.Conv2d(num_feat, num_feat*2, kernel_size=4, stride=2, padding=1, bias=False))
self.conv2 = norm(nn.Conv2d(num_feat*2, num_feat*4, kernel_size=4, stride=2, padding=1, bias=False))
self.conv3 = norm(nn.Conv2d(num_feat*4, num_feat*8, kernel_size=4, stride=2, padding=1, bias=False))
# upsample
self.conv4 = norm(nn.Conv2d(num_feat*8, num_feat*4, kernel_size=3, stride=1, padding=1, bias=False))
self.conv5 = norm(nn.Conv2d(num_feat*4, num_feat*2, kernel_size=3, stride=1, padding=1, bias=False))
self.conv6 = norm(nn.Conv2d(num_feat*2, num_feat, kernel_size=3, stride=1, padding=1, bias=False))
# extra
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1, bias=False))
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1, bias=False))
self.conv9 = nn.Conv2d(num_feat, 1, kernel_size=3, stride=1, padding=1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
# upsample
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear')
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
if self.skip_connection:
x4 = x4 + x2
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear')
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
if self.skip_connection:
x5 = x5 + x1
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear')
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
if self.skip_connection:
x6 = x6 + x0
# extra
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
out = self.conv9(out)
out = self.avg_pool(out)
out = torch.flatten(out, 1)
return out