HAMIM-ML commited on
Commit
821ffc1
·
1 Parent(s): b9a1fab

model building added

Browse files
config/config.yaml CHANGED
@@ -12,4 +12,7 @@ data_transformation:
12
  data_path_black : artifacts/data_ingestion/ab/ab/ab1.npy
13
  data_path_grey : artifacts/data_ingestion/l/gray_scale.npy
14
 
 
 
 
15
 
 
12
  data_path_black : artifacts/data_ingestion/ab/ab/ab1.npy
13
  data_path_grey : artifacts/data_ingestion/l/gray_scale.npy
14
 
15
+ model_building:
16
+ root_dir : artifacts/model
17
+
18
 
main.py CHANGED
@@ -1,5 +1,6 @@
1
  from src.imagecolorization.pipeline.stage01_data_ingestion import DataIngestionPipeline
2
  from src.imagecolorization.pipeline.stage02_data_transformation import DataTransformationPipeline
 
3
  from src.imagecolorization.logging import logger
4
 
5
  STAGE_NAME = 'Data Ingestion Config'
@@ -20,6 +21,18 @@ try:
20
  data_transformation = DataTransformationPipeline()
21
  data_transformation.main()
22
  logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<\n\nx==========x")
 
 
 
 
 
 
 
 
 
 
 
 
23
  except Exception as e:
24
  logger.exception(e)
25
  raise e
 
1
  from src.imagecolorization.pipeline.stage01_data_ingestion import DataIngestionPipeline
2
  from src.imagecolorization.pipeline.stage02_data_transformation import DataTransformationPipeline
3
+ from src.imagecolorization.pipeline.stage_03_model_building import ModelBuildingPipeline
4
  from src.imagecolorization.logging import logger
5
 
6
  STAGE_NAME = 'Data Ingestion Config'
 
21
  data_transformation = DataTransformationPipeline()
22
  data_transformation.main()
23
  logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<\n\nx==========x")
24
+ except Exception as e:
25
+ logger.exception(e)
26
+ raise e
27
+
28
+
29
+ STAGE_NAME = 'Model Building Config'
30
+
31
+ try:
32
+ logger.info(f">>>>>> stage {STAGE_NAME} started <<<<<<")
33
+ model_building = ModelBuildingPipeline()
34
+ model_building.main()
35
+ logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<\n\nx==========x")
36
  except Exception as e:
37
  logger.exception(e)
38
  raise e
params.yaml CHANGED
@@ -1,5 +1,25 @@
1
- BATCH_SIZE : 1
2
- IMAGE_SIZE : [224,224,1]
 
3
  DATA_RANGE: 5000
4
- KERNEL_SIZE : 3
5
- p
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Parameters
2
+ BATCH_SIZE: 1
3
+ IMAGE_SIZE: [224, 224, 1]
4
  DATA_RANGE: 5000
5
+
6
+ # Convolutional Layer Parameters
7
+ KERNEL_SIZE_RES: 3
8
+ PADDING: 1
9
+ STRIDE: 1
10
+ BIAS: False
11
+
12
+ # UpSampling Parameters
13
+ SCALE_FACTOR: 2
14
+ DIM: 1
15
+
16
+ # Dropout Parameters
17
+ DROPOUT_RATE: 0.2
18
+
19
+ # Generator Parameters
20
+ KERNEL_SIZE_GENERATOR: 1
21
+ INPUT_CHANNELS: 1
22
+ OUTPUT_CHANNELS: 2
23
+
24
+ # Critic Parameters
25
+ IN_CHANNELS: 3
research/data_transformation.ipynb CHANGED
@@ -164,15 +164,28 @@
164
  " return train_loader, test_loader\n",
165
  " \n",
166
  " \n",
 
 
 
 
 
167
  " def save_dataloaders(self, train_loader, test_loader):\n",
 
 
 
168
  " train_loader_path = os.path.join(self.config.root_dir, 'train_loader.pt')\n",
169
  " test_loader_path = os.path.join(self.config.root_dir, 'test_loader.pt')\n",
170
- " \n",
171
- " torch.save(train_loader, train_loader_path)\n",
172
- " torch.save(test_loader, test_loader_path)\n",
173
- " \n",
174
- " logger.info(f\"Train Loader saved at: {train_loader_path}\")\n",
175
- " logger.info(f\"Test Loader saved at: {test_loader_path}\")"
 
 
 
 
 
176
  ]
177
  },
178
  {
@@ -184,55 +197,11 @@
184
  "name": "stdout",
185
  "output_type": "stream",
186
  "text": [
187
- "[2024-08-18 13:20:34,232: INFO: common: yaml file: config\\config.yaml loaded successfully]\n",
188
- "[2024-08-18 13:20:34,234: INFO: common: yaml file: params.yaml loaded successfully]\n",
189
- "[2024-08-18 13:20:34,235: INFO: common: created directory at: artifacts]\n"
190
- ]
191
- },
192
- {
193
- "ename": "BoxKeyError",
194
- "evalue": "\"'ConfigBox' object has no attribute 'data_transformation'\"",
195
- "output_type": "error",
196
- "traceback": [
197
- "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
198
- "\u001b[1;31mKeyError\u001b[0m Traceback (most recent call last)",
199
- "File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:503\u001b[0m, in \u001b[0;36mBox.__getitem__\u001b[1;34m(self, item, _ignore_default)\u001b[0m\n\u001b[0;32m 502\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 503\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getitem__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 504\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
200
- "\u001b[1;31mKeyError\u001b[0m: 'data_transformation'",
201
- "\nThe above exception was the direct cause of the following exception:\n",
202
- "\u001b[1;31mBoxKeyError\u001b[0m Traceback (most recent call last)",
203
- "File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:536\u001b[0m, in \u001b[0;36mBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 535\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 536\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getitem__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_ignore_default\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 537\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n",
204
- "File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:524\u001b[0m, in \u001b[0;36mBox.__getitem__\u001b[1;34m(self, item, _ignore_default)\u001b[0m\n\u001b[0;32m 523\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__get_default(item)\n\u001b[1;32m--> 524\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m BoxKeyError(\u001b[38;5;28mstr\u001b[39m(err)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m_exception_cause\u001b[39;00m(err)\n\u001b[0;32m 525\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
205
- "\u001b[1;31mBoxKeyError\u001b[0m: \"'data_transformation'\"",
206
- "\nDuring handling of the above exception, another exception occurred:\n",
207
- "\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
208
- "File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:538\u001b[0m, in \u001b[0;36mBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 537\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n\u001b[1;32m--> 538\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mobject\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getattribute__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mitem\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 539\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
209
- "\u001b[1;31mAttributeError\u001b[0m: 'ConfigBox' object has no attribute 'data_transformation'",
210
- "\nThe above exception was the direct cause of the following exception:\n",
211
- "\u001b[1;31mBoxKeyError\u001b[0m Traceback (most recent call last)",
212
- "File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\config_box.py:28\u001b[0m, in \u001b[0;36mConfigBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 27\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m---> 28\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getattr__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 29\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n",
213
- "File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:552\u001b[0m, in \u001b[0;36mBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__get_default(item, attr\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m--> 552\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m BoxKeyError(\u001b[38;5;28mstr\u001b[39m(err)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m_exception_cause\u001b[39;00m(err)\n\u001b[0;32m 553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value\n",
214
- "\u001b[1;31mBoxKeyError\u001b[0m: \"'ConfigBox' object has no attribute 'data_transformation'\"",
215
- "\nDuring handling of the above exception, another exception occurred:\n",
216
- "\u001b[1;31mKeyError\u001b[0m Traceback (most recent call last)",
217
- "File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:503\u001b[0m, in \u001b[0;36mBox.__getitem__\u001b[1;34m(self, item, _ignore_default)\u001b[0m\n\u001b[0;32m 502\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 503\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getitem__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 504\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
218
- "\u001b[1;31mKeyError\u001b[0m: 'data_transformation'",
219
- "\nThe above exception was the direct cause of the following exception:\n",
220
- "\u001b[1;31mBoxKeyError\u001b[0m Traceback (most recent call last)",
221
- "File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:536\u001b[0m, in \u001b[0;36mBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 535\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 536\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getitem__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_ignore_default\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 537\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n",
222
- "File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:524\u001b[0m, in \u001b[0;36mBox.__getitem__\u001b[1;34m(self, item, _ignore_default)\u001b[0m\n\u001b[0;32m 523\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__get_default(item)\n\u001b[1;32m--> 524\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m BoxKeyError(\u001b[38;5;28mstr\u001b[39m(err)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m_exception_cause\u001b[39;00m(err)\n\u001b[0;32m 525\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
223
- "\u001b[1;31mBoxKeyError\u001b[0m: \"'data_transformation'\"",
224
- "\nDuring handling of the above exception, another exception occurred:\n",
225
- "\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
226
- "File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:538\u001b[0m, in \u001b[0;36mBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 537\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n\u001b[1;32m--> 538\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mobject\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getattribute__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mitem\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 539\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
227
- "\u001b[1;31mAttributeError\u001b[0m: 'ConfigBox' object has no attribute 'data_transformation'",
228
- "\nThe above exception was the direct cause of the following exception:\n",
229
- "\u001b[1;31mBoxKeyError\u001b[0m Traceback (most recent call last)",
230
- "Cell \u001b[1;32mIn[7], line 9\u001b[0m\n\u001b[0;32m 7\u001b[0m data_transformation\u001b[38;5;241m.\u001b[39msa\n\u001b[0;32m 8\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m----> 9\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n",
231
- "Cell \u001b[1;32mIn[7], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 2\u001b[0m config \u001b[38;5;241m=\u001b[39m ConfigurationManager()\n\u001b[1;32m----> 3\u001b[0m data_transformation_config \u001b[38;5;241m=\u001b[39m \u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_data_transformation_config\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 4\u001b[0m data_transformation \u001b[38;5;241m=\u001b[39m DataTransformation(config\u001b[38;5;241m=\u001b[39mdata_transformation_config)\n\u001b[0;32m 5\u001b[0m data_transformation\u001b[38;5;241m.\u001b[39mload_data()\n",
232
- "Cell \u001b[1;32mIn[4], line 17\u001b[0m, in \u001b[0;36mConfigurationManager.get_data_transformation_config\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_data_transformation_config\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m DataTransformationConfig:\n\u001b[1;32m---> 17\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdata_transformation\u001b[49m\n\u001b[0;32m 18\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparams\n\u001b[0;32m 20\u001b[0m data_transformation_cofig \u001b[38;5;241m=\u001b[39m DataTransformationConfig(\n\u001b[0;32m 21\u001b[0m root_dir\u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39mroot_dir,\n\u001b[0;32m 22\u001b[0m data_path_black\u001b[38;5;241m=\u001b[39mconfig\u001b[38;5;241m.\u001b[39mdata_path_black,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 26\u001b[0m DATA_RANGE\u001b[38;5;241m=\u001b[39mparams\u001b[38;5;241m.\u001b[39mDATA_RANGE\n\u001b[0;32m 27\u001b[0m )\n",
233
- "File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\config_box.py:30\u001b[0m, in \u001b[0;36mConfigBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 28\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__getattr__\u001b[39m(item)\n\u001b[0;32m 29\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n\u001b[1;32m---> 30\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getattr__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mitem\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlower\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
234
- "File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:552\u001b[0m, in \u001b[0;36mBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 550\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m BoxKeyError(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mitem\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: Does not exist and internal methods are never defaulted\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__get_default(item, attr\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m--> 552\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m BoxKeyError(\u001b[38;5;28mstr\u001b[39m(err)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m_exception_cause\u001b[39;00m(err)\n\u001b[0;32m 553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value\n",
235
- "\u001b[1;31mBoxKeyError\u001b[0m: \"'ConfigBox' object has no attribute 'data_transformation'\""
236
  ]
237
  }
238
  ],
@@ -241,11 +210,18 @@
241
  " config = ConfigurationManager()\n",
242
  " data_transformation_config = config.get_data_transformation_config()\n",
243
  " data_transformation = DataTransformation(config=data_transformation_config)\n",
244
- " data_transformation.load_data()\n",
245
- " data_transformation.get_dataloader()\n",
246
- " data_transformation.sa\n",
 
 
 
 
 
 
 
247
  "except Exception as e:\n",
248
- " raise e"
249
  ]
250
  },
251
  {
 
164
  " return train_loader, test_loader\n",
165
  " \n",
166
  " \n",
167
+ " \n",
168
+ " \n",
169
+ " \n",
170
+ " \n",
171
+ " \n",
172
  " def save_dataloaders(self, train_loader, test_loader):\n",
173
+ " # Ensure the directory exists\n",
174
+ " os.makedirs(self.config.root_dir, exist_ok=True)\n",
175
+ "\n",
176
  " train_loader_path = os.path.join(self.config.root_dir, 'train_loader.pt')\n",
177
  " test_loader_path = os.path.join(self.config.root_dir, 'test_loader.pt')\n",
178
+ "\n",
179
+ " try:\n",
180
+ " # Save the dataloaders\n",
181
+ " torch.save(train_loader, train_loader_path)\n",
182
+ " torch.save(test_loader, test_loader_path)\n",
183
+ "\n",
184
+ " logger.info(f\"Train Loader saved at: {train_loader_path}\")\n",
185
+ " logger.info(f\"Test Loader saved at: {test_loader_path}\")\n",
186
+ " except Exception as e:\n",
187
+ " logger.error(f\"Error saving dataloaders: {str(e)}\")\n",
188
+ " raise e\n"
189
  ]
190
  },
191
  {
 
197
  "name": "stdout",
198
  "output_type": "stream",
199
  "text": [
200
+ "[2024-08-18 17:50:45,127: INFO: common: yaml file: config\\config.yaml loaded successfully]\n",
201
+ "[2024-08-18 17:50:45,129: INFO: common: yaml file: params.yaml loaded successfully]\n",
202
+ "[2024-08-18 17:50:45,129: INFO: common: created directory at: artifacts]\n",
203
+ "[2024-08-18 17:50:57,600: INFO: 2567581832: Train Loader saved at: artifacts/data_transformation\\train_loader.pt]\n",
204
+ "[2024-08-18 17:50:57,605: INFO: 2567581832: Test Loader saved at: artifacts/data_transformation\\test_loader.pt]\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  ]
206
  }
207
  ],
 
210
  " config = ConfigurationManager()\n",
211
  " data_transformation_config = config.get_data_transformation_config()\n",
212
  " data_transformation = DataTransformation(config=data_transformation_config)\n",
213
+ " \n",
214
+ " # Load the dataset\n",
215
+ " dataset = data_transformation.load_data()\n",
216
+ " \n",
217
+ " # Get the dataloader using the loaded dataset\n",
218
+ " train_loader, test_loader = data_transformation.get_dataloader(dataset)\n",
219
+ " \n",
220
+ " # Perform any further operations (e.g., saving the dataloaders)\n",
221
+ " data_transformation.save_dataloaders(train_loader, test_loader)\n",
222
+ " \n",
223
  "except Exception as e:\n",
224
+ " raise e\n"
225
  ]
226
  },
