{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "310eb987-37b7-4620-b533-089644fbb440", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.functional as F\n", "import torch.nn as nn\n", "import yaml\n", "from easydict import EasyDict\n", "from torchinfo import summary" ] }, { "cell_type": "code", "execution_count": 2, "id": "f8cff897-df8f-4e6d-893b-321805699e1b", "metadata": {}, "outputs": [], "source": [ "config_path = \"./config/paper_config.yml\"\n", "\n", "with open(config_path, \"r\") as file:\n", " yaml_data = yaml.safe_load(file)\n", "\n", "config = EasyDict(yaml_data)" ] }, { "cell_type": "markdown", "id": "ca66846e-d2b4-4dd2-83eb-eee746c26c74", "metadata": {}, "source": [ "# Encoder " ] }, { "cell_type": "code", "execution_count": 3, "id": "975a6f86-68ff-4fda-b7d8-acf453addade", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", "EncoderLayer [64, 568, 568] --\n", "├─Sequential: 1-1 [64, 568, 568] --\n", "│ └─Sequential: 2-1 [64, 570, 570] --\n", "│ │ └─Conv2d: 3-1 [64, 570, 570] 640\n", "│ │ └─ReLU: 3-2 [64, 570, 570] --\n", "│ └─Sequential: 2-2 [64, 568, 568] --\n", "│ │ └─Conv2d: 3-3 [64, 568, 568] 36,928\n", "│ │ └─ReLU: 3-4 [64, 568, 568] --\n", "==========================================================================================\n", "Total params: 37,568\n", "Trainable params: 37,568\n", "Non-trainable params: 0\n", "Total mult-adds (G): 1.37\n", "==========================================================================================\n", "Input size (MB): 1.31\n", "Forward/backward pass size (MB): 331.53\n", "Params size (MB): 0.15\n", "Estimated Total Size (MB): 332.99\n", "==========================================================================================" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"\n", "downsampling blocks \n", "(first half of the 'U' in UNet) \n", "[ENCODER]\n", "\"\"\"\n", "\n", "\n", "class EncoderLayer(nn.Module):\n", " def __init__(\n", " self,\n", " in_channels=1,\n", " out_channels=64,\n", " n_layers=2,\n", " all_padding=False,\n", " maxpool=True,\n", " ):\n", " super(EncoderLayer, self).__init__()\n", "\n", " f_in_channel = lambda layer: in_channels if layer == 0 else out_channels\n", " f_padding = lambda layer: 1 if layer >= 2 or all_padding else 0\n", "\n", " self.layer = nn.Sequential(\n", " *[\n", " self._conv_relu_layer(\n", " in_channels=f_in_channel(i),\n", " out_channels=out_channels,\n", " padding=f_padding(i),\n", " )\n", " for i in range(n_layers)\n", " ]\n", " )\n", " self.maxpool = maxpool\n", "\n", " def _conv_relu_layer(self, in_channels, out_channels, padding=0):\n", " return nn.Sequential(\n", " nn.Conv2d(\n", " in_channels=in_channels,\n", " out_channels=out_channels,\n", " kernel_size=3,\n", " padding=padding,\n", " ),\n", " nn.ReLU(),\n", " )\n", "\n", " def forward(self, x):\n", " return self.layer(x)\n", "\n", "\n", "summary(\n", " EncoderLayer(in_channels=1, out_channels=64, n_layers=2, all_padding=False).cuda(),\n", " input_size=(1, 572, 572),\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "id": "4eb7eedd-6530-44e2-9486-fbd8f39fd0ad", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", "Encoder [1024, 28, 28] --\n", "├─ModuleDict: 1-9 -- (recursive)\n", "│ └─EncoderLayer: 2-1 [64, 568, 568] --\n", "│ │ └─Sequential: 3-1 [64, 568, 568] 37,568\n", "├─MaxPool2d: 1-2 [64, 284, 284] --\n", "├─ModuleDict: 1-9 -- (recursive)\n", "│ └─EncoderLayer: 2-2 [128, 280, 280] --\n", "│ │ └─Sequential: 3-2 [128, 280, 280] 221,440\n", "├─MaxPool2d: 1-4 [128, 140, 140] --\n", "├─ModuleDict: 1-9 -- (recursive)\n", "│ └─EncoderLayer: 2-3 [256, 136, 136] --\n", "│ │ └─Sequential: 3-3 [256, 136, 136] 885,248\n", "├─MaxPool2d: 1-6 [256, 68, 68] --\n", "├─ModuleDict: 1-9 -- (recursive)\n", "│ └─EncoderLayer: 2-4 [512, 64, 64] --\n", "│ │ └─Sequential: 3-4 [512, 64, 64] 3,539,968\n", "├─MaxPool2d: 1-8 [512, 32, 32] --\n", "├─ModuleDict: 1-9 -- (recursive)\n", "│ └─EncoderLayer: 2-5 [512, 28, 28] --\n", "│ │ └─Sequential: 3-5 [512, 28, 28] 4,719,616\n", "│ └─EncoderLayer: 2-6 [1024, 28, 28] --\n", "│ │ └─Sequential: 3-6 [1024, 28, 28] 14,157,824\n", "==========================================================================================\n", "Total params: 23,561,664\n", "Trainable params: 23,561,664\n", "Non-trainable params: 0\n", "Total mult-adds (G): 633.51\n", "==========================================================================================\n", "Input size (MB): 1.31\n", "Forward/backward pass size (MB): 624.49\n", "Params size (MB): 94.25\n", "Estimated Total Size (MB): 720.05\n", "==========================================================================================" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Encoder(nn.Module):\n", " def __init__(self, config):\n", " super(Encoder, self).__init__()\n", " self.encoder = nn.ModuleDict(\n", " {\n", " name: EncoderLayer(\n", " in_channels=block[\"in_channels\"],\n", " out_channels=block[\"out_channels\"],\n", " n_layers=block[\"n_layers\"],\n", " all_padding=block[\"all_padding\"],\n", " maxpool=block[\"maxpool\"],\n", " )\n", " for name, block in config.items()\n", " }\n", " )\n", " self.maxpool = nn.MaxPool2d(2)\n", "\n", " def forward(self, x):\n", " output = dict()\n", "\n", " for i, (block_name, block) in enumerate(self.encoder.items()):\n", " x = block(x)\n", " output[block_name] = x\n", "\n", " if block.maxpool:\n", " x = self.maxpool(x)\n", "\n", " return x, output\n", "\n", "\n", "summary(\n", " Encoder(config.encoder_config).cuda(),\n", " input_size=(1, 572, 572),\n", ")" ] }, { "cell_type": "markdown", "id": "a7ad06cb-61a2-4a66-ba58-f29d402a81f2", "metadata": {}, "source": [ "# Decoder" ] }, { "cell_type": "code", "execution_count": 5, "id": "735322d0-0dc3-4137-b906-ac7e54c43a79", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", "DecoderLayer [1, 512, 52, 52] --\n", "├─ConvTranspose2d: 1-1 [1, 512, 56, 56] 2,097,664\n", "├─Sequential: 1-2 [1, 512, 52, 52] --\n", "│ └─Sequential: 2-1 [1, 512, 54, 54] --\n", "│ │ └─Conv2d: 3-1 [1, 512, 54, 54] 4,719,104\n", "│ │ └─ReLU: 3-2 [1, 512, 54, 54] --\n", "│ └─Sequential: 2-2 [1, 512, 52, 52] --\n", "│ │ └─Conv2d: 3-3 [1, 512, 52, 52] 2,359,808\n", "│ │ └─ReLU: 3-4 [1, 512, 52, 52] --\n", "==========================================================================================\n", "Total params: 9,176,576\n", "Trainable params: 9,176,576\n", "Non-trainable params: 0\n", "Total mult-adds (G): 26.72\n", "==========================================================================================\n", "Input size (MB): 11.60\n", "Forward/backward pass size (MB): 35.86\n", "Params size (MB): 36.71\n", "Estimated Total Size (MB): 84.17\n", "==========================================================================================" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class DecoderLayer(nn.Module):\n", " def __init__(\n", " self, in_channels, out_channels, kernel_size=2, stride=2, padding=[0, 0]\n", " ):\n", " super(DecoderLayer, self).__init__()\n", " self.up_conv = nn.ConvTranspose2d(\n", " in_channels=in_channels,\n", " out_channels=in_channels // 2,\n", " kernel_size=kernel_size,\n", " stride=stride,\n", " padding=padding[0],\n", " )\n", "\n", " self.conv = nn.Sequential(\n", " *[\n", " self._conv_relu_layer(\n", " in_channels=in_channels if i == 0 else out_channels,\n", " out_channels=out_channels,\n", " padding=padding[1],\n", " )\n", " for i in range(2)\n", " ]\n", " )\n", "\n", " def _conv_relu_layer(self, in_channels, out_channels, padding=0):\n", " return nn.Sequential(\n", " nn.Conv2d(\n", " in_channels=in_channels,\n", " out_channels=out_channels,\n", " kernel_size=3,\n", " padding=padding,\n", " ),\n", " nn.ReLU(),\n", " )\n", "\n", " @staticmethod\n", " def crop_cat(x, encoder_output):\n", " delta = (encoder_output.shape[-1] - x.shape[-1]) // 2\n", " encoder_output = encoder_output[\n", " :, :, delta : delta + x.shape[-1], delta : delta + x.shape[-1]\n", " ]\n", " return torch.cat((encoder_output, x), dim=1)\n", "\n", " def forward(self, x, encoder_output):\n", " x = self.crop_cat(self.up_conv(x), encoder_output)\n", " return self.conv(x)\n", "\n", "\n", "# summary\n", "input_data = [torch.rand((1, 1024, 28, 28)), torch.rand((1, 512, 64, 64))]\n", "summary(\n", " DecoderLayer(in_channels=1024, out_channels=512),\n", " input_data=input_data,\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "id": "3795e85d-ff83-457c-9c12-af6cc6e2830c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", "Decoder [1, 64, 388, 388] --\n", "├─ModuleDict: 1-1 -- --\n", "│ └─DecoderLayer: 2-1 [1, 1024, 28, 28] --\n", "│ │ └─ConvTranspose2d: 3-1 [1, 512, 28, 28] 4,719,104\n", "│ │ └─Sequential: 3-2 [1, 1024, 28, 28] 18,876,416\n", "│ └─DecoderLayer: 2-2 [1, 512, 52, 52] --\n", "│ │ └─ConvTranspose2d: 3-3 [1, 512, 56, 56] 2,097,664\n", "│ │ └─Sequential: 3-4 [1, 512, 52, 52] 7,078,912\n", "│ └─DecoderLayer: 2-3 [1, 256, 100, 100] --\n", "│ │ └─ConvTranspose2d: 3-5 [1, 256, 104, 104] 524,544\n", "│ │ └─Sequential: 3-6 [1, 256, 100, 100] 1,769,984\n", "│ └─DecoderLayer: 2-4 [1, 128, 196, 196] --\n", "│ │ └─ConvTranspose2d: 3-7 [1, 128, 200, 200] 131,200\n", "│ │ └─Sequential: 3-8 [1, 128, 196, 196] 442,624\n", "│ └─DecoderLayer: 2-5 [1, 64, 388, 388] --\n", "│ │ └─ConvTranspose2d: 3-9 [1, 64, 392, 392] 32,832\n", "│ │ └─Sequential: 3-10 [1, 64, 388, 388] 110,720\n", "==========================================================================================\n", "Total params: 35,784,000\n", "Trainable params: 35,784,000\n", "Non-trainable params: 0\n", "Total mult-adds (G): 113.38\n", "==========================================================================================\n", "Input size (MB): 158.09\n", "Forward/backward pass size (MB): 469.93\n", "Params size (MB): 143.14\n", "Estimated Total Size (MB): 771.16\n", "==========================================================================================" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Decoder(nn.Module):\n", " def __init__(self, config):\n", " super(Decoder, self).__init__()\n", " self.decoder = nn.ModuleDict(\n", " {\n", " name: DecoderLayer(\n", " in_channels=block[\"in_channels\"],\n", " out_channels=block[\"out_channels\"],\n", " kernel_size=block[\"kernel_size\"],\n", " stride=block[\"stride\"],\n", " padding=block[\"padding\"],\n", " )\n", " for name, block in config.items()\n", " }\n", " )\n", "\n", " def forward(self, x, encoder_output):\n", " for name, block in self.decoder.items():\n", " x = block(x, encoder_output[name])\n", " return x\n", "\n", "\n", "# summary\n", "encoder_input = torch.rand((1, 1, 572, 572), device=\"cuda\")\n", "x, encoder_output = Encoder(config.encoder_config).cuda()(encoder_input)\n", "\n", "input_data = [x, encoder_output]\n", "summary(\n", " Decoder(config.decoder_config).cuda(),\n", " input_data=input_data,\n", ")" ] }, { "cell_type": "markdown", "id": "6cd06e02-abd4-4537-8bce-5a15c4ad4f85", "metadata": {}, "source": [ "# UNet" ] }, { "cell_type": "code", "execution_count": 7, "id": "24fd0355-3603-4a55-b827-068eda70b78a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "===============================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "===============================================================================================\n", "UNet [1, 2, 388, 388] --\n", "├─Encoder: 1-1 [1, 1024, 28, 28] --\n", "│ └─ModuleDict: 2-9 -- (recursive)\n", "│ │ └─EncoderLayer: 3-1 [1, 64, 568, 568] 37,568\n", "│ └─MaxPool2d: 2-2 [1, 64, 284, 284] --\n", "│ └─ModuleDict: 2-9 -- (recursive)\n", "│ │ └─EncoderLayer: 3-2 [1, 128, 280, 280] 221,440\n", "│ └─MaxPool2d: 2-4 [1, 128, 140, 140] --\n", "│ └─ModuleDict: 2-9 -- (recursive)\n", "│ │ └─EncoderLayer: 3-3 [1, 256, 136, 136] 885,248\n", "│ └─MaxPool2d: 2-6 [1, 256, 68, 68] --\n", "│ └─ModuleDict: 2-9 -- (recursive)\n", "│ │ └─EncoderLayer: 3-4 [1, 512, 64, 64] 3,539,968\n", "│ └─MaxPool2d: 2-8 [1, 512, 32, 32] --\n", "│ └─ModuleDict: 2-9 -- (recursive)\n", "│ │ └─EncoderLayer: 3-5 [1, 512, 28, 28] 4,719,616\n", "│ │ └─EncoderLayer: 3-6 [1, 1024, 28, 28] 14,157,824\n", "├─Decoder: 1-2 [1, 64, 388, 388] --\n", "│ └─ModuleDict: 2-10 -- --\n", "│ │ └─DecoderLayer: 3-7 [1, 1024, 28, 28] 23,595,520\n", "│ │ └─DecoderLayer: 3-8 [1, 512, 52, 52] 9,176,576\n", "│ │ └─DecoderLayer: 3-9 [1, 256, 100, 100] 2,294,528\n", "│ │ └─DecoderLayer: 3-10 [1, 128, 196, 196] 573,824\n", "│ │ └─DecoderLayer: 3-11 [1, 64, 388, 388] 143,552\n", "├─Conv2d: 1-3 [1, 2, 388, 388] 130\n", "===============================================================================================\n", "Total params: 59,345,794\n", "Trainable params: 59,345,794\n", "Non-trainable params: 0\n", "Total mult-adds (G): 189.38\n", "===============================================================================================\n", "Input size (MB): 1.31\n", "Forward/backward pass size (MB): 1096.83\n", "Params size (MB): 237.38\n", "Estimated Total Size (MB): 1335.52\n", "===============================================================================================" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class UNet(nn.Module):\n", " def __init__(self, encoder_config, decoder_config, nclasses):\n", " super(UNet, self).__init__()\n", " self.encoder = Encoder(config=encoder_config)\n", " self.decoder = Decoder(config=decoder_config)\n", "\n", " self.output = nn.Conv2d(\n", " in_channels=decoder_config[\"block1\"][\"out_channels\"],\n", " out_channels=nclasses,\n", " kernel_size=1,\n", " )\n", "\n", " def forward(self, x):\n", " x, encoder_step_output = self.encoder(x)\n", " x = self.decoder(x, encoder_step_output)\n", " return self.output(x)\n", "\n", "\n", "summary(\n", " UNet(\n", " config[\"encoder_config\"], config[\"decoder_config\"], nclasses=config[\"nclasses\"]\n", " ),\n", " input_data=torch.rand((1, 1, 572, 572)),\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "550824e4-2151-4c0b-8a12-383fa092b4ac", "metadata": {}, "outputs": [], "source": [ "# # if config is a dict\n", "# with open('custom_config.yml', 'w') as outfile:\n", "# yaml.dump(config, outfile, sort_keys=False)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }