In [1]:
import torch
import torch.functional as F
import torch.nn as nn
import yaml
from easydict import EasyDict
from torchinfo import summary

In [2]:
config_path = "./config/paper_config.yml"

with open(config_path, "r") as file:
 yaml_data = yaml.safe_load(file)

config = EasyDict(yaml_data)

# Encoder 

In [3]:
"""
downsampling blocks 
(first half of the 'U' in UNet) 
[ENCODER]
"""


class EncoderLayer(nn.Module):
 def __init__(
 self,
 in_channels=1,
 out_channels=64,
 n_layers=2,
 all_padding=False,
 maxpool=True,
 ):
 super(EncoderLayer, self).__init__()

 f_in_channel = lambda layer: in_channels if layer == 0 else out_channels
 f_padding = lambda layer: 1 if layer >= 2 or all_padding else 0

 self.layer = nn.Sequential(
 *[
 self._conv_relu_layer(
 in_channels=f_in_channel(i),
 out_channels=out_channels,
 padding=f_padding(i),
 )
 for i in range(n_layers)
 ]
 )
 self.maxpool = maxpool

 def _conv_relu_layer(self, in_channels, out_channels, padding=0):
 return nn.Sequential(
 nn.Conv2d(
 in_channels=in_channels,
 out_channels=out_channels,
 kernel_size=3,
 padding=padding,
 ),
 nn.ReLU(),
 )

 def forward(self, x):
 return self.layer(x)


summary(
 EncoderLayer(in_channels=1, out_channels=64, n_layers=2, all_padding=False).cuda(),
 input_size=(1, 572, 572),
)

Layer (type:depth-idx) Output Shape Param #
EncoderLayer [64, 568, 568] --
├─Sequential: 1-1 [64, 568, 568] --
│ └─Sequential: 2-1 [64, 570, 570] --
│ │ └─Conv2d: 3-1 [64, 570, 570] 640
│ │ └─ReLU: 3-2 [64, 570, 570] --
│ └─Sequential: 2-2 [64, 568, 568] --
│ │ └─Conv2d: 3-3 [64, 568, 568] 36,928
│ │ └─ReLU: 3-4 [64, 568, 568] --
Total params: 37,568
Trainable params: 37,568
Non-trainable params: 0
Total mult-adds (G): 1.37
Input size (MB): 1.31
Forward/backward pass size (MB): 331.53
Params size (MB): 0.15
Estimated Total Size (MB): 332.99

In [4]:
class Encoder(nn.Module):
 def __init__(self, config):
 super(Encoder, self).__init__()
 self.encoder = nn.ModuleDict(
 {
 name: EncoderLayer(
 in_channels=block["in_channels"],
 out_channels=block["out_channels"],
 n_layers=block["n_layers"],
 all_padding=block["all_padding"],
 maxpool=block["maxpool"],
 )
 for name, block in config.items()
 }
 )
 self.maxpool = nn.MaxPool2d(2)

 def forward(self, x):
 output = dict()

 for i, (block_name, block) in enumerate(self.encoder.items()):
 x = block(x)
 output[block_name] = x

 if block.maxpool:
 x = self.maxpool(x)

 return x, output


summary(
 Encoder(config.encoder_config).cuda(),
 input_size=(1, 572, 572),
)

Layer (type:depth-idx) Output Shape Param #
Encoder [1024, 28, 28] --
├─ModuleDict: 1-9 -- (recursive)
│ └─EncoderLayer: 2-1 [64, 568, 568] --
│ │ └─Sequential: 3-1 [64, 568, 568] 37,568
├─MaxPool2d: 1-2 [64, 284, 284] --
├─ModuleDict: 1-9 -- (recursive)
│ └─EncoderLayer: 2-2 [128, 280, 280] --
│ │ └─Sequential: 3-2 [128, 280, 280] 221,440
├─MaxPool2d: 1-4 [128, 140, 140] --
├─ModuleDict: 1-9 -- (recursive)
│ └─EncoderLayer: 2-3 [256, 136, 136] --
│ │ └─Sequential: 3-3 [256, 136, 136] 885,248
├─MaxPool2d: 1-6 [256, 68, 68] --
├─ModuleDict: 1-9 -- (recursive)
│ └─EncoderLayer: 2-4 [512, 64, 64] --
│ │ └─Sequential: 3-4 [512, 64, 64] 3,539,968
├─MaxPool2d: 1-8 [512, 32, 32] --
├─ModuleDict: 1-9 -- (recursive)
│ └─EncoderLayer: 2-5 [512, 28, 28] --
│ │ └─Sequential: 3-5 [512, 28, 28] 4,719,616
│ └─EncoderLayer: 2-6 [1024, 28, 28] --
│ │ └─Sequential: 3-6 [1024, 28, 28] 14,157,824
Total params: 23,561,664
Trainable params: 23,561,664
Non-trainable params: 0
Total mult-adds (G): 633.51
Inpu

