| import os | |
| import sys | |
| import torch.nn as nn | |
| sys.path.append(os.getcwd()) | |
| from main.library.predictors.DJCM.utils import ResConvBlock | |
| class ResEncoderBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| n_blocks, | |
| kernel_size | |
| ): | |
| super(ResEncoderBlock, self).__init__() | |
| self.conv = nn.ModuleList([ | |
| ResConvBlock( | |
| in_channels, | |
| out_channels | |
| ) | |
| ]) | |
| for _ in range(n_blocks - 1): | |
| self.conv.append( | |
| ResConvBlock( | |
| out_channels, | |
| out_channels | |
| ) | |
| ) | |
| self.pool = nn.MaxPool2d(kernel_size) if kernel_size is not None else None | |
| def forward(self, x): | |
| for each_layer in self.conv: | |
| x = each_layer(x) | |
| if self.pool is not None: return x, self.pool(x) | |
| return x | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| n_blocks | |
| ): | |
| super(Encoder, self).__init__() | |
| self.en_blocks = nn.ModuleList([ | |
| ResEncoderBlock( | |
| in_channels, | |
| 32, | |
| n_blocks, | |
| (1, 2) | |
| ), | |
| ResEncoderBlock( | |
| 32, | |
| 64, | |
| n_blocks, | |
| (1, 2) | |
| ), | |
| ResEncoderBlock( | |
| 64, | |
| 128, | |
| n_blocks, | |
| (1, 2) | |
| ), | |
| ResEncoderBlock( | |
| 128, | |
| 256, | |
| n_blocks, | |
| (1, 2) | |
| ), | |
| ResEncoderBlock( | |
| 256, | |
| 384, | |
| n_blocks, | |
| (1, 2) | |
| ), | |
| ResEncoderBlock( | |
| 384, | |
| 384, | |
| n_blocks, | |
| (1, 2) | |
| ) | |
| ]) | |
| def forward(self, x): | |
| concat_tensors = [] | |
| for layer in self.en_blocks: | |
| _, x = layer(x) | |
| concat_tensors.append(_) | |
| return x, concat_tensors |