import torch.nn as nn import torch.nn.functional as F from utils.common import initialize_weights from .layers import LayerNorm2d class DownConv(nn.Module): def __init__(self, channels, bias=False): super(DownConv, self).__init__() self.conv1 = SeparableConv2D(channels, channels, stride=2, bias=bias) self.conv2 = SeparableConv2D(channels, channels, stride=1, bias=bias) def forward(self, x): out1 = self.conv1(x) out2 = F.interpolate(x, scale_factor=0.5, mode='bilinear') out2 = self.conv2(out2) return out1 + out2 class UpConv(nn.Module): def __init__(self, channels, bias=False): super(UpConv, self).__init__() self.conv = SeparableConv2D(channels, channels, stride=1, bias=bias) def forward(self, x): out = F.interpolate(x, scale_factor=2.0, mode='bilinear') out = self.conv(out) return out class UpConvLNormLReLU(nn.Module): """Upsample Conv block with Layer Norm and Leaky ReLU""" def __init__(self, in_channels, out_channels, bias=False): super(UpConvLNormLReLU, self).__init__() self.conv_block = ConvBlock( in_channels, out_channels, kernel_size=3, bias=bias, ) def forward(self, x): out = F.interpolate(x, scale_factor=2.0, mode='bilinear') out = self.conv_block(out) return out class SeparableConv2D(nn.Module): def __init__(self, in_channels, out_channels, stride=1, bias=False): super(SeparableConv2D, self).__init__() self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=bias) self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=bias) # self.pad = self.ins_norm1 = nn.InstanceNorm2d(in_channels) self.activation1 = nn.LeakyReLU(0.2, True) self.ins_norm2 = nn.InstanceNorm2d(out_channels) self.activation2 = nn.LeakyReLU(0.2, True) initialize_weights(self) def forward(self, x): out = self.depthwise(x) out = self.ins_norm1(out) out = self.activation1(out) out = self.pointwise(out) out = self.ins_norm2(out) return self.activation2(out) class ConvBlock(nn.Module): """Stack of Conv2D + Norm + LeakyReLU""" def __init__( self, channels, out_channels, kernel_size=3, stride=1, padding="valid", bias=False, norm_type="instance" ): super(ConvBlock, self).__init__() if kernel_size == 3 and stride == 1: self.pad = nn.ReflectionPad2d((1, 1, 1, 1)) elif kernel_size == 7 and stride == 1: self.pad = nn.ReflectionPad2d((3, 3, 3, 3)) elif stride == 2: self.pad = nn.ReflectionPad2d((0, 1, 1, 0)) else: self.pad = None self.conv = nn.Conv2d( channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) if norm_type == "instance": self.ins_norm = nn.InstanceNorm2d(out_channels) elif norm_type == "layer": self.ins_norm = LayerNorm2d(out_channels) self.activation = nn.LeakyReLU(0.2, True) initialize_weights(self) def forward(self, x): if self.pad is not None: x = self.pad(x) out = self.conv(x) out = self.ins_norm(out) out = self.activation(out) return out class InvertedResBlock(nn.Module): def __init__( self, channels=256, out_channels=256, expand_ratio=2, bias=False, norm_type="instance", ): super(InvertedResBlock, self).__init__() bottleneck_dim = round(expand_ratio * channels) self.conv_block = ConvBlock( channels, bottleneck_dim, kernel_size=1, stride=1, padding=0, bias=bias ) self.depthwise_conv = nn.Conv2d( bottleneck_dim, bottleneck_dim, kernel_size=3, groups=bottleneck_dim, stride=1, padding=1, bias=bias ) self.conv = nn.Conv2d( bottleneck_dim, out_channels, kernel_size=1, stride=1, bias=bias ) if norm_type == "instance": self.ins_norm1 = nn.InstanceNorm2d(out_channels) self.ins_norm2 = nn.InstanceNorm2d(out_channels) elif norm_type == "layer": # Keep var name as is for v1 compatibility. self.ins_norm1 = LayerNorm2d(bottleneck_dim) self.ins_norm2 = LayerNorm2d(out_channels) self.activation = nn.LeakyReLU(0.2, True) initialize_weights(self) def forward(self, x): out = self.conv_block(x) out = self.depthwise_conv(out) out = self.ins_norm1(out) out = self.activation(out) out = self.conv(out) out = self.ins_norm2(out) if out.shape[1] != x.shape[1]: # Only concate if same shape return out return out + x