# Decoder

In [5]:
class DecoderLayer(nn.Module):
 def __init__(
 self, in_channels, out_channels, kernel_size=2, stride=2, padding=[0, 0]
 ):
 super(DecoderLayer, self).__init__()
 self.up_conv = nn.ConvTranspose2d(
 in_channels=in_channels,
 out_channels=in_channels // 2,
 kernel_size=kernel_size,
 stride=stride,
 padding=padding[0],
 )

 self.conv = nn.Sequential(
 *[
 self._conv_relu_layer(
 in_channels=in_channels if i == 0 else out_channels,
 out_channels=out_channels,
 padding=padding[1],
 )
 for i in range(2)
 ]
 )

 def _conv_relu_layer(self, in_channels, out_channels, padding=0):
 return nn.Sequential(
 nn.Conv2d(
 in_channels=in_channels,
 out_channels=out_channels,
 kernel_size=3,
 padding=padding,
 ),
 nn.ReLU(),
 )

 @staticmethod
 def crop_cat(x, encoder_output):
 delta = (encoder_output.shape[-1] - x.shape[-1]) // 2
 encoder_output = encoder_output[
 :, :, delta : delta + x.shape[-1], delta : delta + x.shape[-1]
 ]
 return torch.cat((encoder_output, x), dim=1)

 def forward(self, x, encoder_output):
 x = self.crop_cat(self.up_conv(x), encoder_output)
 return self.conv(x)


# summary
input_data = [torch.rand((1, 1024, 28, 28)), torch.rand((1, 512, 64, 64))]
summary(
 DecoderLayer(in_channels=1024, out_channels=512),
 input_data=input_data,
)

Layer (type:depth-idx) Output Shape Param #
DecoderLayer [1, 512, 52, 52] --
├─ConvTranspose2d: 1-1 [1, 512, 56, 56] 2,097,664
├─Sequential: 1-2 [1, 512, 52, 52] --
│ └─Sequential: 2-1 [1, 512, 54, 54] --
│ │ └─Conv2d: 3-1 [1, 512, 54, 54] 4,719,104
│ │ └─ReLU: 3-2 [1, 512, 54, 54] --
│ └─Sequential: 2-2 [1, 512, 52, 52] --
│ │ └─Conv2d: 3-3 [1, 512, 52, 52] 2,359,808
│ │ └─ReLU: 3-4 [1, 512, 52, 52] --
Total params: 9,176,576
Trainable params: 9,176,576
Non-trainable params: 0
Total mult-adds (G): 26.72
Input size (MB): 11.60
Forward/backward pass size (MB): 35.86
Params size (MB): 36.71
Estimated Total Size (MB): 84.17

In [6]:
class Decoder(nn.Module):
 def __init__(self, config):
 super(Decoder, self).__init__()
 self.decoder = nn.ModuleDict(
 {
 name: DecoderLayer(
 in_channels=block["in_channels"],
 out_channels=block["out_channels"],
 kernel_size=block["kernel_size"],
 stride=block["stride"],
 padding=block["padding"],
 )
 for name, block in config.items()
 }
 )

 def forward(self, x, encoder_output):
 for name, block in self.decoder.items():
 x = block(x, encoder_output[name])
 return x


# summary
encoder_input = torch.rand((1, 1, 572, 572), device="cuda")
x, encoder_output = Encoder(config.encoder_config).cuda()(encoder_input)

input_data = [x, encoder_output]
summary(
 Decoder(config.decoder_config).cuda(),
 input_data=input_data,
)

