import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor __all__ = [ "ResidualDenseBlock", "ResidualResidualDenseBlock", "Generator", "DownSamplingNetwork" ] class ResidualDenseBlock(nn.Module): """Achieves densely connected convolutional layers. `Densely Connected Convolutional Networks" ` paper. Args: channels (int): The number of channels in the input image. growths (int): The number of channels that increase in each layer of convolution. """ def __init__(self, channels: int, growths: int) -> None: super(ResidualDenseBlock, self).__init__() self.conv1 = nn.Conv2d(channels + growths * 0, growths, (3, 3), (1, 1), (1, 1)) self.conv2 = nn.Conv2d(channels + growths * 1, growths, (3, 3), (1, 1), (1, 1)) self.conv3 = nn.Conv2d(channels + growths * 2, growths, (3, 3), (1, 1), (1, 1)) self.conv4 = nn.Conv2d(channels + growths * 3, growths, (3, 3), (1, 1), (1, 1)) self.conv5 = nn.Conv2d(channels + growths * 4, channels, (3, 3), (1, 1), (1, 1)) self.leaky_relu = nn.LeakyReLU(0.2, True) self.identity = nn.Identity() def forward(self, x: Tensor) -> Tensor: identity = x out1 = self.leaky_relu(self.conv1(x)) out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1))) out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1))) out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1))) out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1))) out = out5 * 0.2 + identity return out class ResidualDenseBlock(nn.Module): """Achieves densely connected convolutional layers. `Densely Connected Convolutional Networks" ` paper. Args: channels (int): The number of channels in the input image. growths (int): The number of channels that increase in each layer of convolution. """ def __init__(self, channels: int, growths: int) -> None: super(ResidualDenseBlock, self).__init__() self.conv1 = nn.Conv2d(channels + growths * 0, growths, (3, 3), (1, 1), (1, 1)) self.conv2 = nn.Conv2d(channels + growths * 1, growths, (3, 3), (1, 1), (1, 1)) self.conv3 = nn.Conv2d(channels + growths * 2, growths, (3, 3), (1, 1), (1, 1)) self.conv4 = nn.Conv2d(channels + growths * 3, growths, (3, 3), (1, 1), (1, 1)) self.conv5 = nn.Conv2d(channels + growths * 4, channels, (3, 3), (1, 1), (1, 1)) self.leaky_relu = nn.LeakyReLU(0.2, True) self.identity = nn.Identity() def forward(self, x: Tensor) -> Tensor: identity = x out1 = self.leaky_relu(self.conv1(x)) out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1))) out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1))) out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1))) out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1))) out = out5 * 0.2 + identity return out class MiniResidualDenseBlock(nn.Module): """Achieves densely connected convolutional layers. `Densely Connected Convolutional Networks" ` paper. Args: channels (int): The number of channels in the input image. growths (int): The number of channels that increase in each layer of convolution. """ def __init__(self, channels: int, growths: int) -> None: super(MiniResidualDenseBlock, self).__init__() self.conv1 = nn.Conv2d(channels + growths * 0, growths, (3, 3), (1, 1), (1, 1)) self.conv2 = nn.Conv2d(channels + growths * 1, growths, (3, 3), (1, 1), (1, 1)) self.conv3 = nn.Conv2d(channels + growths * 2, growths, (3, 3), (1, 1), (1, 1)) self.conv4 = nn.Conv2d(channels + growths * 3, growths, (3, 3), (1, 1), (1, 1)) self.conv5 = nn.Conv2d(channels + growths * 4, channels, (3, 3), (1, 1), (1, 1)) self.leaky_relu = nn.LeakyReLU(0.2, True) def forward(self, x: Tensor) -> Tensor: identity = x out1 = self.leaky_relu(self.conv1(x)) out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1))) out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1))) out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1))) out5 = self.leaky_relu(self.conv5(torch.cat([x, out1, out2, out3, out4], 1))) out = out5 * 0.2 + identity return out class ResidualResidualDenseBlock(nn.Module): """Multi-layer residual dense convolution block. Args: channels (int): The number of channels in the input image. growths (int): The number of channels that increase in each layer of convolution. """ def __init__(self, channels: int, growths: int) -> None: super(ResidualResidualDenseBlock, self).__init__() self.rdb1 = ResidualDenseBlock(channels, growths) self.rdb2 = ResidualDenseBlock(channels, growths) self.rdb3 = ResidualDenseBlock(channels, growths) def forward(self, x: torch.Tensor) -> torch.Tensor: identity = x out = self.rdb1(x) out = self.rdb2(out) out = self.rdb3(out) out = out * 0.2 + identity return out class MiniResidualResidualDenseBlock(nn.Module): """Multi-layer residual dense convolution block. Args: channels (int): The number of channels in the input image. growths (int): The number of channels that increase in each layer of convolution. """ def __init__(self, channels: int, growths: int) -> None: super(MiniResidualResidualDenseBlock, self).__init__() self.M_rdb1 = MiniResidualDenseBlock(channels, growths) self.M_rdb2 = MiniResidualDenseBlock(channels, growths) self.M_rdb3 = MiniResidualDenseBlock(channels, growths) def forward(self, x: torch.Tensor) -> torch.Tensor: identity = x out = self.M_rdb1(x) out = self.M_rdb2(out) out = self.M_rdb3(out) out = out * 0.2 + identity return out class Generator(nn.Module): def __init__(self) -> None: super(Generator, self).__init__() #RLNet self.RLNetconv_block1 = nn.Conv2d(1, 64, (3, 3), (1, 1), (1, 1)) RLNettrunk = [] for _ in range(4): RLNettrunk += [ResidualResidualDenseBlock(64, 32)] self.RLNettrunk = nn.Sequential(*RLNettrunk) self.RLNetconv_block2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) self.RLNetconv_block3 = nn.Sequential( nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True) ) self.RLNetconv_block4 = nn.Sequential( nn.Conv2d(64, 1, (3, 3), (1, 1), (1, 1)), nn.Tanh() ) ############################################################################# # Generator self.conv_block1 = nn.Conv2d(1, 64, (3, 3), (1, 1), (1, 1)) trunk = [] for _ in range(16): trunk += [ResidualResidualDenseBlock(64, 32)] self.trunk = nn.Sequential(*trunk) # After the feature extraction network, reconnect a layer of convolutional blocks. self.conv_block2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) # Upsampling convolutional layer. self.upsampling = nn.Sequential( nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True) ) # Reconnect a layer of convolution block after upsampling. self.conv_block3 = nn.Sequential( nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True) ) self.conv_block4 = nn.Sequential( nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), #nn.Sigmoid() ) self.conv_block0_branch0 = nn.Sequential( nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True), nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True), nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True), nn.Conv2d(128, 64, (3, 3), (1, 1), (1, 1)), nn.Tanh() ) self.conv_block0_branch1 = nn.Sequential( nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True), nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True), nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True), nn.Conv2d(128, 64, (3, 3), (1, 1), (1, 1)), nn.Tanh() ) self.conv_block1_branch0 = nn.Sequential( nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True), nn.Conv2d(64, 1, (3, 3), (1, 1), (1, 1)), #nn.LeakyReLU(0.2, True), #nn.Conv2d(32, 1, (3, 3), (1, 1), (1, 1)), nn.Sigmoid() ) self.conv_block1_branch1 = nn.Sequential( nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), nn.LeakyReLU(0.2, True), nn.Conv2d(64, 1, (3, 3), (1, 1), (1, 1)), nn.Sigmoid()) def _forward_impl(self, x: Tensor) -> Tensor: #RLNet out1 = self.RLNetconv_block1(x) out = self.RLNettrunk(out1) out2 = self.RLNetconv_block2(out) out = out1 + out2 out = self.RLNetconv_block3(out) out = self.RLNetconv_block4(out) rlNet_out = out + x #Generator out1 = self.conv_block1(rlNet_out) out = self.trunk(out1) out2 = self.conv_block2(out) out = out1 + out2 out = self.upsampling(F.interpolate(out, scale_factor=2, mode="bicubic")) out = self.upsampling(F.interpolate(out, scale_factor=2, mode="bicubic")) out = self.conv_block3(out) # out = self.conv_block4(out) #demResidual = out[:, 1:2, :, :] #grayResidual = out[:, 0:1, :, :] # out = self.trunkRGB(out_4) # # out_dem = out[:, 3:4, :, :] * 0.2 + demResidual # DEM images extracted # out_rgb = out[:, 0:3, :, :] * 0.2 + rgbResidual # RGB images extracted #ra0 #out_rgb= rgbResidual + self.conv_block0_branch0(rgbResidual) out_dem = out + self.conv_block0_branch1(out) #out+ tanh() out_gray = out + self.conv_block0_branch0(out) #out+ tanh() out_gray = self.conv_block1_branch0(out_gray) #sigmoid() out_dem = self.conv_block1_branch1(out_dem) #sigmoid() return out_gray, out_dem, rlNet_out def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) def _initialize_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) m.weight.data *= 0.1 elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) m.weight.data *= 0.1