Spaces:
Runtime error
Runtime error
HAMIM-ML
commited on
Commit
·
821ffc1
1
Parent(s):
b9a1fab
model building added
Browse files- config/config.yaml +3 -0
- main.py +13 -0
- params.yaml +24 -4
- research/data_transformation.ipynb +35 -59
- research/model_building.ipynb +461 -0
- src/imagecolorization/config/configuration.py +22 -1
- src/imagecolorization/conponents/model_building.py +168 -0
- src/imagecolorization/entity/config_entity.py +17 -1
- src/imagecolorization/pipeline/stage_03_model_building.py +13 -0
config/config.yaml
CHANGED
@@ -12,4 +12,7 @@ data_transformation:
|
|
12 |
data_path_black : artifacts/data_ingestion/ab/ab/ab1.npy
|
13 |
data_path_grey : artifacts/data_ingestion/l/gray_scale.npy
|
14 |
|
|
|
|
|
|
|
15 |
|
|
|
12 |
data_path_black : artifacts/data_ingestion/ab/ab/ab1.npy
|
13 |
data_path_grey : artifacts/data_ingestion/l/gray_scale.npy
|
14 |
|
15 |
+
model_building:
|
16 |
+
root_dir : artifacts/model
|
17 |
+
|
18 |
|
main.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from src.imagecolorization.pipeline.stage01_data_ingestion import DataIngestionPipeline
|
2 |
from src.imagecolorization.pipeline.stage02_data_transformation import DataTransformationPipeline
|
|
|
3 |
from src.imagecolorization.logging import logger
|
4 |
|
5 |
STAGE_NAME = 'Data Ingestion Config'
|
@@ -20,6 +21,18 @@ try:
|
|
20 |
data_transformation = DataTransformationPipeline()
|
21 |
data_transformation.main()
|
22 |
logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<\n\nx==========x")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
except Exception as e:
|
24 |
logger.exception(e)
|
25 |
raise e
|
|
|
1 |
from src.imagecolorization.pipeline.stage01_data_ingestion import DataIngestionPipeline
|
2 |
from src.imagecolorization.pipeline.stage02_data_transformation import DataTransformationPipeline
|
3 |
+
from src.imagecolorization.pipeline.stage_03_model_building import ModelBuildingPipeline
|
4 |
from src.imagecolorization.logging import logger
|
5 |
|
6 |
STAGE_NAME = 'Data Ingestion Config'
|
|
|
21 |
data_transformation = DataTransformationPipeline()
|
22 |
data_transformation.main()
|
23 |
logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<\n\nx==========x")
|
24 |
+
except Exception as e:
|
25 |
+
logger.exception(e)
|
26 |
+
raise e
|
27 |
+
|
28 |
+
|
29 |
+
STAGE_NAME = 'Model Building Config'
|
30 |
+
|
31 |
+
try:
|
32 |
+
logger.info(f">>>>>> stage {STAGE_NAME} started <<<<<<")
|
33 |
+
model_building = ModelBuildingPipeline()
|
34 |
+
model_building.main()
|
35 |
+
logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<\n\nx==========x")
|
36 |
except Exception as e:
|
37 |
logger.exception(e)
|
38 |
raise e
|
params.yaml
CHANGED
@@ -1,5 +1,25 @@
|
|
1 |
-
|
2 |
-
|
|
|
3 |
DATA_RANGE: 5000
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Data Parameters
|
2 |
+
BATCH_SIZE: 1
|
3 |
+
IMAGE_SIZE: [224, 224, 1]
|
4 |
DATA_RANGE: 5000
|
5 |
+
|
6 |
+
# Convolutional Layer Parameters
|
7 |
+
KERNEL_SIZE_RES: 3
|
8 |
+
PADDING: 1
|
9 |
+
STRIDE: 1
|
10 |
+
BIAS: False
|
11 |
+
|
12 |
+
# UpSampling Parameters
|
13 |
+
SCALE_FACTOR: 2
|
14 |
+
DIM: 1
|
15 |
+
|
16 |
+
# Dropout Parameters
|
17 |
+
DROPOUT_RATE: 0.2
|
18 |
+
|
19 |
+
# Generator Parameters
|
20 |
+
KERNEL_SIZE_GENERATOR: 1
|
21 |
+
INPUT_CHANNELS: 1
|
22 |
+
OUTPUT_CHANNELS: 2
|
23 |
+
|
24 |
+
# Critic Parameters
|
25 |
+
IN_CHANNELS: 3
|
research/data_transformation.ipynb
CHANGED
@@ -164,15 +164,28 @@
|
|
164 |
" return train_loader, test_loader\n",
|
165 |
" \n",
|
166 |
" \n",
|
|
|
|
|
|
|
|
|
|
|
167 |
" def save_dataloaders(self, train_loader, test_loader):\n",
|
|
|
|
|
|
|
168 |
" train_loader_path = os.path.join(self.config.root_dir, 'train_loader.pt')\n",
|
169 |
" test_loader_path = os.path.join(self.config.root_dir, 'test_loader.pt')\n",
|
170 |
-
"
|
171 |
-
"
|
172 |
-
"
|
173 |
-
"
|
174 |
-
"
|
175 |
-
"
|
|
|
|
|
|
|
|
|
|
|
176 |
]
|
177 |
},
|
178 |
{
|
@@ -184,55 +197,11 @@
|
|
184 |
"name": "stdout",
|
185 |
"output_type": "stream",
|
186 |
"text": [
|
187 |
-
"[2024-08-18
|
188 |
-
"[2024-08-18
|
189 |
-
"[2024-08-18
|
190 |
-
|
191 |
-
|
192 |
-
{
|
193 |
-
"ename": "BoxKeyError",
|
194 |
-
"evalue": "\"'ConfigBox' object has no attribute 'data_transformation'\"",
|
195 |
-
"output_type": "error",
|
196 |
-
"traceback": [
|
197 |
-
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
198 |
-
"\u001b[1;31mKeyError\u001b[0m Traceback (most recent call last)",
|
199 |
-
"File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:503\u001b[0m, in \u001b[0;36mBox.__getitem__\u001b[1;34m(self, item, _ignore_default)\u001b[0m\n\u001b[0;32m 502\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 503\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getitem__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 504\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
|
200 |
-
"\u001b[1;31mKeyError\u001b[0m: 'data_transformation'",
|
201 |
-
"\nThe above exception was the direct cause of the following exception:\n",
|
202 |
-
"\u001b[1;31mBoxKeyError\u001b[0m Traceback (most recent call last)",
|
203 |
-
"File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:536\u001b[0m, in \u001b[0;36mBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 535\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 536\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getitem__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_ignore_default\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 537\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n",
|
204 |
-
"File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:524\u001b[0m, in \u001b[0;36mBox.__getitem__\u001b[1;34m(self, item, _ignore_default)\u001b[0m\n\u001b[0;32m 523\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__get_default(item)\n\u001b[1;32m--> 524\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m BoxKeyError(\u001b[38;5;28mstr\u001b[39m(err)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m_exception_cause\u001b[39;00m(err)\n\u001b[0;32m 525\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
|
205 |
-
"\u001b[1;31mBoxKeyError\u001b[0m: \"'data_transformation'\"",
|
206 |
-
"\nDuring handling of the above exception, another exception occurred:\n",
|
207 |
-
"\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
208 |
-
"File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:538\u001b[0m, in \u001b[0;36mBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 537\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n\u001b[1;32m--> 538\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mobject\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getattribute__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mitem\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 539\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
|
209 |
-
"\u001b[1;31mAttributeError\u001b[0m: 'ConfigBox' object has no attribute 'data_transformation'",
|
210 |
-
"\nThe above exception was the direct cause of the following exception:\n",
|
211 |
-
"\u001b[1;31mBoxKeyError\u001b[0m Traceback (most recent call last)",
|
212 |
-
"File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\config_box.py:28\u001b[0m, in \u001b[0;36mConfigBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 27\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m---> 28\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getattr__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 29\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n",
|
213 |
-
"File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:552\u001b[0m, in \u001b[0;36mBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__get_default(item, attr\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m--> 552\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m BoxKeyError(\u001b[38;5;28mstr\u001b[39m(err)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m_exception_cause\u001b[39;00m(err)\n\u001b[0;32m 553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value\n",
|
214 |
-
"\u001b[1;31mBoxKeyError\u001b[0m: \"'ConfigBox' object has no attribute 'data_transformation'\"",
|
215 |
-
"\nDuring handling of the above exception, another exception occurred:\n",
|
216 |
-
"\u001b[1;31mKeyError\u001b[0m Traceback (most recent call last)",
|
217 |
-
"File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:503\u001b[0m, in \u001b[0;36mBox.__getitem__\u001b[1;34m(self, item, _ignore_default)\u001b[0m\n\u001b[0;32m 502\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 503\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getitem__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 504\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
|
218 |
-
"\u001b[1;31mKeyError\u001b[0m: 'data_transformation'",
|
219 |
-
"\nThe above exception was the direct cause of the following exception:\n",
|
220 |
-
"\u001b[1;31mBoxKeyError\u001b[0m Traceback (most recent call last)",
|
221 |
-
"File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:536\u001b[0m, in \u001b[0;36mBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 535\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 536\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getitem__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_ignore_default\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 537\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n",
|
222 |
-
"File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:524\u001b[0m, in \u001b[0;36mBox.__getitem__\u001b[1;34m(self, item, _ignore_default)\u001b[0m\n\u001b[0;32m 523\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__get_default(item)\n\u001b[1;32m--> 524\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m BoxKeyError(\u001b[38;5;28mstr\u001b[39m(err)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m_exception_cause\u001b[39;00m(err)\n\u001b[0;32m 525\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
|
223 |
-
"\u001b[1;31mBoxKeyError\u001b[0m: \"'data_transformation'\"",
|
224 |
-
"\nDuring handling of the above exception, another exception occurred:\n",
|
225 |
-
"\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
226 |
-
"File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:538\u001b[0m, in \u001b[0;36mBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 537\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n\u001b[1;32m--> 538\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mobject\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getattribute__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mitem\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 539\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
|
227 |
-
"\u001b[1;31mAttributeError\u001b[0m: 'ConfigBox' object has no attribute 'data_transformation'",
|
228 |
-
"\nThe above exception was the direct cause of the following exception:\n",
|
229 |
-
"\u001b[1;31mBoxKeyError\u001b[0m Traceback (most recent call last)",
|
230 |
-
"Cell \u001b[1;32mIn[7], line 9\u001b[0m\n\u001b[0;32m 7\u001b[0m data_transformation\u001b[38;5;241m.\u001b[39msa\n\u001b[0;32m 8\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m----> 9\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n",
|
231 |
-
"Cell \u001b[1;32mIn[7], line 3\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 2\u001b[0m config \u001b[38;5;241m=\u001b[39m ConfigurationManager()\n\u001b[1;32m----> 3\u001b[0m data_transformation_config \u001b[38;5;241m=\u001b[39m \u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_data_transformation_config\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 4\u001b[0m data_transformation \u001b[38;5;241m=\u001b[39m DataTransformation(config\u001b[38;5;241m=\u001b[39mdata_transformation_config)\n\u001b[0;32m 5\u001b[0m data_transformation\u001b[38;5;241m.\u001b[39mload_data()\n",
|
232 |
-
"Cell \u001b[1;32mIn[4], line 17\u001b[0m, in \u001b[0;36mConfigurationManager.get_data_transformation_config\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_data_transformation_config\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m DataTransformationConfig:\n\u001b[1;32m---> 17\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdata_transformation\u001b[49m\n\u001b[0;32m 18\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparams\n\u001b[0;32m 20\u001b[0m data_transformation_cofig \u001b[38;5;241m=\u001b[39m DataTransformationConfig(\n\u001b[0;32m 21\u001b[0m root_dir\u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39mroot_dir,\n\u001b[0;32m 22\u001b[0m data_path_black\u001b[38;5;241m=\u001b[39mconfig\u001b[38;5;241m.\u001b[39mdata_path_black,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 26\u001b[0m DATA_RANGE\u001b[38;5;241m=\u001b[39mparams\u001b[38;5;241m.\u001b[39mDATA_RANGE\n\u001b[0;32m 27\u001b[0m )\n",
|
233 |
-
"File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\config_box.py:30\u001b[0m, in \u001b[0;36mConfigBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 28\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__getattr__\u001b[39m(item)\n\u001b[0;32m 29\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n\u001b[1;32m---> 30\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getattr__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mitem\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlower\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
|
234 |
-
"File \u001b[1;32mc:\\Users\\azizu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\box\\box.py:552\u001b[0m, in \u001b[0;36mBox.__getattr__\u001b[1;34m(self, item)\u001b[0m\n\u001b[0;32m 550\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m BoxKeyError(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mitem\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: Does not exist and internal methods are never defaulted\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__get_default(item, attr\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m--> 552\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m BoxKeyError(\u001b[38;5;28mstr\u001b[39m(err)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m_exception_cause\u001b[39;00m(err)\n\u001b[0;32m 553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value\n",
|
235 |
-
"\u001b[1;31mBoxKeyError\u001b[0m: \"'ConfigBox' object has no attribute 'data_transformation'\""
|
236 |
]
|
237 |
}
|
238 |
],
|
@@ -241,11 +210,18 @@
|
|
241 |
" config = ConfigurationManager()\n",
|
242 |
" data_transformation_config = config.get_data_transformation_config()\n",
|
243 |
" data_transformation = DataTransformation(config=data_transformation_config)\n",
|
244 |
-
"
|
245 |
-
"
|
246 |
-
" data_transformation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
"except Exception as e:\n",
|
248 |
-
" raise e"
|
249 |
]
|
250 |
},
|
251 |
{
|
|
|
164 |
" return train_loader, test_loader\n",
|
165 |
" \n",
|
166 |
" \n",
|
167 |
+
" \n",
|
168 |
+
" \n",
|
169 |
+
" \n",
|
170 |
+
" \n",
|
171 |
+
" \n",
|
172 |
" def save_dataloaders(self, train_loader, test_loader):\n",
|
173 |
+
" # Ensure the directory exists\n",
|
174 |
+
" os.makedirs(self.config.root_dir, exist_ok=True)\n",
|
175 |
+
"\n",
|
176 |
" train_loader_path = os.path.join(self.config.root_dir, 'train_loader.pt')\n",
|
177 |
" test_loader_path = os.path.join(self.config.root_dir, 'test_loader.pt')\n",
|
178 |
+
"\n",
|
179 |
+
" try:\n",
|
180 |
+
" # Save the dataloaders\n",
|
181 |
+
" torch.save(train_loader, train_loader_path)\n",
|
182 |
+
" torch.save(test_loader, test_loader_path)\n",
|
183 |
+
"\n",
|
184 |
+
" logger.info(f\"Train Loader saved at: {train_loader_path}\")\n",
|
185 |
+
" logger.info(f\"Test Loader saved at: {test_loader_path}\")\n",
|
186 |
+
" except Exception as e:\n",
|
187 |
+
" logger.error(f\"Error saving dataloaders: {str(e)}\")\n",
|
188 |
+
" raise e\n"
|
189 |
]
|
190 |
},
|
191 |
{
|
|
|
197 |
"name": "stdout",
|
198 |
"output_type": "stream",
|
199 |
"text": [
|
200 |
+
"[2024-08-18 17:50:45,127: INFO: common: yaml file: config\\config.yaml loaded successfully]\n",
|
201 |
+
"[2024-08-18 17:50:45,129: INFO: common: yaml file: params.yaml loaded successfully]\n",
|
202 |
+
"[2024-08-18 17:50:45,129: INFO: common: created directory at: artifacts]\n",
|
203 |
+
"[2024-08-18 17:50:57,600: INFO: 2567581832: Train Loader saved at: artifacts/data_transformation\\train_loader.pt]\n",
|
204 |
+
"[2024-08-18 17:50:57,605: INFO: 2567581832: Test Loader saved at: artifacts/data_transformation\\test_loader.pt]\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
]
|
206 |
}
|
207 |
],
|
|
|
210 |
" config = ConfigurationManager()\n",
|
211 |
" data_transformation_config = config.get_data_transformation_config()\n",
|
212 |
" data_transformation = DataTransformation(config=data_transformation_config)\n",
|
213 |
+
" \n",
|
214 |
+
" # Load the dataset\n",
|
215 |
+
" dataset = data_transformation.load_data()\n",
|
216 |
+
" \n",
|
217 |
+
" # Get the dataloader using the loaded dataset\n",
|
218 |
+
" train_loader, test_loader = data_transformation.get_dataloader(dataset)\n",
|
219 |
+
" \n",
|
220 |
+
" # Perform any further operations (e.g., saving the dataloaders)\n",
|
221 |
+
" data_transformation.save_dataloaders(train_loader, test_loader)\n",
|
222 |
+
" \n",
|
223 |
"except Exception as e:\n",
|
224 |
+
" raise e\n"
|
225 |
]
|
226 |
},
|
227 |
{
|
research/model_building.ipynb
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"os.chdir('../')"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 2,
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [
|
18 |
+
{
|
19 |
+
"data": {
|
20 |
+
"text/plain": [
|
21 |
+
"'c:\\\\mlops project\\\\image-colorization-mlops'"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
"execution_count": 2,
|
25 |
+
"metadata": {},
|
26 |
+
"output_type": "execute_result"
|
27 |
+
}
|
28 |
+
],
|
29 |
+
"source": [
|
30 |
+
"%pwd"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": 3,
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"from dataclasses import dataclass\n",
|
40 |
+
"from pathlib import Path\n",
|
41 |
+
"\n",
|
42 |
+
"@dataclass(frozen=True)\n",
|
43 |
+
"class ModelBuildingConfig:\n",
|
44 |
+
" root_dir: Path\n",
|
45 |
+
" KERNEL_SIZE_RES: int\n",
|
46 |
+
" PADDING: int\n",
|
47 |
+
" STRIDE: int\n",
|
48 |
+
" BIAS: bool\n",
|
49 |
+
" SCALE_FACTOR: int\n",
|
50 |
+
" DIM: int\n",
|
51 |
+
" DROPOUT_RATE: float\n",
|
52 |
+
" KERNEL_SIZE_GENERATOR: int\n",
|
53 |
+
" INPUT_CHANNELS: int\n",
|
54 |
+
" OUTPUT_CHANNELS: int\n",
|
55 |
+
" IN_CHANNELS: int\n",
|
56 |
+
"\n"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 4,
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"from src.imagecolorization.constants import *\n",
|
66 |
+
"from src.imagecolorization.utils.common import read_yaml, create_directories\n",
|
67 |
+
"\n",
|
68 |
+
"class ConfigurationManager:\n",
|
69 |
+
" def __init__(self, config_filepath=CONFIG_FILE_PATH, params_filepath=PARAMS_FILE_PATH):\n",
|
70 |
+
" self.config = read_yaml(config_filepath)\n",
|
71 |
+
" self.params = read_yaml(params_filepath)\n",
|
72 |
+
" create_directories([self.config.artifacts_root])\n",
|
73 |
+
"\n",
|
74 |
+
" def get_model_building_config(self) -> ModelBuildingConfig:\n",
|
75 |
+
" config = self.config.model_building\n",
|
76 |
+
" params = self.params\n",
|
77 |
+
"\n",
|
78 |
+
" model_building_config = ModelBuildingConfig(\n",
|
79 |
+
" root_dir=Path(config.root_dir),\n",
|
80 |
+
" KERNEL_SIZE_RES=params.KERNEL_SIZE_RES,\n",
|
81 |
+
" PADDING=params.PADDING,\n",
|
82 |
+
" STRIDE=params.STRIDE,\n",
|
83 |
+
" BIAS=params.BIAS,\n",
|
84 |
+
" SCALE_FACTOR=params.SCALE_FACTOR,\n",
|
85 |
+
" DIM=params.DIM,\n",
|
86 |
+
" DROPOUT_RATE=params.DROPOUT_RATE,\n",
|
87 |
+
" KERNEL_SIZE_GENERATOR=params.KERNEL_SIZE_GENERATOR,\n",
|
88 |
+
" INPUT_CHANNELS=params.INPUT_CHANNELS,\n",
|
89 |
+
" OUTPUT_CHANNELS=params.OUTPUT_CHANNELS,\n",
|
90 |
+
" IN_CHANNELS=params.IN_CHANNELS\n",
|
91 |
+
" )\n",
|
92 |
+
" return model_building_config\n"
|
93 |
+
]
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"cell_type": "code",
|
97 |
+
"execution_count": 5,
|
98 |
+
"metadata": {},
|
99 |
+
"outputs": [],
|
100 |
+
"source": [
|
101 |
+
"import torch \n",
|
102 |
+
"import torch.nn as nn\n",
|
103 |
+
"from pathlib import Path\n",
|
104 |
+
"\n",
|
105 |
+
"class ResBlock(nn.Module):\n",
|
106 |
+
" def __init__(self, in_channles, out_channels, stride = 1, kerenl_size = 3, padding = 1, bias = False):\n",
|
107 |
+
" super().__init__()\n",
|
108 |
+
" self.layer = nn.Sequential(\n",
|
109 |
+
" nn.Conv2d(in_channles, out_channels, kernel_size=kerenl_size, padding=padding, stride=stride, bias = bias),\n",
|
110 |
+
" nn.BatchNorm2d(out_channels),\n",
|
111 |
+
" nn.ReLU(inplace=True),\n",
|
112 |
+
" nn.Conv2d(out_channels, out_channels, kernel_size=kerenl_size, padding=padding, stride = 1, bias = bias),\n",
|
113 |
+
" nn.BatchNorm2d(out_channels),\n",
|
114 |
+
" nn.ReLU(inplace=True)\n",
|
115 |
+
" )\n",
|
116 |
+
" \n",
|
117 |
+
" self.identity_map = nn.Conv2d(in_channles, out_channels,kernel_size=1, stride=stride)\n",
|
118 |
+
" self.relu = nn.ReLU(inplace= True)\n",
|
119 |
+
" \n",
|
120 |
+
" def forward(self, inputs):\n",
|
121 |
+
" x = inputs.clone().detach()\n",
|
122 |
+
" out = self.layer(x)\n",
|
123 |
+
" residual = self.identity_map(inputs)\n",
|
124 |
+
" skip = out + residual\n",
|
125 |
+
" return self.relu(skip)\n",
|
126 |
+
" \n",
|
127 |
+
" \n",
|
128 |
+
"class DownsampleConv(nn.Module):\n",
|
129 |
+
" def __init__(self, in_channels, out_channels, stride = 1):\n",
|
130 |
+
" super().__init__()\n",
|
131 |
+
" self.layer = nn.Sequential(\n",
|
132 |
+
" nn.MaxPool2d(2),\n",
|
133 |
+
" ResBlock(in_channels, out_channels)\n",
|
134 |
+
" )\n",
|
135 |
+
" \n",
|
136 |
+
" def forward(self, inputs):\n",
|
137 |
+
" return self.layer(inputs)\n",
|
138 |
+
" \n",
|
139 |
+
" \n",
|
140 |
+
" \n",
|
141 |
+
"class UpsampleConv(nn.Module):\n",
|
142 |
+
" def __init__(self, in_channels, out_channels, scale_factor=2):\n",
|
143 |
+
" super().__init__()\n",
|
144 |
+
" self.upsample = nn.Upsample(scale_factor=scale_factor,mode = 'bilinear', align_corners=True)\n",
|
145 |
+
" self.res_block = ResBlock(in_channels + out_channels, out_channels)\n",
|
146 |
+
"\n",
|
147 |
+
" def forward(self, inputs, skip):\n",
|
148 |
+
" x = self.upsample(inputs)\n",
|
149 |
+
" x = torch.cat([x, skip], dim = 1)\n",
|
150 |
+
" x = self.res_block(x)\n",
|
151 |
+
" return x\n",
|
152 |
+
" \n",
|
153 |
+
"class Generator(nn.Module):\n",
|
154 |
+
" def __init__(self, input_channels, output_channels, dropout_rate = 0.2):\n",
|
155 |
+
" super().__init__()\n",
|
156 |
+
" self.encoding_layer1_= ResBlock(input_channels, 64)\n",
|
157 |
+
" self.encoding_layer2_ = DownsampleConv(64, 128)\n",
|
158 |
+
" self.encoding_layer3_ = DownsampleConv(128, 256)\n",
|
159 |
+
" self.bridge = DownsampleConv(256, 512)\n",
|
160 |
+
" self.decoding_layer3 = UpsampleConv(512, 256)\n",
|
161 |
+
" self.decoding_layer2 = UpsampleConv(256, 128)\n",
|
162 |
+
" self.decoding_layer1 = UpsampleConv(128 , 64)\n",
|
163 |
+
" self.output = nn.Conv2d(64, output_channels, kernel_size = 1)\n",
|
164 |
+
" self.dropout = nn.Dropout2d(dropout_rate)\n",
|
165 |
+
" \n",
|
166 |
+
" def forward(self, inputs):\n",
|
167 |
+
" e1 = self.encoding_layer1_(inputs)\n",
|
168 |
+
" e1 = self.dropout(e1)\n",
|
169 |
+
" e2 = self.encoding_layer2_(e1)\n",
|
170 |
+
" e2 = self.dropout(e2)\n",
|
171 |
+
" e3 = self.encoding_layer3_(e2)\n",
|
172 |
+
" e3 = self.dropout(e3)\n",
|
173 |
+
" \n",
|
174 |
+
" bridge = self.bridge(e3)\n",
|
175 |
+
" bridge = self.dropout(bridge)\n",
|
176 |
+
" \n",
|
177 |
+
" d3 = self.decoding_layer3(bridge, e3)\n",
|
178 |
+
" d2 =self.decoding_layer2(d3, e2)\n",
|
179 |
+
" d1 = self.decoding_layer1(d2, e1)\n",
|
180 |
+
" \n",
|
181 |
+
" output = self.dropout(d1)\n",
|
182 |
+
" return output\n",
|
183 |
+
" \n",
|
184 |
+
" \n",
|
185 |
+
"class Critic(nn.Module):\n",
|
186 |
+
" def __init__(self, in_channels=3):\n",
|
187 |
+
" super(Critic, self).__init__()\n",
|
188 |
+
"\n",
|
189 |
+
" def critic_block(in_filters, out_filters, normalization=True):\n",
|
190 |
+
" layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]\n",
|
191 |
+
" if normalization:\n",
|
192 |
+
" layers.append(nn.InstanceNorm2d(out_filters))\n",
|
193 |
+
" layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
|
194 |
+
" return layers\n",
|
195 |
+
"\n",
|
196 |
+
" self.model = nn.Sequential(\n",
|
197 |
+
" *critic_block(in_channels, 64, normalization=False),\n",
|
198 |
+
" *critic_block(64, 128),\n",
|
199 |
+
" *critic_block(128, 256),\n",
|
200 |
+
" *critic_block(256, 512),\n",
|
201 |
+
" nn.AdaptiveAvgPool2d(1),\n",
|
202 |
+
" nn.Flatten(),\n",
|
203 |
+
" nn.Linear(512, 1)\n",
|
204 |
+
" )\n",
|
205 |
+
"\n",
|
206 |
+
" def forward(self, ab, l):\n",
|
207 |
+
" img_input = torch.cat((ab, l), 1)\n",
|
208 |
+
" output = self.model(img_input)\n",
|
209 |
+
" return output\n",
|
210 |
+
" \n",
|
211 |
+
" "
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": 6,
|
217 |
+
"metadata": {},
|
218 |
+
"outputs": [],
|
219 |
+
"source": [
|
220 |
+
"from torchsummary import summary\n",
|
221 |
+
"import torch\n",
|
222 |
+
"import os\n",
|
223 |
+
"\n",
|
224 |
+
"class ModelBuilding:\n",
|
225 |
+
" def __init__(self, config: ModelBuildingConfig):\n",
|
226 |
+
" self.config = config\n",
|
227 |
+
" self.root_dir = self.config.root_dir\n",
|
228 |
+
" self.create_root_dir()\n",
|
229 |
+
" self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
230 |
+
"\n",
|
231 |
+
" def create_root_dir(self):\n",
|
232 |
+
" os.makedirs(self.root_dir, exist_ok=True)\n",
|
233 |
+
" print(f\"Created directory: {self.root_dir}\")\n",
|
234 |
+
"\n",
|
235 |
+
" def get_generator(self):\n",
|
236 |
+
" return Generator(\n",
|
237 |
+
" input_channels=self.config.INPUT_CHANNELS, # corrected argument name\n",
|
238 |
+
" output_channels=self.config.OUTPUT_CHANNELS, # corrected argument name\n",
|
239 |
+
" dropout_rate=self.config.DROPOUT_RATE\n",
|
240 |
+
" ).to(self.device)\n",
|
241 |
+
"\n",
|
242 |
+
" def get_critic(self):\n",
|
243 |
+
" return Critic(in_channels=self.config.IN_CHANNELS).to(self.device)\n",
|
244 |
+
"\n",
|
245 |
+
" def build(self):\n",
|
246 |
+
" generator = self.get_generator()\n",
|
247 |
+
" critic = self.get_critic()\n",
|
248 |
+
" return generator, critic\n",
|
249 |
+
"\n",
|
250 |
+
" def save_model(self, model, filename):\n",
|
251 |
+
" path = self.root_dir / filename\n",
|
252 |
+
" torch.save(model.state_dict(), path)\n",
|
253 |
+
" print(f\"Model saved to {path}\")\n",
|
254 |
+
"\n",
|
255 |
+
" def display_summary(self, model, input_size):\n",
|
256 |
+
" print(f\"\\nModel Summary:\")\n",
|
257 |
+
" summary(model, input_size)\n",
|
258 |
+
"\n",
|
259 |
+
" def build_and_save(self):\n",
|
260 |
+
" generator, critic = self.build()\n",
|
261 |
+
"\n",
|
262 |
+
" # Display summaries\n",
|
263 |
+
" print(\"\\nGenerator Summary:\")\n",
|
264 |
+
" self.display_summary(generator, (self.config.INPUT_CHANNELS, 224, 224)) # Assuming input size is 224x224\n",
|
265 |
+
"\n",
|
266 |
+
" print(\"\\nCritic Summary:\")\n",
|
267 |
+
" self.display_summary(critic, [(2, 224, 224), (1, 224, 224)]) # Critic takes two inputs: ab and l\n",
|
268 |
+
"\n",
|
269 |
+
" self.save_model(generator, \"generator.pth\")\n",
|
270 |
+
" self.save_model(critic, \"critic.pth\")\n",
|
271 |
+
" return generator, critic\n",
|
272 |
+
"\n"
|
273 |
+
]
|
274 |
+
},
|
275 |
+
{
|
276 |
+
"cell_type": "code",
|
277 |
+
"execution_count": 7,
|
278 |
+
"metadata": {},
|
279 |
+
"outputs": [
|
280 |
+
{
|
281 |
+
"name": "stdout",
|
282 |
+
"output_type": "stream",
|
283 |
+
"text": [
|
284 |
+
"[2024-08-23 00:00:44,340: INFO: common: yaml file: config\\config.yaml loaded successfully]\n",
|
285 |
+
"[2024-08-23 00:00:44,342: INFO: common: yaml file: params.yaml loaded successfully]\n",
|
286 |
+
"[2024-08-23 00:00:44,343: INFO: common: created directory at: artifacts]\n",
|
287 |
+
"Created directory: artifacts\\model\n",
|
288 |
+
"\n",
|
289 |
+
"Generator Summary:\n",
|
290 |
+
"\n",
|
291 |
+
"Model Summary:\n",
|
292 |
+
"----------------------------------------------------------------\n",
|
293 |
+
" Layer (type) Output Shape Param #\n",
|
294 |
+
"================================================================\n",
|
295 |
+
" Conv2d-1 [-1, 64, 224, 224] 576\n",
|
296 |
+
" BatchNorm2d-2 [-1, 64, 224, 224] 128\n",
|
297 |
+
" ReLU-3 [-1, 64, 224, 224] 0\n",
|
298 |
+
" Conv2d-4 [-1, 64, 224, 224] 36,864\n",
|
299 |
+
" BatchNorm2d-5 [-1, 64, 224, 224] 128\n",
|
300 |
+
" ReLU-6 [-1, 64, 224, 224] 0\n",
|
301 |
+
" Conv2d-7 [-1, 64, 224, 224] 128\n",
|
302 |
+
" ReLU-8 [-1, 64, 224, 224] 0\n",
|
303 |
+
" ResBlock-9 [-1, 64, 224, 224] 0\n",
|
304 |
+
" Dropout2d-10 [-1, 64, 224, 224] 0\n",
|
305 |
+
" MaxPool2d-11 [-1, 64, 112, 112] 0\n",
|
306 |
+
" Conv2d-12 [-1, 128, 112, 112] 73,728\n",
|
307 |
+
" BatchNorm2d-13 [-1, 128, 112, 112] 256\n",
|
308 |
+
" ReLU-14 [-1, 128, 112, 112] 0\n",
|
309 |
+
" Conv2d-15 [-1, 128, 112, 112] 147,456\n",
|
310 |
+
" BatchNorm2d-16 [-1, 128, 112, 112] 256\n",
|
311 |
+
" ReLU-17 [-1, 128, 112, 112] 0\n",
|
312 |
+
" Conv2d-18 [-1, 128, 112, 112] 8,320\n",
|
313 |
+
" ReLU-19 [-1, 128, 112, 112] 0\n",
|
314 |
+
" ResBlock-20 [-1, 128, 112, 112] 0\n",
|
315 |
+
" DownsampleConv-21 [-1, 128, 112, 112] 0\n",
|
316 |
+
" Dropout2d-22 [-1, 128, 112, 112] 0\n",
|
317 |
+
" MaxPool2d-23 [-1, 128, 56, 56] 0\n",
|
318 |
+
" Conv2d-24 [-1, 256, 56, 56] 294,912\n",
|
319 |
+
" BatchNorm2d-25 [-1, 256, 56, 56] 512\n",
|
320 |
+
" ReLU-26 [-1, 256, 56, 56] 0\n",
|
321 |
+
" Conv2d-27 [-1, 256, 56, 56] 589,824\n",
|
322 |
+
" BatchNorm2d-28 [-1, 256, 56, 56] 512\n",
|
323 |
+
" ReLU-29 [-1, 256, 56, 56] 0\n",
|
324 |
+
" Conv2d-30 [-1, 256, 56, 56] 33,024\n",
|
325 |
+
" ReLU-31 [-1, 256, 56, 56] 0\n",
|
326 |
+
" ResBlock-32 [-1, 256, 56, 56] 0\n",
|
327 |
+
" DownsampleConv-33 [-1, 256, 56, 56] 0\n",
|
328 |
+
" Dropout2d-34 [-1, 256, 56, 56] 0\n",
|
329 |
+
" MaxPool2d-35 [-1, 256, 28, 28] 0\n",
|
330 |
+
" Conv2d-36 [-1, 512, 28, 28] 1,179,648\n",
|
331 |
+
" BatchNorm2d-37 [-1, 512, 28, 28] 1,024\n",
|
332 |
+
" ReLU-38 [-1, 512, 28, 28] 0\n",
|
333 |
+
" Conv2d-39 [-1, 512, 28, 28] 2,359,296\n",
|
334 |
+
" BatchNorm2d-40 [-1, 512, 28, 28] 1,024\n",
|
335 |
+
" ReLU-41 [-1, 512, 28, 28] 0\n",
|
336 |
+
" Conv2d-42 [-1, 512, 28, 28] 131,584\n",
|
337 |
+
" ReLU-43 [-1, 512, 28, 28] 0\n",
|
338 |
+
" ResBlock-44 [-1, 512, 28, 28] 0\n",
|
339 |
+
" DownsampleConv-45 [-1, 512, 28, 28] 0\n",
|
340 |
+
" Dropout2d-46 [-1, 512, 28, 28] 0\n",
|
341 |
+
" Upsample-47 [-1, 512, 56, 56] 0\n",
|
342 |
+
" Conv2d-48 [-1, 256, 56, 56] 1,769,472\n",
|
343 |
+
" BatchNorm2d-49 [-1, 256, 56, 56] 512\n",
|
344 |
+
" ReLU-50 [-1, 256, 56, 56] 0\n",
|
345 |
+
" Conv2d-51 [-1, 256, 56, 56] 589,824\n",
|
346 |
+
" BatchNorm2d-52 [-1, 256, 56, 56] 512\n",
|
347 |
+
" ReLU-53 [-1, 256, 56, 56] 0\n",
|
348 |
+
" Conv2d-54 [-1, 256, 56, 56] 196,864\n",
|
349 |
+
" ReLU-55 [-1, 256, 56, 56] 0\n",
|
350 |
+
" ResBlock-56 [-1, 256, 56, 56] 0\n",
|
351 |
+
" UpsampleConv-57 [-1, 256, 56, 56] 0\n",
|
352 |
+
" Upsample-58 [-1, 256, 112, 112] 0\n",
|
353 |
+
" Conv2d-59 [-1, 128, 112, 112] 442,368\n",
|
354 |
+
" BatchNorm2d-60 [-1, 128, 112, 112] 256\n",
|
355 |
+
" ReLU-61 [-1, 128, 112, 112] 0\n",
|
356 |
+
" Conv2d-62 [-1, 128, 112, 112] 147,456\n",
|
357 |
+
" BatchNorm2d-63 [-1, 128, 112, 112] 256\n",
|
358 |
+
" ReLU-64 [-1, 128, 112, 112] 0\n",
|
359 |
+
" Conv2d-65 [-1, 128, 112, 112] 49,280\n",
|
360 |
+
" ReLU-66 [-1, 128, 112, 112] 0\n",
|
361 |
+
" ResBlock-67 [-1, 128, 112, 112] 0\n",
|
362 |
+
" UpsampleConv-68 [-1, 128, 112, 112] 0\n",
|
363 |
+
" Upsample-69 [-1, 128, 224, 224] 0\n",
|
364 |
+
" Conv2d-70 [-1, 64, 224, 224] 110,592\n",
|
365 |
+
" BatchNorm2d-71 [-1, 64, 224, 224] 128\n",
|
366 |
+
" ReLU-72 [-1, 64, 224, 224] 0\n",
|
367 |
+
" Conv2d-73 [-1, 64, 224, 224] 36,864\n",
|
368 |
+
" BatchNorm2d-74 [-1, 64, 224, 224] 128\n",
|
369 |
+
" ReLU-75 [-1, 64, 224, 224] 0\n",
|
370 |
+
" Conv2d-76 [-1, 64, 224, 224] 12,352\n",
|
371 |
+
" ReLU-77 [-1, 64, 224, 224] 0\n",
|
372 |
+
" ResBlock-78 [-1, 64, 224, 224] 0\n",
|
373 |
+
" UpsampleConv-79 [-1, 64, 224, 224] 0\n",
|
374 |
+
" Dropout2d-80 [-1, 64, 224, 224] 0\n",
|
375 |
+
"================================================================\n",
|
376 |
+
"Total params: 8,216,064\n",
|
377 |
+
"Trainable params: 8,216,064\n",
|
378 |
+
"Non-trainable params: 0\n",
|
379 |
+
"----------------------------------------------------------------\n",
|
380 |
+
"Input size (MB): 0.19\n",
|
381 |
+
"Forward/backward pass size (MB): 1030.53\n",
|
382 |
+
"Params size (MB): 31.34\n",
|
383 |
+
"Estimated Total Size (MB): 1062.06\n",
|
384 |
+
"----------------------------------------------------------------\n",
|
385 |
+
"\n",
|
386 |
+
"Critic Summary:\n",
|
387 |
+
"\n",
|
388 |
+
"Model Summary:\n",
|
389 |
+
"----------------------------------------------------------------\n",
|
390 |
+
" Layer (type) Output Shape Param #\n",
|
391 |
+
"================================================================\n",
|
392 |
+
" Conv2d-1 [-1, 64, 112, 112] 3,136\n",
|
393 |
+
" LeakyReLU-2 [-1, 64, 112, 112] 0\n",
|
394 |
+
" Conv2d-3 [-1, 128, 56, 56] 131,200\n",
|
395 |
+
" InstanceNorm2d-4 [-1, 128, 56, 56] 0\n",
|
396 |
+
" LeakyReLU-5 [-1, 128, 56, 56] 0\n",
|
397 |
+
" Conv2d-6 [-1, 256, 28, 28] 524,544\n",
|
398 |
+
" InstanceNorm2d-7 [-1, 256, 28, 28] 0\n",
|
399 |
+
" LeakyReLU-8 [-1, 256, 28, 28] 0\n",
|
400 |
+
" Conv2d-9 [-1, 512, 14, 14] 2,097,664\n",
|
401 |
+
" InstanceNorm2d-10 [-1, 512, 14, 14] 0\n",
|
402 |
+
" LeakyReLU-11 [-1, 512, 14, 14] 0\n",
|
403 |
+
"AdaptiveAvgPool2d-12 [-1, 512, 1, 1] 0\n",
|
404 |
+
" Flatten-13 [-1, 512] 0\n",
|
405 |
+
" Linear-14 [-1, 1] 513\n",
|
406 |
+
"================================================================\n",
|
407 |
+
"Total params: 2,757,057\n",
|
408 |
+
"Trainable params: 2,757,057\n",
|
409 |
+
"Non-trainable params: 0\n",
|
410 |
+
"----------------------------------------------------------------\n",
|
411 |
+
"Input size (MB): 2824.00\n",
|
412 |
+
"Forward/backward pass size (MB): 28.34\n",
|
413 |
+
"Params size (MB): 10.52\n",
|
414 |
+
"Estimated Total Size (MB): 2862.85\n",
|
415 |
+
"----------------------------------------------------------------\n",
|
416 |
+
"Model saved to artifacts\\model\\generator.pth\n",
|
417 |
+
"Model saved to artifacts\\model\\critic.pth\n"
|
418 |
+
]
|
419 |
+
}
|
420 |
+
],
|
421 |
+
"source": [
|
422 |
+
"try:\n",
|
423 |
+
" config_manager = ConfigurationManager()\n",
|
424 |
+
" model_config = config_manager.get_model_building_config()\n",
|
425 |
+
"\n",
|
426 |
+
" model_building = ModelBuilding(config=model_config)\n",
|
427 |
+
" generator, critic = model_building.build_and_save()\n",
|
428 |
+
"except Exception as e:\n",
|
429 |
+
" raise e"
|
430 |
+
]
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"cell_type": "code",
|
434 |
+
"execution_count": null,
|
435 |
+
"metadata": {},
|
436 |
+
"outputs": [],
|
437 |
+
"source": []
|
438 |
+
}
|
439 |
+
],
|
440 |
+
"metadata": {
|
441 |
+
"kernelspec": {
|
442 |
+
"display_name": "Python 3",
|
443 |
+
"language": "python",
|
444 |
+
"name": "python3"
|
445 |
+
},
|
446 |
+
"language_info": {
|
447 |
+
"codemirror_mode": {
|
448 |
+
"name": "ipython",
|
449 |
+
"version": 3
|
450 |
+
},
|
451 |
+
"file_extension": ".py",
|
452 |
+
"mimetype": "text/x-python",
|
453 |
+
"name": "python",
|
454 |
+
"nbconvert_exporter": "python",
|
455 |
+
"pygments_lexer": "ipython3",
|
456 |
+
"version": "3.11.0"
|
457 |
+
}
|
458 |
+
},
|
459 |
+
"nbformat": 4,
|
460 |
+
"nbformat_minor": 2
|
461 |
+
}
|
src/imagecolorization/config/configuration.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
from src.imagecolorization.constants import *
|
2 |
from src.imagecolorization.utils.common import read_yaml, create_directories
|
3 |
from src.imagecolorization.entity.config_entity import (DataIngestionConfig,
|
4 |
-
DataTransformationConfig
|
|
|
5 |
class ConfigurationManager:
|
6 |
def __init__(
|
7 |
self,
|
@@ -44,6 +45,26 @@ class ConfigurationManager:
|
|
44 |
)
|
45 |
|
46 |
return data_transformation_cofig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
|
49 |
|
|
|
1 |
from src.imagecolorization.constants import *
|
2 |
from src.imagecolorization.utils.common import read_yaml, create_directories
|
3 |
from src.imagecolorization.entity.config_entity import (DataIngestionConfig,
|
4 |
+
DataTransformationConfig,
|
5 |
+
ModelBuildingConfig)
|
6 |
class ConfigurationManager:
|
7 |
def __init__(
|
8 |
self,
|
|
|
45 |
)
|
46 |
|
47 |
return data_transformation_cofig
|
48 |
+
|
49 |
+
def get_model_building_config(self) -> ModelBuildingConfig:
|
50 |
+
config = self.config.model_building
|
51 |
+
params = self.params
|
52 |
+
|
53 |
+
model_building_config = ModelBuildingConfig(
|
54 |
+
root_dir=Path(config.root_dir),
|
55 |
+
KERNEL_SIZE_RES=params.KERNEL_SIZE_RES,
|
56 |
+
PADDING=params.PADDING,
|
57 |
+
STRIDE=params.STRIDE,
|
58 |
+
BIAS=params.BIAS,
|
59 |
+
SCALE_FACTOR=params.SCALE_FACTOR,
|
60 |
+
DIM=params.DIM,
|
61 |
+
DROPOUT_RATE=params.DROPOUT_RATE,
|
62 |
+
KERNEL_SIZE_GENERATOR=params.KERNEL_SIZE_GENERATOR,
|
63 |
+
INPUT_CHANNELS=params.INPUT_CHANNELS,
|
64 |
+
OUTPUT_CHANNELS=params.OUTPUT_CHANNELS,
|
65 |
+
IN_CHANNELS=params.IN_CHANNELS
|
66 |
+
)
|
67 |
+
return model_building_config
|
68 |
|
69 |
|
70 |
|
src/imagecolorization/conponents/model_building.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from pathlib import Path
|
4 |
+
from torchsummary import summary
|
5 |
+
import os
|
6 |
+
from src.imagecolorization.config.configuration import ModelBuildingConfig
|
7 |
+
|
8 |
+
class ResBlock(nn.Module):
|
9 |
+
def __init__(self, in_channles, out_channels, stride = 1, kerenl_size = 3, padding = 1, bias = False):
|
10 |
+
super().__init__()
|
11 |
+
self.layer = nn.Sequential(
|
12 |
+
nn.Conv2d(in_channles, out_channels, kernel_size=kerenl_size, padding=padding, stride=stride, bias = bias),
|
13 |
+
nn.BatchNorm2d(out_channels),
|
14 |
+
nn.ReLU(inplace=True),
|
15 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=kerenl_size, padding=padding, stride = 1, bias = bias),
|
16 |
+
nn.BatchNorm2d(out_channels),
|
17 |
+
nn.ReLU(inplace=True)
|
18 |
+
)
|
19 |
+
|
20 |
+
self.identity_map = nn.Conv2d(in_channles, out_channels,kernel_size=1, stride=stride)
|
21 |
+
self.relu = nn.ReLU(inplace= True)
|
22 |
+
|
23 |
+
def forward(self, inputs):
|
24 |
+
x = inputs.clone().detach()
|
25 |
+
out = self.layer(x)
|
26 |
+
residual = self.identity_map(inputs)
|
27 |
+
skip = out + residual
|
28 |
+
return self.relu(skip)
|
29 |
+
|
30 |
+
|
31 |
+
class DownsampleConv(nn.Module):
|
32 |
+
def __init__(self, in_channels, out_channels, stride = 1):
|
33 |
+
super().__init__()
|
34 |
+
self.layer = nn.Sequential(
|
35 |
+
nn.MaxPool2d(2),
|
36 |
+
ResBlock(in_channels, out_channels)
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, inputs):
|
40 |
+
return self.layer(inputs)
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
class UpsampleConv(nn.Module):
|
45 |
+
def __init__(self, in_channels, out_channels, scale_factor=2):
|
46 |
+
super().__init__()
|
47 |
+
self.upsample = nn.Upsample(scale_factor=scale_factor,mode = 'bilinear', align_corners=True)
|
48 |
+
self.res_block = ResBlock(in_channels + out_channels, out_channels)
|
49 |
+
|
50 |
+
def forward(self, inputs, skip):
|
51 |
+
x = self.upsample(inputs)
|
52 |
+
x = torch.cat([x, skip], dim = 1)
|
53 |
+
x = self.res_block(x)
|
54 |
+
return x
|
55 |
+
|
56 |
+
class Generator(nn.Module):
|
57 |
+
def __init__(self, input_channels, output_channels, dropout_rate = 0.2):
|
58 |
+
super().__init__()
|
59 |
+
self.encoding_layer1_= ResBlock(input_channels, 64)
|
60 |
+
self.encoding_layer2_ = DownsampleConv(64, 128)
|
61 |
+
self.encoding_layer3_ = DownsampleConv(128, 256)
|
62 |
+
self.bridge = DownsampleConv(256, 512)
|
63 |
+
self.decoding_layer3 = UpsampleConv(512, 256)
|
64 |
+
self.decoding_layer2 = UpsampleConv(256, 128)
|
65 |
+
self.decoding_layer1 = UpsampleConv(128 , 64)
|
66 |
+
self.output = nn.Conv2d(64, output_channels, kernel_size = 1)
|
67 |
+
self.dropout = nn.Dropout2d(dropout_rate)
|
68 |
+
|
69 |
+
def forward(self, inputs):
|
70 |
+
e1 = self.encoding_layer1_(inputs)
|
71 |
+
e1 = self.dropout(e1)
|
72 |
+
e2 = self.encoding_layer2_(e1)
|
73 |
+
e2 = self.dropout(e2)
|
74 |
+
e3 = self.encoding_layer3_(e2)
|
75 |
+
e3 = self.dropout(e3)
|
76 |
+
|
77 |
+
bridge = self.bridge(e3)
|
78 |
+
bridge = self.dropout(bridge)
|
79 |
+
|
80 |
+
d3 = self.decoding_layer3(bridge, e3)
|
81 |
+
d2 =self.decoding_layer2(d3, e2)
|
82 |
+
d1 = self.decoding_layer1(d2, e1)
|
83 |
+
|
84 |
+
output = self.dropout(d1)
|
85 |
+
return output
|
86 |
+
|
87 |
+
|
88 |
+
class Critic(nn.Module):
|
89 |
+
def __init__(self, in_channels=3):
|
90 |
+
super(Critic, self).__init__()
|
91 |
+
|
92 |
+
def critic_block(in_filters, out_filters, normalization=True):
|
93 |
+
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
|
94 |
+
if normalization:
|
95 |
+
layers.append(nn.InstanceNorm2d(out_filters))
|
96 |
+
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
97 |
+
return layers
|
98 |
+
|
99 |
+
self.model = nn.Sequential(
|
100 |
+
*critic_block(in_channels, 64, normalization=False),
|
101 |
+
*critic_block(64, 128),
|
102 |
+
*critic_block(128, 256),
|
103 |
+
*critic_block(256, 512),
|
104 |
+
nn.AdaptiveAvgPool2d(1),
|
105 |
+
nn.Flatten(),
|
106 |
+
nn.Linear(512, 1)
|
107 |
+
)
|
108 |
+
|
109 |
+
def forward(self, ab, l):
|
110 |
+
img_input = torch.cat((ab, l), 1)
|
111 |
+
output = self.model(img_input)
|
112 |
+
return output
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
class ModelBuilding:
|
118 |
+
def __init__(self, config: ModelBuildingConfig):
|
119 |
+
self.config = config
|
120 |
+
self.root_dir = self.config.root_dir
|
121 |
+
self.create_root_dir()
|
122 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
123 |
+
|
124 |
+
def create_root_dir(self):
|
125 |
+
os.makedirs(self.root_dir, exist_ok=True)
|
126 |
+
print(f"Created directory: {self.root_dir}")
|
127 |
+
|
128 |
+
def get_generator(self):
|
129 |
+
return Generator(
|
130 |
+
input_channels=self.config.INPUT_CHANNELS, # corrected argument name
|
131 |
+
output_channels=self.config.OUTPUT_CHANNELS, # corrected argument name
|
132 |
+
dropout_rate=self.config.DROPOUT_RATE
|
133 |
+
).to(self.device)
|
134 |
+
|
135 |
+
def get_critic(self):
|
136 |
+
return Critic(in_channels=self.config.IN_CHANNELS).to(self.device)
|
137 |
+
|
138 |
+
def build(self):
|
139 |
+
generator = self.get_generator()
|
140 |
+
critic = self.get_critic()
|
141 |
+
return generator, critic
|
142 |
+
|
143 |
+
def save_model(self, model, filename):
|
144 |
+
path = self.root_dir / filename
|
145 |
+
torch.save(model.state_dict(), path)
|
146 |
+
print(f"Model saved to {path}")
|
147 |
+
|
148 |
+
def display_summary(self, model, input_size):
|
149 |
+
print(f"\nModel Summary:")
|
150 |
+
summary(model, input_size)
|
151 |
+
|
152 |
+
def build_and_save(self):
|
153 |
+
generator, critic = self.build()
|
154 |
+
|
155 |
+
# Display summaries
|
156 |
+
print("\nGenerator Summary:")
|
157 |
+
self.display_summary(generator, (self.config.INPUT_CHANNELS, 224, 224)) # Assuming input size is 224x224
|
158 |
+
|
159 |
+
print("\nCritic Summary:")
|
160 |
+
self.display_summary(critic, [(2, 224, 224), (1, 224, 224)]) # Critic takes two inputs: ab and l
|
161 |
+
|
162 |
+
self.save_model(generator, "generator.pth")
|
163 |
+
self.save_model(critic, "critic.pth")
|
164 |
+
return generator, critic
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
|
src/imagecolorization/entity/config_entity.py
CHANGED
@@ -16,4 +16,20 @@ class DataTransformationConfig:
|
|
16 |
data_path_grey : Path
|
17 |
BATCH_SIZE : int
|
18 |
IMAGE_SIZE : list
|
19 |
-
DATA_RANGE: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
data_path_grey : Path
|
17 |
BATCH_SIZE : int
|
18 |
IMAGE_SIZE : list
|
19 |
+
DATA_RANGE: int
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass(frozen=True)
|
23 |
+
class ModelBuildingConfig:
|
24 |
+
root_dir: Path
|
25 |
+
KERNEL_SIZE_RES: int
|
26 |
+
PADDING: int
|
27 |
+
STRIDE: int
|
28 |
+
BIAS: bool
|
29 |
+
SCALE_FACTOR: int
|
30 |
+
DIM: int
|
31 |
+
DROPOUT_RATE: float
|
32 |
+
KERNEL_SIZE_GENERATOR: int
|
33 |
+
INPUT_CHANNELS: int
|
34 |
+
OUTPUT_CHANNELS: int
|
35 |
+
IN_CHANNELS: int
|
src/imagecolorization/pipeline/stage_03_model_building.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.imagecolorization.conponents.model_building import ModelBuilding
|
2 |
+
from src.imagecolorization.config.configuration import ConfigurationManager
|
3 |
+
|
4 |
+
class ModelBuildingPipeline:
|
5 |
+
def __init__(slef):
|
6 |
+
pass
|
7 |
+
|
8 |
+
def main(self):
|
9 |
+
config_manager = ConfigurationManager()
|
10 |
+
model_config = config_manager.get_model_building_config()
|
11 |
+
|
12 |
+
model_building = ModelBuilding(config=model_config)
|
13 |
+
generator, critic = model_building.build_and_save()
|