Spaces:
Build error
Build error
import abc | |
from typing import Tuple, List | |
import torch | |
import torch.nn as nn | |
from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv | |
from saicinpainting.training.modules.multidilated_conv import MultidilatedConv | |
class BaseDiscriminator(nn.Module): | |
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
""" | |
Predict scores and get intermediate activations. Useful for feature matching loss | |
:return tuple (scores, list of intermediate activations) | |
""" | |
raise NotImplemented() | |
def get_conv_block_ctor(kind='default'): | |
if not isinstance(kind, str): | |
return kind | |
if kind == 'default': | |
return nn.Conv2d | |
if kind == 'depthwise': | |
return DepthWiseSeperableConv | |
if kind == 'multidilated': | |
return MultidilatedConv | |
raise ValueError(f'Unknown convolutional block kind {kind}') | |
def get_norm_layer(kind='bn'): | |
if not isinstance(kind, str): | |
return kind | |
if kind == 'bn': | |
return nn.BatchNorm2d | |
if kind == 'in': | |
return nn.InstanceNorm2d | |
raise ValueError(f'Unknown norm block kind {kind}') | |
def get_activation(kind='tanh'): | |
if kind == 'tanh': | |
return nn.Tanh() | |
if kind == 'sigmoid': | |
return nn.Sigmoid() | |
if kind is False: | |
return nn.Identity() | |
raise ValueError(f'Unknown activation kind {kind}') | |
class SimpleMultiStepGenerator(nn.Module): | |
def __init__(self, steps: List[nn.Module]): | |
super().__init__() | |
self.steps = nn.ModuleList(steps) | |
def forward(self, x): | |
cur_in = x | |
outs = [] | |
for step in self.steps: | |
cur_out = step(cur_in) | |
outs.append(cur_out) | |
cur_in = torch.cat((cur_in, cur_out), dim=1) | |
return torch.cat(outs[::-1], dim=1) | |
def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features): | |
if kind == 'convtranspose': | |
return [nn.ConvTranspose2d(min(max_features, ngf * mult), | |
min(max_features, int(ngf * mult / 2)), | |
kernel_size=3, stride=2, padding=1, output_padding=1), | |
norm_layer(min(max_features, int(ngf * mult / 2))), activation] | |
elif kind == 'bilinear': | |
return [nn.Upsample(scale_factor=2, mode='bilinear'), | |
DepthWiseSeperableConv(min(max_features, ngf * mult), | |
min(max_features, int(ngf * mult / 2)), | |
kernel_size=3, stride=1, padding=1), | |
norm_layer(min(max_features, int(ngf * mult / 2))), activation] | |
else: | |
raise Exception(f"Invalid deconv kind: {kind}") |