227
  {
research/model_building.ipynb ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "os.chdir('../')"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 2,
16
+ "metadata": {},
17
+ "outputs": [
18
+ {
19
+ "data": {
20
+ "text/plain": [
21
+ "'c:\\\\mlops project\\\\image-colorization-mlops'"
22
+ ]
23
+ },
24
+ "execution_count": 2,
25
+ "metadata": {},
26
+ "output_type": "execute_result"
27
+ }
28
+ ],
29
+ "source": [
30
+ "%pwd"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 3,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "from dataclasses import dataclass\n",
40
+ "from pathlib import Path\n",
41
+ "\n",
42
+ "@dataclass(frozen=True)\n",
43
+ "class ModelBuildingConfig:\n",
44
+ " root_dir: Path\n",
45
+ " KERNEL_SIZE_RES: int\n",
46
+ " PADDING: int\n",
47
+ " STRIDE: int\n",
48
+ " BIAS: bool\n",
49
+ " SCALE_FACTOR: int\n",
50
+ " DIM: int\n",
51
+ " DROPOUT_RATE: float\n",
52
+ " KERNEL_SIZE_GENERATOR: int\n",
53
+ " INPUT_CHANNELS: int\n",
54
+ " OUTPUT_CHANNELS: int\n",
55
+ " IN_CHANNELS: int\n",
56
+ "\n"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 4,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "from src.imagecolorization.constants import *\n",
66
+ "from src.imagecolorization.utils.common import read_yaml, create_directories\n",
67
+ "\n",
68
+ "class ConfigurationManager:\n",
69
+ " def __init__(self, config_filepath=CONFIG_FILE_PATH, params_filepath=PARAMS_FILE_PATH):\n",
70
+ " self.config = read_yaml(config_filepath)\n",
71
+ " self.params = read_yaml(params_filepath)\n",
72
+ " create_directories([self.config.artifacts_root])\n",
73
+ "\n",
74
+ " def get_model_building_config(self) -> ModelBuildingConfig:\n",
75
+ " config = self.config.model_building\n",
76
+ " params = self.params\n",
77
+ "\n",
78
+ " model_building_config = ModelBuildingConfig(\n",
79
+ " root_dir=Path(config.root_dir),\n",
80
+ " KERNEL_SIZE_RES=params.KERNEL_SIZE_RES,\n",
81
+ " PADDING=params.PADDING,\n",
82
+ " STRIDE=params.STRIDE,\n",
83
+ " BIAS=params.BIAS,\n",
84
+ " SCALE_FACTOR=params.SCALE_FACTOR,\n",
85
+ " DIM=params.DIM,\n",
86
+ " DROPOUT_RATE=params.DROPOUT_RATE,\n",
87
+ " KERNEL_SIZE_GENERATOR=params.KERNEL_SIZE_GENERATOR,\n",
88
+ " INPUT_CHANNELS=params.INPUT_CHANNELS,\n",
89
+ " OUTPUT_CHANNELS=params.OUTPUT_CHANNELS,\n",
90
+ " IN_CHANNELS=params.IN_CHANNELS\n",
91
+ " )\n",
92
+ " return model_building_config\n"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": 5,
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "import torch \n",
102
+ "import torch.nn as nn\n",
103
+ "from pathlib import Path\n",
104
+ "\n",
105
+ "class ResBlock(nn.Module):\n",
106
+ " def __init__(self, in_channles, out_channels, stride = 1, kerenl_size = 3, padding = 1, bias = False):\n",
107
+ " super().__init__()\n",
108
+ " self.layer = nn.Sequential(\n",
109
+ " nn.Conv2d(in_channles, out_channels, kernel_size=kerenl_size, padding=padding, stride=stride, bias = bias),\n",
110
+ " nn.BatchNorm2d(out_channels),\n",
111
+ " nn.ReLU(inplace=True),\n",
112
+ " nn.Conv2d(out_channels, out_channels, kernel_size=kerenl_size, padding=padding, stride = 1, bias = bias),\n",
113
+ " nn.BatchNorm2d(out_channels),\n",
114
+ " nn.ReLU(inplace=True)\n",
115
+ " )\n",
116
+ " \n",
117
+ " self.identity_map = nn.Conv2d(in_channles, out_channels,kernel_size=1, stride=stride)\n",
118
+ " self.relu = nn.ReLU(inplace= True)\n",
119
+ " \n",
120
+ " def forward(self, inputs):\n",
121
+ " x = inputs.clone().detach()\n",
122
+ " out = self.layer(x)\n",
123
+ " residual = self.identity_map(inputs)\n",
124
+ " skip = out + residual\n",
125
+ " return self.relu(skip)\n",
126
+ " \n",
127
+ " \n",
128
+ "class DownsampleConv(nn.Module):\n",
129
+ " def __init__(self, in_channels, out_channels, stride = 1):\n",
130
+ " super().__init__()\n",
131
+ " self.layer = nn.Sequential(\n",
132
+ " nn.MaxPool2d(2),\n",
133
+ " ResBlock(in_channels, out_channels)\n",
134
+ " )\n",
135
+ " \n",
136
+ " def forward(self, inputs):\n",
137
+ " return self.layer(inputs)\n",
138
+ " \n",
139
+ " \n",
140
+ " \n",
141
+ "class UpsampleConv(nn.Module):\n",
142
+ " def __init__(self, in_channels, out_channels, scale_factor=2):\n",
143
+ " super().__init__()\n",
144
+ " self.upsample = nn.Upsample(scale_factor=scale_factor,mode = 'bilinear', align_corners=True)\n",
145
+ " self.res_block = ResBlock(in_channels + out_channels, out_channels)\n",
146
+ "\n",
147
+ " def forward(self, inputs, skip):\n",
148
+ " x = self.upsample(inputs)\n",
149
+ " x = torch.cat([x, skip], dim = 1)\n",
150
+ " x = self.res_block(x)\n",
151
+ " return x\n",
152
+ " \n",
153
+ "class Generator(nn.Module):\n",
154
+ " def __init__(self, input_channels, output_channels, dropout_rate = 0.2):\n",
155
+ " super().__init__()\n",
156
+ " self.encoding_layer1_= ResBlock(input_channels, 64)\n",
157
+ " self.encoding_layer2_ = DownsampleConv(64, 128)\n",
158
+ " self.encoding_layer3_ = DownsampleConv(128, 256)\n",
159
+ " self.bridge = DownsampleConv(256, 512)\n",
160
+ " self.decoding_layer3 = UpsampleConv(512, 256)\n",
161
+ " self.decoding_layer2 = UpsampleConv(256, 128)\n",
162
+ " self.decoding_layer1 = UpsampleConv(128 , 64)\n",
163
+ " self.output = nn.Conv2d(64, output_channels, kernel_size = 1)\n",
164
+ " self.dropout = nn.Dropout2d(dropout_rate)\n",
165
+ " \n",
166
+ " def forward(self, inputs):\n",
167
+ " e1 = self.encoding_layer1_(inputs)\n",
168
+ " e1 = self.dropout(e1)\n",
169
+ " e2 = self.encoding_layer2_(e1)\n",
170
+ " e2 = self.dropout(e2)\n",
171
+ " e3 = self.encoding_layer3_(e2)\n",
172
+ " e3 = self.dropout(e3)\n",
173
+ " \n",
174
+ " bridge = self.bridge(e3)\n",
175
+ " bridge = self.dropout(bridge)\n",
176
+ " \n",
177
+ " d3 = self.decoding_layer3(bridge, e3)\n",
178
+ " d2 =self.decoding_layer2(d3, e2)\n",
179
+ " d1 = self.decoding_layer1(d2, e1)\n",
180
+ " \n",
181
+ " output = self.dropout(d1)\n",
182
+ " return output\n",
183
+ " \n",
184
+ " \n",
185
+ "class Critic(nn.Module):\n",
186
+ " def __init__(self, in_channels=3):\n",
187
+ " super(Critic, self).__init__()\n",
188
+ "\n",
189
+ " def critic_block(in_filters, out_filters, normalization=True):\n",
190
+ " layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]\n",
191
+ " if normalization:\n",
192
+ " layers.append(nn.InstanceNorm2d(out_filters))\n",
193
+ " layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
194
+ " return layers\n",
195
+ "\n",
196
+ " self.model = nn.Sequential(\n",
197
+ " *critic_block(in_channels, 64, normalization=False),\n",
198
+ " *critic_block(64, 128),\n",
199
+ " *critic_block(128, 256),\n",
200
+ " *critic_block(256, 512),\n",
201
+ " nn.AdaptiveAvgPool2d(1),\n",
202
+ " nn.Flatten(),\n",
203
+ " nn.Linear(512, 1)\n",
204
+ " )\n",
205
+ "\n",
206
+ " def forward(self, ab, l):\n",
207
+ " img_input = torch.cat((ab, l), 1)\n",
208
+ " output = self.model(img_input)\n",
209
+ " return output\n",
210
+ " \n",
211
+ " "
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": 6,
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "from torchsummary import summary\n",
221
+ "import torch\n",
222
+ "import os\n",
223
+ "\n",
224
+ "class ModelBuilding:\n",
225
+ " def __init__(self, config: ModelBuildingConfig):\n",
226
+ " self.config = config\n",
227
+ " self.root_dir = self.config.root_dir\n",
228
+ " self.create_root_dir()\n",
229
+ " self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
230
+ "\n",
231
+ " def create_root_dir(self):\n",
232
+ " os.makedirs(self.root_dir, exist_ok=True)\n",
233
+ " print(f\"Created directory: {self.root_dir}\")\n",
234
+ "\n",
235
+ " def get_generator(self):\n",
236
+ " return Generator(\n",
237
+ " input_channels=self.config.INPUT_CHANNELS, # corrected argument name\n",
238
+ " output_channels=self.config.OUTPUT_CHANNELS, # corrected argument name\n",
239
+ " dropout_rate=self.config.DROPOUT_RATE\n",
240
+ " ).to(self.device)\n",
241
+ "\n",
242
+ " def get_critic(self):\n",
243
+ " return Critic(in_channels=self.config.IN_CHANNELS).to(self.device)\n",
244
+ "\n",
245
+ " def build(self):\n",
246
+ " generator = self.get_generator()\n",
247
+ " critic = self.get_critic()\n",
248
+ " return generator, critic\n",
249
+ "\n",
250
+ " def save_model(self, model, filename):\n",
251
+ " path = self.root_dir / filename\n",
252
+ " torch.save(model.state_dict(), path)\n",
253
+ " print(f\"Model saved to {path}\")\n",
254
+ "\n",
255
+ " def display_summary(self, model, input_size):\n",
256
+ " print(f\"\\nModel Summary:\")\n",
257
+ " summary(model, input_size)\n",
258
+ "\n",
259
+ " def build_and_save(self):\n",
260
+ " generator, critic = self.build()\n",
261
+ "\n",
262
+ " # Display summaries\n",
263
+ " print(\"\\nGenerator Summary:\")\n",
264
+ " self.display_summary(generator, (self.config.INPUT_CHANNELS, 224, 224)) # Assuming input size is 224x224\n",
265
+ "\n",
266
+ " print(\"\\nCritic Summary:\")\n",
267
+ " self.display_summary(critic, [(2, 224, 224), (1, 224, 224)]) # Critic takes two inputs: ab and l\n",
268
+ "\n",
269
+ " self.save_model(generator, \"generator.pth\")\n",
270
+ " self.save_model(critic, \"critic.pth\")\n",
271
+ " return generator, critic\n",
272
+ "\n"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "execution_count": 7,
278
+ "metadata": {},
279
+ "outputs": [
280
+ {
281
+ "name": "stdout",
282
+ "output_type": "stream",
283
+ "text": [
284
+ "[2024-08-23 00:00:44,340: INFO: common: yaml file: config\\config.yaml loaded successfully]\n",
285
+ "[2024-08-23 00:00:44,342: INFO: common: yaml file: params.yaml loaded successfully]\n",
286
+ "[2024-08-23 00:00:44,343: INFO: common: created directory at: artifacts]\n",
287
+ "Created directory: artifacts\\model\n",
288
+ "\n",
289
+ "Generator Summary:\n",
290
+ "\n",
291
+ "Model Summary:\n",
292
+ "----------------------------------------------------------------\n",
293
+ " Layer (type) Output Shape Param #\n",
294
+ "================================================================\n",
295
+ " Conv2d-1 [-1, 64, 224, 224] 576\n",
296
+ " BatchNorm2d-2 [-1, 64, 224, 224] 128\n",
297
+ " ReLU-3 [-1, 64, 224, 224] 0\n",
298
+ " Conv2d-4 [-1, 64, 224, 224] 36,864\n",
299
+ " BatchNorm2d-5 [-1, 64, 224, 224] 128\n",
300
+ " ReLU-6 [-1, 64, 224, 224] 0\n",
301
+ " Conv2d-7 [-1, 64, 224, 224] 128\n",
302
+ " ReLU-8 [-1, 64, 224, 224] 0\n",
303
+ " ResBlock-9 [-1, 64, 224, 224] 0\n",
304
+ " Dropout2d-10 [-1, 64, 224, 224] 0\n",
305
+ " MaxPool2d-11 [-1, 64, 112, 112] 0\n",
306
+ " Conv2d-12 [-1, 128, 112, 112] 73,728\n",
307
+ " BatchNorm2d-13 [-1, 128, 112, 112] 256\n",
308
+ " ReLU-14 [-1, 128, 112, 112] 0\n",
309
+ " Conv2d-15 [-1, 128, 112, 112] 147,456\n",
310
+ " BatchNorm2d-16 [-1, 128, 112, 112] 256\n",
311
+ " ReLU-17 [-1, 128, 112, 112] 0\n",
312
+ " Conv2d-18 [-1, 128, 112, 112] 8,320\n",
313
+ " ReLU-19 [-1, 128, 112, 112] 0\n",
314
+ " ResBlock-20 [-1, 128, 112, 112] 0\n",
315
+ " DownsampleConv-21 [-1, 128, 112, 112] 0\n",
316
+ " Dropout2d-22 [-1, 128, 112, 112] 0\n",
317
+ " MaxPool2d-23 [-1, 128, 56, 56] 0\n",
318
+ " Conv2d-24 [-1, 256, 56, 56] 294,912\n",
319
+ " BatchNorm2d-25 [-1, 256, 56, 56] 512\n",
320
+ " ReLU-26 [-1, 256, 56, 56] 0\n",
321
+ " Conv2d-27 [-1, 256, 56, 56] 589,824\n",
322
+ " BatchNorm2d-28 [-1, 256, 56, 56] 512\n",
323
+ " ReLU-29 [-1, 256, 56, 56] 0\n",
324
+ " Conv2d-30 [-1, 256, 56, 56] 33,024\n",
325
+ " ReLU-31 [-1, 256, 56, 56] 0\n",
326
+ " ResBlock-32 [-1, 256, 56, 56] 0\n",
327
+ " DownsampleConv-33 [-1, 256, 56, 56] 0\n",
328
+ " Dropout2d-34 [-1, 256, 56, 56] 0\n",
329
+ " MaxPool2d-35 [-1, 256, 28, 28] 0\n",
330
+ " Conv2d-36 [-1, 512, 28, 28] 1,179,648\n",
331
+ " BatchNorm2d-37 [-1, 512, 28, 28] 1,024\n",
332
+ " ReLU-38 [-1, 512, 28, 28] 0\n",
333
+ " Conv2d-39 [-1, 512, 28, 28] 2,359,296\n",
334
+ " BatchNorm2d-40 [-1, 512, 28, 28] 1,024\n",
335
+ " ReLU-41 [-1, 512, 28, 28] 0\n",
336
+ " Conv2d-42 [-1, 512, 28, 28] 131,584\n",
337
+ " ReLU-43 [-1, 512, 28, 28] 0\n",
338
+ " ResBlock-44 [-1, 512, 28, 28] 0\n",
339
+ " DownsampleConv-45 [-1, 512, 28, 28] 0\n",
340
+ " Dropout2d-46 [-1, 512, 28, 28] 0\n",
341
+ " Upsample-47 [-1, 512, 56, 56] 0\n",
342
+ " Conv2d-48 [-1, 256, 56, 56] 1,769,472\n",
343
+ " BatchNorm2d-49 [-1, 256, 56, 56] 512\n",
344
+ " ReLU-50 [-1, 256, 56, 56] 0\n",
345
+ " Conv2d-51 [-1, 256, 56, 56] 589,824\n",
346
+ " BatchNorm2d-52 [-1, 256, 56, 56] 512\n",
347
+ " ReLU-53 [-1, 256, 56, 56] 0\n",
348
+ " Conv2d-54 [-1, 256, 56, 56] 196,864\n",
349
+ " ReLU-55 [-1, 256, 56, 56] 0\n",
350
+ " ResBlock-56 [-1, 256, 56, 56] 0\n",
351
+ " UpsampleConv-57 [-1, 256, 56, 56] 0\n",
352
+ " Upsample-58 [-1, 256, 112, 112] 0\n",
353
+ " Conv2d-59 [-1, 128, 112, 112] 442,368\n",
354
+ " BatchNorm2d-60 [-1, 128, 112, 112] 256\n",
355
+ " ReLU-61 [-1, 128, 112, 112] 0\n",
356
+ " Conv2d-62 [-1, 128, 112, 112] 147,456\n",
357
+ " BatchNorm2d-63 [-1, 128, 112, 112] 256\n",
358
+ " ReLU-64 [-1, 128, 112, 112] 0\n",
359
+ " Conv2d-65 [-1, 128, 112, 112] 49,280\n",
360
+ " ReLU-66 [-1, 128, 112, 112] 0\n",
361
+ " ResBlock-67 [-1, 128, 112, 112] 0\n",
362
+ " UpsampleConv-68 [-1, 128, 112, 112] 0\n",
363
+ " Upsample-69 [-1, 128, 224, 224] 0\n",
364
+ " Conv2d-70 [-1, 64, 224, 224] 110,592\n",
365
+ " BatchNorm2d-71 [-1, 64, 224, 224] 128\n",
366
+ " ReLU-72 [-1, 64, 224, 224] 0\n",
367
+ " Conv2d-73 [-1, 64, 224, 224] 36,864\n",
368
+ " BatchNorm2d-74 [-1, 64, 224, 224] 128\n",
369
+ " ReLU-75 [-1, 64, 224, 224] 0\n",
370
+ " Conv2d-76 [-1, 64, 224, 224] 12,352\n",
371
+ " ReLU-77 [-1, 64, 224, 224] 0\n",
372
+ " ResBlock-78 [-1, 64, 224, 224] 0\n",
373
+ " UpsampleConv-79 [-1, 64, 224, 224] 0\n",
374
+ " Dropout2d-80 [-1, 64, 224, 224] 0\n",
375
+ "================================================================\n",
376
+ "Total params: 8,216,064\n",
377
+ "Trainable params: 8,216,064\n",
378
+ "Non-trainable params: 0\n",
379
+ "----------------------------------------------------------------\n",
380
+ "Input size (MB): 0.19\n",
381
+ "Forward/backward pass size (MB): 1030.53\n",
382
+ "Params size (MB): 31.34\n",
383
+ "Estimated Total Size (MB): 1062.06\n",
384
+ "----------------------------------------------------------------\n",
385
+ "\n",
386
+ "Critic Summary:\n",
387
+ "\n",
388
+ "Model Summary:\n",
389
+ "----------------------------------------------------------------\n",
390
+ " Layer (type) Output Shape Param #\n",
391
+ "================================================================\n",
392
+ " Conv2d-1 [-1, 64, 112, 112] 3,136\n",
393
+ " LeakyReLU-2 [-1, 64, 112, 112] 0\n",
394
+ " Conv2d-3 [-1, 128, 56, 56] 131,200\n",
395
+ " InstanceNorm2d-4 [-1, 128, 56, 56] 0\n",
396
+ " LeakyReLU-5 [-1, 128, 56, 56] 0\n",
397
+ " Conv2d-6 [-1, 256, 28, 28] 524,544\n",
398
+ " InstanceNorm2d-7 [-1, 256, 28, 28] 0\n",
399
+ " LeakyReLU-8 [-1, 256, 28, 28] 0\n",
400
+ " Conv2d-9 [-1, 512, 14, 14] 2,097,664\n",
401
+ " InstanceNorm2d-10 [-1, 512, 14, 14] 0\n",
402
+ " LeakyReLU-11 [-1, 512, 14, 14] 0\n",
403
+ "AdaptiveAvgPool2d-12 [-1, 512, 1, 1] 0\n",
404
+ " Flatten-13 [-1, 512] 0\n",
405
+ " Linear-14 [-1, 1] 513\n",
406
+ "================================================================\n",
407
+ "Total params: 2,757,057\n",
408
+ "Trainable params: 2,757,057\n",
409
+ "Non-trainable params: 0\n",
410
+ "----------------------------------------------------------------\n",
411
+ "Input size (MB): 2824.00\n",
412
+ "Forward/backward pass size (MB): 28.34\n",
413
+ "Params size (MB): 10.52\n",
414
+ "Estimated Total Size (MB): 2862.85\n",
415
+ "----------------------------------------------------------------\n",
416
+ "Model saved to artifacts\\model\\generator.pth\n",
417
+ "Model saved to artifacts\\model\\critic.pth\n"
418
+ ]
419
+ }
420
+ ],
421
+ "source": [
422
+ "try:\n",
423
+ " config_manager = ConfigurationManager()\n",
424
+ " model_config = config_manager.get_model_building_config()\n",
425
+ "\n",
426
+ " model_building = ModelBuilding(config=model_config)\n",
427
+ " generator, critic = model_building.build_and_save()\n",
428
+ "except Exception as e:\n",
429
+ " raise e"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": null,
435
+ "metadata": {},
436
+ "outputs": [],
437
+ "source": []
438
+ }
439
+ ],
440
+ "metadata": {
441
+ "kernelspec": {
442
+ "display_name": "Python 3",
443
+ "language": "python",
444
+ "name": "python3"
445
+ },
446
+ "language_info": {
447
+ "codemirror_mode": {
448
+ "name": "ipython",
449
+ "version": 3
450
+ },
451
+ "file_extension": ".py",
452
+ "mimetype": "text/x-python",
453
+ "name": "python",
454
+ "nbconvert_exporter": "python",
455
+ "pygments_lexer": "ipython3",
456
+ "version": "3.11.0"
457
+ }
458
+ },
459
+ "nbformat": 4,
460
+ "nbformat_minor": 2
461
+ }
src/imagecolorization/config/configuration.py CHANGED
@@ -1,7 +1,8 @@
1
  from src.imagecolorization.constants import *