Layer (type:depth-idx) Output Shape Param #
Decoder [1, 64, 388, 388] --
├─ModuleDict: 1-1 -- --
│ └─DecoderLayer: 2-1 [1, 1024, 28, 28] --
│ │ └─ConvTranspose2d: 3-1 [1, 512, 28, 28] 4,719,104
│ │ └─Sequential: 3-2 [1, 1024, 28, 28] 18,876,416
│ └─DecoderLayer: 2-2 [1, 512, 52, 52] --
│ │ └─ConvTranspose2d: 3-3 [1, 512, 56, 56] 2,097,664
│ │ └─Sequential: 3-4 [1, 512, 52, 52] 7,078,912
│ └─DecoderLayer: 2-3 [1, 256, 100, 100] --
│ │ └─ConvTranspose2d: 3-5 [1, 256, 104, 104] 524,544
│ │ └─Sequential: 3-6 [1, 256, 100, 100] 1,769,984
│ └─DecoderLayer: 2-4 [1, 128, 196, 196] --
│ │ └─ConvTranspose2d: 3-7 [1, 128, 200, 200] 131,200
│ │ └─Sequential: 3-8 [1, 128, 196, 196] 442,624
│ └─DecoderLayer: 2-5 [1, 64, 388, 388] --
│ │ └─ConvTranspose2d: 3-9 [1, 64, 392, 392] 32,832
│ │ └─Sequential: 3-10 [1, 64, 388, 388] 110,720
Total params: 35,784,000
Trainable params: 35,784,000
Non-trainable params: 0
Total mult-adds (G): 113.38
Input size (MB): 158.09
Forward/backward pass size (MB): 469.93


# UNet

In [7]:
class UNet(nn.Module):
 def __init__(self, encoder_config, decoder_config, nclasses):
 super(UNet, self).__init__()
 self.encoder = Encoder(config=encoder_config)
 self.decoder = Decoder(config=decoder_config)

 self.output = nn.Conv2d(
 in_channels=decoder_config["block1"]["out_channels"],
 out_channels=nclasses,
 kernel_size=1,
 )

 def forward(self, x):
 x, encoder_step_output = self.encoder(x)
 x = self.decoder(x, encoder_step_output)
 return self.output(x)


summary(
 UNet(
 config["encoder_config"], config["decoder_config"], nclasses=config["nclasses"]
 ),
 input_data=torch.rand((1, 1, 572, 572)),
)

Layer (type:depth-idx) Output Shape Param #
UNet [1, 2, 388, 388] --
├─Encoder: 1-1 [1, 1024, 28, 28] --
│ └─ModuleDict: 2-9 -- (recursive)
│ │ └─EncoderLayer: 3-1 [1, 64, 568, 568] 37,568
│ └─MaxPool2d: 2-2 [1, 64, 284, 284] --
│ └─ModuleDict: 2-9 -- (recursive)
│ │ └─EncoderLayer: 3-2 [1, 128, 280, 280] 221,440
│ └─MaxPool2d: 2-4 [1, 128, 140, 140] --
│ └─ModuleDict: 2-9 -- (recursive)
│ │ └─EncoderLayer: 3-3 [1, 256, 136, 136] 885,248
│ └─MaxPool2d: 2-6 [1, 256, 68, 68] --
│ └─ModuleDict: 2-9 -- (recursive)
│ │ └─EncoderLayer: 3-4 [1, 512, 64, 64] 3,539,968
│ └─MaxPool2d: 2-8 [1, 512, 32, 32] --
│ └─ModuleDict: 2-9 -- (recursive)
│ │ └─EncoderLayer: 3-5 [1, 512, 28, 28] 4,719,616
│ │ └─EncoderLayer: 3-6 [1, 1024, 28, 28] 14,157,824
├─Decoder: 1-2 [1, 64, 388, 388] --
│ └─ModuleDict: 2-10 -- --
│ │ └─DecoderLayer: 3-7 [1, 1024, 28, 28] 23,595,520
│ │ └─DecoderLayer: 3-8 [1, 512, 52, 52] 9,176,576
│ │ └─DecoderLayer: 3-9 [1, 256, 100, 100] 2,294,528
│ │ └─DecoderLayer: 3-10 [1, 128, 1

In [13]:
# # if config is a dict
# with open('custom_config.yml', 'w') as outfile:
# yaml.dump(config, outfile, sort_keys=False)