Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils import spectral_norm | |
from .conv_blocks import DownConv | |
from .conv_blocks import UpConv | |
from .conv_blocks import SeparableConv2D | |
from .conv_blocks import InvertedResBlock | |
from .conv_blocks import ConvBlock | |
from .layers import get_norm | |
from utils.common import initialize_weights | |
class GeneratorV1(nn.Module): | |
def __init__(self, dataset=''): | |
super(GeneratorV1, self).__init__() | |
self.name = f'{self.__class__.__name__}_{dataset}' | |
bias = False | |
self.encode_blocks = nn.Sequential( | |
ConvBlock(3, 64, bias=bias), | |
ConvBlock(64, 128, bias=bias), | |
DownConv(128, bias=bias), | |
ConvBlock(128, 128, bias=bias), | |
SeparableConv2D(128, 256, bias=bias), | |
DownConv(256, bias=bias), | |
ConvBlock(256, 256, bias=bias), | |
) | |
self.res_blocks = nn.Sequential( | |
InvertedResBlock(256, 256), | |
InvertedResBlock(256, 256), | |
InvertedResBlock(256, 256), | |
InvertedResBlock(256, 256), | |
InvertedResBlock(256, 256), | |
InvertedResBlock(256, 256), | |
InvertedResBlock(256, 256), | |
InvertedResBlock(256, 256), | |
) | |
self.decode_blocks = nn.Sequential( | |
ConvBlock(256, 128, bias=bias), | |
UpConv(128, bias=bias), | |
SeparableConv2D(128, 128, bias=bias), | |
ConvBlock(128, 128, bias=bias), | |
UpConv(128, bias=bias), | |
ConvBlock(128, 64, bias=bias), | |
ConvBlock(64, 64, bias=bias), | |
nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0, bias=bias), | |
nn.Tanh(), | |
) | |
initialize_weights(self) | |
def forward(self, x): | |
out = self.encode_blocks(x) | |
out = self.res_blocks(out) | |
img = self.decode_blocks(out) | |
return img | |
class Discriminator(nn.Module): | |
def __init__( | |
self, | |
dataset=None, | |
num_layers=1, | |
use_sn=False, | |
norm_type="instance", | |
): | |
super(Discriminator, self).__init__() | |
self.name = f'discriminator_{dataset}' | |
self.bias = False | |
channels = 32 | |
layers = [ | |
nn.Conv2d(3, channels, kernel_size=3, stride=1, padding=1, bias=self.bias), | |
nn.LeakyReLU(0.2, True) | |
] | |
in_channels = channels | |
for i in range(num_layers): | |
layers += [ | |
nn.Conv2d(in_channels, channels * 2, kernel_size=3, stride=2, padding=1, bias=self.bias), | |
nn.LeakyReLU(0.2, True), | |
nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1, bias=self.bias), | |
get_norm(norm_type, channels * 4), | |
nn.LeakyReLU(0.2, True), | |
] | |
in_channels = channels * 4 | |
channels *= 2 | |
channels *= 2 | |
layers += [ | |
nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=self.bias), | |
get_norm(norm_type, channels), | |
nn.LeakyReLU(0.2, True), | |
nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1, bias=self.bias), | |
] | |
if use_sn: | |
for i in range(len(layers)): | |
if isinstance(layers[i], nn.Conv2d): | |
layers[i] = spectral_norm(layers[i]) | |
self.discriminate = nn.Sequential(*layers) | |
initialize_weights(self) | |
def forward(self, img): | |
logits = self.discriminate(img) | |
return logits | |