2
  from src.imagecolorization.utils.common import read_yaml, create_directories
3
  from src.imagecolorization.entity.config_entity import (DataIngestionConfig,
4
- DataTransformationConfig)
 
5
  class ConfigurationManager:
6
  def __init__(
7
  self,
@@ -44,6 +45,26 @@ class ConfigurationManager:
44
  )
45
 
46
  return data_transformation_cofig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
 
 
1
  from src.imagecolorization.constants import *
2
  from src.imagecolorization.utils.common import read_yaml, create_directories
3
  from src.imagecolorization.entity.config_entity import (DataIngestionConfig,
4
+ DataTransformationConfig,
5
+ ModelBuildingConfig)
6
  class ConfigurationManager:
7
  def __init__(
8
  self,
 
45
  )
46
 
47
  return data_transformation_cofig
48
+
49
+ def get_model_building_config(self) -> ModelBuildingConfig:
50
+ config = self.config.model_building
51
+ params = self.params
52
+
53
+ model_building_config = ModelBuildingConfig(
54
+ root_dir=Path(config.root_dir),
55
+ KERNEL_SIZE_RES=params.KERNEL_SIZE_RES,
56
+ PADDING=params.PADDING,
57
+ STRIDE=params.STRIDE,
58
+ BIAS=params.BIAS,
59
+ SCALE_FACTOR=params.SCALE_FACTOR,
60
+ DIM=params.DIM,
61
+ DROPOUT_RATE=params.DROPOUT_RATE,
62
+ KERNEL_SIZE_GENERATOR=params.KERNEL_SIZE_GENERATOR,
63
+ INPUT_CHANNELS=params.INPUT_CHANNELS,
64
+ OUTPUT_CHANNELS=params.OUTPUT_CHANNELS,
65
+ IN_CHANNELS=params.IN_CHANNELS
66
+ )
67
+ return model_building_config
68
 
