{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.chdir('../')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'c:\\\\mlops project\\\\image-colorization-mlops'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%pwd" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass\n", "from pathlib import Path\n", "\n", "@dataclass(frozen=True)\n", "class ModelBuildingConfig:\n", " root_dir: Path\n", " KERNEL_SIZE_RES: int\n", " PADDING: int\n", " STRIDE: int\n", " BIAS: bool\n", " SCALE_FACTOR: int\n", " DIM: int\n", " DROPOUT_RATE: float\n", " KERNEL_SIZE_GENERATOR: int\n", " INPUT_CHANNELS: int\n", " OUTPUT_CHANNELS: int\n", " IN_CHANNELS: int\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from src.imagecolorization.constants import *\n", "from src.imagecolorization.utils.common import read_yaml, create_directories\n", "\n", "class ConfigurationManager:\n", " def __init__(self, config_filepath=CONFIG_FILE_PATH, params_filepath=PARAMS_FILE_PATH):\n", " self.config = read_yaml(config_filepath)\n", " self.params = read_yaml(params_filepath)\n", " create_directories([self.config.artifacts_root])\n", "\n", " def get_model_building_config(self) -> ModelBuildingConfig:\n", " config = self.config.model_building\n", " params = self.params\n", "\n", " model_building_config = ModelBuildingConfig(\n", " root_dir=Path(config.root_dir),\n", " KERNEL_SIZE_RES=params.KERNEL_SIZE_RES,\n", " PADDING=params.PADDING,\n", " STRIDE=params.STRIDE,\n", " BIAS=params.BIAS,\n", " SCALE_FACTOR=params.SCALE_FACTOR,\n", " DIM=params.DIM,\n", " DROPOUT_RATE=params.DROPOUT_RATE,\n", " KERNEL_SIZE_GENERATOR=params.KERNEL_SIZE_GENERATOR,\n", " INPUT_CHANNELS=params.INPUT_CHANNELS,\n", " OUTPUT_CHANNELS=params.OUTPUT_CHANNELS,\n", " IN_CHANNELS=params.IN_CHANNELS\n", " )\n", " return model_building_config\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import torch \n", "import torch.nn as nn\n", "from pathlib import Path\n", "\n", "class ResBlock(nn.Module):\n", " def __init__(self, in_channles, out_channels, stride = 1, kerenl_size = 3, padding = 1, bias = False):\n", " super().__init__()\n", " self.layer = nn.Sequential(\n", " nn.Conv2d(in_channles, out_channels, kernel_size=kerenl_size, padding=padding, stride=stride, bias = bias),\n", " nn.BatchNorm2d(out_channels),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(out_channels, out_channels, kernel_size=kerenl_size, padding=padding, stride = 1, bias = bias),\n", " nn.BatchNorm2d(out_channels),\n", " nn.ReLU(inplace=True)\n", " )\n", " \n", " self.identity_map = nn.Conv2d(in_channles, out_channels,kernel_size=1, stride=stride)\n", " self.relu = nn.ReLU(inplace= True)\n", " \n", " def forward(self, inputs):\n", " x = inputs.clone().detach()\n", " out = self.layer(x)\n", " residual = self.identity_map(inputs)\n", " skip = out + residual\n", " return self.relu(skip)\n", " \n", " \n", "class DownsampleConv(nn.Module):\n", " def __init__(self, in_channels, out_channels, stride = 1):\n", " super().__init__()\n", " self.layer = nn.Sequential(\n", " nn.MaxPool2d(2),\n", " ResBlock(in_channels, out_channels)\n", " )\n", " \n", " def forward(self, inputs):\n", " return self.layer(inputs)\n", " \n", " \n", " \n", "class UpsampleConv(nn.Module):\n", " def __init__(self, in_channels, out_channels, scale_factor=2):\n", " super().__init__()\n", " self.upsample = nn.Upsample(scale_factor=scale_factor,mode = 'bilinear', align_corners=True)\n", " self.res_block = ResBlock(in_channels + out_channels, out_channels)\n", "\n", " def forward(self, inputs, skip):\n", " x = self.upsample(inputs)\n", " x = torch.cat([x, skip], dim = 1)\n", " x = self.res_block(x)\n", " return x\n", " \n", "class Generator(nn.Module):\n", " def __init__(self, input_channels, output_channels, dropout_rate = 0.2):\n", " super().__init__()\n", " self.encoding_layer1_= ResBlock(input_channels, 64)\n", " self.encoding_layer2_ = DownsampleConv(64, 128)\n", " self.encoding_layer3_ = DownsampleConv(128, 256)\n", " self.bridge = DownsampleConv(256, 512)\n", " self.decoding_layer3 = UpsampleConv(512, 256)\n", " self.decoding_layer2 = UpsampleConv(256, 128)\n", " self.decoding_layer1 = UpsampleConv(128 , 64)\n", " self.output = nn.Conv2d(64, output_channels, kernel_size = 1)\n", " self.dropout = nn.Dropout2d(dropout_rate)\n", " \n", " def forward(self, inputs):\n", " e1 = self.encoding_layer1_(inputs)\n", " e1 = self.dropout(e1)\n", " e2 = self.encoding_layer2_(e1)\n", " e2 = self.dropout(e2)\n", " e3 = self.encoding_layer3_(e2)\n", " e3 = self.dropout(e3)\n", " \n", " bridge = self.bridge(e3)\n", " bridge = self.dropout(bridge)\n", " \n", " d3 = self.decoding_layer3(bridge, e3)\n", " d2 =self.decoding_layer2(d3, e2)\n", " d1 = self.decoding_layer1(d2, e1)\n", " \n", " output = self.dropout(d1)\n", " return output\n", " \n", " \n", "class Critic(nn.Module):\n", " def __init__(self, in_channels):\n", " super(Critic, self).__init__()\n", "\n", " def critic_block(in_filters, out_filters, normalization=True):\n", " layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]\n", " if normalization:\n", " layers.append(nn.InstanceNorm2d(out_filters))\n", " layers.append(nn.LeakyReLU(0.2, inplace=True))\n", " return layers\n", "\n", " self.model = nn.Sequential(\n", " *critic_block(in_channels, 64, normalization=False),\n", " *critic_block(64, 128),\n", " *critic_block(128, 256),\n", " *critic_block(256, 512),\n", " nn.AdaptiveAvgPool2d(1),\n", " nn.Flatten(),\n", " nn.Linear(512, 1)\n", " )\n", "\n", " def forward(self, img_input):\n", " output = self.model(img_input)\n", " return output\n", " " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from torchsummary import summary\n", "import torch\n", "import os\n", "\n", "class ModelBuilding:\n", " def __init__(self, config: ModelBuildingConfig):\n", " self.config = config\n", " self.root_dir = self.config.root_dir\n", " self.create_root_dir()\n", " self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", " def create_root_dir(self):\n", " os.makedirs(self.root_dir, exist_ok=True)\n", " print(f\"Created directory: {self.root_dir}\")\n", "\n", " def get_generator(self):\n", " return Generator(\n", " input_channels=self.config.INPUT_CHANNELS, # corrected argument name\n", " output_channels=self.config.OUTPUT_CHANNELS, # corrected argument name\n", " dropout_rate=self.config.DROPOUT_RATE\n", " ).to(self.device)\n", "\n", " def get_critic(self):\n", " return Critic(in_channels=self.config.IN_CHANNELS).to(self.device)\n", "\n", " def build(self):\n", " generator = self.get_generator()\n", " critic = self.get_critic()\n", " return generator, critic\n", "\n", " def save_model(self, model, filename):\n", " path = self.root_dir / filename\n", " torch.save(model.state_dict(), path)\n", " print(f\"Model saved to {path}\")\n", "\n", " def display_summary(self, model, input_size):\n", " print(f\"\\nModel Summary:\")\n", " summary(model, input_size)\n", "\n", " def build_and_save(self):\n", " generator, critic = self.build()\n", "\n", " # Display summaries\n", " print(\"\\nGenerator Summary:\")\n", " self.display_summary(generator, (self.config.INPUT_CHANNELS, 224, 224)) # Assuming input size is 224x224\n", "\n", " print(\"\\nCritic Summary:\")\n", " self.display_summary(critic, [(2, 224, 224), (1, 224, 224)]) # Critic takes two inputs: ab and l\n", "\n", " self.save_model(generator, \"generator.pth\")\n", " self.save_model(critic, \"critic.pth\")\n", " return generator, critic\n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2024-08-23 00:00:44,340: INFO: common: yaml file: config\\config.yaml loaded successfully]\n", "[2024-08-23 00:00:44,342: INFO: common: yaml file: params.yaml loaded successfully]\n", "[2024-08-23 00:00:44,343: INFO: common: created directory at: artifacts]\n", "Created directory: artifacts\\model\n", "\n", "Generator Summary:\n", "\n", "Model Summary:\n", "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Conv2d-1 [-1, 64, 224, 224] 576\n", " BatchNorm2d-2 [-1, 64, 224, 224] 128\n", " ReLU-3 [-1, 64, 224, 224] 0\n", " Conv2d-4 [-1, 64, 224, 224] 36,864\n", " BatchNorm2d-5 [-1, 64, 224, 224] 128\n", " ReLU-6 [-1, 64, 224, 224] 0\n", " Conv2d-7 [-1, 64, 224, 224] 128\n", " ReLU-8 [-1, 64, 224, 224] 0\n", " ResBlock-9 [-1, 64, 224, 224] 0\n", " Dropout2d-10 [-1, 64, 224, 224] 0\n", " MaxPool2d-11 [-1, 64, 112, 112] 0\n", " Conv2d-12 [-1, 128, 112, 112] 73,728\n", " BatchNorm2d-13 [-1, 128, 112, 112] 256\n", " ReLU-14 [-1, 128, 112, 112] 0\n", " Conv2d-15 [-1, 128, 112, 112] 147,456\n", " BatchNorm2d-16 [-1, 128, 112, 112] 256\n", " ReLU-17 [-1, 128, 112, 112] 0\n", " Conv2d-18 [-1, 128, 112, 112] 8,320\n", " ReLU-19 [-1, 128, 112, 112] 0\n", " ResBlock-20 [-1, 128, 112, 112] 0\n", " DownsampleConv-21 [-1, 128, 112, 112] 0\n", " Dropout2d-22 [-1, 128, 112, 112] 0\n", " MaxPool2d-23 [-1, 128, 56, 56] 0\n", " Conv2d-24 [-1, 256, 56, 56] 294,912\n", " BatchNorm2d-25 [-1, 256, 56, 56] 512\n", " ReLU-26 [-1, 256, 56, 56] 0\n", " Conv2d-27 [-1, 256, 56, 56] 589,824\n", " BatchNorm2d-28 [-1, 256, 56, 56] 512\n", " ReLU-29 [-1, 256, 56, 56] 0\n", " Conv2d-30 [-1, 256, 56, 56] 33,024\n", " ReLU-31 [-1, 256, 56, 56] 0\n", " ResBlock-32 [-1, 256, 56, 56] 0\n", " DownsampleConv-33 [-1, 256, 56, 56] 0\n", " Dropout2d-34 [-1, 256, 56, 56] 0\n", " MaxPool2d-35 [-1, 256, 28, 28] 0\n", " Conv2d-36 [-1, 512, 28, 28] 1,179,648\n", " BatchNorm2d-37 [-1, 512, 28, 28] 1,024\n", " ReLU-38 [-1, 512, 28, 28] 0\n", " Conv2d-39 [-1, 512, 28, 28] 2,359,296\n", " BatchNorm2d-40 [-1, 512, 28, 28] 1,024\n", " ReLU-41 [-1, 512, 28, 28] 0\n", " Conv2d-42 [-1, 512, 28, 28] 131,584\n", " ReLU-43 [-1, 512, 28, 28] 0\n", " ResBlock-44 [-1, 512, 28, 28] 0\n", " DownsampleConv-45 [-1, 512, 28, 28] 0\n", " Dropout2d-46 [-1, 512, 28, 28] 0\n", " Upsample-47 [-1, 512, 56, 56] 0\n", " Conv2d-48 [-1, 256, 56, 56] 1,769,472\n", " BatchNorm2d-49 [-1, 256, 56, 56] 512\n", " ReLU-50 [-1, 256, 56, 56] 0\n", " Conv2d-51 [-1, 256, 56, 56] 589,824\n", " BatchNorm2d-52 [-1, 256, 56, 56] 512\n", " ReLU-53 [-1, 256, 56, 56] 0\n", " Conv2d-54 [-1, 256, 56, 56] 196,864\n", " ReLU-55 [-1, 256, 56, 56] 0\n", " ResBlock-56 [-1, 256, 56, 56] 0\n", " UpsampleConv-57 [-1, 256, 56, 56] 0\n", " Upsample-58 [-1, 256, 112, 112] 0\n", " Conv2d-59 [-1, 128, 112, 112] 442,368\n", " BatchNorm2d-60 [-1, 128, 112, 112] 256\n", " ReLU-61 [-1, 128, 112, 112] 0\n", " Conv2d-62 [-1, 128, 112, 112] 147,456\n", " BatchNorm2d-63 [-1, 128, 112, 112] 256\n", " ReLU-64 [-1, 128, 112, 112] 0\n", " Conv2d-65 [-1, 128, 112, 112] 49,280\n", " ReLU-66 [-1, 128, 112, 112] 0\n", " ResBlock-67 [-1, 128, 112, 112] 0\n", " UpsampleConv-68 [-1, 128, 112, 112] 0\n", " Upsample-69 [-1, 128, 224, 224] 0\n", " Conv2d-70 [-1, 64, 224, 224] 110,592\n", " BatchNorm2d-71 [-1, 64, 224, 224] 128\n", " ReLU-72 [-1, 64, 224, 224] 0\n", " Conv2d-73 [-1, 64, 224, 224] 36,864\n", " BatchNorm2d-74 [-1, 64, 224, 224] 128\n", " ReLU-75 [-1, 64, 224, 224] 0\n", " Conv2d-76 [-1, 64, 224, 224] 12,352\n", " ReLU-77 [-1, 64, 224, 224] 0\n", " ResBlock-78 [-1, 64, 224, 224] 0\n", " UpsampleConv-79 [-1, 64, 224, 224] 0\n", " Dropout2d-80 [-1, 64, 224, 224] 0\n", "================================================================\n", "Total params: 8,216,064\n", "Trainable params: 8,216,064\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.19\n", "Forward/backward pass size (MB): 1030.53\n", "Params size (MB): 31.34\n", "Estimated Total Size (MB): 1062.06\n", "----------------------------------------------------------------\n", "\n", "Critic Summary:\n", "\n", "Model Summary:\n", "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Conv2d-1 [-1, 64, 112, 112] 3,136\n", " LeakyReLU-2 [-1, 64, 112, 112] 0\n", " Conv2d-3 [-1, 128, 56, 56] 131,200\n", " InstanceNorm2d-4 [-1, 128, 56, 56] 0\n", " LeakyReLU-5 [-1, 128, 56, 56] 0\n", " Conv2d-6 [-1, 256, 28, 28] 524,544\n", " InstanceNorm2d-7 [-1, 256, 28, 28] 0\n", " LeakyReLU-8 [-1, 256, 28, 28] 0\n", " Conv2d-9 [-1, 512, 14, 14] 2,097,664\n", " InstanceNorm2d-10 [-1, 512, 14, 14] 0\n", " LeakyReLU-11 [-1, 512, 14, 14] 0\n", "AdaptiveAvgPool2d-12 [-1, 512, 1, 1] 0\n", " Flatten-13 [-1, 512] 0\n", " Linear-14 [-1, 1] 513\n", "================================================================\n", "Total params: 2,757,057\n", "Trainable params: 2,757,057\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 2824.00\n", "Forward/backward pass size (MB): 28.34\n", "Params size (MB): 10.52\n", "Estimated Total Size (MB): 2862.85\n", "----------------------------------------------------------------\n", "Model saved to artifacts\\model\\generator.pth\n", "Model saved to artifacts\\model\\critic.pth\n" ] } ], "source": [ "try:\n", " config_manager = ConfigurationManager()\n", " model_config = config_manager.get_model_building_config()\n", "\n", " model_building = ModelBuilding(config=model_config)\n", " generator, critic = model_building.build_and_save()\n", "except Exception as e:\n", " raise e" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 2 }