Add files using upload-large-folder tool
Browse files- .gitattributes +2 -0
- .idea/.gitignore +3 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/thundernet_upload.iml +8 -0
- .idea/workspace.xml +35 -0
- README.md +59 -3
- data_gen.py +632 -0
- images_toolkit.py +285 -0
- inference_config.py +180 -0
- model/model.py +372 -0
- model/model_ppm_factors.py +438 -0
- profiler.py +98 -0
- requirements.txt +93 -0
- resnet/.gitignore +208 -0
- resnet/.idea/inspectionProfiles/profiles_settings.xml +6 -0
- resnet/.idea/misc.xml +4 -0
- resnet/.idea/modules.xml +8 -0
- resnet/.idea/resnet.iml +12 -0
- resnet/apt.txt +1 -0
- resnet/crowdai.json +7 -0
- resnet/fmodel.py +60 -0
- resnet/main.py +7 -0
- resnet/requirements.txt +3 -0
- resnet/resnet18/__init__.py +0 -0
- resnet/resnet18/checkpoints/model/checkpoint +1 -0
- resnet/resnet18/checkpoints/model/graph.pbtxt +0 -0
- resnet/resnet18/checkpoints/model/model.ckpt-5865.data-00000-of-00001 +3 -0
- resnet/resnet18/checkpoints/model/model.ckpt-5865.index +0 -0
- resnet/resnet18/checkpoints/model/model.ckpt-5865.meta +3 -0
- resnet/resnet18/resnet_model.py +570 -0
- resnet/run.sh +2 -0
- thundernet_config.py +18 -0
- train_config.py +311 -0
- train_optuna.py +255 -0
- utils.py +505 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
resnet/resnet18/checkpoints/model/model.ckpt-5865.meta filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
resnet/resnet18/checkpoints/model/model.ckpt-5865.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/misc.xml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (thundernet_upload)" project-jdk-type="Python SDK" />
|
| 4 |
+
</project>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/thundernet_upload.iml" filepath="$PROJECT_DIR$/.idea/thundernet_upload.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/thundernet_upload.iml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="inheritedJdk" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
</module>
|
.idea/workspace.xml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ChangeListManager">
|
| 4 |
+
<list default="true" id="2b1e7e9a-7e79-45bd-bade-eff89bafbc84" name="Changes" comment="" />
|
| 5 |
+
<option name="SHOW_DIALOG" value="false" />
|
| 6 |
+
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
| 7 |
+
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
| 8 |
+
<option name="LAST_RESOLUTION" value="IGNORE" />
|
| 9 |
+
</component>
|
| 10 |
+
<component name="MarkdownSettingsMigration">
|
| 11 |
+
<option name="stateVersion" value="1" />
|
| 12 |
+
</component>
|
| 13 |
+
<component name="ProjectId" id="3AZCQv5vVF8siEEdrRR9I1Xw2Ty" />
|
| 14 |
+
<component name="ProjectViewState">
|
| 15 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
| 16 |
+
<option name="showLibraryContents" value="true" />
|
| 17 |
+
</component>
|
| 18 |
+
<component name="PropertiesComponent"><![CDATA[{
|
| 19 |
+
"keyToString": {
|
| 20 |
+
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
| 21 |
+
"RunOnceActivity.ShowReadmeOnStart": "true"
|
| 22 |
+
}
|
| 23 |
+
}]]></component>
|
| 24 |
+
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
| 25 |
+
<component name="TaskManager">
|
| 26 |
+
<task active="true" id="Default" summary="Default task">
|
| 27 |
+
<changelist id="2b1e7e9a-7e79-45bd-bade-eff89bafbc84" name="Changes" comment="" />
|
| 28 |
+
<created>1772790692792</created>
|
| 29 |
+
<option name="number" value="Default" />
|
| 30 |
+
<option name="presentableId" value="Default" />
|
| 31 |
+
<updated>1772790692792</updated>
|
| 32 |
+
</task>
|
| 33 |
+
<servers />
|
| 34 |
+
</component>
|
| 35 |
+
</project>
|
README.md
CHANGED
|
@@ -1,3 +1,59 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Thundernet
|
| 2 |
+
Thundernet is a semantic segmentation model that processes RGB input using convolutional networks to extract key features.
|
| 3 |
+
|
| 4 |
+
### USE OF THE REPOSITORY - TRAINING
|
| 5 |
+
|
| 6 |
+
To execute the train file: python train_config.py
|
| 7 |
+
|
| 8 |
+
You can change the parameters adding them as flags in the previous command or changing them directly in the Thundernet_config.py file
|
| 9 |
+
The parameters that you can include (that are relevant for training) are:
|
| 10 |
+
- **train_path (str)**: path to the train data (default: "C:/Users/user/Documents/pruned_training/training/")
|
| 11 |
+
- **val_path (str)**: path to the val data (default: "C:/Users/user/Documents/pruned_training/val/")
|
| 12 |
+
- **model_dir (str)**: path to save the trained model (default: "C:/Users/user/Documents/Thundernet/pruebas_modelos/" )
|
| 13 |
+
- **model_weights (str)**: NOT APPLICABLE FOR TRAINING path to the trained weights (default: "C:/Users/user/Documents/Thundernet/pruebas_modelos/32_ppm/BS4_lossBCE_weights_lr_0.00013713842558297858_reg-1.1743577101671763e-05-ep-13-val_loss0.11463435739278793-train_loss0.053004469722509384-val_iou0.8959722518920898-train_iou0.9606077075004578.hdf5")
|
| 14 |
+
- **batch_size (int)**: batch size (default: 4)
|
| 15 |
+
- **loss (str)**: type of loss to use (default: "BCE")
|
| 16 |
+
- **classes (int)**: number of classes (default: 2). The original purpose of the code was to segment egocentric bodies, so two classes were involved (body and background)
|
| 17 |
+
- **pretrained (bool)**: start from a pretrained model (default: False)
|
| 18 |
+
- **pretrained_weights (str)**: path to the pretrained model (default: None)
|
| 19 |
+
- **lr (float)**: learning rate (default: 1e-4)
|
| 20 |
+
- **epochs (int)**: number of epochs (default: 15)
|
| 21 |
+
- **resolution (str)**: resolution of input images (defualt: 640x480)
|
| 22 |
+
- **kernel_regularizer (float)**: kernel regularizer (default: 2e-4)
|
| 23 |
+
|
| 24 |
+
Note: there are more parameters in the file, but they are no use for training
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
### USE OF THE REPOSITORY - EVALUATION
|
| 28 |
+
|
| 29 |
+
To execute the evaluation, you can execute: python inference_config.py
|
| 30 |
+
|
| 31 |
+
You can change the parameters adding them as flags in the previous command or changing them directly in the Thundernet_config2.py file
|
| 32 |
+
(that are relevant for evaluation) are:
|
| 33 |
+
- **model_weights (str)**: NOT APPLICABLE FOR TRAINING path to the trained weights (default: "C:/Users/user/Documents/Thundernet/pruebas_modelos/32_ppm/BS4_lossBCE_weights_lr_0.00013713842558297858_reg-1.1743577101671763e-05-ep-13-val_loss0.11463435739278793-train_loss0.053004469722509384-val_iou0.8959722518920898-train_iou0.9606077075004578.hdf5")
|
| 34 |
+
- **batch_size (int)**: batch size (default: 4)
|
| 35 |
+
- **resolution (str)**: resolution of input images (defualt: 640x480)
|
| 36 |
+
|
| 37 |
+
Note: there are more parameters in the file, but they are no use for evaluation
|
| 38 |
+
|
| 39 |
+
In the inference_config.py file you can execute the main with a "show=True" to display some predictions. However, you will have to close the image after every prediction.
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
### DATA PREPARATION
|
| 43 |
+
The data must be stored in a path with the following structure:
|
| 44 |
+
|
| 45 |
+
- data
|
| 46 |
+
+ images (folder with RGB images in a .jpg format)
|
| 47 |
+
-- example1.jpg
|
| 48 |
+
+ labels (folder with label images in a .png format)
|
| 49 |
+
-- example1.png
|
| 50 |
+
|
| 51 |
+
The name for the files should be the same for the image and it's corresponding label.
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
### OPTIMIZE HYPERPARAMETERS
|
| 55 |
+
To find the best hyperparameters for training, you can execute the train_optuna.py file. This file will find the optimal values for batch_size, learning_rate and kernel_regularizer. The rest of the hyperparameters values will be obtained from the "thundernet_config.py" file.
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
#### Compare models
|
| 59 |
+
To compare if two models are the same (the weights are equal), the file "compare_models.py" can be executed. The path to the model's weights should be changed in the "weights_path1" and "weights_path2" variables.
|
data_gen.py
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
import tensorflow as tf
|
| 6 |
+
import tensorflow.keras as keras
|
| 7 |
+
import tensorflow.compat.v1 as tf1
|
| 8 |
+
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
import random
|
| 11 |
+
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ImageHelper:
|
| 16 |
+
|
| 17 |
+
COLOR_TRANSFORMATIONS = [
|
| 18 |
+
"saturation",
|
| 19 |
+
"contrast",
|
| 20 |
+
"brightness",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
def __init__(self, img_path, label_path, output_size, **kwargs):
|
| 24 |
+
self.img_path = img_path
|
| 25 |
+
self.label_path = label_path
|
| 26 |
+
self.output_size = output_size
|
| 27 |
+
self.kwargs = kwargs
|
| 28 |
+
|
| 29 |
+
# Stereo
|
| 30 |
+
self.to_stereo = False
|
| 31 |
+
if "to_stereo" in kwargs.keys() and kwargs["to_stereo"]:
|
| 32 |
+
self.to_stereo = True
|
| 33 |
+
|
| 34 |
+
# Flip
|
| 35 |
+
self.flip = False
|
| 36 |
+
if "flip" in kwargs.keys() and kwargs["flip"]:
|
| 37 |
+
self.flip = True
|
| 38 |
+
|
| 39 |
+
# Color transformations
|
| 40 |
+
self.color_transformations = []
|
| 41 |
+
for k, v in self.kwargs.items():
|
| 42 |
+
if k in self.COLOR_TRANSFORMATIONS and v:
|
| 43 |
+
self.color_transformations.append(k)
|
| 44 |
+
|
| 45 |
+
def get(self):
|
| 46 |
+
|
| 47 |
+
# Load
|
| 48 |
+
img = cv2.imread(str(self.img_path))
|
| 49 |
+
label = cv2.imread(str(self.label_path))
|
| 50 |
+
|
| 51 |
+
# Size checking
|
| 52 |
+
assert img.shape == label.shape
|
| 53 |
+
|
| 54 |
+
# Flip
|
| 55 |
+
if self.flip:
|
| 56 |
+
img, label = self.apply_transformation("flip", img, label)
|
| 57 |
+
|
| 58 |
+
# Color transformations
|
| 59 |
+
for color_tr in self.color_transformations:
|
| 60 |
+
img, label = self.apply_transformation(color_tr, img, label)
|
| 61 |
+
|
| 62 |
+
# Numpy3333
|
| 63 |
+
if type(img) != np.ndarray:
|
| 64 |
+
img = np.array(img)
|
| 65 |
+
if type(label) != np.ndarray:
|
| 66 |
+
label = np.array(label)
|
| 67 |
+
|
| 68 |
+
# To stereo
|
| 69 |
+
if self.to_stereo:
|
| 70 |
+
img = np.concatenate((img, img), axis=1)
|
| 71 |
+
label = np.concatenate((label, label), axis=1)
|
| 72 |
+
|
| 73 |
+
# Size
|
| 74 |
+
img = cv2.resize(img, self.output_size[::-1])
|
| 75 |
+
label = cv2.resize(
|
| 76 |
+
label, self.output_size[::-1], interpolation=cv2.INTER_NEAREST
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
label = label[:, :, 0]
|
| 80 |
+
return img, label
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
def apply_transformation(cls, transformation, img, label):
|
| 84 |
+
if transformation == "flip":
|
| 85 |
+
return cls.tensor_to_numpy(
|
| 86 |
+
tf.image.flip_left_right(img)
|
| 87 |
+
), cls.tensor_to_numpy(tf.image.flip_left_right(label))
|
| 88 |
+
elif transformation == "saturation":
|
| 89 |
+
return cls.tensor_to_numpy(tf.image.random_saturation(img, 0.5, 1.5)), label
|
| 90 |
+
elif transformation == "contrast":
|
| 91 |
+
return cls.tensor_to_numpy(tf.image.random_contrast(img, 0.5, 1.5)), label
|
| 92 |
+
elif transformation == "brightness":
|
| 93 |
+
return cls.tensor_to_numpy(tf.image.random_brightness(img, 0.3)), label
|
| 94 |
+
elif transformation == "rotate":
|
| 95 |
+
raise ValueError("This transformation is not supported yet")
|
| 96 |
+
elif transformation == "directed_crop":
|
| 97 |
+
raise ValueError("This transformation is not supported")
|
| 98 |
+
|
| 99 |
+
@staticmethod
|
| 100 |
+
def tensor_to_numpy(tensor):
|
| 101 |
+
if tf.executing_eagerly():
|
| 102 |
+
a = tensor.numpy()
|
| 103 |
+
else:
|
| 104 |
+
raise NotImplementedError(
|
| 105 |
+
"Please adapt the Data Generator to work when not executing eagerly"
|
| 106 |
+
)
|
| 107 |
+
return a
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class DataGenerator(keras.utils.Sequence):
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
images_path,
|
| 115 |
+
labels_path,
|
| 116 |
+
n_classes,
|
| 117 |
+
batch_size=32,
|
| 118 |
+
output_size=(480, 640),
|
| 119 |
+
to_stereo=False,
|
| 120 |
+
flip=False,
|
| 121 |
+
saturation=False,
|
| 122 |
+
contrast=False,
|
| 123 |
+
brightness=False,
|
| 124 |
+
class_mappings=None,
|
| 125 |
+
):
|
| 126 |
+
|
| 127 |
+
self.images_path = Path(images_path)
|
| 128 |
+
self.labels_path = Path(labels_path)
|
| 129 |
+
self.n_classes = n_classes
|
| 130 |
+
self.batch_size = batch_size
|
| 131 |
+
self.output_size = output_size
|
| 132 |
+
self.to_stereo = to_stereo
|
| 133 |
+
self.class_mappings = class_mappings
|
| 134 |
+
|
| 135 |
+
# Check image and labels dir
|
| 136 |
+
img_paths = sorted(list(self.images_path.iterdir()))
|
| 137 |
+
|
| 138 |
+
def has_label(img_filename):
|
| 139 |
+
return (self.labels_path / f"{img_filename.stem}.png").exists()
|
| 140 |
+
|
| 141 |
+
if not all(map(has_label, img_paths)):
|
| 142 |
+
raise FileNotFoundError("Check every image has a label")
|
| 143 |
+
|
| 144 |
+
# Obtain transformations
|
| 145 |
+
transformations = []
|
| 146 |
+
if flip:
|
| 147 |
+
transformations.append("flip")
|
| 148 |
+
if saturation:
|
| 149 |
+
transformations.append("saturation")
|
| 150 |
+
if contrast:
|
| 151 |
+
transformations.append("contrast")
|
| 152 |
+
if brightness:
|
| 153 |
+
transformations.append("brightness")
|
| 154 |
+
|
| 155 |
+
# Prepare augmentation
|
| 156 |
+
elements = []
|
| 157 |
+
for image_path in img_paths:
|
| 158 |
+
label_path = self.labels_path / f"{image_path.stem}.png"
|
| 159 |
+
elements.append(
|
| 160 |
+
ImageHelper(
|
| 161 |
+
image_path,
|
| 162 |
+
label_path,
|
| 163 |
+
self.output_size,
|
| 164 |
+
to_stereo=self.to_stereo,
|
| 165 |
+
)
|
| 166 |
+
)
|
| 167 |
+
for tr in transformations:
|
| 168 |
+
elements.append(
|
| 169 |
+
ImageHelper(
|
| 170 |
+
image_path,
|
| 171 |
+
label_path,
|
| 172 |
+
self.output_size,
|
| 173 |
+
to_stereo=self.to_stereo,
|
| 174 |
+
**{tr: True},
|
| 175 |
+
)
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
self.elements = elements
|
| 179 |
+
|
| 180 |
+
# Shuffle
|
| 181 |
+
np.random.shuffle(self.elements)
|
| 182 |
+
|
| 183 |
+
def __getitem__(self, idx):
|
| 184 |
+
batch_elements = self.elements[
|
| 185 |
+
idx * self.batch_size : (idx + 1) * self.batch_size
|
| 186 |
+
]
|
| 187 |
+
batch_elements_tuple = list(map(lambda x: x.get(), batch_elements))
|
| 188 |
+
X, y = zip(*batch_elements_tuple)
|
| 189 |
+
X, y = np.array(X), np.array(y)
|
| 190 |
+
y_onehot = np.zeros(y.shape + (self.n_classes,))
|
| 191 |
+
for i in np.unique(y):
|
| 192 |
+
i = int(i)
|
| 193 |
+
idx_for_this_class = np.where(y == i)
|
| 194 |
+
if self.class_mappings:
|
| 195 |
+
y_onehot[
|
| 196 |
+
idx_for_this_class
|
| 197 |
+
+ (
|
| 198 |
+
np.ones(len(idx_for_this_class[0]), dtype=int)
|
| 199 |
+
* self.class_mappings[i],
|
| 200 |
+
)
|
| 201 |
+
] = 1
|
| 202 |
+
else:
|
| 203 |
+
y_onehot[
|
| 204 |
+
idx_for_this_class
|
| 205 |
+
+ (np.ones(len(idx_for_this_class[0]), dtype=int) * i,)
|
| 206 |
+
] = 1
|
| 207 |
+
final_X, final_y = X.astype(np.float64) / 255, y_onehot.astype(bool)
|
| 208 |
+
# assert final_X.shape[:-1] == final_y.shape[:-1]
|
| 209 |
+
return final_X, final_y
|
| 210 |
+
|
| 211 |
+
def get_item_name(self, idx):
|
| 212 |
+
return self.elements[idx].img_path.stem
|
| 213 |
+
|
| 214 |
+
def __len__(self):
|
| 215 |
+
|
| 216 |
+
try:
|
| 217 |
+
return np.int(len(self.elements) / self.batch_size)
|
| 218 |
+
except AttributeError:
|
| 219 |
+
return int(len(self.elements) / self.batch_size)
|
| 220 |
+
|
| 221 |
+
def on_epoch_end(self):
|
| 222 |
+
np.random.shuffle(self.elements)
|
| 223 |
+
|
| 224 |
+
@classmethod
|
| 225 |
+
def create_generators(
|
| 226 |
+
cls,
|
| 227 |
+
dataset_dir,
|
| 228 |
+
n_classes,
|
| 229 |
+
training_batch_size=32,
|
| 230 |
+
validation_batch_size=8,
|
| 231 |
+
output_size=(480, 640),
|
| 232 |
+
to_stereo=False,
|
| 233 |
+
transformations=tuple(),
|
| 234 |
+
class_mappings=None,
|
| 235 |
+
):
|
| 236 |
+
"""
|
| 237 |
+
Utily method to create both generators
|
| 238 |
+
Args:
|
| 239 |
+
dataset_dir: path of the dataset, must have training and val dirs
|
| 240 |
+
training_batch_size: batch size of the training generator
|
| 241 |
+
output_size: shape of the generated images
|
| 242 |
+
transformations: for data agumentations
|
| 243 |
+
to_stereo: whether the image and label must be converted to stereo
|
| 244 |
+
class_mappings: dict containing a mapping for each class
|
| 245 |
+
|
| 246 |
+
Returns: a tuple with the training and val genearators
|
| 247 |
+
|
| 248 |
+
"""
|
| 249 |
+
dataset_dir = Path(dataset_dir)
|
| 250 |
+
training_generator = cls(
|
| 251 |
+
dataset_dir / "training" / "images",
|
| 252 |
+
dataset_dir / "training" / "labels",
|
| 253 |
+
n_classes,
|
| 254 |
+
batch_size=training_batch_size,
|
| 255 |
+
output_size=output_size,
|
| 256 |
+
to_stereo=to_stereo,
|
| 257 |
+
**{tr: True for tr in transformations},
|
| 258 |
+
class_mappings=class_mappings,
|
| 259 |
+
)
|
| 260 |
+
validation_generator = cls(
|
| 261 |
+
dataset_dir / "val" / "images",
|
| 262 |
+
dataset_dir / "val" / "labels",
|
| 263 |
+
n_classes,
|
| 264 |
+
batch_size=validation_batch_size,
|
| 265 |
+
output_size=output_size,
|
| 266 |
+
to_stereo=to_stereo,
|
| 267 |
+
**{tr: True for tr in transformations},
|
| 268 |
+
class_mappings=class_mappings,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
return training_generator, validation_generator
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
y_k_size = 6
|
| 275 |
+
x_k_size = 6
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class BaseDataset(Dataset):
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
ignore_label=255,
|
| 282 |
+
base_size=2048,
|
| 283 |
+
crop_size=(512, 1024),
|
| 284 |
+
scale_factor=16,
|
| 285 |
+
mean=[0.485, 0.456, 0.406],
|
| 286 |
+
std=[0.229, 0.224, 0.225],
|
| 287 |
+
):
|
| 288 |
+
|
| 289 |
+
self.base_size = base_size
|
| 290 |
+
self.crop_size = crop_size
|
| 291 |
+
self.ignore_label = ignore_label
|
| 292 |
+
|
| 293 |
+
self.mean = mean
|
| 294 |
+
self.std = std
|
| 295 |
+
self.scale_factor = scale_factor
|
| 296 |
+
|
| 297 |
+
self.files = []
|
| 298 |
+
|
| 299 |
+
def __len__(self):
|
| 300 |
+
return len(self.files)
|
| 301 |
+
|
| 302 |
+
def input_transform(self, image, city=True):
|
| 303 |
+
if city:
|
| 304 |
+
image = image.astype(np.float32)[:, :, ::-1]
|
| 305 |
+
else:
|
| 306 |
+
image = image.astype(np.float32)
|
| 307 |
+
image = image / 255.0
|
| 308 |
+
image -= self.mean
|
| 309 |
+
image /= self.std
|
| 310 |
+
return image
|
| 311 |
+
|
| 312 |
+
def label_transform(self, label):
|
| 313 |
+
return np.array(label).astype(np.uint8)
|
| 314 |
+
|
| 315 |
+
def pad_image(self, image, h, w, size, padvalue):
|
| 316 |
+
pad_image = image.copy()
|
| 317 |
+
pad_h = max(size[0] - h, 0)
|
| 318 |
+
pad_w = max(size[1] - w, 0)
|
| 319 |
+
if pad_h > 0 or pad_w > 0:
|
| 320 |
+
pad_image = cv2.copyMakeBorder(
|
| 321 |
+
image, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=padvalue
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
return pad_image
|
| 325 |
+
|
| 326 |
+
def rand_crop(self, image, label, edge):
|
| 327 |
+
h, w = image.shape[:-1]
|
| 328 |
+
image = self.pad_image(image, h, w, self.crop_size, (0.0, 0.0, 0.0))
|
| 329 |
+
label = self.pad_image(label, h, w, self.crop_size, (self.ignore_label,))
|
| 330 |
+
edge = self.pad_image(edge, h, w, self.crop_size, (0.0,))
|
| 331 |
+
|
| 332 |
+
new_h, new_w = label.shape
|
| 333 |
+
x = random.randint(0, new_w - self.crop_size[1])
|
| 334 |
+
y = random.randint(0, new_h - self.crop_size[0])
|
| 335 |
+
image = image[y : y + self.crop_size[0], x : x + self.crop_size[1]]
|
| 336 |
+
label = label[y : y + self.crop_size[0], x : x + self.crop_size[1]]
|
| 337 |
+
edge = edge[y : y + self.crop_size[0], x : x + self.crop_size[1]]
|
| 338 |
+
|
| 339 |
+
return image, label, edge
|
| 340 |
+
|
| 341 |
+
def multi_scale_aug(
|
| 342 |
+
self, image, label=None, edge=None, rand_scale=1, rand_crop=True
|
| 343 |
+
):
|
| 344 |
+
long_size = np.int(self.base_size * rand_scale + 0.5)
|
| 345 |
+
h, w = image.shape[:2]
|
| 346 |
+
if h > w:
|
| 347 |
+
new_h = long_size
|
| 348 |
+
new_w = np.int(w * long_size / h + 0.5)
|
| 349 |
+
else:
|
| 350 |
+
new_w = long_size
|
| 351 |
+
new_h = np.int(h * long_size / w + 0.5)
|
| 352 |
+
|
| 353 |
+
image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
| 354 |
+
if label is not None:
|
| 355 |
+
label = cv2.resize(label, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
|
| 356 |
+
if edge is not None:
|
| 357 |
+
edge = cv2.resize(edge, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
|
| 358 |
+
else:
|
| 359 |
+
return image
|
| 360 |
+
|
| 361 |
+
if rand_crop:
|
| 362 |
+
image, label, edge = self.rand_crop(image, label, edge)
|
| 363 |
+
|
| 364 |
+
return image, label, edge
|
| 365 |
+
|
| 366 |
+
def gen_sample(
|
| 367 |
+
self,
|
| 368 |
+
image,
|
| 369 |
+
label,
|
| 370 |
+
multi_scale=True,
|
| 371 |
+
is_flip=True,
|
| 372 |
+
edge_pad=True,
|
| 373 |
+
edge_size=4,
|
| 374 |
+
city=False,
|
| 375 |
+
):
|
| 376 |
+
|
| 377 |
+
edge = cv2.Canny(label, 0.1, 0.2)
|
| 378 |
+
kernel = np.ones((edge_size, edge_size), np.uint8)
|
| 379 |
+
if edge_pad:
|
| 380 |
+
edge = edge[y_k_size:-y_k_size, x_k_size:-x_k_size]
|
| 381 |
+
edge = np.pad(
|
| 382 |
+
edge, ((y_k_size, y_k_size), (x_k_size, x_k_size)), mode="constant"
|
| 383 |
+
)
|
| 384 |
+
edge = (cv2.dilate(edge, kernel, iterations=1) > 50) * 1.0
|
| 385 |
+
|
| 386 |
+
if multi_scale:
|
| 387 |
+
rand_scale = 0.5 + random.randint(0, self.scale_factor) / 10.0
|
| 388 |
+
image, label, edge = self.multi_scale_aug(
|
| 389 |
+
image, label, edge, rand_scale=rand_scale
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
image = self.input_transform(image, city=city)
|
| 393 |
+
label = self.label_transform(label)
|
| 394 |
+
|
| 395 |
+
image = image.transpose((2, 0, 1))
|
| 396 |
+
|
| 397 |
+
if is_flip:
|
| 398 |
+
flip = np.random.choice(2) * 2 - 1
|
| 399 |
+
image = image[:, :, ::flip]
|
| 400 |
+
label = label[:, ::flip]
|
| 401 |
+
edge = edge[:, ::flip]
|
| 402 |
+
|
| 403 |
+
return image, label, edge
|
| 404 |
+
|
| 405 |
+
def inference(self, config, model, image):
|
| 406 |
+
size = image.size()
|
| 407 |
+
pred = model(image)
|
| 408 |
+
|
| 409 |
+
if config.MODEL.NUM_OUTPUTS > 1:
|
| 410 |
+
pred = pred[config.TEST.OUTPUT_INDEX]
|
| 411 |
+
|
| 412 |
+
pred = F.interpolate(
|
| 413 |
+
input=pred,
|
| 414 |
+
size=size[-2:],
|
| 415 |
+
mode="bilinear",
|
| 416 |
+
align_corners=config.MODEL.ALIGN_CORNERS,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
return pred.exp()
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
class PIDNetDataset(BaseDataset):
|
| 423 |
+
|
| 424 |
+
def __init__(
|
| 425 |
+
self,
|
| 426 |
+
images_path,
|
| 427 |
+
labels_path,
|
| 428 |
+
n_classes,
|
| 429 |
+
output_size=(480, 640),
|
| 430 |
+
to_stereo=False,
|
| 431 |
+
flip=False,
|
| 432 |
+
saturation=False,
|
| 433 |
+
contrast=False,
|
| 434 |
+
brightness=False,
|
| 435 |
+
class_mappings=None,
|
| 436 |
+
multi_scale=True,
|
| 437 |
+
ignore_label=255,
|
| 438 |
+
base_size=2048,
|
| 439 |
+
crop_size=(512, 1024),
|
| 440 |
+
scale_factor=16,
|
| 441 |
+
# mean=[0.485, 0.456, 0.406],
|
| 442 |
+
# std=[0.229, 0.224, 0.225],
|
| 443 |
+
mean=[0.342, 0.374, 0.416],
|
| 444 |
+
std=[0.241, 0.239, 0.253],
|
| 445 |
+
bd_dilate_size=4,
|
| 446 |
+
):
|
| 447 |
+
|
| 448 |
+
super(PIDNetDataset, self).__init__(
|
| 449 |
+
ignore_label, base_size, crop_size, scale_factor, mean, std
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
self.images_path = Path(images_path)
|
| 453 |
+
self.labels_path = Path(labels_path)
|
| 454 |
+
self.n_classes = n_classes
|
| 455 |
+
self.output_size = output_size
|
| 456 |
+
self.to_stereo = to_stereo
|
| 457 |
+
self.class_mappings = class_mappings
|
| 458 |
+
|
| 459 |
+
self.bd_dilate_size = bd_dilate_size
|
| 460 |
+
self.multi_scale = multi_scale
|
| 461 |
+
self.flip = flip
|
| 462 |
+
|
| 463 |
+
# Check image and labels dir
|
| 464 |
+
img_paths = sorted(list(self.images_path.iterdir()))
|
| 465 |
+
|
| 466 |
+
def has_label(img_filename):
|
| 467 |
+
return (self.labels_path / f"{img_filename.stem}.png").exists()
|
| 468 |
+
|
| 469 |
+
if not all(map(has_label, img_paths)):
|
| 470 |
+
raise FileNotFoundError("Check every image has a label")
|
| 471 |
+
|
| 472 |
+
# Obtain transformations
|
| 473 |
+
transformations = []
|
| 474 |
+
# if flip:
|
| 475 |
+
# transformations.append('flip')
|
| 476 |
+
if saturation:
|
| 477 |
+
transformations.append("saturation")
|
| 478 |
+
if contrast:
|
| 479 |
+
transformations.append("contrast")
|
| 480 |
+
if brightness:
|
| 481 |
+
transformations.append("brightness")
|
| 482 |
+
|
| 483 |
+
# Prepare augmentation
|
| 484 |
+
elements = []
|
| 485 |
+
for image_path in img_paths:
|
| 486 |
+
label_path = self.labels_path / f"{image_path.stem}.png"
|
| 487 |
+
elements.append(
|
| 488 |
+
ImageHelper(
|
| 489 |
+
image_path,
|
| 490 |
+
label_path,
|
| 491 |
+
self.output_size,
|
| 492 |
+
to_stereo=self.to_stereo,
|
| 493 |
+
)
|
| 494 |
+
)
|
| 495 |
+
for tr in transformations:
|
| 496 |
+
elements.append(
|
| 497 |
+
ImageHelper(
|
| 498 |
+
image_path,
|
| 499 |
+
label_path,
|
| 500 |
+
self.output_size,
|
| 501 |
+
to_stereo=self.to_stereo,
|
| 502 |
+
**{tr: True},
|
| 503 |
+
)
|
| 504 |
+
)
|
| 505 |
+
self.elements = elements
|
| 506 |
+
|
| 507 |
+
def __len__(self):
|
| 508 |
+
return len(self.elements)
|
| 509 |
+
|
| 510 |
+
def __getitem__(self, idx):
|
| 511 |
+
|
| 512 |
+
element = self.elements[idx]
|
| 513 |
+
name = element.img_path.stem
|
| 514 |
+
|
| 515 |
+
X, y = element.get()
|
| 516 |
+
|
| 517 |
+
# Class mappings
|
| 518 |
+
if self.class_mappings:
|
| 519 |
+
y = np.vectorize(lambda x: self.class_mappings[x])(y).astype(np.uint8)
|
| 520 |
+
|
| 521 |
+
y_onehot = np.zeros(y.shape + (self.n_classes,))
|
| 522 |
+
for i in np.unique(y):
|
| 523 |
+
i = int(i)
|
| 524 |
+
idx_for_this_class = np.where(y == i)
|
| 525 |
+
if self.class_mappings:
|
| 526 |
+
y_onehot[
|
| 527 |
+
idx_for_this_class
|
| 528 |
+
+ (
|
| 529 |
+
np.ones(len(idx_for_this_class[0]), dtype=int)
|
| 530 |
+
* self.class_mappings[i],
|
| 531 |
+
)
|
| 532 |
+
] = 1
|
| 533 |
+
else:
|
| 534 |
+
y_onehot[
|
| 535 |
+
idx_for_this_class
|
| 536 |
+
+ (np.ones(len(idx_for_this_class[0]), dtype=int) * i,)
|
| 537 |
+
] = 1
|
| 538 |
+
|
| 539 |
+
# assert final_X.shape[:-1] == final_y.shape[:-1]
|
| 540 |
+
image, label = X, y
|
| 541 |
+
|
| 542 |
+
image, label, edge = self.gen_sample(
|
| 543 |
+
image, label, self.multi_scale, self.flip, edge_size=self.bd_dilate_size
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
return image.copy(), label.copy(), edge.copy(), np.array(image.shape), name
|
| 547 |
+
|
| 548 |
+
@classmethod
|
| 549 |
+
def create_train_and_test_datasets(
|
| 550 |
+
cls,
|
| 551 |
+
dataset_dir,
|
| 552 |
+
n_classes,
|
| 553 |
+
output_size=(480, 640),
|
| 554 |
+
to_stereo=False,
|
| 555 |
+
transformations=tuple(),
|
| 556 |
+
class_mappings=None,
|
| 557 |
+
):
|
| 558 |
+
dataset_dir = Path(dataset_dir)
|
| 559 |
+
training_generator = cls(
|
| 560 |
+
dataset_dir / "training" / "images",
|
| 561 |
+
dataset_dir / "training" / "labels",
|
| 562 |
+
n_classes,
|
| 563 |
+
output_size=output_size,
|
| 564 |
+
to_stereo=to_stereo,
|
| 565 |
+
**{tr: True for tr in transformations},
|
| 566 |
+
class_mappings=class_mappings,
|
| 567 |
+
)
|
| 568 |
+
validation_generator = cls(
|
| 569 |
+
dataset_dir / "val" / "images",
|
| 570 |
+
dataset_dir / "val" / "labels",
|
| 571 |
+
n_classes,
|
| 572 |
+
output_size=output_size,
|
| 573 |
+
to_stereo=to_stereo,
|
| 574 |
+
# **{tr: True for tr in transformations}
|
| 575 |
+
class_mappings=class_mappings,
|
| 576 |
+
)
|
| 577 |
+
return training_generator, validation_generator
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
class MergedDataset(Dataset):
|
| 581 |
+
|
| 582 |
+
def __init__(self, *datasets):
|
| 583 |
+
self.datasets = datasets
|
| 584 |
+
for d in self.datasets:
|
| 585 |
+
assert isinstance(d, Dataset)
|
| 586 |
+
self.lens = [len(d) for d in self.datasets]
|
| 587 |
+
self.acc_lens = [sum(self.lens[: i + 1]) for i in range(len(self.lens))]
|
| 588 |
+
|
| 589 |
+
def __len__(self):
|
| 590 |
+
return sum(self.lens)
|
| 591 |
+
|
| 592 |
+
def __getitem__(self, idx):
|
| 593 |
+
for i in range(len(self.acc_lens)):
|
| 594 |
+
if idx < self.acc_lens[i]:
|
| 595 |
+
diff = self.acc_lens[i - 1] if i != 0 else 0
|
| 596 |
+
s = self.datasets[i][idx - diff]
|
| 597 |
+
# assert s[1].max() <= 3
|
| 598 |
+
# assert s[1].max() <= 3
|
| 599 |
+
return s
|
| 600 |
+
raise ValueError(
|
| 601 |
+
f"idx out of range, was {idx}, should be less than {self.__len__()}"
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
if __name__ == "__main__":
|
| 606 |
+
"""
|
| 607 |
+
dataset_dir = Path('/home/user/nas/Datasets/egocentric_segmentation/joint-ep-of-thu-ego-for-5-office-objects/')
|
| 608 |
+
helper = ImageHelper(
|
| 609 |
+
dataset_dir / 'training' / 'images' / 'L515_020_003_rgb_0246.jpg',
|
| 610 |
+
dataset_dir / 'training' / 'labels' / 'L515_020_003_rgb_0246.png',
|
| 611 |
+
(480, 640),
|
| 612 |
+
to_stereo=True
|
| 613 |
+
)
|
| 614 |
+
image, label = helper.get()
|
| 615 |
+
"""
|
| 616 |
+
gen = DataGenerator(
|
| 617 |
+
Path(
|
| 618 |
+
"C:/Users/xruser/RealTimeSemanticSegmentation/joint-ep-of-thu-ego-stereo-1280x480/joint-ep-of-thu-ego-stereo-1280x480/"
|
| 619 |
+
)
|
| 620 |
+
/ "pruned_training"
|
| 621 |
+
/ "images",
|
| 622 |
+
Path(
|
| 623 |
+
"C:/Users/xruser/RealTimeSemanticSegmentation/joint-ep-of-thu-ego-stereo-1280x480/joint-ep-of-thu-ego-stereo-1280x480/"
|
| 624 |
+
)
|
| 625 |
+
/ "pruned_training"
|
| 626 |
+
/ "labels",
|
| 627 |
+
7,
|
| 628 |
+
batch_size=4,
|
| 629 |
+
to_stereo=True,
|
| 630 |
+
)
|
| 631 |
+
images, labels = gen[0]
|
| 632 |
+
print("hola")
|
images_toolkit.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
from random import random
|
| 8 |
+
import scipy.misc
|
| 9 |
+
from glob import glob
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def make_image_square(image):
|
| 15 |
+
|
| 16 |
+
if image.shape[0] > image.shape[1]:
|
| 17 |
+
shift = int((image.shape[0] - image.shape[1]) / 2)
|
| 18 |
+
print(shift)
|
| 19 |
+
small_dim = image.shape[1]
|
| 20 |
+
image = image[shift : shift + small_dim, :, :]
|
| 21 |
+
|
| 22 |
+
else:
|
| 23 |
+
shift = int((image.shape[1] - image.shape[0]) / 2)
|
| 24 |
+
print(shift)
|
| 25 |
+
small_dim = image.shape[0]
|
| 26 |
+
image = image[:, shift : shift + small_dim]
|
| 27 |
+
|
| 28 |
+
return image
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def mix_with_background(path_background, frame, fg):
|
| 32 |
+
|
| 33 |
+
images = [img for img in os.listdir(path_background) if img.endswith(".jpg")]
|
| 34 |
+
|
| 35 |
+
# Assure that there is a binary image
|
| 36 |
+
# fg[np.where(fg > 150)] = 255
|
| 37 |
+
# fg[np.where(fg < 150)] = 0
|
| 38 |
+
|
| 39 |
+
# fg[np.where(fg == 1)] = 255
|
| 40 |
+
# fg[np.where(fg != 1)] = 0
|
| 41 |
+
|
| 42 |
+
upper_limit = 0
|
| 43 |
+
lower_limit = len(images) - 1
|
| 44 |
+
num = np.uint8(random() * (upper_limit - lower_limit) + lower_limit)
|
| 45 |
+
# print(images[num])
|
| 46 |
+
bkg = cv2.imread(os.path.join(path_background, images[num]))
|
| 47 |
+
# bkg = cv2.cvtColor(bkg, cv2.COLOR_BGR2RGB)
|
| 48 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 49 |
+
|
| 50 |
+
[index_r, index_c] = np.where((fg[:, :] >= 200))
|
| 51 |
+
# [index_r, index_c] = np.where((fg[:,:] == 1))
|
| 52 |
+
|
| 53 |
+
for i in range(1, len(index_r)):
|
| 54 |
+
|
| 55 |
+
bkg[index_r[i], index_c[i], 0] = frame[index_r[i], index_c[i], 0]
|
| 56 |
+
bkg[index_r[i], index_c[i], 1] = frame[index_r[i], index_c[i], 1]
|
| 57 |
+
bkg[index_r[i], index_c[i], 2] = frame[index_r[i], index_c[i], 2]
|
| 58 |
+
|
| 59 |
+
return bkg
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def equalize(image):
|
| 63 |
+
|
| 64 |
+
output = np.zeros((image.shape[0], image.shape[1]))
|
| 65 |
+
|
| 66 |
+
img_yuv = cv2.cvtColor(image, cv2.COLOR_BGR2YUV)
|
| 67 |
+
# This configuration achieves a very slight equalization
|
| 68 |
+
clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(1, 1))
|
| 69 |
+
img_yuv[:, :, 0] = clahe.apply(img_yuv[:, :, 0])
|
| 70 |
+
output = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2BGR)
|
| 71 |
+
|
| 72 |
+
return output
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def show_image(image, text=None):
|
| 76 |
+
plt.imshow(image)
|
| 77 |
+
plt.show()
|
| 78 |
+
if text is not None:
|
| 79 |
+
plt.title(text)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def show_two_images(image, image2, text=None, cmap=None, horizontal=True):
|
| 83 |
+
|
| 84 |
+
plt.figure(1)
|
| 85 |
+
|
| 86 |
+
if horizontal:
|
| 87 |
+
plt.subplot(121)
|
| 88 |
+
else:
|
| 89 |
+
plt.subplot(211)
|
| 90 |
+
|
| 91 |
+
if cmap:
|
| 92 |
+
plt.imshow(image, cmap=cmap)
|
| 93 |
+
else:
|
| 94 |
+
plt.imshow(cv2.cvtColor(image.astype("float32"), cv2.COLOR_BGR2RGB))
|
| 95 |
+
|
| 96 |
+
if text is not None:
|
| 97 |
+
plt.title(text)
|
| 98 |
+
|
| 99 |
+
if horizontal:
|
| 100 |
+
plt.subplot(122)
|
| 101 |
+
else:
|
| 102 |
+
plt.subplot(212)
|
| 103 |
+
|
| 104 |
+
if cmap:
|
| 105 |
+
plt.imshow(image2, cmap=cmap)
|
| 106 |
+
else:
|
| 107 |
+
plt.imshow(cv2.cvtColor(image2.astype("float32"), cv2.COLOR_BGR2RGB))
|
| 108 |
+
plt.show()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def show_three_images(image, image2, image3, text=None):
|
| 112 |
+
|
| 113 |
+
plt.figure(1)
|
| 114 |
+
plt.subplot(311)
|
| 115 |
+
plt.imshow(image)
|
| 116 |
+
|
| 117 |
+
if text is not None:
|
| 118 |
+
plt.title(text)
|
| 119 |
+
|
| 120 |
+
plt.subplot(312)
|
| 121 |
+
|
| 122 |
+
# plt.imshow(image2,cmap='gray', vmin=0, vmax=1)
|
| 123 |
+
plt.imshow(image2)
|
| 124 |
+
plt.subplot(313)
|
| 125 |
+
|
| 126 |
+
# plt.imshow(image3,cmap='gray', vmin=0, vmax=1)
|
| 127 |
+
plt.imshow(image3)
|
| 128 |
+
plt.show()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def show_four_images(image, image2, image3, image4, text=None):
|
| 132 |
+
|
| 133 |
+
plt.figure(1)
|
| 134 |
+
plt.subplot(221)
|
| 135 |
+
plt.imshow(image)
|
| 136 |
+
|
| 137 |
+
if text is not None:
|
| 138 |
+
plt.title(text)
|
| 139 |
+
|
| 140 |
+
plt.subplot(222)
|
| 141 |
+
plt.imshow(image2)
|
| 142 |
+
# plt.imshow(image2, cmap='gray', vmin=0, vmax=1)
|
| 143 |
+
|
| 144 |
+
plt.subplot(223)
|
| 145 |
+
plt.imshow(image3)
|
| 146 |
+
# plt.imshow(image3, cmap='gray', vmin=0, vmax=1)
|
| 147 |
+
|
| 148 |
+
plt.subplot(224)
|
| 149 |
+
plt.imshow(image4)
|
| 150 |
+
# plt.imshow(image4, cmap='gray', vmin=0, vmax=1)
|
| 151 |
+
plt.show()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def show_six_images(image, image2, image3, image4, image5, image6, text=None):
|
| 155 |
+
|
| 156 |
+
plt.figure(1)
|
| 157 |
+
plt.subplot(231)
|
| 158 |
+
plt.imshow(image)
|
| 159 |
+
|
| 160 |
+
if text is not None:
|
| 161 |
+
plt.title(text)
|
| 162 |
+
|
| 163 |
+
plt.subplot(232)
|
| 164 |
+
|
| 165 |
+
plt.imshow(image2, cmap="gray", vmin=0, vmax=1)
|
| 166 |
+
|
| 167 |
+
plt.subplot(233)
|
| 168 |
+
|
| 169 |
+
plt.imshow(image3, cmap="gray", vmin=0, vmax=1)
|
| 170 |
+
|
| 171 |
+
plt.subplot(234)
|
| 172 |
+
|
| 173 |
+
plt.imshow(image4, cmap="gray", vmin=0, vmax=1)
|
| 174 |
+
|
| 175 |
+
plt.subplot(235)
|
| 176 |
+
|
| 177 |
+
plt.imshow(image5, cmap="gray", vmin=0, vmax=1)
|
| 178 |
+
|
| 179 |
+
plt.subplot(236)
|
| 180 |
+
|
| 181 |
+
plt.imshow(image6, cmap="gray", vmin=0, vmax=1)
|
| 182 |
+
|
| 183 |
+
plt.show()
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def show_image_per_channel(image, text):
|
| 187 |
+
|
| 188 |
+
plt.figure(1)
|
| 189 |
+
plt.subplot(311)
|
| 190 |
+
plt.imshow(image[:, :, 0])
|
| 191 |
+
plt.subplot(312)
|
| 192 |
+
plt.imshow(image[:, :, 1])
|
| 193 |
+
plt.subplot(313)
|
| 194 |
+
plt.imshow(image[:, :, 2])
|
| 195 |
+
|
| 196 |
+
plt.show()
|
| 197 |
+
if text is not None:
|
| 198 |
+
plt.title(text)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def pixels_to_labels(image):
|
| 202 |
+
|
| 203 |
+
print(image.shape)
|
| 204 |
+
output = np.zeros((image.shape[0], image.shape[1]))
|
| 205 |
+
|
| 206 |
+
output[np.where(image > 200)] = 255
|
| 207 |
+
output[np.where(image < 200)] = 0
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def jpg_to_png(path):
|
| 211 |
+
|
| 212 |
+
image_paths = glob(os.path.join(path, "*.jpg"))
|
| 213 |
+
|
| 214 |
+
for i in range(0, len(image_paths)):
|
| 215 |
+
print(i)
|
| 216 |
+
name = image_paths[i]
|
| 217 |
+
idx = name.rfind("/")
|
| 218 |
+
image = scipy.misc.imread(image_paths[i])
|
| 219 |
+
scipy.misc.imsave(
|
| 220 |
+
os.path.join(path, name[idx + 1 : -4] + ".png"), image.astype(np.uint8)
|
| 221 |
+
) # Really important to convert to uint8
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def jpeg_to_jpg(path):
|
| 225 |
+
|
| 226 |
+
image_paths = glob(os.path.join(path, "*.jpeg"))
|
| 227 |
+
|
| 228 |
+
for i in range(0, len(image_paths)):
|
| 229 |
+
print(i)
|
| 230 |
+
name = image_paths[i]
|
| 231 |
+
idx = name.rfind("/")
|
| 232 |
+
image = scipy.misc.imread(image_paths[i])
|
| 233 |
+
scipy.misc.imsave(
|
| 234 |
+
os.path.join(path, name[idx + 1 : -5] + ".jpg"), image
|
| 235 |
+
) # Really important to convert to uint8
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def overlap_image_with_label(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 239 |
+
"""
|
| 240 |
+
This function overlaps the mask with the image.
|
| 241 |
+
In other words, it only plots the image where the
|
| 242 |
+
segmentation mask is >= 1
|
| 243 |
+
Args:
|
| 244 |
+
- image (numpy array): RGB image of shape (1, 480, 640, 3)
|
| 245 |
+
- label (numpy array): segmentation mask of shape (480, 640)
|
| 246 |
+
Returns:
|
| 247 |
+
- overlap (numpy array): overlapped image
|
| 248 |
+
"""
|
| 249 |
+
binary_mask = (mask > 0).astype(np.uint8)
|
| 250 |
+
overlapped = image.copy().squeeze(0)
|
| 251 |
+
overlapped[binary_mask == 0] = 0
|
| 252 |
+
|
| 253 |
+
return overlapped
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def show_x_images(images, titles=None, cmap=None, horizontal=False):
|
| 257 |
+
num_images = len(images)
|
| 258 |
+
|
| 259 |
+
if num_images == 1:
|
| 260 |
+
plt.imshow(images[0])
|
| 261 |
+
if titles:
|
| 262 |
+
plt.title(titles[0])
|
| 263 |
+
plt.show()
|
| 264 |
+
return
|
| 265 |
+
|
| 266 |
+
if horizontal:
|
| 267 |
+
cols = num_images
|
| 268 |
+
rows = 1
|
| 269 |
+
else:
|
| 270 |
+
cols = math.ceil(math.sqrt(num_images))
|
| 271 |
+
rows = math.ceil(num_images / cols)
|
| 272 |
+
|
| 273 |
+
plt.figure(figsize=(15, 5))
|
| 274 |
+
for i, image in enumerate(images):
|
| 275 |
+
plt.subplot(rows, cols, i + 1)
|
| 276 |
+
if cmap:
|
| 277 |
+
plt.imshow(image, cmap=cmap)
|
| 278 |
+
else:
|
| 279 |
+
plt.imshow(cv2.cvtColor(image.astype("float32"), cv2.COLOR_BGR2RGB))
|
| 280 |
+
|
| 281 |
+
if titles and i < len(titles):
|
| 282 |
+
plt.title(titles[i])
|
| 283 |
+
|
| 284 |
+
plt.show()
|
| 285 |
+
return
|
inference_config.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import sys
|
| 3 |
+
from data_gen import DataGenerator
|
| 4 |
+
|
| 5 |
+
from model.model import Thundernet as Thundernet_original
|
| 6 |
+
from model.model_ppm_factors import Thundernet as Thundernet_ppm
|
| 7 |
+
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
|
| 10 |
+
import thundernet_config as Thundernet_config
|
| 11 |
+
import numpy as np
|
| 12 |
+
import argparse
|
| 13 |
+
from glob import glob
|
| 14 |
+
from utils import resolution2framesize3cha, simple_iou_for_multiple_classes, image_test
|
| 15 |
+
import tqdm
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
from images_toolkit import show_two_images, overlap_image_with_label, show_x_images
|
| 19 |
+
|
| 20 |
+
# Example command: python inference_config.py --model_path C:/Users/user/Documents/Thundernet/pruebas_modelos/32_ppm/BS4_lossBCE_weights_lr_0.00013713842558297858_reg-1.1743577101671763e-05-ep-13-val_loss0.11463435739278793-train_loss0.053004469722509384-val_iou0.8959722518920898-train_iou0.9606077075004578.hdf5 --classes 2
|
| 21 |
+
|
| 22 |
+
baseline_duration = None
|
| 23 |
+
|
| 24 |
+
parser = argparse.ArgumentParser()
|
| 25 |
+
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--model_path",
|
| 28 |
+
type=str,
|
| 29 |
+
default=Thundernet_config.model_weights,
|
| 30 |
+
help="Base directory for the hdf5 model, they are usually stored is /home/user/nas/deep_experiments/",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--classes", type=int, default=Thundernet_config.classes, help="Number of classes. "
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--resolution",
|
| 39 |
+
type=str,
|
| 40 |
+
default=Thundernet_config.resolution,
|
| 41 |
+
help="Input Resolution",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main(
|
| 46 |
+
args: list,
|
| 47 |
+
model: str = "original",
|
| 48 |
+
class_mappings: dict = None,
|
| 49 |
+
transformations: tuple = tuple(),
|
| 50 |
+
show: bool = False,
|
| 51 |
+
) -> None:
|
| 52 |
+
"""
|
| 53 |
+
Perform inference in a set of images. If show=True, each prediction
|
| 54 |
+
will be shown in the screen.
|
| 55 |
+
Args:
|
| 56 |
+
- args (list): list of parsed arguments
|
| 57 |
+
- model (str): type of model. Default: "original"
|
| 58 |
+
- class_mappings (dict): class mapper. Default: None
|
| 59 |
+
- transformations (tuple): list of transformations to execute in the data. Default: tuple()
|
| 60 |
+
- show (bool): display the predictions. Default: False
|
| 61 |
+
Returns:
|
| 62 |
+
- None
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
FLAGS: argparse.Namespace = parser.parse_args(args)
|
| 66 |
+
|
| 67 |
+
# Get the model
|
| 68 |
+
if model == "original":
|
| 69 |
+
Thundernet = Thundernet_original
|
| 70 |
+
elif model == "ppm":
|
| 71 |
+
Thundernet = Thundernet_ppm
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError(f"Unknown model: {model}")
|
| 74 |
+
|
| 75 |
+
# Set class mapping
|
| 76 |
+
if class_mappings is not None:
|
| 77 |
+
FLAGS.classes = len(set(class_mappings.values())) + 1
|
| 78 |
+
|
| 79 |
+
# Get the shape and the classes
|
| 80 |
+
input_shape = resolution2framesize3cha(FLAGS.resolution)
|
| 81 |
+
classes = FLAGS.classes
|
| 82 |
+
|
| 83 |
+
# Initialize the model with loaded weights
|
| 84 |
+
try:
|
| 85 |
+
thundernet = Thundernet(
|
| 86 |
+
input_shape=input_shape, resnet_trainable=False, n_classes=classes
|
| 87 |
+
)
|
| 88 |
+
model = thundernet.model
|
| 89 |
+
except ValueError:
|
| 90 |
+
if model == "ppm":
|
| 91 |
+
Thundernet = Thundernet_original
|
| 92 |
+
else:
|
| 93 |
+
Thundernet = Thundernet_ppm
|
| 94 |
+
|
| 95 |
+
thundernet = Thundernet(
|
| 96 |
+
input_shape=input_shape, resnet_trainable=False, n_classes=classes
|
| 97 |
+
)
|
| 98 |
+
model = thundernet.model
|
| 99 |
+
|
| 100 |
+
thundernet.model.load_weights(FLAGS.model_path)
|
| 101 |
+
|
| 102 |
+
# Create dataloader for data
|
| 103 |
+
dataset_dir: Path = Path(Thundernet_config.train_path).parent
|
| 104 |
+
validation_generator: DataGenerator
|
| 105 |
+
_, validation_generator = DataGenerator.create_generators(
|
| 106 |
+
dataset_dir,
|
| 107 |
+
FLAGS.classes,
|
| 108 |
+
training_batch_size=1,
|
| 109 |
+
validation_batch_size=1,
|
| 110 |
+
to_stereo=False,
|
| 111 |
+
transformations=transformations,
|
| 112 |
+
class_mappings=class_mappings,
|
| 113 |
+
)
|
| 114 |
+
# Initilize lists to save data
|
| 115 |
+
iou_aux: list = []
|
| 116 |
+
iou_global: list = []
|
| 117 |
+
durations: list = []
|
| 118 |
+
|
| 119 |
+
# Iterate through the generator to get the iou metrics
|
| 120 |
+
for i in tqdm.tqdm(range(len(validation_generator))):
|
| 121 |
+
|
| 122 |
+
X, y = validation_generator[i]
|
| 123 |
+
start_t = time.perf_counter()
|
| 124 |
+
pred = model.predict(X) # Shape: [1, 480, 640, 2]
|
| 125 |
+
duration = time.perf_counter() - start_t
|
| 126 |
+
durations.append(1000 * duration)
|
| 127 |
+
|
| 128 |
+
pred = pred[0, :, :, :] # Shape [480, 640, 2]
|
| 129 |
+
|
| 130 |
+
prediction = np.argmax(pred, axis=2) # Shape [480, 640]
|
| 131 |
+
|
| 132 |
+
label = y[0].argmax(axis=-1) * 255
|
| 133 |
+
|
| 134 |
+
if show:
|
| 135 |
+
|
| 136 |
+
label_RGB = overlap_image_with_label(X, label)
|
| 137 |
+
prediction_RGB = overlap_image_with_label(X, prediction)
|
| 138 |
+
show_x_images(
|
| 139 |
+
images=[label_RGB, prediction_RGB],
|
| 140 |
+
titles=["Real", "Prediction"],
|
| 141 |
+
horizontal=True,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
iou_simple_iou = simple_iou_for_multiple_classes(
|
| 145 |
+
y[0].argmax(axis=-1), prediction, classes
|
| 146 |
+
)
|
| 147 |
+
iou_global.append(iou_simple_iou)
|
| 148 |
+
iou_aux = np.array(iou_global)
|
| 149 |
+
|
| 150 |
+
name_image = validation_generator.get_item_name(i)
|
| 151 |
+
|
| 152 |
+
for i in range(0, classes + 1):
|
| 153 |
+
if classes <= 3 and i == classes:
|
| 154 |
+
break
|
| 155 |
+
values = iou_aux[:, i]
|
| 156 |
+
values = values[~np.isnan(values)]
|
| 157 |
+
print("IoU for class=", i, "is ", np.mean(values))
|
| 158 |
+
|
| 159 |
+
durations = np.array(durations)
|
| 160 |
+
|
| 161 |
+
print("")
|
| 162 |
+
print("INFERENCE TIME")
|
| 163 |
+
print(f" - Mean: {np.mean(durations)}")
|
| 164 |
+
print(f" - Std: {np.std(durations)}")
|
| 165 |
+
|
| 166 |
+
if baseline_duration:
|
| 167 |
+
durations_baseline = np.load(Path(baseline_duration).open("rb"))
|
| 168 |
+
diff_durations = durations - durations_baseline
|
| 169 |
+
print("INFERENCE TIME WITH RESPECT TO BASELINE (ABSOLUTE)")
|
| 170 |
+
print(f" - Mean: {np.mean(diff_durations)}")
|
| 171 |
+
print(f" - Std: {np.std(diff_durations)}")
|
| 172 |
+
increase_durations = (durations - durations_baseline) / durations_baseline
|
| 173 |
+
print("INFERENCE TIME WITH RESPECT TO BASELINE (RELATIVE)")
|
| 174 |
+
print(f" - Mean: {np.mean(increase_durations)}")
|
| 175 |
+
print(f" - Std: {np.std(increase_durations)}")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if __name__ == "__main__":
|
| 179 |
+
main(sys.argv[1:], model="ppm", class_mappings=defaultdict(int, {1: 1}))
|
| 180 |
+
# main(sys.argv[1:], model="original", class_mappings=defaultdict(int, {1: 1}), show=False)
|
model/model.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tensorflow.keras.layers import (
|
| 2 |
+
Input,
|
| 3 |
+
Lambda,
|
| 4 |
+
Concatenate,
|
| 5 |
+
Conv2D,
|
| 6 |
+
Conv2DTranspose,
|
| 7 |
+
MaxPooling2D,
|
| 8 |
+
BatchNormalization,
|
| 9 |
+
Activation,
|
| 10 |
+
Add,
|
| 11 |
+
AveragePooling2D,
|
| 12 |
+
UpSampling2D,
|
| 13 |
+
SeparableConv2D,
|
| 14 |
+
SpatialDropout2D,
|
| 15 |
+
)
|
| 16 |
+
from tensorflow.keras.models import Model
|
| 17 |
+
from tensorflow.keras.layers imorport ConvLSTM2D
|
| 18 |
+
from tensorflow.keras import callbacks
|
| 19 |
+
import tensorflow.keras.optimizers
|
| 20 |
+
from tensorflow.keras.regularizers import l2
|
| 21 |
+
from tensorflow.python import pywrap_tensorflow
|
| 22 |
+
import tensorflow as tf
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Thundernet:
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
input_shape=(512, 1024, 3),
|
| 30 |
+
resnet_trainable=False,
|
| 31 |
+
kernel_regularizer=0,
|
| 32 |
+
n_classes=38,
|
| 33 |
+
):
|
| 34 |
+
self.input_shape = input_shape
|
| 35 |
+
self.resnet_trainable = resnet_trainable
|
| 36 |
+
self.n_classes = n_classes
|
| 37 |
+
self.model = self.thundernet(input_shape, resnet_trainable, kernel_regularizer)
|
| 38 |
+
self.load_resnet_weights()
|
| 39 |
+
|
| 40 |
+
def resnet_layer(
|
| 41 |
+
self,
|
| 42 |
+
inp,
|
| 43 |
+
downsample_first=True,
|
| 44 |
+
filters=64,
|
| 45 |
+
first=False,
|
| 46 |
+
number=0,
|
| 47 |
+
resnet_trainable=False,
|
| 48 |
+
kernel_regularizer=0,
|
| 49 |
+
):
|
| 50 |
+
if downsample_first:
|
| 51 |
+
conv_1 = Conv2D(
|
| 52 |
+
filters,
|
| 53 |
+
kernel_size=3,
|
| 54 |
+
strides=2,
|
| 55 |
+
padding="same",
|
| 56 |
+
name="conv2d_" + str(2 + (number - 1) * 5),
|
| 57 |
+
use_bias=False,
|
| 58 |
+
trainable=resnet_trainable,
|
| 59 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 60 |
+
)(inp)
|
| 61 |
+
else:
|
| 62 |
+
conv_1 = Conv2D(
|
| 63 |
+
filters,
|
| 64 |
+
kernel_size=3,
|
| 65 |
+
strides=1,
|
| 66 |
+
padding="same",
|
| 67 |
+
name="conv2d_" + str(2 + (number - 1) * 5),
|
| 68 |
+
use_bias=False,
|
| 69 |
+
trainable=resnet_trainable,
|
| 70 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 71 |
+
)(inp)
|
| 72 |
+
bn_1 = BatchNormalization(
|
| 73 |
+
axis=3,
|
| 74 |
+
name="batch_normalization_" + str(1 + (number - 1) * 4),
|
| 75 |
+
trainable=resnet_trainable,
|
| 76 |
+
)(conv_1)
|
| 77 |
+
relu_1 = Activation("relu")(bn_1)
|
| 78 |
+
conv_2 = Conv2D(
|
| 79 |
+
filters,
|
| 80 |
+
kernel_size=3,
|
| 81 |
+
strides=1,
|
| 82 |
+
padding="same",
|
| 83 |
+
name="conv2d_" + str(3 + (number - 1) * 5),
|
| 84 |
+
use_bias=False,
|
| 85 |
+
trainable=resnet_trainable,
|
| 86 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 87 |
+
)(relu_1)
|
| 88 |
+
bn_2 = BatchNormalization(
|
| 89 |
+
axis=3,
|
| 90 |
+
name="batch_normalization_" + str(2 + (number - 1) * 4),
|
| 91 |
+
trainable=resnet_trainable,
|
| 92 |
+
)(conv_2)
|
| 93 |
+
if downsample_first:
|
| 94 |
+
shortcut_1 = Conv2D(
|
| 95 |
+
filters,
|
| 96 |
+
kernel_size=1,
|
| 97 |
+
strides=2,
|
| 98 |
+
padding="same",
|
| 99 |
+
name="conv2d_" + str(1 + (number - 1) * 5),
|
| 100 |
+
use_bias=False,
|
| 101 |
+
trainable=resnet_trainable,
|
| 102 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 103 |
+
)(inp)
|
| 104 |
+
# bn_short = BatchNormalization(axis = 3, name = 'batch_normalization_' + str(1+(number-1)*5))(shortcut_1)
|
| 105 |
+
joint = Add()([shortcut_1, bn_2])
|
| 106 |
+
elif first:
|
| 107 |
+
shortcut_1 = Conv2D(
|
| 108 |
+
filters,
|
| 109 |
+
kernel_size=1,
|
| 110 |
+
strides=1,
|
| 111 |
+
padding="same",
|
| 112 |
+
name="conv2d_" + str(1 + (number - 1) * 5),
|
| 113 |
+
use_bias=False,
|
| 114 |
+
trainable=resnet_trainable,
|
| 115 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 116 |
+
)(inp)
|
| 117 |
+
# bn_short = BatchNormalization(axis=3, name = 'batch_normalization_' + str(1+(number-1)*5))(shortcut_1)
|
| 118 |
+
joint = Add()([shortcut_1, bn_2])
|
| 119 |
+
else:
|
| 120 |
+
joint = Add()([inp, bn_2])
|
| 121 |
+
block_1 = Activation("relu")(joint)
|
| 122 |
+
conv_3 = Conv2D(
|
| 123 |
+
filters,
|
| 124 |
+
kernel_size=3,
|
| 125 |
+
strides=1,
|
| 126 |
+
padding="same",
|
| 127 |
+
name="conv2d_" + str(4 + (number - 1) * 5),
|
| 128 |
+
use_bias=False,
|
| 129 |
+
trainable=resnet_trainable,
|
| 130 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 131 |
+
)(block_1)
|
| 132 |
+
bn_3 = BatchNormalization(
|
| 133 |
+
axis=3,
|
| 134 |
+
name="batch_normalization_" + str(3 + (number - 1) * 4),
|
| 135 |
+
trainable=resnet_trainable,
|
| 136 |
+
)(conv_3)
|
| 137 |
+
relu_3 = Activation("relu")(bn_3)
|
| 138 |
+
conv_4 = Conv2D(
|
| 139 |
+
filters,
|
| 140 |
+
kernel_size=3,
|
| 141 |
+
strides=1,
|
| 142 |
+
padding="same",
|
| 143 |
+
name="conv2d_" + str(5 + (number - 1) * 5),
|
| 144 |
+
use_bias=False,
|
| 145 |
+
trainable=resnet_trainable,
|
| 146 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 147 |
+
)(relu_3)
|
| 148 |
+
bn_4 = BatchNormalization(
|
| 149 |
+
axis=3,
|
| 150 |
+
name="batch_normalization_" + str(4 + (number - 1) * 4),
|
| 151 |
+
trainable=resnet_trainable,
|
| 152 |
+
)(conv_4)
|
| 153 |
+
joint_2 = Add()([block_1, bn_4])
|
| 154 |
+
out = Activation("relu")(joint_2)
|
| 155 |
+
return out
|
| 156 |
+
|
| 157 |
+
def pyramid_pooling_block(self, input_tensor, number=0, kernel_regularizer=0):
|
| 158 |
+
concat_list = []
|
| 159 |
+
|
| 160 |
+
w = input_tensor.shape[1]
|
| 161 |
+
h = input_tensor.shape[2]
|
| 162 |
+
|
| 163 |
+
if w == None:
|
| 164 |
+
w = 45
|
| 165 |
+
if h == None:
|
| 166 |
+
h = 45
|
| 167 |
+
|
| 168 |
+
k = 0
|
| 169 |
+
import tensorflow as tf
|
| 170 |
+
|
| 171 |
+
for bin_size in [6, 12]:
|
| 172 |
+
x = AveragePooling2D(
|
| 173 |
+
pool_size=(w // bin_size, h // bin_size),
|
| 174 |
+
strides=(w // bin_size, h // bin_size),
|
| 175 |
+
)(input_tensor)
|
| 176 |
+
x = Conv2D(
|
| 177 |
+
512,
|
| 178 |
+
kernel_size=1,
|
| 179 |
+
strides=1,
|
| 180 |
+
padding="same",
|
| 181 |
+
name="conv2d_" + str(number + k),
|
| 182 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 183 |
+
)(x)
|
| 184 |
+
x = Lambda(lambda x: tf.image.resize(x, (w, h)))(x)
|
| 185 |
+
concat_list.append(x)
|
| 186 |
+
k += 1
|
| 187 |
+
|
| 188 |
+
for bin_size in [18, 24]:
|
| 189 |
+
x = AveragePooling2D(
|
| 190 |
+
pool_size=(w // bin_size, h // bin_size),
|
| 191 |
+
strides=(w // bin_size, h // bin_size),
|
| 192 |
+
)(input_tensor)
|
| 193 |
+
x = Conv2D(
|
| 194 |
+
256,
|
| 195 |
+
kernel_size=1,
|
| 196 |
+
strides=1,
|
| 197 |
+
padding="same",
|
| 198 |
+
name="conv2d_" + str(number + k),
|
| 199 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 200 |
+
)(x)
|
| 201 |
+
x = Lambda(lambda x: tf.image.resize(x, (w, h)))(x)
|
| 202 |
+
concat_list.append(x)
|
| 203 |
+
k += 1
|
| 204 |
+
|
| 205 |
+
ppm = Concatenate()(concat_list)
|
| 206 |
+
conv = Conv2D(
|
| 207 |
+
256,
|
| 208 |
+
kernel_size=1,
|
| 209 |
+
name="conv2d_" + str(number + k),
|
| 210 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 211 |
+
)(ppm)
|
| 212 |
+
out = Activation("relu")(conv)
|
| 213 |
+
|
| 214 |
+
return out
|
| 215 |
+
|
| 216 |
+
def decoder_block(self, inp, filters, number=0, kernel_regularizer=0):
|
| 217 |
+
# filters = inp.shape[3]
|
| 218 |
+
conv_1 = Conv2D(
|
| 219 |
+
filters,
|
| 220 |
+
kernel_size=1,
|
| 221 |
+
name="conv2d_" + str(number),
|
| 222 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 223 |
+
)(inp)
|
| 224 |
+
# conv_1 = SeparableConv2D(filters, kernel_size=1, name='conv2d_' + str(number), kernel_regularizer=l2(kernel_regularizer))(inp)
|
| 225 |
+
deconv = Conv2DTranspose(filters, kernel_size=3, strides=2, padding="same")(
|
| 226 |
+
conv_1
|
| 227 |
+
)
|
| 228 |
+
bn_1 = BatchNormalization(axis=3, name="batch_normalization_" + str(number))(
|
| 229 |
+
deconv
|
| 230 |
+
)
|
| 231 |
+
conv_2 = Conv2D(
|
| 232 |
+
filters // 2,
|
| 233 |
+
kernel_size=1,
|
| 234 |
+
name="conv2d_" + str(number + 1),
|
| 235 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 236 |
+
)(bn_1)
|
| 237 |
+
# conv_2 = SeparableConv2D(filters // 2, kernel_size=1, name='conv2d_' + str(number + 1), kernel_regularizer=l2(kernel_regularizer))(bn_1)
|
| 238 |
+
bn_2 = BatchNormalization(
|
| 239 |
+
axis=3, name="batch_normalization_" + str(number + 1)
|
| 240 |
+
)(conv_2)
|
| 241 |
+
|
| 242 |
+
inp_deconv = Conv2DTranspose(
|
| 243 |
+
filters // 2, kernel_size=3, strides=2, padding="same"
|
| 244 |
+
)(inp)
|
| 245 |
+
inp_bn = BatchNormalization(
|
| 246 |
+
axis=3, name="batch_normalization_" + str(number + 2)
|
| 247 |
+
)(inp_deconv)
|
| 248 |
+
|
| 249 |
+
joint = Add()([inp_bn, bn_2])
|
| 250 |
+
out = Activation("relu")(joint)
|
| 251 |
+
return out
|
| 252 |
+
|
| 253 |
+
def thundernet(
|
| 254 |
+
self, input_shape=(512, 1024, 3), resnet_trainable=False, kernel_regularizer=0
|
| 255 |
+
):
|
| 256 |
+
# This returns a tensor
|
| 257 |
+
inputs = Input(shape=(input_shape))
|
| 258 |
+
|
| 259 |
+
# a layer instance is callable on a tensor, and returns a tensor
|
| 260 |
+
conv_1 = Conv2D(
|
| 261 |
+
64,
|
| 262 |
+
kernel_size=3,
|
| 263 |
+
strides=2,
|
| 264 |
+
padding="same",
|
| 265 |
+
name="conv2d",
|
| 266 |
+
use_bias=False,
|
| 267 |
+
trainable=resnet_trainable,
|
| 268 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 269 |
+
)(inputs)
|
| 270 |
+
bn_1 = BatchNormalization(
|
| 271 |
+
axis=3, name="batch_normalization", trainable=resnet_trainable
|
| 272 |
+
)(conv_1)
|
| 273 |
+
relu_1 = Activation("relu")(bn_1)
|
| 274 |
+
maxp_1 = MaxPooling2D(pool_size=(3, 3), strides=2, padding="same")(relu_1)
|
| 275 |
+
|
| 276 |
+
res1 = self.resnet_layer(
|
| 277 |
+
maxp_1,
|
| 278 |
+
downsample_first=False,
|
| 279 |
+
filters=64,
|
| 280 |
+
first=True,
|
| 281 |
+
number=1,
|
| 282 |
+
resnet_trainable=resnet_trainable,
|
| 283 |
+
kernel_regularizer=kernel_regularizer,
|
| 284 |
+
)
|
| 285 |
+
# res1 = SpatialDropout2D(0.25)(res1)
|
| 286 |
+
res2 = self.resnet_layer(
|
| 287 |
+
res1,
|
| 288 |
+
downsample_first=True,
|
| 289 |
+
filters=128,
|
| 290 |
+
first=False,
|
| 291 |
+
number=2,
|
| 292 |
+
resnet_trainable=resnet_trainable,
|
| 293 |
+
kernel_regularizer=kernel_regularizer,
|
| 294 |
+
)
|
| 295 |
+
# res2 = SpatialDropout2D(0.25)(res2)
|
| 296 |
+
res3 = self.resnet_layer(
|
| 297 |
+
res2,
|
| 298 |
+
downsample_first=True,
|
| 299 |
+
filters=256,
|
| 300 |
+
first=False,
|
| 301 |
+
number=3,
|
| 302 |
+
resnet_trainable=resnet_trainable,
|
| 303 |
+
kernel_regularizer=kernel_regularizer,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
ppm = self.pyramid_pooling_block(
|
| 307 |
+
res3, number=16, kernel_regularizer=kernel_regularizer
|
| 308 |
+
)
|
| 309 |
+
# ppm = Add()([ppm,res3])
|
| 310 |
+
ppm = Concatenate()([ppm, res3])
|
| 311 |
+
|
| 312 |
+
dec_1 = self.decoder_block(
|
| 313 |
+
ppm, 256, number=21, kernel_regularizer=kernel_regularizer
|
| 314 |
+
)
|
| 315 |
+
# dec_1 = Add()([dec_1, res2])
|
| 316 |
+
dec_1 = Concatenate()([dec_1, res2])
|
| 317 |
+
|
| 318 |
+
dec_2 = self.decoder_block(
|
| 319 |
+
dec_1, 128, number=24, kernel_regularizer=kernel_regularizer
|
| 320 |
+
)
|
| 321 |
+
# dec_2 = Add()([dec_2, res1])
|
| 322 |
+
dec_2 = Concatenate()([dec_2, res1])
|
| 323 |
+
|
| 324 |
+
# dec_3 = self.decoder_block(dec_2, 128, number=27)
|
| 325 |
+
|
| 326 |
+
ups = UpSampling2D(size=(4, 4), interpolation="bilinear")(dec_2)
|
| 327 |
+
# ups = UpSampling2D(size=(2, 2), interpolation='bilinear')(dec_3)
|
| 328 |
+
|
| 329 |
+
out = Conv2D(
|
| 330 |
+
filters=int(self.n_classes),
|
| 331 |
+
kernel_size=1,
|
| 332 |
+
activation="softmax",
|
| 333 |
+
name="conv2d_out",
|
| 334 |
+
)(ups)
|
| 335 |
+
|
| 336 |
+
model = Model(inputs=inputs, outputs=out)
|
| 337 |
+
return model
|
| 338 |
+
|
| 339 |
+
def load_resnet_weights(self):
|
| 340 |
+
print("Loading weights for resnet18 backbone")
|
| 341 |
+
checkpoint_path = "./resnet/resnet18/checkpoints/model/model.ckpt-5865"
|
| 342 |
+
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
|
| 343 |
+
var_to_shape_map = reader.get_variable_to_shape_map()
|
| 344 |
+
|
| 345 |
+
# for key in var_to_shape_map:
|
| 346 |
+
# print("tensor_name: ", key)
|
| 347 |
+
# print(reader.get_tensor(key).shape) # Remove this is you want to print only variable names
|
| 348 |
+
|
| 349 |
+
for k in range(0, 16):
|
| 350 |
+
layer_name = "conv2d"
|
| 351 |
+
if k != 0:
|
| 352 |
+
layer_name += "_" + str(k)
|
| 353 |
+
weights_key = layer_name + "/kernel"
|
| 354 |
+
weights = reader.get_tensor(weights_key)
|
| 355 |
+
keras_weights = self.model.get_layer(layer_name).get_weights()
|
| 356 |
+
self.model.get_layer(layer_name).set_weights([weights])
|
| 357 |
+
|
| 358 |
+
layer_name = "batch_normalization"
|
| 359 |
+
if k != 0:
|
| 360 |
+
layer_name += "_" + str(k)
|
| 361 |
+
if k < 13:
|
| 362 |
+
beta_key = layer_name + "/beta"
|
| 363 |
+
beta = reader.get_tensor(beta_key)
|
| 364 |
+
gamma_key = layer_name + "/gamma"
|
| 365 |
+
gamma = reader.get_tensor(gamma_key)
|
| 366 |
+
mean_key = layer_name + "/moving_mean"
|
| 367 |
+
mean = reader.get_tensor(mean_key)
|
| 368 |
+
var_key = layer_name + "/moving_variance"
|
| 369 |
+
var = reader.get_tensor(var_key)
|
| 370 |
+
keras_weights = self.model.get_layer(layer_name).get_weights()
|
| 371 |
+
self.model.get_layer(layer_name).set_weights([gamma, beta, mean, var])
|
| 372 |
+
print("Weights for resnet18 backbone loaded!")
|
model/model_ppm_factors.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tensorflow.keras.layers import (
|
| 2 |
+
Input,
|
| 3 |
+
Lambda,
|
| 4 |
+
Concatenate,
|
| 5 |
+
Conv2D,
|
| 6 |
+
Conv2DTranspose,
|
| 7 |
+
MaxPooling2D,
|
| 8 |
+
BatchNormalization,
|
| 9 |
+
Activation,
|
| 10 |
+
Add,
|
| 11 |
+
AveragePooling2D,
|
| 12 |
+
UpSampling2D,
|
| 13 |
+
SeparableConv2D,
|
| 14 |
+
SpatialDropout2D,
|
| 15 |
+
)
|
| 16 |
+
from tensorflow.keras.models import Model
|
| 17 |
+
from keras import callbacks
|
| 18 |
+
import keras.optimizers
|
| 19 |
+
from tensorflow.keras.regularizers import l2
|
| 20 |
+
import tensorflow as tf
|
| 21 |
+
from tensorflow.python import pywrap_tensorflow
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Thundernet:
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
input_shape=(512, 1024, 3),
|
| 29 |
+
resnet_trainable=False,
|
| 30 |
+
kernel_regularizer=0,
|
| 31 |
+
n_classes=2,
|
| 32 |
+
add_2x1up_layer=False,
|
| 33 |
+
add_2up_layer=False,
|
| 34 |
+
resize_first=False,
|
| 35 |
+
):
|
| 36 |
+
self.input_shape = input_shape
|
| 37 |
+
self.resnet_trainable = resnet_trainable
|
| 38 |
+
self.n_classes = n_classes
|
| 39 |
+
self.model = self.thundernet(
|
| 40 |
+
input_shape,
|
| 41 |
+
resnet_trainable,
|
| 42 |
+
kernel_regularizer,
|
| 43 |
+
add_2x1up_layer,
|
| 44 |
+
add_2up_layer,
|
| 45 |
+
resize_first,
|
| 46 |
+
)
|
| 47 |
+
self.load_resnet_weights()
|
| 48 |
+
self.add_2x1up_layer = add_2x1up_layer
|
| 49 |
+
self.add_2up_layer = add_2up_layer
|
| 50 |
+
self.resize_first = resize_first
|
| 51 |
+
|
| 52 |
+
def resnet_layer(
|
| 53 |
+
self,
|
| 54 |
+
inp,
|
| 55 |
+
downsample_first=True,
|
| 56 |
+
filters=64,
|
| 57 |
+
first=False,
|
| 58 |
+
number=0,
|
| 59 |
+
resnet_trainable=False,
|
| 60 |
+
kernel_regularizer=0,
|
| 61 |
+
):
|
| 62 |
+
if downsample_first:
|
| 63 |
+
conv_1 = Conv2D(
|
| 64 |
+
filters,
|
| 65 |
+
kernel_size=3,
|
| 66 |
+
strides=2,
|
| 67 |
+
padding="same",
|
| 68 |
+
name="conv2d_" + str(2 + (number - 1) * 5),
|
| 69 |
+
use_bias=False,
|
| 70 |
+
trainable=resnet_trainable,
|
| 71 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 72 |
+
)(inp)
|
| 73 |
+
else:
|
| 74 |
+
conv_1 = Conv2D(
|
| 75 |
+
filters,
|
| 76 |
+
kernel_size=3,
|
| 77 |
+
strides=1,
|
| 78 |
+
padding="same",
|
| 79 |
+
name="conv2d_" + str(2 + (number - 1) * 5),
|
| 80 |
+
use_bias=False,
|
| 81 |
+
trainable=resnet_trainable,
|
| 82 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 83 |
+
)(inp)
|
| 84 |
+
bn_1 = BatchNormalization(
|
| 85 |
+
axis=3,
|
| 86 |
+
name="batch_normalization_" + str(1 + (number - 1) * 4),
|
| 87 |
+
trainable=resnet_trainable,
|
| 88 |
+
)(conv_1)
|
| 89 |
+
relu_1 = Activation("relu")(bn_1)
|
| 90 |
+
conv_2 = Conv2D(
|
| 91 |
+
filters,
|
| 92 |
+
kernel_size=3,
|
| 93 |
+
strides=1,
|
| 94 |
+
padding="same",
|
| 95 |
+
name="conv2d_" + str(3 + (number - 1) * 5),
|
| 96 |
+
use_bias=False,
|
| 97 |
+
trainable=resnet_trainable,
|
| 98 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 99 |
+
)(relu_1)
|
| 100 |
+
bn_2 = BatchNormalization(
|
| 101 |
+
axis=3,
|
| 102 |
+
name="batch_normalization_" + str(2 + (number - 1) * 4),
|
| 103 |
+
trainable=resnet_trainable,
|
| 104 |
+
)(conv_2)
|
| 105 |
+
if downsample_first:
|
| 106 |
+
shortcut_1 = Conv2D(
|
| 107 |
+
filters,
|
| 108 |
+
kernel_size=1,
|
| 109 |
+
strides=2,
|
| 110 |
+
padding="same",
|
| 111 |
+
name="conv2d_" + str(1 + (number - 1) * 5),
|
| 112 |
+
use_bias=False,
|
| 113 |
+
trainable=resnet_trainable,
|
| 114 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 115 |
+
)(inp)
|
| 116 |
+
# bn_short = BatchNormalization(axis = 3, name = 'batch_normalization_' + str(1+(number-1)*5))(shortcut_1)
|
| 117 |
+
joint = Add()([shortcut_1, bn_2])
|
| 118 |
+
elif first:
|
| 119 |
+
shortcut_1 = Conv2D(
|
| 120 |
+
filters,
|
| 121 |
+
kernel_size=1,
|
| 122 |
+
strides=1,
|
| 123 |
+
padding="same",
|
| 124 |
+
name="conv2d_" + str(1 + (number - 1) * 5),
|
| 125 |
+
use_bias=False,
|
| 126 |
+
trainable=resnet_trainable,
|
| 127 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 128 |
+
)(inp)
|
| 129 |
+
# bn_short = BatchNormalization(axis=3, name = 'batch_normalization_' + str(1+(number-1)*5))(shortcut_1)
|
| 130 |
+
joint = Add()([shortcut_1, bn_2])
|
| 131 |
+
else:
|
| 132 |
+
joint = Add()([inp, bn_2])
|
| 133 |
+
block_1 = Activation("relu")(joint)
|
| 134 |
+
conv_3 = Conv2D(
|
| 135 |
+
filters,
|
| 136 |
+
kernel_size=3,
|
| 137 |
+
strides=1,
|
| 138 |
+
padding="same",
|
| 139 |
+
name="conv2d_" + str(4 + (number - 1) * 5),
|
| 140 |
+
use_bias=False,
|
| 141 |
+
trainable=resnet_trainable,
|
| 142 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 143 |
+
)(block_1)
|
| 144 |
+
bn_3 = BatchNormalization(
|
| 145 |
+
axis=3,
|
| 146 |
+
name="batch_normalization_" + str(3 + (number - 1) * 4),
|
| 147 |
+
trainable=resnet_trainable,
|
| 148 |
+
)(conv_3)
|
| 149 |
+
relu_3 = Activation("relu")(bn_3)
|
| 150 |
+
conv_4 = Conv2D(
|
| 151 |
+
filters,
|
| 152 |
+
kernel_size=3,
|
| 153 |
+
strides=1,
|
| 154 |
+
padding="same",
|
| 155 |
+
name="conv2d_" + str(5 + (number - 1) * 5),
|
| 156 |
+
use_bias=False,
|
| 157 |
+
trainable=resnet_trainable,
|
| 158 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 159 |
+
)(relu_3)
|
| 160 |
+
bn_4 = BatchNormalization(
|
| 161 |
+
axis=3,
|
| 162 |
+
name="batch_normalization_" + str(4 + (number - 1) * 4),
|
| 163 |
+
trainable=resnet_trainable,
|
| 164 |
+
)(conv_4)
|
| 165 |
+
joint_2 = Add()([block_1, bn_4])
|
| 166 |
+
out = Activation("relu")(joint_2)
|
| 167 |
+
return out
|
| 168 |
+
|
| 169 |
+
def pyramid_pooling_block(self, input_tensor, number=0, kernel_regularizer=0):
|
| 170 |
+
|
| 171 |
+
concat_list = []
|
| 172 |
+
|
| 173 |
+
# w = input_tensor.shape[1].value
|
| 174 |
+
# h = input_tensor.shape[2].value
|
| 175 |
+
|
| 176 |
+
w = input_tensor.shape[1]
|
| 177 |
+
h = input_tensor.shape[2]
|
| 178 |
+
|
| 179 |
+
if w == None:
|
| 180 |
+
w = 45
|
| 181 |
+
if h == None:
|
| 182 |
+
h = 45
|
| 183 |
+
|
| 184 |
+
k = 0
|
| 185 |
+
for bin_size in [1, 3, 6]:
|
| 186 |
+
x = AveragePooling2D(
|
| 187 |
+
pool_size=(w // bin_size, h // bin_size),
|
| 188 |
+
strides=(w // bin_size, h // bin_size),
|
| 189 |
+
)(input_tensor)
|
| 190 |
+
x = Conv2D(
|
| 191 |
+
512,
|
| 192 |
+
kernel_size=1,
|
| 193 |
+
strides=1,
|
| 194 |
+
padding="same",
|
| 195 |
+
name="conv2d_" + str(number + k),
|
| 196 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 197 |
+
)(x)
|
| 198 |
+
x = Lambda(lambda x: tf.image.resize(x, (w, h)))(x)
|
| 199 |
+
concat_list.append(x)
|
| 200 |
+
k += 1
|
| 201 |
+
|
| 202 |
+
for bin_size in [12, 18, 24]:
|
| 203 |
+
x = AveragePooling2D(
|
| 204 |
+
pool_size=(w // bin_size, h // bin_size),
|
| 205 |
+
strides=(w // bin_size, h // bin_size),
|
| 206 |
+
)(input_tensor)
|
| 207 |
+
x = Conv2D(
|
| 208 |
+
256,
|
| 209 |
+
kernel_size=1,
|
| 210 |
+
strides=1,
|
| 211 |
+
padding="same",
|
| 212 |
+
name="conv2d_" + str(number + k),
|
| 213 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 214 |
+
)(x)
|
| 215 |
+
x = Lambda(lambda x: tf.image.resize(x, (w, h)))(x)
|
| 216 |
+
concat_list.append(x)
|
| 217 |
+
k += 1
|
| 218 |
+
|
| 219 |
+
ppm = Concatenate()(concat_list)
|
| 220 |
+
conv = Conv2D(
|
| 221 |
+
256,
|
| 222 |
+
kernel_size=1,
|
| 223 |
+
name="conv2d_" + str(number + k),
|
| 224 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 225 |
+
)(ppm)
|
| 226 |
+
out = Activation("relu")(conv)
|
| 227 |
+
|
| 228 |
+
return out
|
| 229 |
+
|
| 230 |
+
def decoder_block(self, inp, filters, number=0, kernel_regularizer=0):
|
| 231 |
+
# filters = inp.shape[3]
|
| 232 |
+
conv_1 = Conv2D(
|
| 233 |
+
filters,
|
| 234 |
+
kernel_size=1,
|
| 235 |
+
name="conv2d_" + str(number),
|
| 236 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 237 |
+
)(inp)
|
| 238 |
+
# conv_1 = SeparableConv2D(filters, kernel_size=1, name='conv2d_' + str(number), kernel_regularizer=l2(kernel_regularizer))(inp)
|
| 239 |
+
deconv = Conv2DTranspose(filters, kernel_size=3, strides=2, padding="same")(
|
| 240 |
+
conv_1
|
| 241 |
+
)
|
| 242 |
+
bn_1 = BatchNormalization(axis=3, name="batch_normalization_" + str(number))(
|
| 243 |
+
deconv
|
| 244 |
+
)
|
| 245 |
+
conv_2 = Conv2D(
|
| 246 |
+
filters // 2,
|
| 247 |
+
kernel_size=1,
|
| 248 |
+
name="conv2d_" + str(number + 1),
|
| 249 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 250 |
+
)(bn_1)
|
| 251 |
+
# conv_2 = SeparableConv2D(filters // 2, kernel_size=1, name='conv2d_' + str(number + 1), kernel_regularizer=l2(kernel_regularizer))(bn_1)
|
| 252 |
+
bn_2 = BatchNormalization(
|
| 253 |
+
axis=3, name="batch_normalization_" + str(number + 1)
|
| 254 |
+
)(conv_2)
|
| 255 |
+
|
| 256 |
+
inp_deconv = Conv2DTranspose(
|
| 257 |
+
filters // 2, kernel_size=3, strides=2, padding="same"
|
| 258 |
+
)(inp)
|
| 259 |
+
inp_bn = BatchNormalization(
|
| 260 |
+
axis=3, name="batch_normalization_" + str(number + 2)
|
| 261 |
+
)(inp_deconv)
|
| 262 |
+
|
| 263 |
+
joint = Add()([inp_bn, bn_2])
|
| 264 |
+
out = Activation("relu")(joint)
|
| 265 |
+
return out
|
| 266 |
+
|
| 267 |
+
def thundernet(
|
| 268 |
+
self,
|
| 269 |
+
input_shape=(512, 1024, 3),
|
| 270 |
+
resnet_trainable=False,
|
| 271 |
+
kernel_regularizer=0,
|
| 272 |
+
add_2x1up_layer=False,
|
| 273 |
+
add_2up_layer=False,
|
| 274 |
+
resize_first=False,
|
| 275 |
+
):
|
| 276 |
+
|
| 277 |
+
# This returns a tensor
|
| 278 |
+
inputs = Input(shape=(input_shape))
|
| 279 |
+
|
| 280 |
+
if resize_first:
|
| 281 |
+
|
| 282 |
+
# Lambda are needed so that you can have
|
| 283 |
+
# aux = Lambda(lambda x: tf.image.resize_images(x, (480, 640)))(inputs)
|
| 284 |
+
aux = Lambda(
|
| 285 |
+
lambda x: tf.image.resize(
|
| 286 |
+
x, (inputs.shape[0] // 2, inputs.shape[1] // 2)
|
| 287 |
+
)
|
| 288 |
+
)(inputs)
|
| 289 |
+
|
| 290 |
+
else:
|
| 291 |
+
|
| 292 |
+
aux = inputs
|
| 293 |
+
|
| 294 |
+
# a layer instance is callable on a tensor, and returns a tensor
|
| 295 |
+
conv_1 = Conv2D(
|
| 296 |
+
64,
|
| 297 |
+
kernel_size=3,
|
| 298 |
+
strides=2,
|
| 299 |
+
padding="same",
|
| 300 |
+
name="conv2d",
|
| 301 |
+
use_bias=False,
|
| 302 |
+
trainable=resnet_trainable,
|
| 303 |
+
kernel_regularizer=l2(kernel_regularizer),
|
| 304 |
+
)(aux)
|
| 305 |
+
bn_1 = BatchNormalization(
|
| 306 |
+
axis=3, name="batch_normalization", trainable=resnet_trainable
|
| 307 |
+
)(conv_1)
|
| 308 |
+
relu_1 = Activation("relu")(bn_1)
|
| 309 |
+
maxp_1 = MaxPooling2D(pool_size=(3, 3), strides=2, padding="same")(relu_1)
|
| 310 |
+
|
| 311 |
+
res1 = self.resnet_layer(
|
| 312 |
+
maxp_1,
|
| 313 |
+
downsample_first=False,
|
| 314 |
+
filters=64,
|
| 315 |
+
first=True,
|
| 316 |
+
number=1,
|
| 317 |
+
resnet_trainable=resnet_trainable,
|
| 318 |
+
kernel_regularizer=kernel_regularizer,
|
| 319 |
+
)
|
| 320 |
+
# res1 = SpatialDropout2D(0.25)(res1)
|
| 321 |
+
res2 = self.resnet_layer(
|
| 322 |
+
res1,
|
| 323 |
+
downsample_first=True,
|
| 324 |
+
filters=128,
|
| 325 |
+
first=False,
|
| 326 |
+
number=2,
|
| 327 |
+
resnet_trainable=resnet_trainable,
|
| 328 |
+
kernel_regularizer=kernel_regularizer,
|
| 329 |
+
)
|
| 330 |
+
# res2 = SpatialDropout2D(0.25)(res2)
|
| 331 |
+
res3 = self.resnet_layer(
|
| 332 |
+
res2,
|
| 333 |
+
downsample_first=True,
|
| 334 |
+
filters=256,
|
| 335 |
+
first=False,
|
| 336 |
+
number=3,
|
| 337 |
+
resnet_trainable=resnet_trainable,
|
| 338 |
+
kernel_regularizer=kernel_regularizer,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
ppm = self.pyramid_pooling_block(
|
| 342 |
+
res3, number=16, kernel_regularizer=kernel_regularizer
|
| 343 |
+
)
|
| 344 |
+
# ppm = Add()([ppm,res3])
|
| 345 |
+
ppm = Concatenate()([ppm, res3])
|
| 346 |
+
0
|
| 347 |
+
|
| 348 |
+
dec_1 = self.decoder_block(
|
| 349 |
+
ppm, 256, number=30, kernel_regularizer=kernel_regularizer
|
| 350 |
+
)
|
| 351 |
+
# dec_1 = Add()([dec_1, res2])
|
| 352 |
+
dec_1 = Concatenate()([dec_1, res2])
|
| 353 |
+
|
| 354 |
+
dec_2 = self.decoder_block(
|
| 355 |
+
dec_1, 128, number=33, kernel_regularizer=kernel_regularizer
|
| 356 |
+
)
|
| 357 |
+
# dec_2 = Add()([dec_2, res1])
|
| 358 |
+
dec_2 = Concatenate()([dec_2, res1])
|
| 359 |
+
|
| 360 |
+
# dec_3 = self.decoder_block(dec_2, 128, number=27)
|
| 361 |
+
|
| 362 |
+
if add_2x1up_layer:
|
| 363 |
+
|
| 364 |
+
if add_2up_layer:
|
| 365 |
+
|
| 366 |
+
dec_3 = UpSampling2D(size=(2, 2), interpolation="bilinear")(dec_2)
|
| 367 |
+
ups = UpSampling2D(size=(2, 2), interpolation="bilinear")(dec_3)
|
| 368 |
+
|
| 369 |
+
else:
|
| 370 |
+
|
| 371 |
+
ups = UpSampling2D(size=(4, 4), interpolation="bilinear")(dec_2)
|
| 372 |
+
|
| 373 |
+
print("adding the new upsampling")
|
| 374 |
+
ups_2 = UpSampling2D(size=(1, 2), interpolation="bilinear")(ups)
|
| 375 |
+
|
| 376 |
+
else:
|
| 377 |
+
|
| 378 |
+
if add_2up_layer:
|
| 379 |
+
|
| 380 |
+
dec_3 = UpSampling2D(size=(2, 2), interpolation="bilinear")(dec_2)
|
| 381 |
+
ups_2 = UpSampling2D(size=(2, 2), interpolation="bilinear")(dec_3)
|
| 382 |
+
|
| 383 |
+
else:
|
| 384 |
+
|
| 385 |
+
ups_2 = UpSampling2D(size=(4, 4), interpolation="bilinear")(dec_2)
|
| 386 |
+
|
| 387 |
+
out = Conv2D(
|
| 388 |
+
filters=int(self.n_classes),
|
| 389 |
+
kernel_size=1,
|
| 390 |
+
activation="softmax",
|
| 391 |
+
name="conv2d_out",
|
| 392 |
+
)(ups_2)
|
| 393 |
+
|
| 394 |
+
model = Model(inputs=inputs, outputs=out)
|
| 395 |
+
|
| 396 |
+
return model
|
| 397 |
+
|
| 398 |
+
def load_resnet_weights(self):
|
| 399 |
+
|
| 400 |
+
print("Loading weights for resnet18 backbone")
|
| 401 |
+
checkpoint_path = "./resnet/resnet18/checkpoints/model/model.ckpt-5865"
|
| 402 |
+
# reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
|
| 403 |
+
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path) # for tf 2.0
|
| 404 |
+
|
| 405 |
+
var_to_shape_map = reader.get_variable_to_shape_map()
|
| 406 |
+
|
| 407 |
+
# for key in var_to_shape_map:
|
| 408 |
+
# print("tensor_name: ", key)
|
| 409 |
+
# print(reader.get_tensor(key).shape) # Remove this is you want to print only variable names
|
| 410 |
+
|
| 411 |
+
for k in range(0, 16):
|
| 412 |
+
layer_name = "conv2d"
|
| 413 |
+
if k != 0:
|
| 414 |
+
layer_name += "_" + str(k)
|
| 415 |
+
weights_key = layer_name + "/kernel"
|
| 416 |
+
weights = reader.get_tensor(weights_key)
|
| 417 |
+
# print(weights.shape)
|
| 418 |
+
keras_weights = self.model.get_layer(layer_name).get_weights()
|
| 419 |
+
# print(keras_weights[0].shape)
|
| 420 |
+
self.model.get_layer(layer_name).set_weights([weights])
|
| 421 |
+
|
| 422 |
+
layer_name = "batch_normalization"
|
| 423 |
+
if k != 0:
|
| 424 |
+
layer_name += "_" + str(k)
|
| 425 |
+
if k < 13:
|
| 426 |
+
beta_key = layer_name + "/beta"
|
| 427 |
+
beta = reader.get_tensor(beta_key)
|
| 428 |
+
gamma_key = layer_name + "/gamma"
|
| 429 |
+
gamma = reader.get_tensor(gamma_key)
|
| 430 |
+
mean_key = layer_name + "/moving_mean"
|
| 431 |
+
mean = reader.get_tensor(mean_key)
|
| 432 |
+
var_key = layer_name + "/moving_variance"
|
| 433 |
+
var = reader.get_tensor(var_key)
|
| 434 |
+
keras_weights = self.model.get_layer(layer_name).get_weights()
|
| 435 |
+
# print(len(keras_weights))
|
| 436 |
+
# print(keras_weights[0].shape)
|
| 437 |
+
self.model.get_layer(layer_name).set_weights([gamma, beta, mean, var])
|
| 438 |
+
print("Weights for resnet18 backbone loaded!")
|
profiler.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
warnings.simplefilter("ignore", FutureWarning)
|
| 4 |
+
warnings.simplefilter("ignore", UserWarning)
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.profiler import profile, ProfilerActivity
|
| 8 |
+
from model.model import Thundernet
|
| 9 |
+
from models_repo.model_attention import Thundernet as Thundernet_attention
|
| 10 |
+
from models_repo.model_attention_2 import Thundernet as Thundernet_attention2
|
| 11 |
+
from models_repo.model_ppm_factors import Thundernet as Thundernet_ppm
|
| 12 |
+
|
| 13 |
+
import time
|
| 14 |
+
import cv2
|
| 15 |
+
import numpy as np
|
| 16 |
+
import tensorflow as tf
|
| 17 |
+
|
| 18 |
+
device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
# Define input shape
|
| 20 |
+
input_shape = (480, 640, 3)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def execute_profiler(model: Thundernet) -> None:
|
| 24 |
+
"""
|
| 25 |
+
Function to measure de CPU and CUDA times.
|
| 26 |
+
It prints the results on the console
|
| 27 |
+
Args:
|
| 28 |
+
- model: loaded model to profile
|
| 29 |
+
Returns:
|
| 30 |
+
- None
|
| 31 |
+
"""
|
| 32 |
+
image = torch.randn(1, 480, 640, 3).cpu().numpy()
|
| 33 |
+
|
| 34 |
+
# model = torch.jit.trace(model, (image, depth))
|
| 35 |
+
with profile(
|
| 36 |
+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
| 37 |
+
record_shapes=True,
|
| 38 |
+
profile_memory=True,
|
| 39 |
+
) as prof:
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
_ = model.predict(image)
|
| 42 |
+
|
| 43 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
| 44 |
+
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def measure_inference_time(model: Thundernet) -> None:
|
| 49 |
+
"""
|
| 50 |
+
Function to measure the average inference time
|
| 51 |
+
and the FPS of a given AsymFormer model.
|
| 52 |
+
It prints the results on the console
|
| 53 |
+
Args:
|
| 54 |
+
- model: loaded model to profile
|
| 55 |
+
Returns:
|
| 56 |
+
- None
|
| 57 |
+
"""
|
| 58 |
+
image = torch.randn(1, 480, 640, 3).cpu().numpy()
|
| 59 |
+
|
| 60 |
+
for _ in range(5):
|
| 61 |
+
_ = model.predict(image)
|
| 62 |
+
|
| 63 |
+
times = []
|
| 64 |
+
for _ in range(20):
|
| 65 |
+
tf.constant(0).numpy()
|
| 66 |
+
start = time.time()
|
| 67 |
+
_ = model.predict(image)
|
| 68 |
+
tf.constant(0).numpy()
|
| 69 |
+
times.append((time.time() - start) * 1000)
|
| 70 |
+
|
| 71 |
+
avg_time = sum(times) / len(times)
|
| 72 |
+
print(f"Average inference time: {avg_time:.2f} ms")
|
| 73 |
+
|
| 74 |
+
fps = 1000 / avg_time
|
| 75 |
+
print(f"FPS: {fps:.2f}")
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main() -> None:
|
| 80 |
+
|
| 81 |
+
# PATH TO THE BEST MODEL SO FAR (.hdf5)
|
| 82 |
+
# weights_path = "D:/RealTimeSemanticSegmentation_Sofia/keras.hdf5"
|
| 83 |
+
# weights_path = "C:/Users/user/Documents/Thundernet/pruebas_modelos/32_ppm/BS4_lossBCE_weights_lr_0.00013713842558297858_reg-1.1743577101671763e-05-ep-13-val_loss0.11463435739278793-train_loss0.053004469722509384-val_iou0.8959722518920898-train_iou0.9606077075004578.hdf5"
|
| 84 |
+
weights_path = "keras.hdf5"
|
| 85 |
+
# Load the model. Change it depending on where it was trained
|
| 86 |
+
# ThunderNet = Thundernet_ppm(input_shape=input_shape, resnet_trainable=False, n_classes = 2)
|
| 87 |
+
ThunderNet = Thundernet(
|
| 88 |
+
input_shape=input_shape, resnet_trainable=False, n_classes=2
|
| 89 |
+
)
|
| 90 |
+
model = ThunderNet.model
|
| 91 |
+
ThunderNet.model.load_weights(weights_path)
|
| 92 |
+
execute_profiler(model)
|
| 93 |
+
measure_inference_time(model)
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==0.15.0
|
| 2 |
+
asttokens==2.1.0
|
| 3 |
+
astunparse==1.6.3
|
| 4 |
+
backcall==0.2.0
|
| 5 |
+
cachetools==5.0.0
|
| 6 |
+
certifi==2021.10.8
|
| 7 |
+
charset-normalizer==2.0.12
|
| 8 |
+
clang==5.0
|
| 9 |
+
colorama==0.4.6
|
| 10 |
+
contourpy==1.0.5
|
| 11 |
+
cycler==0.11.0
|
| 12 |
+
decorator==5.1.1
|
| 13 |
+
efficientnet==1.0.0
|
| 14 |
+
executing==1.2.0
|
| 15 |
+
flatbuffers==1.12
|
| 16 |
+
fonttools==4.37.4
|
| 17 |
+
gast==0.4.0
|
| 18 |
+
google-auth==2.6.5
|
| 19 |
+
google-auth-oauthlib==0.4.6
|
| 20 |
+
google-pasta==0.2.0
|
| 21 |
+
grpcio==1.44.0
|
| 22 |
+
h5py==3.1.0
|
| 23 |
+
idna==3.3
|
| 24 |
+
image-classifiers==1.0.0
|
| 25 |
+
imageio==2.22.4
|
| 26 |
+
importlib-metadata==4.11.3
|
| 27 |
+
install==1.3.5
|
| 28 |
+
ipython==8.6.0
|
| 29 |
+
jedi==0.18.1
|
| 30 |
+
keras==2.6.0
|
| 31 |
+
Keras-Applications==1.0.8
|
| 32 |
+
Keras-Preprocessing==1.1.2
|
| 33 |
+
kiwisolver==1.4.4
|
| 34 |
+
Markdown==3.3.6
|
| 35 |
+
matplotlib==3.6.1
|
| 36 |
+
matplotlib-inline==0.1.6
|
| 37 |
+
mtcnn==0.1.1
|
| 38 |
+
networkx==2.8.8
|
| 39 |
+
numpy==1.19.5
|
| 40 |
+
nvidia-cublas-cu11==11.10.3.66
|
| 41 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
| 42 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
| 43 |
+
nvidia-cudnn-cu11==8.5.0.96
|
| 44 |
+
nvidia-pyindex==1.0.9
|
| 45 |
+
oauthlib==3.2.0
|
| 46 |
+
opencv-python==4.3.0.38
|
| 47 |
+
opt-einsum==3.3.0
|
| 48 |
+
packaging==21.3
|
| 49 |
+
pandas==1.4.0
|
| 50 |
+
parso==0.8.3
|
| 51 |
+
pexpect==4.8.0
|
| 52 |
+
pickleshare==0.7.5
|
| 53 |
+
Pillow==9.2.0
|
| 54 |
+
pkg_resources==0.0.0
|
| 55 |
+
prompt-toolkit==3.0.32
|
| 56 |
+
protobuf==3.20.0
|
| 57 |
+
ptyprocess==0.7.0
|
| 58 |
+
pure-eval==0.2.2
|
| 59 |
+
pyasn1==0.4.8
|
| 60 |
+
pyasn1-modules==0.2.8
|
| 61 |
+
Pygments==2.13.0
|
| 62 |
+
pyparsing==3.0.9
|
| 63 |
+
python-dateutil==2.8.2
|
| 64 |
+
pytz==2022.6
|
| 65 |
+
PyWavelets==1.4.1
|
| 66 |
+
PyYAML==6.0
|
| 67 |
+
pyzmq==22.3.0
|
| 68 |
+
requests==2.27.1
|
| 69 |
+
requests-oauthlib==1.3.1
|
| 70 |
+
rsa==4.8
|
| 71 |
+
scikit-image==0.19.3
|
| 72 |
+
scipy==1.8.0
|
| 73 |
+
segmentation-models==1.0.1
|
| 74 |
+
six==1.15.0
|
| 75 |
+
stack-data==0.6.1
|
| 76 |
+
tensorboard==2.8.0
|
| 77 |
+
tensorboard-data-server==0.6.1
|
| 78 |
+
tensorboard-plugin-wit==1.8.1
|
| 79 |
+
tensorflow==2.6.0
|
| 80 |
+
tensorflow-estimator==2.8.0
|
| 81 |
+
tensorflow-gpu==2.6.0
|
| 82 |
+
termcolor==1.1.0
|
| 83 |
+
tifffile==2022.10.10
|
| 84 |
+
torch==1.13.1
|
| 85 |
+
tqdm==4.64.1
|
| 86 |
+
traitlets==5.5.0
|
| 87 |
+
typing-extensions==3.7.4.3
|
| 88 |
+
urllib3==1.26.9
|
| 89 |
+
wcwidth==0.2.5
|
| 90 |
+
Werkzeug==2.1.1
|
| 91 |
+
wrapt==1.12.1
|
| 92 |
+
yacs==0.1.8
|
| 93 |
+
zipp==3.8.0
|
resnet/.gitignore
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# customer
|
| 2 |
+
docker.sh
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Created by https://www.gitignore.io/api/macos,python,pycharm,visualstudiocode
|
| 6 |
+
|
| 7 |
+
### macOS ###
|
| 8 |
+
*.DS_Store
|
| 9 |
+
.AppleDouble
|
| 10 |
+
.LSOverride
|
| 11 |
+
|
| 12 |
+
# Icon must end with two \r
|
| 13 |
+
Icon
|
| 14 |
+
|
| 15 |
+
# Thumbnails
|
| 16 |
+
._*
|
| 17 |
+
|
| 18 |
+
# Files that might appear in the root of a volume
|
| 19 |
+
.DocumentRevisions-V100
|
| 20 |
+
.fseventsd
|
| 21 |
+
.Spotlight-V100
|
| 22 |
+
.TemporaryItems
|
| 23 |
+
.Trashes
|
| 24 |
+
.VolumeIcon.icns
|
| 25 |
+
.com.apple.timemachine.donotpresent
|
| 26 |
+
|
| 27 |
+
# Directories potentially created on remote AFP share
|
| 28 |
+
.AppleDB
|
| 29 |
+
.AppleDesktop
|
| 30 |
+
Network Trash Folder
|
| 31 |
+
Temporary Items
|
| 32 |
+
.apdisk
|
| 33 |
+
|
| 34 |
+
### PyCharm ###
|
| 35 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
|
| 36 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
| 37 |
+
|
| 38 |
+
# User-specific stuff:
|
| 39 |
+
.idea/**/workspace.xml
|
| 40 |
+
.idea/**/tasks.xml
|
| 41 |
+
.idea/dictionaries
|
| 42 |
+
|
| 43 |
+
# Sensitive or high-churn files:
|
| 44 |
+
.idea/**/dataSources/
|
| 45 |
+
.idea/**/dataSources.ids
|
| 46 |
+
.idea/**/dataSources.xml
|
| 47 |
+
.idea/**/dataSources.local.xml
|
| 48 |
+
.idea/**/sqlDataSources.xml
|
| 49 |
+
.idea/**/dynamic.xml
|
| 50 |
+
.idea/**/uiDesigner.xml
|
| 51 |
+
|
| 52 |
+
# Gradle:
|
| 53 |
+
.idea/**/gradle.xml
|
| 54 |
+
.idea/**/libraries
|
| 55 |
+
|
| 56 |
+
# CMake
|
| 57 |
+
cmake-build-debug/
|
| 58 |
+
|
| 59 |
+
# Mongo Explorer plugin:
|
| 60 |
+
.idea/**/mongoSettings.xml
|
| 61 |
+
|
| 62 |
+
## File-based project format:
|
| 63 |
+
*.iws
|
| 64 |
+
|
| 65 |
+
## Plugin-specific files:
|
| 66 |
+
|
| 67 |
+
# IntelliJ
|
| 68 |
+
/out/
|
| 69 |
+
|
| 70 |
+
# mpeltonen/sbt-idea plugin
|
| 71 |
+
.idea_modules/
|
| 72 |
+
|
| 73 |
+
# JIRA plugin
|
| 74 |
+
atlassian-ide-plugin.xml
|
| 75 |
+
|
| 76 |
+
# Cursive Clojure plugin
|
| 77 |
+
.idea/replstate.xml
|
| 78 |
+
|
| 79 |
+
# Ruby plugin and RubyMine
|
| 80 |
+
/.rakeTasks
|
| 81 |
+
|
| 82 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
| 83 |
+
com_crashlytics_export_strings.xml
|
| 84 |
+
crashlytics.properties
|
| 85 |
+
crashlytics-build.properties
|
| 86 |
+
fabric.properties
|
| 87 |
+
|
| 88 |
+
### PyCharm Patch ###
|
| 89 |
+
# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
|
| 90 |
+
|
| 91 |
+
# *.iml
|
| 92 |
+
# modules.xml
|
| 93 |
+
# .idea/misc.xml
|
| 94 |
+
# *.ipr
|
| 95 |
+
|
| 96 |
+
# Sonarlint plugin
|
| 97 |
+
.idea/sonarlint
|
| 98 |
+
|
| 99 |
+
### Python ###
|
| 100 |
+
# Byte-compiled / optimized / DLL files
|
| 101 |
+
__pycache__/
|
| 102 |
+
*.py[cod]
|
| 103 |
+
*$py.class
|
| 104 |
+
|
| 105 |
+
# C extensions
|
| 106 |
+
*.so
|
| 107 |
+
|
| 108 |
+
# Distribution / packaging
|
| 109 |
+
.Python
|
| 110 |
+
build/
|
| 111 |
+
develop-eggs/
|
| 112 |
+
dist/
|
| 113 |
+
downloads/
|
| 114 |
+
eggs/
|
| 115 |
+
.eggs/
|
| 116 |
+
lib/
|
| 117 |
+
lib64/
|
| 118 |
+
parts/
|
| 119 |
+
sdist/
|
| 120 |
+
var/
|
| 121 |
+
wheels/
|
| 122 |
+
*.egg-info/
|
| 123 |
+
.installed.cfg
|
| 124 |
+
*.egg
|
| 125 |
+
|
| 126 |
+
# PyInstaller
|
| 127 |
+
# Usually these files are written by a python script from a template
|
| 128 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 129 |
+
*.manifest
|
| 130 |
+
*.spec
|
| 131 |
+
|
| 132 |
+
# Installer logs
|
| 133 |
+
pip-log.txt
|
| 134 |
+
pip-delete-this-directory.txt
|
| 135 |
+
|
| 136 |
+
# Unit test / coverage reports
|
| 137 |
+
htmlcov/
|
| 138 |
+
.tox/
|
| 139 |
+
.coverage
|
| 140 |
+
.coverage.*
|
| 141 |
+
.cache
|
| 142 |
+
.pytest_cache/
|
| 143 |
+
nosetests.xml
|
| 144 |
+
coverage.xml
|
| 145 |
+
*.cover
|
| 146 |
+
.hypothesis/
|
| 147 |
+
|
| 148 |
+
# Translations
|
| 149 |
+
*.mo
|
| 150 |
+
*.pot
|
| 151 |
+
|
| 152 |
+
# Flask stuff:
|
| 153 |
+
instance/
|
| 154 |
+
.webassets-cache
|
| 155 |
+
|
| 156 |
+
# Scrapy stuff:
|
| 157 |
+
.scrapy
|
| 158 |
+
|
| 159 |
+
# Sphinx documentation
|
| 160 |
+
docs/_build/
|
| 161 |
+
|
| 162 |
+
# PyBuilder
|
| 163 |
+
target/
|
| 164 |
+
|
| 165 |
+
# Jupyter Notebook
|
| 166 |
+
.ipynb_checkpoints
|
| 167 |
+
|
| 168 |
+
# pyenv
|
| 169 |
+
.python-version
|
| 170 |
+
|
| 171 |
+
# celery beat schedule file
|
| 172 |
+
celerybeat-schedule.*
|
| 173 |
+
|
| 174 |
+
# SageMath parsed files
|
| 175 |
+
*.sage.py
|
| 176 |
+
|
| 177 |
+
# Environments
|
| 178 |
+
.env
|
| 179 |
+
.venv
|
| 180 |
+
env/
|
| 181 |
+
venv/
|
| 182 |
+
ENV/
|
| 183 |
+
env.bak/
|
| 184 |
+
venv.bak/
|
| 185 |
+
|
| 186 |
+
# Spyder project settings
|
| 187 |
+
.spyderproject
|
| 188 |
+
.spyproject
|
| 189 |
+
|
| 190 |
+
# Rope project settings
|
| 191 |
+
.ropeproject
|
| 192 |
+
|
| 193 |
+
# mkdocs documentation
|
| 194 |
+
/site
|
| 195 |
+
|
| 196 |
+
# mypy
|
| 197 |
+
.mypy_cache/
|
| 198 |
+
|
| 199 |
+
### VisualStudioCode ###
|
| 200 |
+
.vscode/*
|
| 201 |
+
!.vscode/settings.json
|
| 202 |
+
!.vscode/tasks.json
|
| 203 |
+
!.vscode/launch.json
|
| 204 |
+
!.vscode/extensions.json
|
| 205 |
+
.history
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# End of https://www.gitignore.io/api/macos,python,pycharm,visualstudiocode
|
resnet/.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
resnet/.idea/misc.xml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (deeplearning)" project-jdk-type="Python SDK" />
|
| 4 |
+
</project>
|
resnet/.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/resnet.iml" filepath="$PROJECT_DIR$/.idea/resnet.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
resnet/.idea/resnet.iml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="inheritedJdk" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
<component name="PyDocumentationSettings">
|
| 9 |
+
<option name="format" value="PLAIN" />
|
| 10 |
+
<option name="myDocStringFormat" value="Plain" />
|
| 11 |
+
</component>
|
| 12 |
+
</module>
|
resnet/apt.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
git
|
resnet/crowdai.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"challenge_id" : "nips-2018-avc-robust-model-track",
|
| 3 |
+
"grader_id": "nips-2018-avc-robust-model-track",
|
| 4 |
+
"authors" : ["bveliqi"],
|
| 5 |
+
"description" : "resnet-18 baseline model",
|
| 6 |
+
"gpu": true
|
| 7 |
+
}
|
resnet/fmodel.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import os
|
| 3 |
+
from foolbox.models import TensorFlowModel
|
| 4 |
+
|
| 5 |
+
from resnet18.resnet_model import Model
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_model():
|
| 9 |
+
graph = tf.Graph()
|
| 10 |
+
with graph.as_default():
|
| 11 |
+
images = tf.placeholder(tf.float32, (None, 64, 64, 3))
|
| 12 |
+
|
| 13 |
+
# preprocessing
|
| 14 |
+
_R_MEAN = 123.68
|
| 15 |
+
_G_MEAN = 116.78
|
| 16 |
+
_B_MEAN = 103.94
|
| 17 |
+
_CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN]
|
| 18 |
+
features = images - tf.constant(_CHANNEL_MEANS)
|
| 19 |
+
|
| 20 |
+
model = Model(
|
| 21 |
+
resnet_size=18,
|
| 22 |
+
bottleneck=False,
|
| 23 |
+
num_classes=200,
|
| 24 |
+
num_filters=64,
|
| 25 |
+
kernel_size=3,
|
| 26 |
+
conv_stride=1,
|
| 27 |
+
first_pool_size=0,
|
| 28 |
+
first_pool_stride=2,
|
| 29 |
+
second_pool_size=7,
|
| 30 |
+
second_pool_stride=1,
|
| 31 |
+
block_sizes=[2, 2, 2, 2],
|
| 32 |
+
block_strides=[1, 2, 2, 2],
|
| 33 |
+
final_size=512,
|
| 34 |
+
version=2,
|
| 35 |
+
data_format=None,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
logits = model(features, False)
|
| 39 |
+
|
| 40 |
+
with tf.variable_scope("utilities"):
|
| 41 |
+
saver = tf.train.Saver()
|
| 42 |
+
|
| 43 |
+
return graph, saver, images, logits
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def create_fmodel():
|
| 47 |
+
graph, saver, images, logits = create_model()
|
| 48 |
+
sess = tf.Session(graph=graph)
|
| 49 |
+
path = os.path.dirname(os.path.abspath(__file__))
|
| 50 |
+
path = os.path.join(path, "resnet18", "checkpoints", "models_repo")
|
| 51 |
+
saver.restore(sess, tf.train.latest_checkpoint(path))
|
| 52 |
+
|
| 53 |
+
with sess.as_default():
|
| 54 |
+
fmodel = TensorFlowModel(images, logits, bounds=(0, 255))
|
| 55 |
+
return fmodel
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
# executable for debuggin and testing
|
| 60 |
+
print(create_fmodel())
|
resnet/main.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fmodel import create_fmodel
|
| 2 |
+
from adversarial_vision_challenge import model_server
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
fmodel = create_fmodel()
|
| 7 |
+
model_server(fmodel)
|
resnet/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tensorflow-gpu===1.8.0
|
| 2 |
+
foolbox==1.1.0
|
| 3 |
+
git+https://github.com/bveliqi/adversarial-vision-challenge
|
resnet/resnet18/__init__.py
ADDED
|
File without changes
|
resnet/resnet18/checkpoints/model/checkpoint
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
model_checkpoint_path: "model.ckpt-5865"
|
resnet/resnet18/checkpoints/model/graph.pbtxt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
resnet/resnet18/checkpoints/model/model.ckpt-5865.data-00000-of-00001
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a567213201e4c8fda90ed9196633feb86a551ed76578456a90796f94b674b96
|
| 3 |
+
size 90221128
|
resnet/resnet18/checkpoints/model/model.ckpt-5865.index
ADDED
|
Binary file (5.74 kB). View file
|
|
|
resnet/resnet18/checkpoints/model/model.ckpt-5865.meta
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d6a76e71158e1b23d8993993d90ce435122e086def1a0c249e6a224655e73592
|
| 3 |
+
size 1161995
|
resnet/resnet18/resnet_model.py
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Contains definitions for Residual Networks.
|
| 16 |
+
|
| 17 |
+
Residual networks ('v1' ResNets) were originally proposed in:
|
| 18 |
+
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
| 19 |
+
Deep Residual Learning for Image Recognition. arXiv:1512.03385
|
| 20 |
+
|
| 21 |
+
The full preactivation 'v2' ResNet variant was introduced by:
|
| 22 |
+
[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
| 23 |
+
Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
|
| 24 |
+
|
| 25 |
+
The key difference of the full preactivation 'v2' variant compared to the
|
| 26 |
+
'v1' variant in [1] is the use of batch normalization before every weight layer
|
| 27 |
+
rather than after.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
from __future__ import absolute_import
|
| 31 |
+
from __future__ import division
|
| 32 |
+
from __future__ import print_function
|
| 33 |
+
|
| 34 |
+
import tensorflow as tf
|
| 35 |
+
|
| 36 |
+
_BATCH_NORM_DECAY = 0.997
|
| 37 |
+
_BATCH_NORM_EPSILON = 1e-5
|
| 38 |
+
DEFAULT_VERSION = 2
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
################################################################################
|
| 42 |
+
# Convenience functions for building the ResNet models_repo.
|
| 43 |
+
################################################################################
|
| 44 |
+
def batch_norm(inputs, training, data_format):
|
| 45 |
+
"""Performs a batch normalization using a standard set of parameters."""
|
| 46 |
+
# We set fused=True for a significant performance boost. See
|
| 47 |
+
# https://www.tensorflow.org/performance/performance_guide#common_fused_ops
|
| 48 |
+
return tf.layers.batch_normalization(
|
| 49 |
+
inputs=inputs,
|
| 50 |
+
axis=1 if data_format == "channels_first" else 3,
|
| 51 |
+
momentum=_BATCH_NORM_DECAY,
|
| 52 |
+
epsilon=_BATCH_NORM_EPSILON,
|
| 53 |
+
center=True,
|
| 54 |
+
scale=True,
|
| 55 |
+
training=training,
|
| 56 |
+
fused=True,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def fixed_padding(inputs, kernel_size, data_format):
|
| 61 |
+
"""Pads the input along the spatial dimensions independently of input size.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
inputs: A tensor of size [batch, channels, height_in, width_in] or
|
| 65 |
+
[batch, height_in, width_in, channels] depending on data_format.
|
| 66 |
+
kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
|
| 67 |
+
Should be a positive integer.
|
| 68 |
+
data_format: The input format ('channels_last' or 'channels_first').
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
A tensor with the same format as the input with the data either intact
|
| 72 |
+
(if kernel_size == 1) or padded (if kernel_size > 1).
|
| 73 |
+
"""
|
| 74 |
+
pad_total = kernel_size - 1
|
| 75 |
+
pad_beg = pad_total // 2
|
| 76 |
+
pad_end = pad_total - pad_beg
|
| 77 |
+
|
| 78 |
+
if data_format == "channels_first":
|
| 79 |
+
padded_inputs = tf.pad(
|
| 80 |
+
inputs, [[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]]
|
| 81 |
+
)
|
| 82 |
+
else:
|
| 83 |
+
padded_inputs = tf.pad(
|
| 84 |
+
inputs, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]
|
| 85 |
+
)
|
| 86 |
+
return padded_inputs
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format):
|
| 90 |
+
"""Strided 2-D convolution with explicit padding."""
|
| 91 |
+
# The padding is consistent and is based only on `kernel_size`, not on the
|
| 92 |
+
# dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
|
| 93 |
+
if strides > 1:
|
| 94 |
+
inputs = fixed_padding(inputs, kernel_size, data_format)
|
| 95 |
+
|
| 96 |
+
return tf.layers.conv2d(
|
| 97 |
+
inputs=inputs,
|
| 98 |
+
filters=filters,
|
| 99 |
+
kernel_size=kernel_size,
|
| 100 |
+
strides=strides,
|
| 101 |
+
padding=("SAME" if strides == 1 else "VALID"),
|
| 102 |
+
use_bias=False,
|
| 103 |
+
kernel_initializer=tf.variance_scaling_initializer(),
|
| 104 |
+
data_format=data_format,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
################################################################################
|
| 109 |
+
# ResNet block definitions.
|
| 110 |
+
################################################################################
|
| 111 |
+
def _building_block_v1(
|
| 112 |
+
inputs, filters, training, projection_shortcut, strides, data_format
|
| 113 |
+
):
|
| 114 |
+
"""A single block for ResNet v1, without a bottleneck.
|
| 115 |
+
|
| 116 |
+
Convolution then batch normalization then ReLU as described by:
|
| 117 |
+
Deep Residual Learning for Image Recognition
|
| 118 |
+
https://arxiv.org/pdf/1512.03385.pdf
|
| 119 |
+
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
inputs: A tensor of size [batch, channels, height_in, width_in] or
|
| 123 |
+
[batch, height_in, width_in, channels] depending on data_format.
|
| 124 |
+
filters: The number of filters for the convolutions.
|
| 125 |
+
training: A Boolean for whether the models_repo is in training or inference
|
| 126 |
+
mode. Needed for batch normalization.
|
| 127 |
+
projection_shortcut: The function to use for projection shortcuts
|
| 128 |
+
(typically a 1x1 convolution when downsampling the input).
|
| 129 |
+
strides: The block's stride. If greater than 1, this block will ultimately
|
| 130 |
+
downsample the input.
|
| 131 |
+
data_format: The input format ('channels_last' or 'channels_first').
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
The output tensor of the block; shape should match inputs.
|
| 135 |
+
"""
|
| 136 |
+
shortcut = inputs
|
| 137 |
+
|
| 138 |
+
if projection_shortcut is not None:
|
| 139 |
+
shortcut = projection_shortcut(inputs)
|
| 140 |
+
shortcut = batch_norm(
|
| 141 |
+
inputs=shortcut, training=training, data_format=data_format
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
inputs = conv2d_fixed_padding(
|
| 145 |
+
inputs=inputs,
|
| 146 |
+
filters=filters,
|
| 147 |
+
kernel_size=3,
|
| 148 |
+
strides=strides,
|
| 149 |
+
data_format=data_format,
|
| 150 |
+
)
|
| 151 |
+
inputs = batch_norm(inputs, training, data_format)
|
| 152 |
+
inputs = tf.nn.relu(inputs)
|
| 153 |
+
|
| 154 |
+
inputs = conv2d_fixed_padding(
|
| 155 |
+
inputs=inputs,
|
| 156 |
+
filters=filters,
|
| 157 |
+
kernel_size=3,
|
| 158 |
+
strides=1,
|
| 159 |
+
data_format=data_format,
|
| 160 |
+
)
|
| 161 |
+
inputs = batch_norm(inputs, training, data_format)
|
| 162 |
+
inputs += shortcut
|
| 163 |
+
inputs = tf.nn.relu(inputs)
|
| 164 |
+
|
| 165 |
+
return inputs
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _building_block_v2(
|
| 169 |
+
inputs, filters, training, projection_shortcut, strides, data_format
|
| 170 |
+
):
|
| 171 |
+
"""A single block for ResNet v2, without a bottleneck.
|
| 172 |
+
|
| 173 |
+
Batch normalization then ReLu then convolution as described by:
|
| 174 |
+
Identity Mappings in Deep Residual Networks
|
| 175 |
+
https://arxiv.org/pdf/1603.05027.pdf
|
| 176 |
+
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
inputs: A tensor of size [batch, channels, height_in, width_in] or
|
| 180 |
+
[batch, height_in, width_in, channels] depending on data_format.
|
| 181 |
+
filters: The number of filters for the convolutions.
|
| 182 |
+
training: A Boolean for whether the models_repo is in training or inference
|
| 183 |
+
mode. Needed for batch normalization.
|
| 184 |
+
projection_shortcut: The function to use for projection shortcuts
|
| 185 |
+
(typically a 1x1 convolution when downsampling the input).
|
| 186 |
+
strides: The block's stride. If greater than 1, this block will ultimately
|
| 187 |
+
downsample the input.
|
| 188 |
+
data_format: The input format ('channels_last' or 'channels_first').
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
The output tensor of the block; shape should match inputs.
|
| 192 |
+
"""
|
| 193 |
+
shortcut = inputs
|
| 194 |
+
inputs = batch_norm(inputs, training, data_format)
|
| 195 |
+
inputs = tf.nn.relu(inputs)
|
| 196 |
+
|
| 197 |
+
# The projection shortcut should come after the first batch norm and ReLU
|
| 198 |
+
# since it performs a 1x1 convolution.
|
| 199 |
+
if projection_shortcut is not None:
|
| 200 |
+
shortcut = projection_shortcut(inputs)
|
| 201 |
+
|
| 202 |
+
inputs = conv2d_fixed_padding(
|
| 203 |
+
inputs=inputs,
|
| 204 |
+
filters=filters,
|
| 205 |
+
kernel_size=3,
|
| 206 |
+
strides=strides,
|
| 207 |
+
data_format=data_format,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
inputs = batch_norm(inputs, training, data_format)
|
| 211 |
+
inputs = tf.nn.relu(inputs)
|
| 212 |
+
inputs = conv2d_fixed_padding(
|
| 213 |
+
inputs=inputs,
|
| 214 |
+
filters=filters,
|
| 215 |
+
kernel_size=3,
|
| 216 |
+
strides=1,
|
| 217 |
+
data_format=data_format,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
return inputs + shortcut
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def _bottleneck_block_v1(
|
| 224 |
+
inputs, filters, training, projection_shortcut, strides, data_format
|
| 225 |
+
):
|
| 226 |
+
"""A single block for ResNet v1, with a bottleneck.
|
| 227 |
+
|
| 228 |
+
Similar to _building_block_v1(), except using the "bottleneck" blocks
|
| 229 |
+
described in:
|
| 230 |
+
Convolution then batch normalization then ReLU as described by:
|
| 231 |
+
Deep Residual Learning for Image Recognition
|
| 232 |
+
https://arxiv.org/pdf/1512.03385.pdf
|
| 233 |
+
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
inputs: A tensor of size [batch, channels, height_in, width_in] or
|
| 237 |
+
[batch, height_in, width_in, channels] depending on data_format.
|
| 238 |
+
filters: The number of filters for the convolutions.
|
| 239 |
+
training: A Boolean for whether the models_repo is in training or inference
|
| 240 |
+
mode. Needed for batch normalization.
|
| 241 |
+
projection_shortcut: The function to use for projection shortcuts
|
| 242 |
+
(typically a 1x1 convolution when downsampling the input).
|
| 243 |
+
strides: The block's stride. If greater than 1, this block will ultimately
|
| 244 |
+
downsample the input.
|
| 245 |
+
data_format: The input format ('channels_last' or 'channels_first').
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
The output tensor of the block; shape should match inputs.
|
| 249 |
+
"""
|
| 250 |
+
shortcut = inputs
|
| 251 |
+
|
| 252 |
+
if projection_shortcut is not None:
|
| 253 |
+
shortcut = projection_shortcut(inputs)
|
| 254 |
+
shortcut = batch_norm(
|
| 255 |
+
inputs=shortcut, training=training, data_format=data_format
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
inputs = conv2d_fixed_padding(
|
| 259 |
+
inputs=inputs,
|
| 260 |
+
filters=filters,
|
| 261 |
+
kernel_size=1,
|
| 262 |
+
strides=1,
|
| 263 |
+
data_format=data_format,
|
| 264 |
+
)
|
| 265 |
+
inputs = batch_norm(inputs, training, data_format)
|
| 266 |
+
inputs = tf.nn.relu(inputs)
|
| 267 |
+
|
| 268 |
+
inputs = conv2d_fixed_padding(
|
| 269 |
+
inputs=inputs,
|
| 270 |
+
filters=filters,
|
| 271 |
+
kernel_size=3,
|
| 272 |
+
strides=strides,
|
| 273 |
+
data_format=data_format,
|
| 274 |
+
)
|
| 275 |
+
inputs = batch_norm(inputs, training, data_format)
|
| 276 |
+
inputs = tf.nn.relu(inputs)
|
| 277 |
+
|
| 278 |
+
inputs = conv2d_fixed_padding(
|
| 279 |
+
inputs=inputs,
|
| 280 |
+
filters=4 * filters,
|
| 281 |
+
kernel_size=1,
|
| 282 |
+
strides=1,
|
| 283 |
+
data_format=data_format,
|
| 284 |
+
)
|
| 285 |
+
inputs = batch_norm(inputs, training, data_format)
|
| 286 |
+
inputs += shortcut
|
| 287 |
+
inputs = tf.nn.relu(inputs)
|
| 288 |
+
|
| 289 |
+
return inputs
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def _bottleneck_block_v2(
|
| 293 |
+
inputs, filters, training, projection_shortcut, strides, data_format
|
| 294 |
+
):
|
| 295 |
+
"""A single block for ResNet v2, without a bottleneck.
|
| 296 |
+
|
| 297 |
+
Similar to _building_block_v2(), except using the "bottleneck" blocks
|
| 298 |
+
described in:
|
| 299 |
+
Convolution then batch normalization then ReLU as described by:
|
| 300 |
+
Deep Residual Learning for Image Recognition
|
| 301 |
+
https://arxiv.org/pdf/1512.03385.pdf
|
| 302 |
+
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015.
|
| 303 |
+
|
| 304 |
+
Adapted to the ordering conventions of:
|
| 305 |
+
Batch normalization then ReLu then convolution as described by:
|
| 306 |
+
Identity Mappings in Deep Residual Networks
|
| 307 |
+
https://arxiv.org/pdf/1603.05027.pdf
|
| 308 |
+
by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
inputs: A tensor of size [batch, channels, height_in, width_in] or
|
| 312 |
+
[batch, height_in, width_in, channels] depending on data_format.
|
| 313 |
+
filters: The number of filters for the convolutions.
|
| 314 |
+
training: A Boolean for whether the models_repo is in training or inference
|
| 315 |
+
mode. Needed for batch normalization.
|
| 316 |
+
projection_shortcut: The function to use for projection shortcuts
|
| 317 |
+
(typically a 1x1 convolution when downsampling the input).
|
| 318 |
+
strides: The block's stride. If greater than 1, this block will ultimately
|
| 319 |
+
downsample the input.
|
| 320 |
+
data_format: The input format ('channels_last' or 'channels_first').
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
The output tensor of the block; shape should match inputs.
|
| 324 |
+
"""
|
| 325 |
+
shortcut = inputs
|
| 326 |
+
inputs = batch_norm(inputs, training, data_format)
|
| 327 |
+
inputs = tf.nn.relu(inputs)
|
| 328 |
+
|
| 329 |
+
# The projection shortcut should come after the first batch norm and ReLU
|
| 330 |
+
# since it performs a 1x1 convolution.
|
| 331 |
+
if projection_shortcut is not None:
|
| 332 |
+
shortcut = projection_shortcut(inputs)
|
| 333 |
+
|
| 334 |
+
inputs = conv2d_fixed_padding(
|
| 335 |
+
inputs=inputs,
|
| 336 |
+
filters=filters,
|
| 337 |
+
kernel_size=1,
|
| 338 |
+
strides=1,
|
| 339 |
+
data_format=data_format,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
inputs = batch_norm(inputs, training, data_format)
|
| 343 |
+
inputs = tf.nn.relu(inputs)
|
| 344 |
+
inputs = conv2d_fixed_padding(
|
| 345 |
+
inputs=inputs,
|
| 346 |
+
filters=filters,
|
| 347 |
+
kernel_size=3,
|
| 348 |
+
strides=strides,
|
| 349 |
+
data_format=data_format,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
inputs = batch_norm(inputs, training, data_format)
|
| 353 |
+
inputs = tf.nn.relu(inputs)
|
| 354 |
+
inputs = conv2d_fixed_padding(
|
| 355 |
+
inputs=inputs,
|
| 356 |
+
filters=4 * filters,
|
| 357 |
+
kernel_size=1,
|
| 358 |
+
strides=1,
|
| 359 |
+
data_format=data_format,
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
return inputs + shortcut
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def block_layer(
|
| 366 |
+
inputs, filters, bottleneck, block_fn, blocks, strides, training, name, data_format
|
| 367 |
+
):
|
| 368 |
+
"""Creates one layer of blocks for the ResNet models_repo.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
inputs: A tensor of size [batch, channels, height_in, width_in] or
|
| 372 |
+
[batch, height_in, width_in, channels] depending on data_format.
|
| 373 |
+
filters: The number of filters for the first convolution of the layer.
|
| 374 |
+
bottleneck: Is the block created a bottleneck block.
|
| 375 |
+
block_fn: The block to use within the models_repo, either `building_block` or
|
| 376 |
+
`bottleneck_block`.
|
| 377 |
+
blocks: The number of blocks contained in the layer.
|
| 378 |
+
strides: The stride to use for the first convolution of the layer. If
|
| 379 |
+
greater than 1, this layer will ultimately downsample the input.
|
| 380 |
+
training: Either True or False, whether we are currently training the
|
| 381 |
+
models_repo. Needed for batch norm.
|
| 382 |
+
name: A string name for the tensor output of the block layer.
|
| 383 |
+
data_format: The input format ('channels_last' or 'channels_first').
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
The output tensor of the block layer.
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
# Bottleneck blocks end with 4x the number of filters as they start with
|
| 390 |
+
filters_out = filters * 4 if bottleneck else filters
|
| 391 |
+
|
| 392 |
+
def projection_shortcut(inputs):
|
| 393 |
+
return conv2d_fixed_padding(
|
| 394 |
+
inputs=inputs,
|
| 395 |
+
filters=filters_out,
|
| 396 |
+
kernel_size=1,
|
| 397 |
+
strides=strides,
|
| 398 |
+
data_format=data_format,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
# Only the first block per block_layer uses projection_shortcut and strides
|
| 402 |
+
inputs = block_fn(
|
| 403 |
+
inputs, filters, training, projection_shortcut, strides, data_format
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
for _ in range(1, blocks):
|
| 407 |
+
inputs = block_fn(inputs, filters, training, None, 1, data_format)
|
| 408 |
+
|
| 409 |
+
return tf.identity(inputs, name)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class Model(object):
|
| 413 |
+
"""Base class for building the Resnet Model."""
|
| 414 |
+
|
| 415 |
+
def __init__(
|
| 416 |
+
self,
|
| 417 |
+
resnet_size,
|
| 418 |
+
bottleneck,
|
| 419 |
+
num_classes,
|
| 420 |
+
num_filters,
|
| 421 |
+
kernel_size,
|
| 422 |
+
conv_stride,
|
| 423 |
+
first_pool_size,
|
| 424 |
+
first_pool_stride,
|
| 425 |
+
second_pool_size,
|
| 426 |
+
second_pool_stride,
|
| 427 |
+
block_sizes,
|
| 428 |
+
block_strides,
|
| 429 |
+
final_size,
|
| 430 |
+
version=DEFAULT_VERSION,
|
| 431 |
+
data_format=None,
|
| 432 |
+
):
|
| 433 |
+
"""Creates a models_repo for classifying an image.
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
resnet_size: A single integer for the size of the ResNet models_repo.
|
| 437 |
+
bottleneck: Use regular blocks or bottleneck blocks.
|
| 438 |
+
num_classes: The number of classes used as labels.
|
| 439 |
+
num_filters: The number of filters to use for the first block layer
|
| 440 |
+
of the models_repo. This number is then doubled for each subsequent block
|
| 441 |
+
layer.
|
| 442 |
+
kernel_size: The kernel size to use for convolution.
|
| 443 |
+
conv_stride: stride size for the initial convolutional layer
|
| 444 |
+
first_pool_size: Pool size to be used for the first pooling layer.
|
| 445 |
+
If none, the first pooling layer is skipped.
|
| 446 |
+
first_pool_stride: stride size for the first pooling layer. Not used
|
| 447 |
+
if first_pool_size is None.
|
| 448 |
+
second_pool_size: Pool size to be used for the second pooling layer.
|
| 449 |
+
second_pool_stride: stride size for the final pooling layer
|
| 450 |
+
block_sizes: A list containing n values, where n is the number of sets of
|
| 451 |
+
block layers desired. Each value should be the number of blocks in the
|
| 452 |
+
i-th set.
|
| 453 |
+
block_strides: List of integers representing the desired stride size for
|
| 454 |
+
each of the sets of block layers. Should be same length as block_sizes.
|
| 455 |
+
final_size: The expected size of the models_repo after the second pooling.
|
| 456 |
+
version: Integer representing which version of the ResNet network to use.
|
| 457 |
+
See README for details. Valid values: [1, 2]
|
| 458 |
+
data_format: Input format ('channels_last', 'channels_first', or None).
|
| 459 |
+
If set to None, the format is dependent on whether a GPU is available.
|
| 460 |
+
|
| 461 |
+
Raises:
|
| 462 |
+
ValueError: if invalid version is selected.
|
| 463 |
+
"""
|
| 464 |
+
self.resnet_size = resnet_size
|
| 465 |
+
|
| 466 |
+
if not data_format:
|
| 467 |
+
data_format = (
|
| 468 |
+
"channels_first" if tf.test.is_built_with_cuda() else "channels_last"
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
self.resnet_version = version
|
| 472 |
+
if version not in (1, 2):
|
| 473 |
+
raise ValueError(
|
| 474 |
+
"Resnet version should be 1 or 2. See README for citations."
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
self.bottleneck = bottleneck
|
| 478 |
+
if bottleneck:
|
| 479 |
+
if version == 1:
|
| 480 |
+
self.block_fn = _bottleneck_block_v1
|
| 481 |
+
else:
|
| 482 |
+
self.block_fn = _bottleneck_block_v2
|
| 483 |
+
else:
|
| 484 |
+
if version == 1:
|
| 485 |
+
self.block_fn = _building_block_v1
|
| 486 |
+
else:
|
| 487 |
+
self.block_fn = _building_block_v2
|
| 488 |
+
|
| 489 |
+
self.data_format = data_format
|
| 490 |
+
self.num_classes = num_classes
|
| 491 |
+
self.num_filters = num_filters
|
| 492 |
+
self.kernel_size = kernel_size
|
| 493 |
+
self.conv_stride = conv_stride
|
| 494 |
+
self.first_pool_size = first_pool_size
|
| 495 |
+
self.first_pool_stride = first_pool_stride
|
| 496 |
+
self.second_pool_size = second_pool_size
|
| 497 |
+
self.second_pool_stride = second_pool_stride
|
| 498 |
+
self.block_sizes = block_sizes
|
| 499 |
+
self.block_strides = block_strides
|
| 500 |
+
self.final_size = final_size
|
| 501 |
+
|
| 502 |
+
def __call__(self, inputs, training):
|
| 503 |
+
"""Add operations to classify a batch of input images.
|
| 504 |
+
|
| 505 |
+
Args:
|
| 506 |
+
inputs: A Tensor representing a batch of input images.
|
| 507 |
+
training: A boolean. Set to True to add operations required only when
|
| 508 |
+
training the classifier.
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
A logits Tensor with shape [<batch_size>, self.num_classes].
|
| 512 |
+
"""
|
| 513 |
+
|
| 514 |
+
if self.data_format == "channels_first":
|
| 515 |
+
# Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
|
| 516 |
+
# This provides a large performance boost on GPU. See
|
| 517 |
+
# https://www.tensorflow.org/performance/performance_guide#data_formats
|
| 518 |
+
inputs = tf.transpose(inputs, [0, 3, 1, 2])
|
| 519 |
+
|
| 520 |
+
inputs = conv2d_fixed_padding(
|
| 521 |
+
inputs=inputs,
|
| 522 |
+
filters=self.num_filters,
|
| 523 |
+
kernel_size=self.kernel_size,
|
| 524 |
+
strides=self.conv_stride,
|
| 525 |
+
data_format=self.data_format,
|
| 526 |
+
)
|
| 527 |
+
inputs = tf.identity(inputs, "initial_conv")
|
| 528 |
+
|
| 529 |
+
if self.first_pool_size:
|
| 530 |
+
inputs = tf.layers.max_pooling2d(
|
| 531 |
+
inputs=inputs,
|
| 532 |
+
pool_size=self.first_pool_size,
|
| 533 |
+
strides=self.first_pool_stride,
|
| 534 |
+
padding="SAME",
|
| 535 |
+
data_format=self.data_format,
|
| 536 |
+
)
|
| 537 |
+
inputs = tf.identity(inputs, "initial_max_pool")
|
| 538 |
+
|
| 539 |
+
for i, num_blocks in enumerate(self.block_sizes):
|
| 540 |
+
num_filters = self.num_filters * (2**i)
|
| 541 |
+
inputs = block_layer(
|
| 542 |
+
inputs=inputs,
|
| 543 |
+
filters=num_filters,
|
| 544 |
+
bottleneck=self.bottleneck,
|
| 545 |
+
block_fn=self.block_fn,
|
| 546 |
+
blocks=num_blocks,
|
| 547 |
+
strides=self.block_strides[i],
|
| 548 |
+
training=training,
|
| 549 |
+
name="block_layer{}".format(i + 1),
|
| 550 |
+
data_format=self.data_format,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
inputs = batch_norm(inputs, training, self.data_format)
|
| 554 |
+
inputs = tf.nn.relu(inputs)
|
| 555 |
+
|
| 556 |
+
# The current top layer has shape
|
| 557 |
+
# `batch_size x pool_size x pool_size x final_size`.
|
| 558 |
+
# ResNet does an Average Pooling layer over pool_size,
|
| 559 |
+
# but that is the same as doing a reduce_mean. We do a reduce_mean
|
| 560 |
+
# here because it performs better than AveragePooling2D.
|
| 561 |
+
axes = [2, 3] if self.data_format == "channels_first" else [1, 2]
|
| 562 |
+
inputs = tf.reduce_mean(inputs, axes, keepdims=True)
|
| 563 |
+
inputs = tf.identity(inputs, "final_reduce_mean")
|
| 564 |
+
|
| 565 |
+
inputs = tf.reshape(inputs, [-1, self.final_size])
|
| 566 |
+
readout_layer = tf.layers.Dense(units=self.num_classes, name="readout_layer")
|
| 567 |
+
inputs = readout_layer(inputs)
|
| 568 |
+
inputs = tf.identity(inputs, "final_dense")
|
| 569 |
+
|
| 570 |
+
return inputs
|
resnet/run.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
echo "Starting Server..."
|
| 2 |
+
python ./main.py
|
thundernet_config.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train_path: str = "C:/Users/user/Documents/pruned_training/training/"
|
| 2 |
+
val_path: str = "C:/Users/user/Documents/pruned_training/val/"
|
| 3 |
+
model_dir: str = "C:/Users/user/Documents/Thundernet/models/"
|
| 4 |
+
model_weights: str = (
|
| 5 |
+
"C:/Users/user/Documents/Thundernet/model/BS4_lossBCE_weights_lr_0.00013713842558297858_reg-1.1743577101671763e-05-ep-13-val_loss0.11463435739278793-train_loss0.053004469722509384-val_iou0.8959722518920898-train_iou0.9606077075004578.hdf5"
|
| 6 |
+
)
|
| 7 |
+
batch_size: int = 4
|
| 8 |
+
augment: bool = False # True
|
| 9 |
+
rand_crop: bool = 0.05
|
| 10 |
+
loss: str = "BCE"
|
| 11 |
+
weights: list = None # [0.56, 3.27]
|
| 12 |
+
classes: int = 2
|
| 13 |
+
pretrained_bool: bool = False
|
| 14 |
+
pretrained_weigths: str = None
|
| 15 |
+
lr: float = 1e-4
|
| 16 |
+
epochs: int = 15
|
| 17 |
+
resolution: str = "640x480"
|
| 18 |
+
kernel_regularizer: float = 2e-4
|
train_config.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from data_gen import DataGenerator
|
| 3 |
+
from os import listdir
|
| 4 |
+
from utils import (
|
| 5 |
+
iou,
|
| 6 |
+
PlotLosses,
|
| 7 |
+
dice_loss,
|
| 8 |
+
focal_loss,
|
| 9 |
+
categorical_loss,
|
| 10 |
+
categorical_focal_loss,
|
| 11 |
+
resolution2framesize3cha,
|
| 12 |
+
resolution2framesize,
|
| 13 |
+
bce_loss,
|
| 14 |
+
)
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import tensorflow as tf
|
| 17 |
+
|
| 18 |
+
tf.config.run_functions_eagerly(True)
|
| 19 |
+
# from keras.backend.tensorflow_backend import set_session
|
| 20 |
+
import argparse
|
| 21 |
+
import sys
|
| 22 |
+
import numpy as np
|
| 23 |
+
import thundernet_config as Thundernet_config
|
| 24 |
+
from datetime import datetime
|
| 25 |
+
from matplotlib import pyplot as plt
|
| 26 |
+
|
| 27 |
+
from model.model import Thundernet as Thundernet_original
|
| 28 |
+
from model.model_ppm_factors import Thundernet as Thundernet_ppm
|
| 29 |
+
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from collections import defaultdict
|
| 32 |
+
import copy
|
| 33 |
+
|
| 34 |
+
plt.switch_backend("agg")
|
| 35 |
+
|
| 36 |
+
parser = argparse.ArgumentParser()
|
| 37 |
+
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--train_dir",
|
| 40 |
+
type=str,
|
| 41 |
+
default=Thundernet_config.train_path,
|
| 42 |
+
help="The directory containing the train image dataset.",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--val_dir",
|
| 47 |
+
type=str,
|
| 48 |
+
default=Thundernet_config.val_path,
|
| 49 |
+
help="The directory containing the validation image dataset.",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--batch_size",
|
| 54 |
+
type=int,
|
| 55 |
+
default=Thundernet_config.batch_size,
|
| 56 |
+
choices=[1, 2, 4, 8, 16],
|
| 57 |
+
help="Batch size used for training Thundernet",
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--augment",
|
| 62 |
+
type=bool,
|
| 63 |
+
default=Thundernet_config.augment,
|
| 64 |
+
choices=[False, True],
|
| 65 |
+
help="Whether to use color augmentation for training Thundernet.",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--rand_crop",
|
| 70 |
+
type=float,
|
| 71 |
+
default=Thundernet_config.rand_crop,
|
| 72 |
+
choices=[0, 0.02, 0.05, 0.1, 0.2, 0.5],
|
| 73 |
+
help="Frequency of random crop data augmentation technique.",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--loss",
|
| 78 |
+
type=str,
|
| 79 |
+
default=Thundernet_config.loss,
|
| 80 |
+
choices=["BCE", "BFL", "CFL", "DCL", "FTL", "CAT"],
|
| 81 |
+
help="Loss function to be used - Binary Cross Entropy (BCE), Focal Loss (FL) , Dice Coefficient Loss (DCL) and Focal Tversky Loss (FTL)",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--model_dir",
|
| 86 |
+
type=str,
|
| 87 |
+
default=Thundernet_config.model_dir,
|
| 88 |
+
help="Base directory for the models_repo. "
|
| 89 |
+
"Make sure 'model_checkpoint_path' given in 'checkpoint' file matches "
|
| 90 |
+
"with checkpoint name.",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--weights",
|
| 95 |
+
type=dict,
|
| 96 |
+
default=Thundernet_config.weights,
|
| 97 |
+
help="Class weights used for Weighted Binary Cross Entropy Loss.",
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--lr", type=float, default=Thundernet_config.lr, help="Learning Rate."
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--epochs", type=int, default=Thundernet_config.epochs, help="Epochs"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--classes", type=int, default=Thundernet_config.classes, help="Epochs"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--resolution",
|
| 114 |
+
type=str,
|
| 115 |
+
default=Thundernet_config.resolution,
|
| 116 |
+
help="Input Resolution",
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--kernel_regularizer",
|
| 121 |
+
type=float,
|
| 122 |
+
default=Thundernet_config.kernel_regularizer,
|
| 123 |
+
help="kernel_regularizer",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--pretrained",
|
| 128 |
+
type=bool,
|
| 129 |
+
default=Thundernet_config.pretrained_bool,
|
| 130 |
+
help="In case you want to train",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--pretrained_weigths",
|
| 135 |
+
type=str,
|
| 136 |
+
default=Thundernet_config.pretrained_weigths,
|
| 137 |
+
help="In case you want to train",
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def main(
|
| 142 |
+
args: list,
|
| 143 |
+
transformations: tuple = tuple(),
|
| 144 |
+
model: str = "original",
|
| 145 |
+
class_mappings: dict = None,
|
| 146 |
+
):
|
| 147 |
+
"""
|
| 148 |
+
Train the model
|
| 149 |
+
Args:
|
| 150 |
+
- args (list): list of parsed arguments
|
| 151 |
+
- model (str): type of model. Default: "original"
|
| 152 |
+
- class_mappings (dict): class mapper. Default: None
|
| 153 |
+
- transformations (tuple): list of transformations to execute in the data. Default: tuple()
|
| 154 |
+
- show (bool): display the predictions. Default: False
|
| 155 |
+
Returns:
|
| 156 |
+
- None
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
FLAGS: list = parser.parse_args(args)
|
| 160 |
+
|
| 161 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 162 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # use id from $ nvidia-smi
|
| 163 |
+
|
| 164 |
+
mypath_train = FLAGS.train_dir + "images/"
|
| 165 |
+
label_path_train = FLAGS.train_dir + "labels/"
|
| 166 |
+
|
| 167 |
+
list_IDs_train = [f[:-4] for f in listdir(mypath_train) if f[-4:] == ".jpg"]
|
| 168 |
+
mypath_val = FLAGS.val_dir + "images/"
|
| 169 |
+
label_path_val = FLAGS.val_dir + "labels/"
|
| 170 |
+
list_IDs_val = [f[:-4] for f in listdir(mypath_val) if f[-4:] == ".jpg"]
|
| 171 |
+
|
| 172 |
+
# First we assure that the dir for saving the experiments is created
|
| 173 |
+
if not os.path.exists(FLAGS.model_dir):
|
| 174 |
+
os.makedirs(FLAGS.model_dir)
|
| 175 |
+
|
| 176 |
+
# For every trial of the same experiment we create a new subfolder
|
| 177 |
+
k = 1
|
| 178 |
+
dir_created = False
|
| 179 |
+
while not dir_created:
|
| 180 |
+
model_dir = FLAGS.model_dir + str(k) + "/"
|
| 181 |
+
if not os.path.exists(model_dir):
|
| 182 |
+
os.makedirs(model_dir)
|
| 183 |
+
dir_created = True
|
| 184 |
+
else:
|
| 185 |
+
k += 1
|
| 186 |
+
|
| 187 |
+
# Model
|
| 188 |
+
if model == "original":
|
| 189 |
+
Thundernet = Thundernet_original
|
| 190 |
+
elif model == "ppm":
|
| 191 |
+
Thundernet = Thundernet_ppm
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError(f"Unknown model: {model}")
|
| 194 |
+
|
| 195 |
+
# Class mappings
|
| 196 |
+
|
| 197 |
+
if class_mappings is not None:
|
| 198 |
+
FLAGS.classes = len(set(class_mappings.values())) + 1
|
| 199 |
+
|
| 200 |
+
# Write the file configuration in model_dir
|
| 201 |
+
file = open(model_dir + "config.txt", "w")
|
| 202 |
+
file.write("Experiment num " + str(k) + "\n")
|
| 203 |
+
file.write("Fecha=" + str(datetime.now()) + "\n")
|
| 204 |
+
file.write("Train with=" + FLAGS.train_dir + "\n")
|
| 205 |
+
file.write("Val with=" + FLAGS.val_dir + "\n")
|
| 206 |
+
file.write("Input Resoltuion with=" + FLAGS.resolution + "\n")
|
| 207 |
+
file.write("Batch Size=" + str(FLAGS.batch_size) + "\n")
|
| 208 |
+
file.write("Batch augment=" + str(FLAGS.augment) + "\n")
|
| 209 |
+
file.write("Rand Crop=" + str(FLAGS.rand_crop) + "\n")
|
| 210 |
+
file.write("Loss=" + FLAGS.loss + "\n")
|
| 211 |
+
file.write("Model dir=" + FLAGS.model_dir + "\n")
|
| 212 |
+
file.write("weights=" + str(FLAGS.weights) + "\n")
|
| 213 |
+
file.write("lr=" + str(FLAGS.lr) + "\n")
|
| 214 |
+
file.write("epochs=" + str(FLAGS.epochs) + "\n")
|
| 215 |
+
file.write("classes=" + str(FLAGS.classes) + "\n")
|
| 216 |
+
file.write("kernel_regularizer=" + str(FLAGS.kernel_regularizer) + "\n")
|
| 217 |
+
file.write("pretrained=" + str(FLAGS.pretrained) + "\n")
|
| 218 |
+
file.write("pretrained_weigths=" + str(FLAGS.pretrained_weigths) + "\n")
|
| 219 |
+
file.write("Class mappings=" + str(class_mappings) + "\n")
|
| 220 |
+
file.write("Model=" + model + "\n")
|
| 221 |
+
file.write(f"Transformations: {transformations}\n")
|
| 222 |
+
file.write("Comentarios=" + "" + "\n")
|
| 223 |
+
file.close()
|
| 224 |
+
|
| 225 |
+
print(
|
| 226 |
+
"resolution2framesize3cha(FLAGS.resolution) ",
|
| 227 |
+
resolution2framesize3cha(FLAGS.resolution),
|
| 228 |
+
)
|
| 229 |
+
thundernet = Thundernet(
|
| 230 |
+
input_shape=resolution2framesize3cha(FLAGS.resolution),
|
| 231 |
+
n_classes=FLAGS.classes,
|
| 232 |
+
resnet_trainable=True,
|
| 233 |
+
kernel_regularizer=FLAGS.kernel_regularizer,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if FLAGS.pretrained:
|
| 237 |
+
print("loading weights from", FLAGS.pretrained_weigths)
|
| 238 |
+
thundernet.model.load_weights(
|
| 239 |
+
FLAGS.pretrained_weigths, by_name=True, skip_mismatch=True
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
lr = FLAGS.lr
|
| 243 |
+
opt = tf.keras.optimizers.Adam(learning_rate=lr) # for keras 2.6.0
|
| 244 |
+
|
| 245 |
+
if not model_dir.endswith(os.path.sep):
|
| 246 |
+
model_dir += os.path.sep
|
| 247 |
+
|
| 248 |
+
callbacks = [
|
| 249 |
+
PlotLosses(model_dir),
|
| 250 |
+
tf.keras.callbacks.ModelCheckpoint(
|
| 251 |
+
filepath=os.path.normpath(
|
| 252 |
+
os.path.join(
|
| 253 |
+
model_dir,
|
| 254 |
+
f"BS{FLAGS.batch_size}_loss{FLAGS.loss}_weights_lr_{lr}_reg-{FLAGS.kernel_regularizer}-ep-{{epoch}}-val_loss{{val_loss}}-train_loss{{loss}}-val_iou{{val_iou}}-train_iou{{iou}}.hdf5",
|
| 255 |
+
)
|
| 256 |
+
),
|
| 257 |
+
save_best_only=True,
|
| 258 |
+
save_weights_only=True,
|
| 259 |
+
),
|
| 260 |
+
]
|
| 261 |
+
|
| 262 |
+
if FLAGS.loss == "BCE":
|
| 263 |
+
loss = bce_loss()
|
| 264 |
+
elif FLAGS.loss == "BFL":
|
| 265 |
+
loss = focal_loss()
|
| 266 |
+
elif FLAGS.loss == "DCL":
|
| 267 |
+
loss = dice_loss()
|
| 268 |
+
elif FLAGS.loss == "CFL":
|
| 269 |
+
loss = categorical_focal_loss()
|
| 270 |
+
elif FLAGS.loss == "CAT":
|
| 271 |
+
loss = categorical_loss()
|
| 272 |
+
|
| 273 |
+
thundernet.model.compile(loss=loss, optimizer=opt, metrics=[iou])
|
| 274 |
+
|
| 275 |
+
dataset_dir = Path(Thundernet_config.train_path).parent
|
| 276 |
+
|
| 277 |
+
training_generator, validation_generator = DataGenerator.create_generators(
|
| 278 |
+
dataset_dir,
|
| 279 |
+
FLAGS.classes,
|
| 280 |
+
training_batch_size=Thundernet_config.batch_size,
|
| 281 |
+
validation_batch_size=Thundernet_config.batch_size,
|
| 282 |
+
to_stereo=False,
|
| 283 |
+
transformations=transformations,
|
| 284 |
+
class_mappings=class_mappings,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if FLAGS.loss == "BCE":
|
| 288 |
+
|
| 289 |
+
weights = FLAGS.weights
|
| 290 |
+
else:
|
| 291 |
+
|
| 292 |
+
weights = FLAGS.weights
|
| 293 |
+
|
| 294 |
+
history = thundernet.model.fit_generator(
|
| 295 |
+
generator=training_generator,
|
| 296 |
+
validation_data=validation_generator,
|
| 297 |
+
callbacks=callbacks,
|
| 298 |
+
use_multiprocessing=False,
|
| 299 |
+
workers=6,
|
| 300 |
+
epochs=FLAGS.epochs,
|
| 301 |
+
class_weight=None,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
if __name__ == "__main__":
|
| 306 |
+
|
| 307 |
+
main(sys.argv[1:], model="original", class_mappings=defaultdict(int, {1: 1}))
|
| 308 |
+
# main(sys.argv[1:], model="ppm", class_mappings=defaultdict(int, {1: 1}))
|
| 309 |
+
# main(sys.argv[1:], model='original', class_mappings=defaultdict(int, {1: 1, 2: 2, 5: 3})) # In case you also want to segment two specific type of objects (original class_id=2 and class_id=5)
|
| 310 |
+
# main(sys.argv[1:], model='ppm', class_mappings=defaultdict(int, {1: 1, 2: 2, 5: 2})) # In case you want to treat both objects as the same class
|
| 311 |
+
|
train_optuna.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import optuna
|
| 3 |
+
from data_gen import DataGenerator
|
| 4 |
+
from os import listdir
|
| 5 |
+
from utils import (
|
| 6 |
+
iou,
|
| 7 |
+
PlotLosses,
|
| 8 |
+
dice_loss,
|
| 9 |
+
focal_loss,
|
| 10 |
+
categorical_loss,
|
| 11 |
+
categorical_focal_loss,
|
| 12 |
+
resolution2framesize3cha,
|
| 13 |
+
resolution2framesize,
|
| 14 |
+
bce_loss,
|
| 15 |
+
)
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
import tensorflow as tf
|
| 18 |
+
from model.model import Thundernet as Thundernet_original
|
| 19 |
+
from models_repo.model_attention import Thundernet as Thundernet_attention
|
| 20 |
+
from models_repo.model_attention_2 import Thundernet as Thundernet_attention2
|
| 21 |
+
from models_repo.model_ppm_factors import Thundernet as Thundernet_ppm
|
| 22 |
+
from datetime import datetime
|
| 23 |
+
from matplotlib import pyplot as plt
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
# from data_gen_tfkeras import DataGenerator
|
| 29 |
+
from data_gen import DataGenerator
|
| 30 |
+
from os import listdir
|
| 31 |
+
from utils import (
|
| 32 |
+
iou,
|
| 33 |
+
PlotLosses,
|
| 34 |
+
dice_loss,
|
| 35 |
+
focal_loss,
|
| 36 |
+
categorical_loss,
|
| 37 |
+
categorical_focal_loss,
|
| 38 |
+
resolution2framesize3cha,
|
| 39 |
+
resolution2framesize,
|
| 40 |
+
)
|
| 41 |
+
import matplotlib.pyplot as plt
|
| 42 |
+
import tensorflow as tf
|
| 43 |
+
|
| 44 |
+
tf.config.run_functions_eagerly(True)
|
| 45 |
+
# from keras.backend.tensorflow_backend import set_session
|
| 46 |
+
import argparse
|
| 47 |
+
import sys
|
| 48 |
+
import numpy as np
|
| 49 |
+
import thundernet_config as Thundernet_config
|
| 50 |
+
from datetime import datetime
|
| 51 |
+
from matplotlib import pyplot as plt
|
| 52 |
+
|
| 53 |
+
from model.model import Thundernet as Thundernet_original
|
| 54 |
+
from models_repo.model_attention import Thundernet as Thundernet_attention
|
| 55 |
+
from models_repo.model_attention_2 import Thundernet as Thundernet_attention2
|
| 56 |
+
from models_repo.model_ppm_factors import Thundernet as Thundernet_ppm
|
| 57 |
+
|
| 58 |
+
from pathlib import Path
|
| 59 |
+
from collections import defaultdict
|
| 60 |
+
import copy
|
| 61 |
+
|
| 62 |
+
from collections import defaultdict
|
| 63 |
+
|
| 64 |
+
# Optuna-related imports
|
| 65 |
+
import optuna
|
| 66 |
+
import copy
|
| 67 |
+
|
| 68 |
+
plt.switch_backend("agg")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def objective(trial):
|
| 72 |
+
# Define the hyperparameters you want to tune
|
| 73 |
+
batch_size = trial.suggest_categorical("batch_size", [1, 2, 4])
|
| 74 |
+
lr = trial.suggest_loguniform("lr", 1e-5, 1e-1) # Learning rate
|
| 75 |
+
kernel_regularizer = trial.suggest_loguniform("kernel_regularizer", 1e-5, 1e-2)
|
| 76 |
+
|
| 77 |
+
# Call the main function with trial parameters
|
| 78 |
+
return main(
|
| 79 |
+
model="ppm", # Use the 'ppm' model as per your request
|
| 80 |
+
class_mappings=defaultdict(int, {1: 1}),
|
| 81 |
+
batch_size=batch_size,
|
| 82 |
+
lr=lr,
|
| 83 |
+
kernel_regularizer=kernel_regularizer,
|
| 84 |
+
epochs=1, # Run only for 1 epoch
|
| 85 |
+
loss="BCE",
|
| 86 |
+
transformations=(), # Add transformations as needed
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def main(
|
| 91 |
+
model="original",
|
| 92 |
+
class_mappings=None,
|
| 93 |
+
batch_size=8,
|
| 94 |
+
lr=1e-4,
|
| 95 |
+
kernel_regularizer=0.001,
|
| 96 |
+
epochs=1,
|
| 97 |
+
loss="BCE",
|
| 98 |
+
transformations=tuple(),
|
| 99 |
+
):
|
| 100 |
+
|
| 101 |
+
# Parsing arguments for the main function
|
| 102 |
+
FLAGS = argparse.Namespace(
|
| 103 |
+
train_dir=Thundernet_config.train_path,
|
| 104 |
+
val_dir=Thundernet_config.val_path,
|
| 105 |
+
batch_size=batch_size,
|
| 106 |
+
augment=Thundernet_config.augment,
|
| 107 |
+
rand_crop=Thundernet_config.rand_crop,
|
| 108 |
+
loss=loss,
|
| 109 |
+
model_dir=Thundernet_config.model_dir,
|
| 110 |
+
weights=Thundernet_config.weights,
|
| 111 |
+
lr=lr,
|
| 112 |
+
epochs=epochs,
|
| 113 |
+
classes=Thundernet_config.classes,
|
| 114 |
+
resolution=Thundernet_config.resolution,
|
| 115 |
+
kernel_regularizer=kernel_regularizer,
|
| 116 |
+
pretrained=Thundernet_config.pretrained_bool,
|
| 117 |
+
pretrained_weigths=Thundernet_config.pretrained_weigths,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 121 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 122 |
+
|
| 123 |
+
mypath_train = FLAGS.train_dir + "images/"
|
| 124 |
+
label_path_train = FLAGS.train_dir + "labels/"
|
| 125 |
+
list_IDs_train = [f[:-4] for f in listdir(mypath_train) if f[-4:] == ".jpg"]
|
| 126 |
+
mypath_val = FLAGS.val_dir + "images/"
|
| 127 |
+
label_path_val = FLAGS.val_dir + "labels/"
|
| 128 |
+
list_IDs_val = [f[:-4] for f in listdir(mypath_val) if f[-4:] == ".jpg"]
|
| 129 |
+
|
| 130 |
+
# Model Setup
|
| 131 |
+
if model == "original":
|
| 132 |
+
Thundernet = Thundernet_original
|
| 133 |
+
elif model == "attention":
|
| 134 |
+
Thundernet = Thundernet_attention
|
| 135 |
+
elif model == "attention2":
|
| 136 |
+
Thundernet = Thundernet_attention2
|
| 137 |
+
elif model == "ppm":
|
| 138 |
+
Thundernet = Thundernet_ppm
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError(f"Unknown model: {model}")
|
| 141 |
+
|
| 142 |
+
# Model directory setup
|
| 143 |
+
model_dir = FLAGS.model_dir
|
| 144 |
+
if not os.path.exists(model_dir):
|
| 145 |
+
os.makedirs(model_dir)
|
| 146 |
+
|
| 147 |
+
thundernet = Thundernet(
|
| 148 |
+
input_shape=resolution2framesize3cha(FLAGS.resolution),
|
| 149 |
+
n_classes=FLAGS.classes,
|
| 150 |
+
resnet_trainable=True,
|
| 151 |
+
kernel_regularizer=FLAGS.kernel_regularizer,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if FLAGS.pretrained:
|
| 155 |
+
thundernet.model.load_weights(
|
| 156 |
+
FLAGS.pretrained_weigths, by_name=True, skip_mismatch=True
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Optimizer setup
|
| 160 |
+
opt = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr)
|
| 161 |
+
|
| 162 |
+
# Set the loss function
|
| 163 |
+
if FLAGS.loss == "BCE":
|
| 164 |
+
loss = bce_loss()
|
| 165 |
+
elif FLAGS.loss == "BFL":
|
| 166 |
+
loss = focal_loss()
|
| 167 |
+
elif FLAGS.loss == "DCL":
|
| 168 |
+
loss = dice_loss()
|
| 169 |
+
elif FLAGS.loss == "CFL":
|
| 170 |
+
loss = categorical_focal_loss()
|
| 171 |
+
elif FLAGS.loss == "CAT":
|
| 172 |
+
loss = categorical_loss()
|
| 173 |
+
|
| 174 |
+
thundernet.model.compile(loss=loss, optimizer=opt, metrics=[iou])
|
| 175 |
+
|
| 176 |
+
# Data generators setup
|
| 177 |
+
dataset_dir = Path(Thundernet_config.train_path).parent
|
| 178 |
+
training_generator, validation_generator = DataGenerator.create_generators(
|
| 179 |
+
dataset_dir,
|
| 180 |
+
FLAGS.classes,
|
| 181 |
+
training_batch_size=FLAGS.batch_size,
|
| 182 |
+
to_stereo=False,
|
| 183 |
+
transformations=transformations,
|
| 184 |
+
class_mappings=class_mappings,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Train the model
|
| 188 |
+
history = thundernet.model.fit(
|
| 189 |
+
training_generator,
|
| 190 |
+
validation_data=validation_generator,
|
| 191 |
+
epochs=FLAGS.epochs,
|
| 192 |
+
class_weight=None,
|
| 193 |
+
callbacks=[PlotLosses(model_dir)],
|
| 194 |
+
use_multiprocessing=False,
|
| 195 |
+
workers=6,
|
| 196 |
+
)
|
| 197 |
+
# Return validation loss or metric for Optuna optimization
|
| 198 |
+
print(history)
|
| 199 |
+
return np.mean(history.history["iou"])
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# Optuna study setup
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
study = optuna.create_study(
|
| 205 |
+
direction="maximize", storage="sqlite:///db.sqlite3"
|
| 206 |
+
) # Minimize the validation loss
|
| 207 |
+
study.optimize(objective, n_trials=100) # Optimize for 10 trials
|
| 208 |
+
print("Best hyperparameters found: ", study.best_params)
|
| 209 |
+
|
| 210 |
+
import optuna.visualization as vis
|
| 211 |
+
|
| 212 |
+
# Guardar el gráfico de importancia de parámetros
|
| 213 |
+
fig = vis.plot_param_importances(study)
|
| 214 |
+
fig.write_image("param_importance_IoU.png")
|
| 215 |
+
|
| 216 |
+
# Guardar el gráfico del historial de optimización
|
| 217 |
+
fig = vis.plot_optimization_history(study)
|
| 218 |
+
fig.write_image("optimization_history_IoU.png")
|
| 219 |
+
|
| 220 |
+
import pandas as pd
|
| 221 |
+
|
| 222 |
+
# Assuming `study` is the Optuna study object
|
| 223 |
+
df = study.trials_dataframe()
|
| 224 |
+
|
| 225 |
+
df.to_csv("results_optuna_IoU.csv")
|
| 226 |
+
|
| 227 |
+
# Plot Learning Rate vs Loss
|
| 228 |
+
plt.figure(figsize=(8, 6))
|
| 229 |
+
plt.scatter(df["params_lr"], df["value"], color="blue", alpha=0.7)
|
| 230 |
+
plt.title("Learning Rate vs Loss")
|
| 231 |
+
plt.xlabel("Learning Rate")
|
| 232 |
+
plt.ylabel("Loss")
|
| 233 |
+
plt.grid(True)
|
| 234 |
+
plt.savefig("lr_vs_loss_IoU.png")
|
| 235 |
+
plt.close()
|
| 236 |
+
|
| 237 |
+
# Plot Weight Decay vs Loss
|
| 238 |
+
plt.figure(figsize=(8, 6))
|
| 239 |
+
plt.scatter(df["params_batch_size"], df["value"], color="green", alpha=0.7)
|
| 240 |
+
plt.title("Batch size vs Loss")
|
| 241 |
+
plt.xlabel("Batch size")
|
| 242 |
+
plt.ylabel("Loss")
|
| 243 |
+
plt.grid(True)
|
| 244 |
+
plt.savefig("batch_size_vs_loss_IoU.png")
|
| 245 |
+
plt.close()
|
| 246 |
+
|
| 247 |
+
# Plot Loss Weight vs Loss
|
| 248 |
+
plt.figure(figsize=(8, 6))
|
| 249 |
+
plt.scatter(df["params_kernel_regularizer"], df["value"], color="red", alpha=0.7)
|
| 250 |
+
plt.title("Kernel regularizer vs Loss")
|
| 251 |
+
plt.xlabel("Kernel regularizer")
|
| 252 |
+
plt.ylabel("Loss")
|
| 253 |
+
plt.grid(True)
|
| 254 |
+
plt.savefig("kernel_regularizer_vs_loss_IoU.png")
|
| 255 |
+
plt.close()
|
utils.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from tensorflow.keras import backend as K
|
| 3 |
+
import tensorflow.keras as keras
|
| 4 |
+
import math
|
| 5 |
+
from matplotlib import pyplot as plt
|
| 6 |
+
import cv2
|
| 7 |
+
import time
|
| 8 |
+
import scipy
|
| 9 |
+
from os import listdir
|
| 10 |
+
from IPython.display import clear_output
|
| 11 |
+
import segmentation_models as sm
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import images_toolkit as tlk
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def dice_coef(y_true, y_pred, smooth=1):
|
| 17 |
+
y_true_f = K.flatten(y_true)
|
| 18 |
+
y_pred_f = K.flatten(y_pred)
|
| 19 |
+
intersection = K.sum(y_true_f * y_pred_f)
|
| 20 |
+
return (2.0 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def dice_loss(alpha=1):
|
| 24 |
+
def dice_coef_loss(y_true, y_pred):
|
| 25 |
+
return 1 - alpha * dice_coef(y_true, y_pred)
|
| 26 |
+
|
| 27 |
+
return dice_coef_loss
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def categorical_loss():
|
| 31 |
+
def categorical(y_true, y_pred):
|
| 32 |
+
return keras.losses.CategoricalCrossentropy()(y_true, y_pred)
|
| 33 |
+
|
| 34 |
+
return categorical
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def bce_loss():
|
| 38 |
+
def bce(y_true, y_pred):
|
| 39 |
+
return keras.losses.BinaryCrossentropy()(y_true, y_pred)
|
| 40 |
+
|
| 41 |
+
return bce
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def tversky(y_true, y_pred, smooth=1, alpha=0.7):
|
| 45 |
+
y_true_pos = K.flatten(y_true)
|
| 46 |
+
y_pred_pos = K.flatten(y_pred)
|
| 47 |
+
true_pos = K.sum(y_true_pos * y_pred_pos)
|
| 48 |
+
false_neg = K.sum(y_true_pos * (1 - y_pred_pos))
|
| 49 |
+
false_pos = K.sum((1 - y_true_pos) * y_pred_pos)
|
| 50 |
+
return (true_pos + smooth) / (
|
| 51 |
+
true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def tversky_loss(y_true, y_pred):
|
| 56 |
+
return 1 - tversky(y_true, y_pred)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# def focal_tversky_loss(y_true, y_pred, gamma=0.75):
|
| 60 |
+
# tv = tversky(y_true, y_pred)
|
| 61 |
+
# return K.pow((1 - tv), gamma)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def categorical_focal_loss(gamma=2.0, alpha=0.25):
|
| 65 |
+
def cate_focal_loss(y_true, y_pred):
|
| 66 |
+
CAT_FL = sm.losses.categorical_focal_loss
|
| 67 |
+
CAT_FL.gamma = gamma
|
| 68 |
+
CAT_FL.alpha = alpha
|
| 69 |
+
return CAT_FL(y_true, y_pred)
|
| 70 |
+
|
| 71 |
+
return cate_focal_loss
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def focal_loss(gamma=2.0, alpha=0.7):
|
| 75 |
+
def focal_tversky_loss(y_true, y_pred):
|
| 76 |
+
tv = tversky(y_true, y_pred, alpha)
|
| 77 |
+
return K.pow((1 - tv), gamma)
|
| 78 |
+
|
| 79 |
+
return focal_tversky_loss
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def single_iou(y_true, y_pred, label: int):
|
| 83 |
+
"""
|
| 84 |
+
Return the Intersection over Union (IoU) for a given label.
|
| 85 |
+
Args:
|
| 86 |
+
y_true: the expected y values as a one-hot
|
| 87 |
+
y_pred: the predicted y values as a one-hot or softmax output
|
| 88 |
+
label: the label to return the IoU for
|
| 89 |
+
Returns:
|
| 90 |
+
the IoU for the given label
|
| 91 |
+
"""
|
| 92 |
+
# extract the label values using the argmax operator then
|
| 93 |
+
# calculate equality of the predictions and truths to the label
|
| 94 |
+
y_true = K.cast(K.equal(K.argmax(y_true), label), K.floatx())
|
| 95 |
+
y_pred = K.cast(K.equal(K.argmax(y_pred), label), K.floatx())
|
| 96 |
+
|
| 97 |
+
# y_true = K.cast(K.equal(K.argmax(y_true), 1), K.floatx())
|
| 98 |
+
# y_pred = K.cast(K.equal(K.argmax(y_pred), 1), K.floatx())
|
| 99 |
+
# calculate the |intersection| (AND) of the labels
|
| 100 |
+
intersection = K.sum(y_true * y_pred)
|
| 101 |
+
# calculate the |union| (OR) of the labels
|
| 102 |
+
union = K.sum(y_true) + K.sum(y_pred) - intersection
|
| 103 |
+
# avoid divide by zero - if the union is zero, return 1
|
| 104 |
+
# otherwise, return the intersection over union
|
| 105 |
+
a = K.switch(K.equal(union, 0), 1.0, intersection / union)
|
| 106 |
+
return K.switch(K.equal(union, 0), 1.0, intersection / union)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def iou(y_true, y_pred):
|
| 110 |
+
"""
|
| 111 |
+
Return the Intersection over Union (IoU) score.
|
| 112 |
+
Args:
|
| 113 |
+
y_true: the expected y values as a one-hot
|
| 114 |
+
y_pred: the predicted y values as a one-hot or softmax output
|
| 115 |
+
Returns:
|
| 116 |
+
the scalar IoU value (mean over all labels)
|
| 117 |
+
"""
|
| 118 |
+
# get number of labels to calculate IoU for
|
| 119 |
+
num_labels = K.int_shape(y_pred)[-1]
|
| 120 |
+
# initialize a variable to store total IoU in
|
| 121 |
+
total_iou = K.variable(0)
|
| 122 |
+
# iterate over labels to calculate IoU for
|
| 123 |
+
for label in range(num_labels):
|
| 124 |
+
total_iou = total_iou + single_iou(y_true, y_pred, label)
|
| 125 |
+
# divide total IoU by number of labels to get mean IoU
|
| 126 |
+
a = total_iou / num_labels
|
| 127 |
+
return total_iou / num_labels
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def simple_iou(gt, pred):
|
| 131 |
+
"""Computes IoU for a binary classified image. Input shapes: (h, w)"""
|
| 132 |
+
return np.nan_to_num(
|
| 133 |
+
np.sum((gt == 1) & (pred == 1)) / np.sum((gt == 1) | (pred == 1)), 0
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def simple_iou_for_multiple_classes(gt, pred, n_classes):
|
| 138 |
+
"""Computes IoU for a categorically classified image. Input shapes: (h, w)
|
| 139 |
+
If n_classes > 3, then it will also compute the IoU of the union of all classes
|
| 140 |
+
that are >= 3 (i.e., the IoU of objects as one).
|
| 141 |
+
Returns: array of (h, w, n_classes) if n_classes <= 3
|
| 142 |
+
array of (h, w, n_classes+1) if n_classes > 3
|
| 143 |
+
"""
|
| 144 |
+
assert gt.shape == pred.shape and gt.ndim == 2
|
| 145 |
+
assert np.max(gt) < n_classes and np.max(pred) < n_classes
|
| 146 |
+
|
| 147 |
+
f_gt = gt.flatten()
|
| 148 |
+
f_pred = pred.flatten()
|
| 149 |
+
gt_matrix = np.zeros((f_gt.size, n_classes), dtype=int)
|
| 150 |
+
pred_matrix = gt_matrix.copy()
|
| 151 |
+
gt_matrix[np.arange(f_gt.size), f_gt] = 1
|
| 152 |
+
pred_matrix[np.arange(f_gt.size), f_pred] = 1
|
| 153 |
+
intersections = np.sum((gt_matrix == 1) & (pred_matrix == 1), axis=0)
|
| 154 |
+
unions = np.sum((gt_matrix == 1) | (pred_matrix == 1), axis=0)
|
| 155 |
+
ious = intersections / unions
|
| 156 |
+
if n_classes > 3:
|
| 157 |
+
gt_as_one = f_gt >= 2
|
| 158 |
+
pred_as_one = f_pred >= 2
|
| 159 |
+
iou_as_one = np.sum(gt_as_one & pred_as_one) / np.sum(gt_as_one | pred_as_one)
|
| 160 |
+
return np.append(ious, iou_as_one)
|
| 161 |
+
else:
|
| 162 |
+
return ious
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def add_mask(image, mask):
|
| 166 |
+
b_channel, g_channel, r_channel = cv2.split(image)
|
| 167 |
+
alpha_channel = mask * 255
|
| 168 |
+
|
| 169 |
+
alpha_channel = alpha_channel.astype(np.float64)
|
| 170 |
+
|
| 171 |
+
g_channel_out = np.clip(np.add(alpha_channel, g_channel), 0, 255)
|
| 172 |
+
g_channel_out = g_channel_out.astype(np.uint8)
|
| 173 |
+
|
| 174 |
+
alpha_channel = alpha_channel.astype(np.uint8)
|
| 175 |
+
img_BGRA = cv2.merge((b_channel, g_channel_out, r_channel, alpha_channel))
|
| 176 |
+
image_RGBA = cv2.cvtColor(img_BGRA, cv2.COLOR_BGRA2RGBA)
|
| 177 |
+
|
| 178 |
+
return image_RGBA, alpha_channel
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def resolution2framesize3cha(resolution):
|
| 182 |
+
if resolution == "640x240":
|
| 183 |
+
framesize = (240, 640, 3)
|
| 184 |
+
if resolution == "640x480":
|
| 185 |
+
framesize = (480, 640, 3)
|
| 186 |
+
if resolution == "1280x480":
|
| 187 |
+
framesize = (480, 1280, 3)
|
| 188 |
+
if resolution == "1280x720":
|
| 189 |
+
framesize = (720, 1280, 3)
|
| 190 |
+
if resolution == "960x540":
|
| 191 |
+
framesize = (540, 960, 3)
|
| 192 |
+
if resolution == "320x240":
|
| 193 |
+
framesize = (240, 320, 3)
|
| 194 |
+
if resolution == "1024x768":
|
| 195 |
+
framesize = (768, 1024, 3)
|
| 196 |
+
if resolution == "2560x960":
|
| 197 |
+
framesize = (960, 2560, 3)
|
| 198 |
+
if resolution == "2560x720":
|
| 199 |
+
framesize = (720, 2560, 3)
|
| 200 |
+
return framesize
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def resolution2framesize(resolution):
|
| 204 |
+
if resolution == "640x240":
|
| 205 |
+
framesize = (240, 640)
|
| 206 |
+
if resolution == "640x480":
|
| 207 |
+
framesize = (480, 640)
|
| 208 |
+
if resolution == "1280x480":
|
| 209 |
+
framesize = (480, 1280)
|
| 210 |
+
if resolution == "1280x720":
|
| 211 |
+
framesize = (720, 1280)
|
| 212 |
+
if resolution == "960x540":
|
| 213 |
+
framesize = (540, 960)
|
| 214 |
+
if resolution == "320x240":
|
| 215 |
+
framesize = (240, 320)
|
| 216 |
+
if resolution == "1024x768":
|
| 217 |
+
framesize = (768, 1024)
|
| 218 |
+
if resolution == "2560x960":
|
| 219 |
+
framesize = (960, 2560)
|
| 220 |
+
if resolution == "2560x720":
|
| 221 |
+
framesize = (720, 2560)
|
| 222 |
+
return framesize
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def webcam_test(model):
|
| 226 |
+
|
| 227 |
+
cap = cv2.VideoCapture(2)
|
| 228 |
+
cont = True
|
| 229 |
+
|
| 230 |
+
while cont:
|
| 231 |
+
|
| 232 |
+
# Capture a frame from camera
|
| 233 |
+
|
| 234 |
+
ret, frame = cap.read()
|
| 235 |
+
print(frame.shape)
|
| 236 |
+
|
| 237 |
+
if not ret:
|
| 238 |
+
break
|
| 239 |
+
|
| 240 |
+
# x = [frame]
|
| 241 |
+
frame = np.array(frame) / 255.0
|
| 242 |
+
x = np.reshape(frame, (1, 480, 640, 3))
|
| 243 |
+
|
| 244 |
+
# frame = cv2.resize(frame, (720,720))
|
| 245 |
+
# x = np.reshape(frame,(1,720,720,3))
|
| 246 |
+
start_t = time.time()
|
| 247 |
+
pred = model.predict(x)
|
| 248 |
+
duration = time.time() - start_t
|
| 249 |
+
pred = pred[0, :, :, :]
|
| 250 |
+
pred = np.argmax(pred, 2)
|
| 251 |
+
|
| 252 |
+
print(pred.shape)
|
| 253 |
+
overlap = add_mask(frame, pred)
|
| 254 |
+
|
| 255 |
+
print(duration)
|
| 256 |
+
|
| 257 |
+
cv2.imshow("Overlap", overlap)
|
| 258 |
+
if cv2.waitKey(1) & 0xFF == ord("q"):
|
| 259 |
+
break
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def image_test(model, img_dir, img_num, label_dir=None):
|
| 263 |
+
|
| 264 |
+
list_IDs = [f[:-4] for f in listdir(img_dir) if f[-4:] == ".jpg"]
|
| 265 |
+
img_path = img_dir + list_IDs[img_num] + ".jpg"
|
| 266 |
+
test_img = cv2.imread(img_path) / 255.0
|
| 267 |
+
test_img = np.reshape(test_img, (1, test_img.shape[0], test_img.shape[1], 3))
|
| 268 |
+
|
| 269 |
+
pred = model.predict(test_img)
|
| 270 |
+
pred = pred[0, :, :, :]
|
| 271 |
+
predict = np.argmax(pred, 2)
|
| 272 |
+
|
| 273 |
+
overlapping = add_mask(test_img[0, :, :, :], predict)
|
| 274 |
+
|
| 275 |
+
cv2.imshow("Prediction", overlapping)
|
| 276 |
+
cv2.imwrite(
|
| 277 |
+
"./models_repo/frozen_resnet/Trial11/prediction_" + str(img_num) + ".png",
|
| 278 |
+
overlapping,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
if label_dir != None:
|
| 282 |
+
lab = label_dir + list_IDs[img_num] + ".png"
|
| 283 |
+
|
| 284 |
+
lab_img = cv2.imread(lab) * 255
|
| 285 |
+
lab_img = np.array(lab_img)
|
| 286 |
+
cv2.imshow("Label", lab_img)
|
| 287 |
+
cv2.imwrite(
|
| 288 |
+
"./models_repo/frozen_resnet/Trial10/label_" + str(img_num) + ".png",
|
| 289 |
+
lab_img,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
cv2.waitKey(0)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class PlotLosses(keras.callbacks.Callback):
|
| 296 |
+
def __init__(self, out_dir):
|
| 297 |
+
self.out_dir = out_dir
|
| 298 |
+
|
| 299 |
+
def on_train_begin(self, logs={}):
|
| 300 |
+
self.i = 0
|
| 301 |
+
self.x = []
|
| 302 |
+
self.losses = []
|
| 303 |
+
self.val_losses = []
|
| 304 |
+
# self.fig_loss = plt.figure()
|
| 305 |
+
self.train_iou = []
|
| 306 |
+
self.val_iou = []
|
| 307 |
+
# self.fig_iou = plt.figure()
|
| 308 |
+
|
| 309 |
+
self.live_loss = []
|
| 310 |
+
self.fig_livel = plt.figure()
|
| 311 |
+
|
| 312 |
+
self.live_iou = []
|
| 313 |
+
self.fig_livei = plt.figure()
|
| 314 |
+
|
| 315 |
+
self.logs = []
|
| 316 |
+
self.live_logs = []
|
| 317 |
+
self.b = 0
|
| 318 |
+
self.x_b = []
|
| 319 |
+
self.loss = 0
|
| 320 |
+
self.iou = 0
|
| 321 |
+
self.num = 0
|
| 322 |
+
|
| 323 |
+
def on_batch_end(self, batch, logs={}):
|
| 324 |
+
self.iou += logs.get("iou")
|
| 325 |
+
self.loss += logs.get("loss")
|
| 326 |
+
self.num += 1
|
| 327 |
+
if self.b % 50 == 0:
|
| 328 |
+
self.x_b.append(self.num)
|
| 329 |
+
self.live_loss.append(self.loss / float(self.b + 1))
|
| 330 |
+
self.live_iou.append(self.iou / float(self.b + 1))
|
| 331 |
+
clear_output(wait=True)
|
| 332 |
+
plt.ioff()
|
| 333 |
+
fig1 = plt.figure(1)
|
| 334 |
+
plt.ioff()
|
| 335 |
+
plt.plot(self.x_b, self.live_loss, label="Training loss")
|
| 336 |
+
plt.title("Training loss")
|
| 337 |
+
plt.xlabel("Iteration")
|
| 338 |
+
plt.ylabel("Loss")
|
| 339 |
+
plt.savefig(self.out_dir + "training_loss.png")
|
| 340 |
+
plt.close(fig1)
|
| 341 |
+
clear_output(wait=True)
|
| 342 |
+
fig2 = plt.figure(2)
|
| 343 |
+
plt.plot(self.x_b, self.live_iou, label="Training iou")
|
| 344 |
+
plt.title("Training IoU")
|
| 345 |
+
plt.xlabel("Iteration")
|
| 346 |
+
plt.ylabel("IoU")
|
| 347 |
+
plt.savefig(self.out_dir + "training_iou.png")
|
| 348 |
+
plt.close(fig2)
|
| 349 |
+
self.b += 1
|
| 350 |
+
|
| 351 |
+
def on_epoch_end(self, epoch, logs={}):
|
| 352 |
+
self.loss = 0
|
| 353 |
+
self.iou = 0
|
| 354 |
+
self.b = 0
|
| 355 |
+
self.logs.append(logs)
|
| 356 |
+
self.x.append(self.i)
|
| 357 |
+
self.losses.append(logs.get("loss"))
|
| 358 |
+
self.val_losses.append(logs.get("val_loss"))
|
| 359 |
+
self.i += 1
|
| 360 |
+
self.train_iou.append(logs.get("iou"))
|
| 361 |
+
self.val_iou.append(logs.get("val_iou"))
|
| 362 |
+
plt.ioff()
|
| 363 |
+
fig3 = plt.figure(3)
|
| 364 |
+
clear_output(wait=True)
|
| 365 |
+
plt.plot(self.x, self.losses, label="loss")
|
| 366 |
+
plt.plot(self.x, self.val_losses, label="val_loss")
|
| 367 |
+
plt.title("Loss curve")
|
| 368 |
+
plt.xlabel("Epoch")
|
| 369 |
+
plt.ylabel("Loss")
|
| 370 |
+
plt.legend()
|
| 371 |
+
plt.savefig(self.out_dir + "loss_curve.png")
|
| 372 |
+
plt.close(fig3)
|
| 373 |
+
fig4 = plt.figure(4)
|
| 374 |
+
clear_output(wait=True)
|
| 375 |
+
plt.plot(self.x, self.train_iou, label="train_iou")
|
| 376 |
+
plt.plot(self.x, self.val_iou, label="val_iou")
|
| 377 |
+
plt.title("IoU curve")
|
| 378 |
+
plt.xlabel("Epoch")
|
| 379 |
+
plt.ylabel("IoU")
|
| 380 |
+
plt.legend()
|
| 381 |
+
plt.savefig(self.out_dir + "mean_iou_curve.png")
|
| 382 |
+
plt.close(fig4)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def step_decay(epoch):
|
| 386 |
+
initial_lrate = 0.1
|
| 387 |
+
drop = 0.5
|
| 388 |
+
epochs_drop = 10.0
|
| 389 |
+
lrate = initial_lrate * math.pow(drop, math.floor((epoch) / epochs_drop))
|
| 390 |
+
return lrate
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
label_colours = [
|
| 394 |
+
(0, 0, 0), # 0=background
|
| 395 |
+
# 1=wall, 2=floor, 3=cabinet, 4=bed, 5=chair
|
| 396 |
+
(128, 0, 0),
|
| 397 |
+
(0, 128, 0),
|
| 398 |
+
(128, 128, 0),
|
| 399 |
+
(0, 0, 128),
|
| 400 |
+
(128, 0, 128),
|
| 401 |
+
# 6=sofa, 7=table, 8=door, 9=window, 10=bookshelf
|
| 402 |
+
(0, 128, 128),
|
| 403 |
+
(128, 128, 128),
|
| 404 |
+
(255, 200, 180),
|
| 405 |
+
(192, 0, 0),
|
| 406 |
+
(192, 192, 192),
|
| 407 |
+
# 11=picture, 12=counter, 13=blinds, 14=desk, 15=shelves
|
| 408 |
+
(192, 128, 0),
|
| 409 |
+
(64, 0, 128),
|
| 410 |
+
(192, 0, 128),
|
| 411 |
+
(255, 128, 0),
|
| 412 |
+
(192, 128, 128),
|
| 413 |
+
# 16=curtain, 17=dresser, 18=pillow, 19=mirror, 20=floor_mat, 21=clothes
|
| 414 |
+
(0, 64, 0),
|
| 415 |
+
(128, 64, 0),
|
| 416 |
+
(0, 192, 0),
|
| 417 |
+
(153, 153, 255),
|
| 418 |
+
(0, 64, 128),
|
| 419 |
+
(255, 255, 0),
|
| 420 |
+
# 22=ceiling, 23=books, 24=fridge, 25=tv, 26=paper, 27=towel
|
| 421 |
+
(250, 250, 250),
|
| 422 |
+
(0, 192, 128),
|
| 423 |
+
(250, 102, 250),
|
| 424 |
+
(102, 250, 250),
|
| 425 |
+
(44, 166, 44),
|
| 426 |
+
(44, 44, 166),
|
| 427 |
+
# 28=shower_curtain, 29=box, 30=whiteboard, 31=person, 32=night_stand, 33=toilet
|
| 428 |
+
(166, 44, 44),
|
| 429 |
+
(0, 250, 0),
|
| 430 |
+
(250, 0, 0),
|
| 431 |
+
(0, 0, 250),
|
| 432 |
+
(206, 219, 156),
|
| 433 |
+
(219, 156, 206),
|
| 434 |
+
# 34=sink #35=lamp #36=bathtub #37=bag #38=Unknown
|
| 435 |
+
(156, 206, 219),
|
| 436 |
+
(23, 190, 207),
|
| 437 |
+
(207, 23, 190),
|
| 438 |
+
(190, 207, 23),
|
| 439 |
+
(153, 0, 76),
|
| 440 |
+
]
|
| 441 |
+
# #
|
| 442 |
+
# label_colours = [(0, 0, 0), # 0=background
|
| 443 |
+
# #1=hand,
|
| 444 |
+
# (128, 0, 0)]
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def decode_labels(mask, num_classes=38):
|
| 448 |
+
"""Decode batch of segmentation masks.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
mask: result of inference after taking argmax.
|
| 452 |
+
num_images: number of images to decode from the batch.
|
| 453 |
+
num_classes: number of classes to predict (including background).
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
A batch with num_images RGB images of the same size as the input.
|
| 457 |
+
"""
|
| 458 |
+
n, h, w, c = mask.shape
|
| 459 |
+
outputs = np.zeros((h, w, 3), dtype=np.uint8)
|
| 460 |
+
binary = np.zeros((h, w), dtype=np.uint8)
|
| 461 |
+
R = np.zeros((h, w), dtype=np.uint8)
|
| 462 |
+
G = np.zeros((h, w), dtype=np.uint8)
|
| 463 |
+
B = np.zeros((h, w), dtype=np.uint8)
|
| 464 |
+
|
| 465 |
+
for i in range(0, num_classes):
|
| 466 |
+
|
| 467 |
+
# print("i is",i)
|
| 468 |
+
# tlk.show_image(mask[0,:,:,i])
|
| 469 |
+
binary[mask[0, :, :, i] >= 0.5] = 1
|
| 470 |
+
binary[mask[0, :, :, i] < 0.5] = 0
|
| 471 |
+
# tlk.show_image(binary)
|
| 472 |
+
color_R = label_colours[i][0] * np.ones([h, w])
|
| 473 |
+
color_G = label_colours[i][1] * np.ones([h, w])
|
| 474 |
+
color_B = label_colours[i][2] * np.ones([h, w])
|
| 475 |
+
# print("colour_R.shape",color_R.shape)
|
| 476 |
+
|
| 477 |
+
R_aux = np.multiply(binary, color_R)
|
| 478 |
+
R_aux_int = R_aux.astype(np.uint8)
|
| 479 |
+
G_aux = np.multiply(binary, color_G)
|
| 480 |
+
G_aux_int = G_aux.astype(np.uint8)
|
| 481 |
+
B_aux = np.multiply(binary, color_B)
|
| 482 |
+
B_aux_int = B_aux.astype(np.uint8)
|
| 483 |
+
|
| 484 |
+
R += R_aux_int
|
| 485 |
+
G += G_aux_int
|
| 486 |
+
B += B_aux_int
|
| 487 |
+
|
| 488 |
+
R_ = R.reshape(*R.shape, 1)
|
| 489 |
+
G_ = G.reshape(*G.shape, 1)
|
| 490 |
+
B_ = B.reshape(*B.shape, 1)
|
| 491 |
+
|
| 492 |
+
outputs = np.concatenate((R_, G_, B_), axis=2)
|
| 493 |
+
# outputs[:, :,:] = (np.multiply(binary, label_colours[i]))
|
| 494 |
+
|
| 495 |
+
# tlk.show_image(outputs)
|
| 496 |
+
|
| 497 |
+
# img = Image.new('RGB', (len(mask[0, 0]), len(mask[0])))
|
| 498 |
+
# pixels = img.load()
|
| 499 |
+
# for j_, j in enumerate(mask[0, :, :, 0]):
|
| 500 |
+
# for k_, k in enumerate(j):
|
| 501 |
+
# if k < num_classes:
|
| 502 |
+
# pixels[k_, j_] = label_colours[k]
|
| 503 |
+
#
|
| 504 |
+
# outputs = np.array(img)
|
| 505 |
+
return outputs
|