69
 
70
 
src/imagecolorization/conponents/model_building.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from pathlib import Path
4
+ from torchsummary import summary
5
+ import os
6
+ from src.imagecolorization.config.configuration import ModelBuildingConfig
7
+
8
+ class ResBlock(nn.Module):
9
+ def __init__(self, in_channles, out_channels, stride = 1, kerenl_size = 3, padding = 1, bias = False):
10
+ super().__init__()
11
+ self.layer = nn.Sequential(
12
+ nn.Conv2d(in_channles, out_channels, kernel_size=kerenl_size, padding=padding, stride=stride, bias = bias),
13
+ nn.BatchNorm2d(out_channels),
14
+ nn.ReLU(inplace=True),
15
+ nn.Conv2d(out_channels, out_channels, kernel_size=kerenl_size, padding=padding, stride = 1, bias = bias),
16
+ nn.BatchNorm2d(out_channels),
17
+ nn.ReLU(inplace=True)
18
+ )
19
+
20
+ self.identity_map = nn.Conv2d(in_channles, out_channels,kernel_size=1, stride=stride)
21
+ self.relu = nn.ReLU(inplace= True)
22
+
23
+ def forward(self, inputs):
24
+ x = inputs.clone().detach()
25
+ out = self.layer(x)
26
+ residual = self.identity_map(inputs)
27
+ skip = out + residual
28
+ return self.relu(skip)
29
+
30
+
31
+ class DownsampleConv(nn.Module):
32
+ def __init__(self, in_channels, out_channels, stride = 1):
33
+ super().__init__()
34
+ self.layer = nn.Sequential(
35
+ nn.MaxPool2d(2),
36
+ ResBlock(in_channels, out_channels)
37
+ )
38
+
39
+ def forward(self, inputs):
40
+ return self.layer(inputs)
41
+
42
+
43
+
44
+ class UpsampleConv(nn.Module):
45
+ def __init__(self, in_channels, out_channels, scale_factor=2):
46
+ super().__init__()
47
+ self.upsample = nn.Upsample(scale_factor=scale_factor,mode = 'bilinear', align_corners=True)
48
+ self.res_block = ResBlock(in_channels + out_channels, out_channels)
49
+
50
+ def forward(self, inputs, skip):
51
+ x = self.upsample(inputs)
52
+ x = torch.cat([x, skip], dim = 1)
53
+ x = self.res_block(x)
54
+ return x
55
+
56
+ class Generator(nn.Module):
57
+ def __init__(self, input_channels, output_channels, dropout_rate = 0.2):
58
+ super().__init__()
59
+ self.encoding_layer1_= ResBlock(input_channels, 64)
60
+ self.encoding_layer2_ = DownsampleConv(64, 128)
61
+ self.encoding_layer3_ = DownsampleConv(128, 256)
62
+ self.bridge = DownsampleConv(256, 512)
63
+ self.decoding_layer3 = UpsampleConv(512, 256)
64
+ self.decoding_layer2 = UpsampleConv(256, 128)
65
+ self.decoding_layer1 = UpsampleConv(128 , 64)
66
+ self.output = nn.Conv2d(64, output_channels, kernel_size = 1)
67
+ self.dropout = nn.Dropout2d(dropout_rate)
68
+
69
+ def forward(self, inputs):
70
+ e1 = self.encoding_layer1_(inputs)
71
+ e1 = self.dropout(e1)
72
+ e2 = self.encoding_layer2_(e1)
73
+ e2 = self.dropout(e2)
74
+ e3 = self.encoding_layer3_(e2)
75
+ e3 = self.dropout(e3)
76
+
77
+ bridge = self.bridge(e3)
78
+ bridge = self.dropout(bridge)
79
+
80
+ d3 = self.decoding_layer3(bridge, e3)
81
+ d2 =self.decoding_layer2(d3, e2)
82
+ d1 = self.decoding_layer1(d2, e1)
83
+
84
+ output = self.dropout(d1)
85
+ return output
86
+
87
+
88
+ class Critic(nn.Module):
89
+ def __init__(self, in_channels=3):
90
+ super(Critic, self).__init__()
91
+
92
+ def critic_block(in_filters, out_filters, normalization=True):
93
+ layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
94
+ if normalization:
95
+ layers.append(nn.InstanceNorm2d(out_filters))
96
+ layers.append(nn.LeakyReLU(0.2, inplace=True))
97
+ return layers
98
+
99
+ self.model = nn.Sequential(
100
+ *critic_block(in_channels, 64, normalization=False),
101
+ *critic_block(64, 128),
102
+ *critic_block(128, 256),
103
+ *critic_block(256, 512),
104
+ nn.AdaptiveAvgPool2d(1),
105
+ nn.Flatten(),
106
+ nn.Linear(512, 1)
107
+ )
108
+
109
+ def forward(self, ab, l):
110
+ img_input = torch.cat((ab, l), 1)
111
+ output = self.model(img_input)
112
+ return output
113
+
114
+
115
+
116
+
117
+ class ModelBuilding:
118
+ def __init__(self, config: ModelBuildingConfig):
119
+ self.config = config
120
+ self.root_dir = self.config.root_dir
121
+ self.create_root_dir()
122
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
+
124
+ def create_root_dir(self):
125
+ os.makedirs(self.root_dir, exist_ok=True)
126
+ print(f"Created directory: {self.root_dir}")
127
+
128
+ def get_generator(self):
129
+ return Generator(
130
+ input_channels=self.config.INPUT_CHANNELS, # corrected argument name
131
+ output_channels=self.config.OUTPUT_CHANNELS, # corrected argument name
132
+ dropout_rate=self.config.DROPOUT_RATE
133
+ ).to(self.device)
134
+
135
+ def get_critic(self):
136
+ return Critic(in_channels=self.config.IN_CHANNELS).to(self.device)
137
+
138
+ def build(self):
139
+ generator = self.get_generator()
140
+ critic = self.get_critic()
141
+ return generator, critic
142
+
143
+ def save_model(self, model, filename):
144
+ path = self.root_dir / filename
145
+ torch.save(model.state_dict(), path)
146
+ print(f"Model saved to {path}")
147
+
148
+ def display_summary(self, model, input_size):
149
+ print(f"\nModel Summary:")
150
+ summary(model, input_size)
151
+
152
+ def build_and_save(self):
153
+ generator, critic = self.build()
154
+
155
+ # Display summaries
156
+ print("\nGenerator Summary:")
157
+ self.display_summary(generator, (self.config.INPUT_CHANNELS, 224, 224)) # Assuming input size is 224x224
158
+
159
+ print("\nCritic Summary:")
160
+ self.display_summary(critic, [(2, 224, 224), (1, 224, 224)]) # Critic takes two inputs: ab and l
161
+
162
+ self.save_model(generator, "generator.pth")
163
+ self.save_model(critic, "critic.pth")
164
+ return generator, critic
165
+
166
+
167
+
168
+
src/imagecolorization/entity/config_entity.py CHANGED
@@ -16,4 +16,20 @@ class DataTransformationConfig:
16
  data_path_grey : Path
