AK391
files
d380b77
raw history blame
No virus
2.68 kB
import abc
from typing import Tuple, List
import torch
import torch.nn as nn
from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
from saicinpainting.training.modules.multidilated_conv import MultidilatedConv
class BaseDiscriminator(nn.Module):
@abc.abstractmethod
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Predict scores and get intermediate activations. Useful for feature matching loss
:return tuple (scores, list of intermediate activations)
"""
raise NotImplemented()
def get_conv_block_ctor(kind='default'):
if not isinstance(kind, str):
return kind
if kind == 'default':
return nn.Conv2d
if kind == 'depthwise':
return DepthWiseSeperableConv
if kind == 'multidilated':
return MultidilatedConv
raise ValueError(f'Unknown convolutional block kind {kind}')
def get_norm_layer(kind='bn'):
if not isinstance(kind, str):
return kind
if kind == 'bn':
return nn.BatchNorm2d
if kind == 'in':
return nn.InstanceNorm2d
raise ValueError(f'Unknown norm block kind {kind}')
def get_activation(kind='tanh'):
if kind == 'tanh':
return nn.Tanh()
if kind == 'sigmoid':
return nn.Sigmoid()
if kind is False:
return nn.Identity()
raise ValueError(f'Unknown activation kind {kind}')
class SimpleMultiStepGenerator(nn.Module):
def __init__(self, steps: List[nn.Module]):
super().__init__()
self.steps = nn.ModuleList(steps)
def forward(self, x):
cur_in = x
outs = []
for step in self.steps:
cur_out = step(cur_in)
outs.append(cur_out)
cur_in = torch.cat((cur_in, cur_out), dim=1)
return torch.cat(outs[::-1], dim=1)
def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features):
if kind == 'convtranspose':
return [nn.ConvTranspose2d(min(max_features, ngf * mult),
min(max_features, int(ngf * mult / 2)),
kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(min(max_features, int(ngf * mult / 2))), activation]
elif kind == 'bilinear':
return [nn.Upsample(scale_factor=2, mode='bilinear'),
DepthWiseSeperableConv(min(max_features, ngf * mult),
min(max_features, int(ngf * mult / 2)),
kernel_size=3, stride=1, padding=1),
norm_layer(min(max_features, int(ngf * mult / 2))), activation]
else:
raise Exception(f"Invalid deconv kind: {kind}")