from torch import nn import torch.nn.functional as F import torch class ResBlock2d(nn.Module): def __init__(self, in_features, kernel_size, padding): super(ResBlock2d, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) self.norm1 = nn.Conv2d( in_channels=in_features, out_channels=in_features, kernel_size=1) self.norm2 = nn.Conv2d( in_channels=in_features, out_channels=in_features, kernel_size=1) def forward(self, x): out = self.norm1(x) out = F.relu(out, inplace=True) out = self.conv1(out) out = self.norm2(out) out = F.relu(out, inplace=True) out = self.conv2(out) out += x return out class RGBADecoderNet(nn.Module): def __init__(self, c=64, out_planes=4, num_bottleneck_blocks=1): super(RGBADecoderNet, self).__init__() self.conv_rgba = nn.Sequential(nn.Conv2d(c, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, bias=True)) self.bottleneck = torch.nn.Sequential() for i in range(num_bottleneck_blocks): self.bottleneck.add_module( 'r' + str(i), ResBlock2d(c, kernel_size=(3, 3), padding=(1, 1))) def forward(self, features_weighted_mask_atfeaturesscale_list=[]): return torch.sigmoid(self.conv_rgba(self.bottleneck(features_weighted_mask_atfeaturesscale_list.pop(0))))