17
  BATCH_SIZE : int
18
  IMAGE_SIZE : list
19
- DATA_RANGE: int
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  data_path_grey : Path
17
  BATCH_SIZE : int
18
  IMAGE_SIZE : list
19
+ DATA_RANGE: int
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class ModelBuildingConfig:
24
+ root_dir: Path
25
+ KERNEL_SIZE_RES: int
26
+ PADDING: int
27
+ STRIDE: int
28
+ BIAS: bool
29
+ SCALE_FACTOR: int
30
+ DIM: int
31
+ DROPOUT_RATE: float
32
+ KERNEL_SIZE_GENERATOR: int
33
+ INPUT_CHANNELS: int
34
+ OUTPUT_CHANNELS: int
35
+ IN_CHANNELS: int
src/imagecolorization/pipeline/stage_03_model_building.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.imagecolorization.conponents.model_building import ModelBuilding
2
+ from src.imagecolorization.config.configuration import ConfigurationManager
3
+
4
+ class ModelBuildingPipeline:
5
+ def __init__(slef):
6
+ pass
7
+
8
+ def main(self):
9
+ config_manager = ConfigurationManager()
10
+ model_config = config_manager.get_model_building_config()
11
+
12
+ model_building = ModelBuilding(config=model_config)
13
+ generator, critic = model_building.build_and_save()