Spaces:
Runtime error
Runtime error
gabgrenier
commited on
Commit
·
060b41f
1
Parent(s):
ead9f2a
added harmonizer
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- csai.py +34 -1
- harmonizer/.gitignore +101 -0
- harmonizer/README.md +101 -0
- harmonizer/pretrained/README.md +3 -0
- harmonizer/src/__init__.py +0 -0
- harmonizer/src/model/__init__.py +2 -0
- harmonizer/src/model/backbone/__init__.py +1 -0
- harmonizer/src/model/backbone/efficientnet/__init__.py +116 -0
- harmonizer/src/model/backbone/efficientnet/model.py +395 -0
- harmonizer/src/model/backbone/efficientnet/utils.py +586 -0
- harmonizer/src/model/enhancer.py +40 -0
- harmonizer/src/model/filter.py +231 -0
- harmonizer/src/model/harmonizer.py +44 -0
- harmonizer/src/model/module.py +80 -0
- harmonizer/src/requirements.txt +6 -0
- harmonizer/src/train/README.md +14 -0
- harmonizer/src/train/harmonizer/__init__.py +0 -0
- harmonizer/src/train/harmonizer/criterion.py +47 -0
- harmonizer/src/train/harmonizer/data.py +198 -0
- harmonizer/src/train/harmonizer/func.py +41 -0
- harmonizer/src/train/harmonizer/model.py +41 -0
- harmonizer/src/train/harmonizer/module/__init__.py +1 -0
- harmonizer/src/train/harmonizer/module/backbone/__init__.py +1 -0
- harmonizer/src/train/harmonizer/module/backbone/efficientnet/__init__.py +116 -0
- harmonizer/src/train/harmonizer/module/backbone/efficientnet/model.py +395 -0
- harmonizer/src/train/harmonizer/module/backbone/efficientnet/utils.py +586 -0
- harmonizer/src/train/harmonizer/module/filter.py +231 -0
- harmonizer/src/train/harmonizer/module/harmonizer.py +83 -0
- harmonizer/src/train/harmonizer/module/module.py +80 -0
- harmonizer/src/train/harmonizer/proxy.py +20 -0
- harmonizer/src/train/harmonizer/script/train.py +85 -0
- harmonizer/src/train/harmonizer/trainer.py +322 -0
- harmonizer/src/train/torchtask/__init__.py +9 -0
- harmonizer/src/train/torchtask/nn/__init__.py +3 -0
- harmonizer/src/train/torchtask/nn/data.py +190 -0
- harmonizer/src/train/torchtask/nn/func.py +99 -0
- harmonizer/src/train/torchtask/nn/lrer.py +179 -0
- harmonizer/src/train/torchtask/nn/module/__init__.py +3 -0
- harmonizer/src/train/torchtask/nn/module/gaussian_blur.py +64 -0
- harmonizer/src/train/torchtask/nn/module/gaussian_noise.py +40 -0
- harmonizer/src/train/torchtask/nn/module/third_party/__init__.py +1 -0
- harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/__init__.py +12 -0
- harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/batchnorm.py +282 -0
- harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/comm.py +129 -0
- harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/replicate.py +88 -0
- harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/unittest.py +29 -0
- harmonizer/src/train/torchtask/nn/optimizer.py +247 -0
- harmonizer/src/train/torchtask/requirements.txt +5 -0
- harmonizer/src/train/torchtask/runner.py +33 -0
- harmonizer/src/train/torchtask/template/__init__.py +16 -0
csai.py
CHANGED
|
@@ -67,9 +67,42 @@ def process(fg, bg):
|
|
| 67 |
|
| 68 |
# Use the final_mask_img when pasting
|
| 69 |
bg.paste(fg, (0, 0), final_mask_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
return bg
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
def rvm(fg):
|
| 75 |
model = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50"
|
|
|
|
| 67 |
|
| 68 |
# Use the final_mask_img when pasting
|
| 69 |
bg.paste(fg, (0, 0), final_mask_img)
|
| 70 |
+
|
| 71 |
+
# now run the harmonizer to make sure the foreground matches the background
|
| 72 |
+
harmonized = harmonizer(bg, final_mask_img)
|
| 73 |
+
|
| 74 |
+
if harmonized != None:
|
| 75 |
+
return harmonized
|
| 76 |
+
else:
|
| 77 |
+
return bg
|
| 78 |
|
|
|
|
| 79 |
|
| 80 |
+
def harmonizer(comp, mask):
|
| 81 |
+
try:
|
| 82 |
+
import torchvision.transforms.functional as tf
|
| 83 |
+
from harmonizer.src import model
|
| 84 |
+
harmonizer = model.Harmonizer()
|
| 85 |
+
harmonizer = harmonizer.cuda()
|
| 86 |
+
harmonizer.load_state_dict(torch.load("harmonizer/pretrained/harmonizer.pth"), strict=True)
|
| 87 |
+
harmonizer.eval()
|
| 88 |
+
|
| 89 |
+
comp = tf.to_tensor(comp)[None, ...]
|
| 90 |
+
mask = tf.to_tensor(mask)[None, ...]
|
| 91 |
+
comp = comp.cuda()
|
| 92 |
+
mask = mask.cuda()
|
| 93 |
+
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
arguments = harmonizer.predict_arguments(comp, mask)
|
| 96 |
+
harmonized = harmonizer.restore_image(comp, mask, arguments)[-1]
|
| 97 |
+
|
| 98 |
+
harmonized = np.transpose(harmonized[0].cpu().numpy(), (1, 2, 0)) * 255
|
| 99 |
+
harmonized = Image.fromarray(harmonized.astype(np.uint8))
|
| 100 |
+
|
| 101 |
+
return harmonized
|
| 102 |
+
|
| 103 |
+
except:
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
|
| 107 |
def rvm(fg):
|
| 108 |
model = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50"
|
harmonizer/.gitignore
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Temporary directories and files
|
| 2 |
+
*.ckpt
|
| 3 |
+
*.pth
|
| 4 |
+
*.zip
|
| 5 |
+
*.tar
|
| 6 |
+
result/
|
| 7 |
+
dataset/
|
| 8 |
+
|
| 9 |
+
# Byte-compiled / optimized / DLL files
|
| 10 |
+
__pycache__/
|
| 11 |
+
*.py[cod]
|
| 12 |
+
*$py.class
|
| 13 |
+
|
| 14 |
+
# C extensions
|
| 15 |
+
*.so
|
| 16 |
+
|
| 17 |
+
# Distribution / packaging
|
| 18 |
+
.Python
|
| 19 |
+
env/
|
| 20 |
+
build/
|
| 21 |
+
develop-eggs/
|
| 22 |
+
dist/
|
| 23 |
+
downloads/
|
| 24 |
+
eggs/
|
| 25 |
+
.eggs/
|
| 26 |
+
lib/
|
| 27 |
+
lib64/
|
| 28 |
+
parts/
|
| 29 |
+
sdist/
|
| 30 |
+
var/
|
| 31 |
+
*.egg-info/
|
| 32 |
+
.installed.cfg
|
| 33 |
+
*.egg
|
| 34 |
+
|
| 35 |
+
# PyInstaller
|
| 36 |
+
# Usually these files are written by a python script from a template
|
| 37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 38 |
+
*.manifest
|
| 39 |
+
*.spec
|
| 40 |
+
|
| 41 |
+
# Installer logs
|
| 42 |
+
pip-log.txt
|
| 43 |
+
pip-delete-this-directory.txt
|
| 44 |
+
|
| 45 |
+
# Unit test / coverage reports
|
| 46 |
+
htmlcov/
|
| 47 |
+
.tox/
|
| 48 |
+
.coverage
|
| 49 |
+
.coverage.*
|
| 50 |
+
.cache
|
| 51 |
+
nosetests.xml
|
| 52 |
+
coverage.xml
|
| 53 |
+
*,cover
|
| 54 |
+
.hypothesis/
|
| 55 |
+
|
| 56 |
+
# Translations
|
| 57 |
+
*.mo
|
| 58 |
+
*.pot
|
| 59 |
+
|
| 60 |
+
# Django stuff:
|
| 61 |
+
*.log
|
| 62 |
+
local_settings.py
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
target/
|
| 76 |
+
|
| 77 |
+
# IPython Notebook
|
| 78 |
+
.ipynb_checkpoints
|
| 79 |
+
|
| 80 |
+
# pyenv
|
| 81 |
+
.python-version
|
| 82 |
+
|
| 83 |
+
# celery beat schedule file
|
| 84 |
+
celerybeat-schedule
|
| 85 |
+
|
| 86 |
+
# dotenv
|
| 87 |
+
.env
|
| 88 |
+
|
| 89 |
+
# virtualenv
|
| 90 |
+
venv/
|
| 91 |
+
ENV/
|
| 92 |
+
|
| 93 |
+
# Spyder project settings
|
| 94 |
+
.spyderproject
|
| 95 |
+
|
| 96 |
+
# Rope project settings
|
| 97 |
+
.ropeproject
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Project files
|
| 101 |
+
.vscode
|
harmonizer/README.md
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h2 align="center">Harmonizer: High-Resolution Image/Video Harmonization</h2>
|
| 2 |
+
|
| 3 |
+
<p align="center"><i>Harmonizer: Learning to Perform White-Box
|
| 4 |
+
Image and Video Harmonization (ECCV 2022)</i></p>
|
| 5 |
+
|
| 6 |
+
<p align="center">
|
| 7 |
+
<a href="https://arxiv.org/abs/2207.01322">Paper</a> |
|
| 8 |
+
<a href="#demo">Demo</a> |
|
| 9 |
+
<a href="#code">Code</a> |
|
| 10 |
+
<a href="#license">License</a> |
|
| 11 |
+
<a href="#citation">Citation</a> |
|
| 12 |
+
<a href="#contact">Contact</a>
|
| 13 |
+
</p>
|
| 14 |
+
|
| 15 |
+
<p align="center">
|
| 16 |
+
<a href="https://youtu.be/kKKK3D1f_Mc">Harmonizer Result Video</a> |
|
| 17 |
+
<a href="https://youtu.be/NS8f-eJY9cc">Enhancer Result Video</a>
|
| 18 |
+
</p>
|
| 19 |
+
|
| 20 |
+
<div align="center"><b>Harmonizer</b> is a <b>lightweight (20MB)</b> model enabled image/video harmonization up to <b>8K</b> resolution.</div>
|
| 21 |
+
<div align="center">With GPUs, Harmonizer has <b>real-time</b> performance at <b>Full HD</b> resolution.</div>
|
| 22 |
+
<img src="doc/gif/harmonizer.gif" width="100%">
|
| 23 |
+
|
| 24 |
+
<div align="center"><b>Enhancer</b> is a model applied the Harmonizer architecture for image/video color enhancement.</div>
|
| 25 |
+
<img src="doc/gif/enhancer.gif" width="100%">
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## Demo
|
| 30 |
+
|
| 31 |
+
In our demos, the <b>Harmonizer</b> model is trained on the *iHarmony4* dataset, while the <b>Enhancer</b> model is trained on the *FiveK + HDRPlus* datasets.
|
| 32 |
+
|
| 33 |
+
### Online Demo
|
| 34 |
+
Try our online demos for fun without code!
|
| 35 |
+
|
| 36 |
+
| Image Harmonization | Image Enhancement |
|
| 37 |
+
| :---: | :---: |
|
| 38 |
+
| [Online Demo](https://zhke.io/?harmonizer_demo) | [Online Demo](https://zhke.io/?enhancer_demo) |
|
| 39 |
+
|
| 40 |
+
<img src="doc/gif/online_demo.gif" width="100%">
|
| 41 |
+
|
| 42 |
+
### Offline Demo
|
| 43 |
+
We provide offline demos for image/video harmonization/enhancement.
|
| 44 |
+
|
| 45 |
+
| Image Harmonization | Video Harmonization | Image Enhancement | Video Enhancement |
|
| 46 |
+
| :---: | :---: | :---: | :---: |
|
| 47 |
+
| [Offline Demo](demo/image_harmonization) | [Offline Demo](demo/video_harmonization) | [Offline Demo](demo/image_enhancement) | [Offline Demo](demo/video_enhancement) |
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
## Code
|
| 51 |
+
|
| 52 |
+
### Training
|
| 53 |
+
|
| 54 |
+
The training code is released in the folder `./src/train`.
|
| 55 |
+
Refer to [README.md](src/train/README.md) for more details about training.
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
### Validation
|
| 59 |
+
|
| 60 |
+
We provide PyTorch validation code to reproduce the iHarmony4 results reported in our [paper](https://arxiv.org/abs/2207.01322), please:
|
| 61 |
+
|
| 62 |
+
1. Download the Harmonizer model pre-trained on the iHarmony4 dataset from [this link](https://drive.google.com/file/d/15XGPQHBppaYGnhsP9l7iOGZudXNw1WbA/view?usp=sharing) and put it in the folder `./pretrained`.
|
| 63 |
+
|
| 64 |
+
2. Download the four subsets of iHarmony4 from [this repository](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4) and put them in the folder `./dataset/harmonization/iHarmony4`.
|
| 65 |
+
|
| 66 |
+
3. Install python requirements. In the root path of this repository, run:
|
| 67 |
+
```
|
| 68 |
+
pip install -r src/requirements.txt
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
5. For validation, in the root path of this repository, run:
|
| 72 |
+
```
|
| 73 |
+
python -m src.val_harmonizer \
|
| 74 |
+
--pretrained ./pretrained/harmonizer \
|
| 75 |
+
--datasets HCOCO HFlickr HAdobe5k Hday2night \
|
| 76 |
+
--metric-size 256
|
| 77 |
+
```
|
| 78 |
+
- You can change `--datasets` to validate a specific subset.
|
| 79 |
+
- You can remove `--metric-size` to calculate the metrics without resizing the outputs.
|
| 80 |
+
- The metric values may slightly different from our [paper](https://arxiv.org/abs/2207.01322) due to the dependency versions.
|
| 81 |
+
|
| 82 |
+
## License
|
| 83 |
+
This project is released under the [Creative Commons Attribution NonCommercial ShareAlike 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) license.
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
## Citation
|
| 87 |
+
If this work helps your research, please consider to cite:
|
| 88 |
+
|
| 89 |
+
```bibtex
|
| 90 |
+
@InProceedings{Harmonizer,
|
| 91 |
+
author = {Zhanghan Ke and Chunyi Sun and Lei Zhu and Ke Xu and Rynson W.H. Lau},
|
| 92 |
+
title = {Harmonizer: Learning to Perform White-Box Image and Video Harmonization},
|
| 93 |
+
booktitle = {European Conference on Computer Vision (ECCV)},
|
| 94 |
+
year = {2022},
|
| 95 |
+
}
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
## Contact
|
| 100 |
+
This repository is maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)).
|
| 101 |
+
For questions, please contact `kezhanghan@outlook.com`.
|
harmonizer/pretrained/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Harmonizer - Pre-Trained Models
|
| 2 |
+
This folder is used to save the official pre-trained models of Harmonizer/Enhancer.
|
| 3 |
+
You can download them from [this link](https://drive.google.com/drive/folders/1k7TCcwETeF5SYoD2Ic211UQyV1lwIBHY?usp=sharing).
|
harmonizer/src/__init__.py
ADDED
|
File without changes
|
harmonizer/src/model/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .harmonizer import Harmonizer
|
| 2 |
+
from .enhancer import Enhancer
|
harmonizer/src/model/backbone/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .efficientnet import EfficientBackbone, EfficientBackboneCommon
|
harmonizer/src/model/backbone/efficientnet/__init__.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This EfficientNet implementation comes from:
|
| 3 |
+
Author: lukemelas (github username)
|
| 4 |
+
Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from .model import EfficientNet
|
| 11 |
+
from .utils import round_filters, get_same_padding_conv2d
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# for EfficientNet
|
| 15 |
+
class EfficientBackbone(EfficientNet):
|
| 16 |
+
def __init__(self, blocks_args=None, global_params=None):
|
| 17 |
+
super(EfficientBackbone, self).__init__(blocks_args, global_params)
|
| 18 |
+
|
| 19 |
+
self.enc_channels = [16, 24, 40, 112, 1280]
|
| 20 |
+
|
| 21 |
+
# ------------------------------------------------------------
|
| 22 |
+
# delete the useless layers
|
| 23 |
+
# ------------------------------------------------------------
|
| 24 |
+
del self._conv_stem
|
| 25 |
+
del self._bn0
|
| 26 |
+
# ------------------------------------------------------------
|
| 27 |
+
|
| 28 |
+
# ------------------------------------------------------------
|
| 29 |
+
# parameters for the input layers
|
| 30 |
+
# ------------------------------------------------------------
|
| 31 |
+
bn_mom = 1 - self._global_params.batch_norm_momentum
|
| 32 |
+
bn_eps = self._global_params.batch_norm_epsilon
|
| 33 |
+
|
| 34 |
+
in_channels = 4
|
| 35 |
+
out_channels = round_filters(32, self._global_params)
|
| 36 |
+
out_channels = int(out_channels / 2)
|
| 37 |
+
# ------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
# ------------------------------------------------------------
|
| 40 |
+
# define the input layers
|
| 41 |
+
# ------------------------------------------------------------
|
| 42 |
+
image_size = global_params.image_size
|
| 43 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 44 |
+
self._conv_fg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
| 45 |
+
self._bn_fg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 46 |
+
|
| 47 |
+
self._conv_bg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
| 48 |
+
self._bn_bg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 49 |
+
# ------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def forward(self, xfg, xbg):
|
| 52 |
+
xfg = self._swish(self._bn_fg(self._conv_fg(xfg)))
|
| 53 |
+
xbg = self._swish(self._bn_bg(self._conv_bg(xbg)))
|
| 54 |
+
|
| 55 |
+
x = torch.cat((xfg, xbg), dim=1)
|
| 56 |
+
|
| 57 |
+
block_outputs = []
|
| 58 |
+
for idx, block in enumerate(self._blocks):
|
| 59 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
| 60 |
+
drop_connect_rate *= float(idx) / len(self._blocks)
|
| 61 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
| 62 |
+
block_outputs.append(x)
|
| 63 |
+
|
| 64 |
+
# Head
|
| 65 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
| 66 |
+
|
| 67 |
+
return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# for EfficientNet
|
| 71 |
+
class EfficientBackboneCommon(EfficientNet):
|
| 72 |
+
def __init__(self, blocks_args=None, global_params=None):
|
| 73 |
+
super(EfficientBackboneCommon, self).__init__(blocks_args, global_params)
|
| 74 |
+
|
| 75 |
+
self.enc_channels = [16, 24, 40, 112, 1280]
|
| 76 |
+
|
| 77 |
+
# ------------------------------------------------------------
|
| 78 |
+
# delete the useless layers
|
| 79 |
+
# ------------------------------------------------------------
|
| 80 |
+
del self._conv_stem
|
| 81 |
+
del self._bn0
|
| 82 |
+
# ------------------------------------------------------------
|
| 83 |
+
|
| 84 |
+
# ------------------------------------------------------------
|
| 85 |
+
# parameters for the input layers
|
| 86 |
+
# ------------------------------------------------------------
|
| 87 |
+
bn_mom = 1 - self._global_params.batch_norm_momentum
|
| 88 |
+
bn_eps = self._global_params.batch_norm_epsilon
|
| 89 |
+
|
| 90 |
+
in_channels = 3
|
| 91 |
+
out_channels = round_filters(32, self._global_params)
|
| 92 |
+
# ------------------------------------------------------------
|
| 93 |
+
|
| 94 |
+
# ------------------------------------------------------------
|
| 95 |
+
# define the input layers
|
| 96 |
+
# ------------------------------------------------------------
|
| 97 |
+
image_size = global_params.image_size
|
| 98 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 99 |
+
self._conv = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
| 100 |
+
self._bn = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 101 |
+
# ------------------------------------------------------------
|
| 102 |
+
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
x = self._swish(self._bn(self._conv(x)))
|
| 105 |
+
|
| 106 |
+
block_outputs = []
|
| 107 |
+
for idx, block in enumerate(self._blocks):
|
| 108 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
| 109 |
+
drop_connect_rate *= float(idx) / len(self._blocks)
|
| 110 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
| 111 |
+
block_outputs.append(x)
|
| 112 |
+
|
| 113 |
+
# Head
|
| 114 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
| 115 |
+
|
| 116 |
+
return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x
|
harmonizer/src/model/backbone/efficientnet/model.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""model.py - Model and module class for EfficientNet.
|
| 2 |
+
They are built to mirror those in the official TensorFlow implementation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# Author: lukemelas (github username)
|
| 6 |
+
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
|
| 7 |
+
# With adjustments and added comments by workingcoder (github username).
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from .utils import (
|
| 13 |
+
round_filters,
|
| 14 |
+
round_repeats,
|
| 15 |
+
drop_connect,
|
| 16 |
+
get_same_padding_conv2d,
|
| 17 |
+
get_model_params,
|
| 18 |
+
efficientnet_params,
|
| 19 |
+
load_pretrained_weights,
|
| 20 |
+
Swish,
|
| 21 |
+
MemoryEfficientSwish,
|
| 22 |
+
calculate_output_image_size
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
VALID_MODELS = (
|
| 27 |
+
'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
|
| 28 |
+
'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
|
| 29 |
+
'efficientnet-b8',
|
| 30 |
+
|
| 31 |
+
# Support the construction of 'efficientnet-l2' without pretrained weights
|
| 32 |
+
'efficientnet-l2'
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MBConvBlock(nn.Module):
|
| 37 |
+
"""Mobile Inverted Residual Bottleneck Block.
|
| 38 |
+
Args:
|
| 39 |
+
block_args (namedtuple): BlockArgs, defined in utils.py.
|
| 40 |
+
global_params (namedtuple): GlobalParam, defined in utils.py.
|
| 41 |
+
image_size (tuple or list): [image_height, image_width].
|
| 42 |
+
References:
|
| 43 |
+
[1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
|
| 44 |
+
[2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
|
| 45 |
+
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, block_args, global_params, image_size=None):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self._block_args = block_args
|
| 51 |
+
self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
|
| 52 |
+
self._bn_eps = global_params.batch_norm_epsilon
|
| 53 |
+
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
|
| 54 |
+
self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
|
| 55 |
+
|
| 56 |
+
# Expansion phase (Inverted Bottleneck)
|
| 57 |
+
inp = self._block_args.input_filters # number of input channels
|
| 58 |
+
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
|
| 59 |
+
if self._block_args.expand_ratio != 1:
|
| 60 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 61 |
+
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
|
| 62 |
+
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 63 |
+
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
|
| 64 |
+
|
| 65 |
+
# Depthwise convolution phase
|
| 66 |
+
k = self._block_args.kernel_size
|
| 67 |
+
s = self._block_args.stride
|
| 68 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 69 |
+
self._depthwise_conv = Conv2d(
|
| 70 |
+
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
|
| 71 |
+
kernel_size=k, stride=s, bias=False)
|
| 72 |
+
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 73 |
+
image_size = calculate_output_image_size(image_size, s)
|
| 74 |
+
|
| 75 |
+
# Squeeze and Excitation layer, if desired
|
| 76 |
+
if self.has_se:
|
| 77 |
+
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
|
| 78 |
+
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
|
| 79 |
+
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
|
| 80 |
+
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
|
| 81 |
+
|
| 82 |
+
# Pointwise convolution phase
|
| 83 |
+
final_oup = self._block_args.output_filters
|
| 84 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 85 |
+
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
|
| 86 |
+
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 87 |
+
self._swish = MemoryEfficientSwish()
|
| 88 |
+
|
| 89 |
+
def forward(self, inputs, drop_connect_rate=None):
|
| 90 |
+
"""MBConvBlock's forward function.
|
| 91 |
+
Args:
|
| 92 |
+
inputs (tensor): Input tensor.
|
| 93 |
+
drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
|
| 94 |
+
Returns:
|
| 95 |
+
Output of this block after processing.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
# Expansion and Depthwise Convolution
|
| 99 |
+
x = inputs
|
| 100 |
+
if self._block_args.expand_ratio != 1:
|
| 101 |
+
x = self._expand_conv(inputs)
|
| 102 |
+
x = self._bn0(x)
|
| 103 |
+
x = self._swish(x)
|
| 104 |
+
|
| 105 |
+
x = self._depthwise_conv(x)
|
| 106 |
+
x = self._bn1(x)
|
| 107 |
+
x = self._swish(x)
|
| 108 |
+
|
| 109 |
+
# Squeeze and Excitation
|
| 110 |
+
if self.has_se:
|
| 111 |
+
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
| 112 |
+
x_squeezed = self._se_reduce(x_squeezed)
|
| 113 |
+
x_squeezed = self._swish(x_squeezed)
|
| 114 |
+
x_squeezed = self._se_expand(x_squeezed)
|
| 115 |
+
x = torch.sigmoid(x_squeezed) * x
|
| 116 |
+
|
| 117 |
+
# Pointwise Convolution
|
| 118 |
+
x = self._project_conv(x)
|
| 119 |
+
x = self._bn2(x)
|
| 120 |
+
|
| 121 |
+
# Skip connection and drop connect
|
| 122 |
+
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
|
| 123 |
+
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
|
| 124 |
+
# The combination of skip connection and drop connect brings about stochastic depth.
|
| 125 |
+
if drop_connect_rate:
|
| 126 |
+
x = drop_connect(x, p=drop_connect_rate, training=self.training)
|
| 127 |
+
x = x + inputs # skip connection
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
def set_swish(self, memory_efficient=True):
|
| 131 |
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
| 132 |
+
Args:
|
| 133 |
+
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
| 134 |
+
"""
|
| 135 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class EfficientNet(nn.Module):
|
| 139 |
+
"""EfficientNet model.
|
| 140 |
+
Most easily loaded with the .from_name or .from_pretrained methods.
|
| 141 |
+
Args:
|
| 142 |
+
blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
|
| 143 |
+
global_params (namedtuple): A set of GlobalParams shared between blocks.
|
| 144 |
+
References:
|
| 145 |
+
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)
|
| 146 |
+
Example:
|
| 147 |
+
>>> import torch
|
| 148 |
+
>>> from efficientnet.model import EfficientNet
|
| 149 |
+
>>> inputs = torch.rand(1, 3, 224, 224)
|
| 150 |
+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
|
| 151 |
+
>>> model.eval()
|
| 152 |
+
>>> outputs = model(inputs)
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def __init__(self, blocks_args=None, global_params=None):
|
| 156 |
+
super().__init__()
|
| 157 |
+
assert isinstance(blocks_args, list), 'blocks_args should be a list'
|
| 158 |
+
assert len(blocks_args) > 0, 'block args must be greater than 0'
|
| 159 |
+
self._global_params = global_params
|
| 160 |
+
self._blocks_args = blocks_args
|
| 161 |
+
|
| 162 |
+
# Batch norm parameters
|
| 163 |
+
bn_mom = 1 - self._global_params.batch_norm_momentum
|
| 164 |
+
bn_eps = self._global_params.batch_norm_epsilon
|
| 165 |
+
|
| 166 |
+
# Get stem static or dynamic convolution depending on image size
|
| 167 |
+
image_size = global_params.image_size
|
| 168 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 169 |
+
|
| 170 |
+
# Stem
|
| 171 |
+
in_channels = 3 # rgb
|
| 172 |
+
out_channels = round_filters(32, self._global_params) # number of output channels
|
| 173 |
+
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
| 174 |
+
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 175 |
+
image_size = calculate_output_image_size(image_size, 2)
|
| 176 |
+
|
| 177 |
+
# Build blocks
|
| 178 |
+
self._blocks = nn.ModuleList([])
|
| 179 |
+
for block_args in self._blocks_args:
|
| 180 |
+
|
| 181 |
+
# Update block input and output filters based on depth multiplier.
|
| 182 |
+
block_args = block_args._replace(
|
| 183 |
+
input_filters=round_filters(block_args.input_filters, self._global_params),
|
| 184 |
+
output_filters=round_filters(block_args.output_filters, self._global_params),
|
| 185 |
+
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# The first block needs to take care of stride and filter size increase.
|
| 189 |
+
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
|
| 190 |
+
image_size = calculate_output_image_size(image_size, block_args.stride)
|
| 191 |
+
if block_args.num_repeat > 1: # modify block_args to keep same output size
|
| 192 |
+
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
|
| 193 |
+
for _ in range(block_args.num_repeat - 1):
|
| 194 |
+
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
|
| 195 |
+
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
|
| 196 |
+
|
| 197 |
+
# Head
|
| 198 |
+
in_channels = block_args.output_filters # output of final block
|
| 199 |
+
out_channels = round_filters(1280, self._global_params)
|
| 200 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 201 |
+
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
| 202 |
+
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 203 |
+
|
| 204 |
+
# Final linear layer
|
| 205 |
+
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
|
| 206 |
+
if self._global_params.include_top:
|
| 207 |
+
self._dropout = nn.Dropout(self._global_params.dropout_rate)
|
| 208 |
+
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
|
| 209 |
+
|
| 210 |
+
# set activation to memory efficient swish by default
|
| 211 |
+
self._swish = MemoryEfficientSwish()
|
| 212 |
+
|
| 213 |
+
def set_swish(self, memory_efficient=True):
|
| 214 |
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
| 215 |
+
Args:
|
| 216 |
+
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
| 217 |
+
"""
|
| 218 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
| 219 |
+
for block in self._blocks:
|
| 220 |
+
block.set_swish(memory_efficient)
|
| 221 |
+
|
| 222 |
+
def extract_endpoints(self, inputs):
|
| 223 |
+
"""Use convolution layer to extract features
|
| 224 |
+
from reduction levels i in [1, 2, 3, 4, 5].
|
| 225 |
+
Args:
|
| 226 |
+
inputs (tensor): Input tensor.
|
| 227 |
+
Returns:
|
| 228 |
+
Dictionary of last intermediate features
|
| 229 |
+
with reduction levels i in [1, 2, 3, 4, 5].
|
| 230 |
+
Example:
|
| 231 |
+
>>> import torch
|
| 232 |
+
>>> from efficientnet.model import EfficientNet
|
| 233 |
+
>>> inputs = torch.rand(1, 3, 224, 224)
|
| 234 |
+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
|
| 235 |
+
>>> endpoints = model.extract_endpoints(inputs)
|
| 236 |
+
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
|
| 237 |
+
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
|
| 238 |
+
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
|
| 239 |
+
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
|
| 240 |
+
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
|
| 241 |
+
>>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
|
| 242 |
+
"""
|
| 243 |
+
endpoints = dict()
|
| 244 |
+
|
| 245 |
+
# Stem
|
| 246 |
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
| 247 |
+
prev_x = x
|
| 248 |
+
|
| 249 |
+
# Blocks
|
| 250 |
+
for idx, block in enumerate(self._blocks):
|
| 251 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
| 252 |
+
if drop_connect_rate:
|
| 253 |
+
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
|
| 254 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
| 255 |
+
if prev_x.size(2) > x.size(2):
|
| 256 |
+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
|
| 257 |
+
elif idx == len(self._blocks) - 1:
|
| 258 |
+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
|
| 259 |
+
prev_x = x
|
| 260 |
+
|
| 261 |
+
# Head
|
| 262 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
| 263 |
+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
|
| 264 |
+
|
| 265 |
+
return endpoints
|
| 266 |
+
|
| 267 |
+
def extract_features(self, inputs):
|
| 268 |
+
"""use convolution layer to extract feature .
|
| 269 |
+
Args:
|
| 270 |
+
inputs (tensor): Input tensor.
|
| 271 |
+
Returns:
|
| 272 |
+
Output of the final convolution
|
| 273 |
+
layer in the efficientnet model.
|
| 274 |
+
"""
|
| 275 |
+
# Stem
|
| 276 |
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
| 277 |
+
|
| 278 |
+
# Blocks
|
| 279 |
+
for idx, block in enumerate(self._blocks):
|
| 280 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
| 281 |
+
if drop_connect_rate:
|
| 282 |
+
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
|
| 283 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
| 284 |
+
|
| 285 |
+
# Head
|
| 286 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
| 287 |
+
|
| 288 |
+
return x
|
| 289 |
+
|
| 290 |
+
def forward(self, inputs):
|
| 291 |
+
"""EfficientNet's forward function.
|
| 292 |
+
Calls extract_features to extract features, applies final linear layer, and returns logits.
|
| 293 |
+
Args:
|
| 294 |
+
inputs (tensor): Input tensor.
|
| 295 |
+
Returns:
|
| 296 |
+
Output of this model after processing.
|
| 297 |
+
"""
|
| 298 |
+
# Convolution layers
|
| 299 |
+
x = self.extract_features(inputs)
|
| 300 |
+
# Pooling and final linear layer
|
| 301 |
+
x = self._avg_pooling(x)
|
| 302 |
+
if self._global_params.include_top:
|
| 303 |
+
x = x.flatten(start_dim=1)
|
| 304 |
+
x = self._dropout(x)
|
| 305 |
+
x = self._fc(x)
|
| 306 |
+
return x
|
| 307 |
+
|
| 308 |
+
@classmethod
|
| 309 |
+
def from_name(cls, model_name, in_channels=3, **override_params):
|
| 310 |
+
"""Create an efficientnet model according to name.
|
| 311 |
+
Args:
|
| 312 |
+
model_name (str): Name for efficientnet.
|
| 313 |
+
in_channels (int): Input data's channel number.
|
| 314 |
+
override_params (other key word params):
|
| 315 |
+
Params to override model's global_params.
|
| 316 |
+
Optional key:
|
| 317 |
+
'width_coefficient', 'depth_coefficient',
|
| 318 |
+
'image_size', 'dropout_rate',
|
| 319 |
+
'num_classes', 'batch_norm_momentum',
|
| 320 |
+
'batch_norm_epsilon', 'drop_connect_rate',
|
| 321 |
+
'depth_divisor', 'min_depth'
|
| 322 |
+
Returns:
|
| 323 |
+
An efficientnet model.
|
| 324 |
+
"""
|
| 325 |
+
cls._check_model_name_is_valid(model_name)
|
| 326 |
+
blocks_args, global_params = get_model_params(model_name, override_params)
|
| 327 |
+
model = cls(blocks_args, global_params)
|
| 328 |
+
model._change_in_channels(in_channels)
|
| 329 |
+
return model
|
| 330 |
+
|
| 331 |
+
@classmethod
|
| 332 |
+
def from_pretrained(cls, model_name, weights_path=None, advprop=False,
|
| 333 |
+
in_channels=3, num_classes=1000, **override_params):
|
| 334 |
+
"""Create an efficientnet model according to name.
|
| 335 |
+
Args:
|
| 336 |
+
model_name (str): Name for efficientnet.
|
| 337 |
+
weights_path (None or str):
|
| 338 |
+
str: path to pretrained weights file on the local disk.
|
| 339 |
+
None: use pretrained weights downloaded from the Internet.
|
| 340 |
+
advprop (bool):
|
| 341 |
+
Whether to load pretrained weights
|
| 342 |
+
trained with advprop (valid when weights_path is None).
|
| 343 |
+
in_channels (int): Input data's channel number.
|
| 344 |
+
num_classes (int):
|
| 345 |
+
Number of categories for classification.
|
| 346 |
+
It controls the output size for final linear layer.
|
| 347 |
+
override_params (other key word params):
|
| 348 |
+
Params to override model's global_params.
|
| 349 |
+
Optional key:
|
| 350 |
+
'width_coefficient', 'depth_coefficient',
|
| 351 |
+
'image_size', 'dropout_rate',
|
| 352 |
+
'batch_norm_momentum',
|
| 353 |
+
'batch_norm_epsilon', 'drop_connect_rate',
|
| 354 |
+
'depth_divisor', 'min_depth'
|
| 355 |
+
Returns:
|
| 356 |
+
A pretrained efficientnet model.
|
| 357 |
+
"""
|
| 358 |
+
model = cls.from_name(model_name, num_classes=num_classes, **override_params)
|
| 359 |
+
load_pretrained_weights(model, model_name, weights_path=weights_path,
|
| 360 |
+
load_fc=(num_classes == 1000), advprop=advprop)
|
| 361 |
+
model._change_in_channels(in_channels)
|
| 362 |
+
return model
|
| 363 |
+
|
| 364 |
+
@classmethod
|
| 365 |
+
def get_image_size(cls, model_name):
|
| 366 |
+
"""Get the input image size for a given efficientnet model.
|
| 367 |
+
Args:
|
| 368 |
+
model_name (str): Name for efficientnet.
|
| 369 |
+
Returns:
|
| 370 |
+
Input image size (resolution).
|
| 371 |
+
"""
|
| 372 |
+
cls._check_model_name_is_valid(model_name)
|
| 373 |
+
_, _, res, _ = efficientnet_params(model_name)
|
| 374 |
+
return res
|
| 375 |
+
|
| 376 |
+
@classmethod
|
| 377 |
+
def _check_model_name_is_valid(cls, model_name):
|
| 378 |
+
"""Validates model name.
|
| 379 |
+
Args:
|
| 380 |
+
model_name (str): Name for efficientnet.
|
| 381 |
+
Returns:
|
| 382 |
+
bool: Is a valid name or not.
|
| 383 |
+
"""
|
| 384 |
+
if model_name not in VALID_MODELS:
|
| 385 |
+
raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
|
| 386 |
+
|
| 387 |
+
def _change_in_channels(self, in_channels):
|
| 388 |
+
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
|
| 389 |
+
Args:
|
| 390 |
+
in_channels (int): Input data's channel number.
|
| 391 |
+
"""
|
| 392 |
+
if in_channels != 3:
|
| 393 |
+
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
|
| 394 |
+
out_channels = round_filters(32, self._global_params)
|
| 395 |
+
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
harmonizer/src/model/backbone/efficientnet/utils.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""utils.py - Helper functions for building the model and for loading model parameters.
|
| 2 |
+
These helper functions are built to mirror those in the official TensorFlow implementation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# Author: lukemelas (github username)
|
| 6 |
+
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
|
| 7 |
+
# With adjustments and added comments by workingcoder (github username).
|
| 8 |
+
|
| 9 |
+
import re
|
| 10 |
+
import math
|
| 11 |
+
import collections
|
| 12 |
+
from functools import partial
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
from torch.nn import functional as F
|
| 16 |
+
from torch.utils import model_zoo
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
################################################################################
|
| 20 |
+
# Help functions for model architecture
|
| 21 |
+
################################################################################
|
| 22 |
+
|
| 23 |
+
# GlobalParams and BlockArgs: Two namedtuples
|
| 24 |
+
# Swish and MemoryEfficientSwish: Two implementations of the method
|
| 25 |
+
# round_filters and round_repeats:
|
| 26 |
+
# Functions to calculate params for scaling model width and depth ! ! !
|
| 27 |
+
# get_width_and_height_from_size and calculate_output_image_size
|
| 28 |
+
# drop_connect: A structural design
|
| 29 |
+
# get_same_padding_conv2d:
|
| 30 |
+
# Conv2dDynamicSamePadding
|
| 31 |
+
# Conv2dStaticSamePadding
|
| 32 |
+
# get_same_padding_maxPool2d:
|
| 33 |
+
# MaxPool2dDynamicSamePadding
|
| 34 |
+
# MaxPool2dStaticSamePadding
|
| 35 |
+
# It's an additional function, not used in EfficientNet,
|
| 36 |
+
# but can be used in other model (such as EfficientDet).
|
| 37 |
+
|
| 38 |
+
# Parameters for the entire model (stem, all blocks, and head)
|
| 39 |
+
GlobalParams = collections.namedtuple('GlobalParams', [
|
| 40 |
+
'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
|
| 41 |
+
'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
|
| 42 |
+
'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])
|
| 43 |
+
|
| 44 |
+
# Parameters for an individual model block
|
| 45 |
+
BlockArgs = collections.namedtuple('BlockArgs', [
|
| 46 |
+
'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
|
| 47 |
+
'input_filters', 'output_filters', 'se_ratio', 'id_skip'])
|
| 48 |
+
|
| 49 |
+
# Set GlobalParams and BlockArgs's defaults
|
| 50 |
+
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
|
| 51 |
+
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
|
| 52 |
+
|
| 53 |
+
# Swish activation function
|
| 54 |
+
if hasattr(nn, 'SiLU'):
|
| 55 |
+
Swish = nn.SiLU
|
| 56 |
+
else:
|
| 57 |
+
# For compatibility with old PyTorch versions
|
| 58 |
+
class Swish(nn.Module):
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
return x * torch.sigmoid(x)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# A memory-efficient implementation of Swish function
|
| 64 |
+
class SwishImplementation(torch.autograd.Function):
|
| 65 |
+
@staticmethod
|
| 66 |
+
def forward(ctx, i):
|
| 67 |
+
result = i * torch.sigmoid(i)
|
| 68 |
+
ctx.save_for_backward(i)
|
| 69 |
+
return result
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def backward(ctx, grad_output):
|
| 73 |
+
i = ctx.saved_tensors[0]
|
| 74 |
+
sigmoid_i = torch.sigmoid(i)
|
| 75 |
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class MemoryEfficientSwish(nn.Module):
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
return SwishImplementation.apply(x)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def round_filters(filters, global_params):
|
| 84 |
+
"""Calculate and round number of filters based on width multiplier.
|
| 85 |
+
Use width_coefficient, depth_divisor and min_depth of global_params.
|
| 86 |
+
Args:
|
| 87 |
+
filters (int): Filters number to be calculated.
|
| 88 |
+
global_params (namedtuple): Global params of the model.
|
| 89 |
+
Returns:
|
| 90 |
+
new_filters: New filters number after calculating.
|
| 91 |
+
"""
|
| 92 |
+
multiplier = global_params.width_coefficient
|
| 93 |
+
if not multiplier:
|
| 94 |
+
return filters
|
| 95 |
+
# TODO: modify the params names.
|
| 96 |
+
# maybe the names (width_divisor,min_width)
|
| 97 |
+
# are more suitable than (depth_divisor,min_depth).
|
| 98 |
+
divisor = global_params.depth_divisor
|
| 99 |
+
min_depth = global_params.min_depth
|
| 100 |
+
filters *= multiplier
|
| 101 |
+
min_depth = min_depth or divisor # pay attention to this line when using min_depth
|
| 102 |
+
# follow the formula transferred from official TensorFlow implementation
|
| 103 |
+
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
|
| 104 |
+
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
| 105 |
+
new_filters += divisor
|
| 106 |
+
return int(new_filters)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def round_repeats(repeats, global_params):
|
| 110 |
+
"""Calculate module's repeat number of a block based on depth multiplier.
|
| 111 |
+
Use depth_coefficient of global_params.
|
| 112 |
+
Args:
|
| 113 |
+
repeats (int): num_repeat to be calculated.
|
| 114 |
+
global_params (namedtuple): Global params of the model.
|
| 115 |
+
Returns:
|
| 116 |
+
new repeat: New repeat number after calculating.
|
| 117 |
+
"""
|
| 118 |
+
multiplier = global_params.depth_coefficient
|
| 119 |
+
if not multiplier:
|
| 120 |
+
return repeats
|
| 121 |
+
# follow the formula transferred from official TensorFlow implementation
|
| 122 |
+
return int(math.ceil(multiplier * repeats))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def drop_connect(inputs, p, training):
|
| 126 |
+
"""Drop connect.
|
| 127 |
+
Args:
|
| 128 |
+
input (tensor: BCWH): Input of this structure.
|
| 129 |
+
p (float: 0.0~1.0): Probability of drop connection.
|
| 130 |
+
training (bool): The running mode.
|
| 131 |
+
Returns:
|
| 132 |
+
output: Output after drop connection.
|
| 133 |
+
"""
|
| 134 |
+
assert 0 <= p <= 1, 'p must be in range of [0,1]'
|
| 135 |
+
|
| 136 |
+
if not training:
|
| 137 |
+
return inputs
|
| 138 |
+
|
| 139 |
+
batch_size = inputs.shape[0]
|
| 140 |
+
keep_prob = 1 - p
|
| 141 |
+
|
| 142 |
+
# generate binary_tensor mask according to probability (p for 0, 1-p for 1)
|
| 143 |
+
random_tensor = keep_prob
|
| 144 |
+
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
|
| 145 |
+
binary_tensor = torch.floor(random_tensor)
|
| 146 |
+
|
| 147 |
+
output = inputs / keep_prob * binary_tensor
|
| 148 |
+
return output
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def get_width_and_height_from_size(x):
|
| 152 |
+
"""Obtain height and width from x.
|
| 153 |
+
Args:
|
| 154 |
+
x (int, tuple or list): Data size.
|
| 155 |
+
Returns:
|
| 156 |
+
size: A tuple or list (H,W).
|
| 157 |
+
"""
|
| 158 |
+
if isinstance(x, int):
|
| 159 |
+
return x, x
|
| 160 |
+
if isinstance(x, list) or isinstance(x, tuple):
|
| 161 |
+
return x
|
| 162 |
+
else:
|
| 163 |
+
raise TypeError()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def calculate_output_image_size(input_image_size, stride):
|
| 167 |
+
"""Calculates the output image size when using Conv2dSamePadding with a stride.
|
| 168 |
+
Necessary for static padding. Thanks to mannatsingh for pointing this out.
|
| 169 |
+
Args:
|
| 170 |
+
input_image_size (int, tuple or list): Size of input image.
|
| 171 |
+
stride (int, tuple or list): Conv2d operation's stride.
|
| 172 |
+
Returns:
|
| 173 |
+
output_image_size: A list [H,W].
|
| 174 |
+
"""
|
| 175 |
+
if input_image_size is None:
|
| 176 |
+
return None
|
| 177 |
+
image_height, image_width = get_width_and_height_from_size(input_image_size)
|
| 178 |
+
stride = stride if isinstance(stride, int) else stride[0]
|
| 179 |
+
image_height = int(math.ceil(image_height / stride))
|
| 180 |
+
image_width = int(math.ceil(image_width / stride))
|
| 181 |
+
return [image_height, image_width]
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# Note:
|
| 185 |
+
# The following 'SamePadding' functions make output size equal ceil(input size/stride).
|
| 186 |
+
# Only when stride equals 1, can the output size be the same as input size.
|
| 187 |
+
# Don't be confused by their function names ! ! !
|
| 188 |
+
|
| 189 |
+
def get_same_padding_conv2d(image_size=None):
|
| 190 |
+
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
| 191 |
+
Static padding is necessary for ONNX exporting of models.
|
| 192 |
+
Args:
|
| 193 |
+
image_size (int or tuple): Size of the image.
|
| 194 |
+
Returns:
|
| 195 |
+
Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
|
| 196 |
+
"""
|
| 197 |
+
if image_size is None:
|
| 198 |
+
return Conv2dDynamicSamePadding
|
| 199 |
+
else:
|
| 200 |
+
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class Conv2dDynamicSamePadding(nn.Conv2d):
|
| 204 |
+
"""2D Convolutions like TensorFlow, for a dynamic image size.
|
| 205 |
+
The padding is operated in forward function by calculating dynamically.
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
# Tips for 'SAME' mode padding.
|
| 209 |
+
# Given the following:
|
| 210 |
+
# i: width or height
|
| 211 |
+
# s: stride
|
| 212 |
+
# k: kernel size
|
| 213 |
+
# d: dilation
|
| 214 |
+
# p: padding
|
| 215 |
+
# Output after Conv2d:
|
| 216 |
+
# o = floor((i+p-((k-1)*d+1))/s+1)
|
| 217 |
+
# If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
|
| 218 |
+
# => p = (i-1)*s+((k-1)*d+1)-i
|
| 219 |
+
|
| 220 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
|
| 221 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
| 222 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
| 223 |
+
|
| 224 |
+
def forward(self, x):
|
| 225 |
+
ih, iw = x.size()[-2:]
|
| 226 |
+
kh, kw = self.weight.size()[-2:]
|
| 227 |
+
sh, sw = self.stride
|
| 228 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
|
| 229 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
| 230 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
| 231 |
+
if pad_h > 0 or pad_w > 0:
|
| 232 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
| 233 |
+
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class Conv2dStaticSamePadding(nn.Conv2d):
|
| 237 |
+
"""2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
|
| 238 |
+
The padding mudule is calculated in construction function, then used in forward.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
# With the same calculation as Conv2dDynamicSamePadding
|
| 242 |
+
|
| 243 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
|
| 244 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
|
| 245 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
| 246 |
+
|
| 247 |
+
# Calculate padding based on image size and save it
|
| 248 |
+
assert image_size is not None
|
| 249 |
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
| 250 |
+
kh, kw = self.weight.size()[-2:]
|
| 251 |
+
sh, sw = self.stride
|
| 252 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
| 253 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
| 254 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
| 255 |
+
if pad_h > 0 or pad_w > 0:
|
| 256 |
+
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2,
|
| 257 |
+
pad_h // 2, pad_h - pad_h // 2))
|
| 258 |
+
else:
|
| 259 |
+
self.static_padding = nn.Identity()
|
| 260 |
+
|
| 261 |
+
def forward(self, x):
|
| 262 |
+
x = self.static_padding(x)
|
| 263 |
+
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
| 264 |
+
return x
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def get_same_padding_maxPool2d(image_size=None):
|
| 268 |
+
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
| 269 |
+
Static padding is necessary for ONNX exporting of models.
|
| 270 |
+
Args:
|
| 271 |
+
image_size (int or tuple): Size of the image.
|
| 272 |
+
Returns:
|
| 273 |
+
MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
|
| 274 |
+
"""
|
| 275 |
+
if image_size is None:
|
| 276 |
+
return MaxPool2dDynamicSamePadding
|
| 277 |
+
else:
|
| 278 |
+
return partial(MaxPool2dStaticSamePadding, image_size=image_size)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
|
| 282 |
+
"""2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
|
| 283 |
+
The padding is operated in forward function by calculating dynamically.
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False):
|
| 287 |
+
super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
|
| 288 |
+
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
|
| 289 |
+
self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
|
| 290 |
+
self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
|
| 291 |
+
|
| 292 |
+
def forward(self, x):
|
| 293 |
+
ih, iw = x.size()[-2:]
|
| 294 |
+
kh, kw = self.kernel_size
|
| 295 |
+
sh, sw = self.stride
|
| 296 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
| 297 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
| 298 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
| 299 |
+
if pad_h > 0 or pad_w > 0:
|
| 300 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
| 301 |
+
return F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
|
| 302 |
+
self.dilation, self.ceil_mode, self.return_indices)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class MaxPool2dStaticSamePadding(nn.MaxPool2d):
|
| 306 |
+
"""2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
|
| 307 |
+
The padding mudule is calculated in construction function, then used in forward.
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
def __init__(self, kernel_size, stride, image_size=None, **kwargs):
|
| 311 |
+
super().__init__(kernel_size, stride, **kwargs)
|
| 312 |
+
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
|
| 313 |
+
self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
|
| 314 |
+
self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
|
| 315 |
+
|
| 316 |
+
# Calculate padding based on image size and save it
|
| 317 |
+
assert image_size is not None
|
| 318 |
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
| 319 |
+
kh, kw = self.kernel_size
|
| 320 |
+
sh, sw = self.stride
|
| 321 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
| 322 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
| 323 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
| 324 |
+
if pad_h > 0 or pad_w > 0:
|
| 325 |
+
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
|
| 326 |
+
else:
|
| 327 |
+
self.static_padding = nn.Identity()
|
| 328 |
+
|
| 329 |
+
def forward(self, x):
|
| 330 |
+
x = self.static_padding(x)
|
| 331 |
+
x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
|
| 332 |
+
self.dilation, self.ceil_mode, self.return_indices)
|
| 333 |
+
return x
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
################################################################################
|
| 337 |
+
# Helper functions for loading model params
|
| 338 |
+
################################################################################
|
| 339 |
+
|
| 340 |
+
# BlockDecoder: A Class for encoding and decoding BlockArgs
|
| 341 |
+
# efficientnet_params: A function to query compound coefficient
|
| 342 |
+
# get_model_params and efficientnet:
|
| 343 |
+
# Functions to get BlockArgs and GlobalParams for efficientnet
|
| 344 |
+
# url_map and url_map_advprop: Dicts of url_map for pretrained weights
|
| 345 |
+
# load_pretrained_weights: A function to load pretrained weights
|
| 346 |
+
|
| 347 |
+
class BlockDecoder(object):
|
| 348 |
+
"""Block Decoder for readability,
|
| 349 |
+
straight from the official TensorFlow repository.
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
@staticmethod
|
| 353 |
+
def _decode_block_string(block_string):
|
| 354 |
+
"""Get a block through a string notation of arguments.
|
| 355 |
+
Args:
|
| 356 |
+
block_string (str): A string notation of arguments.
|
| 357 |
+
Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
|
| 358 |
+
Returns:
|
| 359 |
+
BlockArgs: The namedtuple defined at the top of this file.
|
| 360 |
+
"""
|
| 361 |
+
assert isinstance(block_string, str)
|
| 362 |
+
|
| 363 |
+
ops = block_string.split('_')
|
| 364 |
+
options = {}
|
| 365 |
+
for op in ops:
|
| 366 |
+
splits = re.split(r'(\d.*)', op)
|
| 367 |
+
if len(splits) >= 2:
|
| 368 |
+
key, value = splits[:2]
|
| 369 |
+
options[key] = value
|
| 370 |
+
|
| 371 |
+
# Check stride
|
| 372 |
+
assert (('s' in options and len(options['s']) == 1) or
|
| 373 |
+
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
|
| 374 |
+
|
| 375 |
+
return BlockArgs(
|
| 376 |
+
num_repeat=int(options['r']),
|
| 377 |
+
kernel_size=int(options['k']),
|
| 378 |
+
stride=[int(options['s'][0])],
|
| 379 |
+
expand_ratio=int(options['e']),
|
| 380 |
+
input_filters=int(options['i']),
|
| 381 |
+
output_filters=int(options['o']),
|
| 382 |
+
se_ratio=float(options['se']) if 'se' in options else None,
|
| 383 |
+
id_skip=('noskip' not in block_string))
|
| 384 |
+
|
| 385 |
+
@staticmethod
|
| 386 |
+
def _encode_block_string(block):
|
| 387 |
+
"""Encode a block to a string.
|
| 388 |
+
Args:
|
| 389 |
+
block (namedtuple): A BlockArgs type argument.
|
| 390 |
+
Returns:
|
| 391 |
+
block_string: A String form of BlockArgs.
|
| 392 |
+
"""
|
| 393 |
+
args = [
|
| 394 |
+
'r%d' % block.num_repeat,
|
| 395 |
+
'k%d' % block.kernel_size,
|
| 396 |
+
's%d%d' % (block.strides[0], block.strides[1]),
|
| 397 |
+
'e%s' % block.expand_ratio,
|
| 398 |
+
'i%d' % block.input_filters,
|
| 399 |
+
'o%d' % block.output_filters
|
| 400 |
+
]
|
| 401 |
+
if 0 < block.se_ratio <= 1:
|
| 402 |
+
args.append('se%s' % block.se_ratio)
|
| 403 |
+
if block.id_skip is False:
|
| 404 |
+
args.append('noskip')
|
| 405 |
+
return '_'.join(args)
|
| 406 |
+
|
| 407 |
+
@staticmethod
|
| 408 |
+
def decode(string_list):
|
| 409 |
+
"""Decode a list of string notations to specify blocks inside the network.
|
| 410 |
+
Args:
|
| 411 |
+
string_list (list[str]): A list of strings, each string is a notation of block.
|
| 412 |
+
Returns:
|
| 413 |
+
blocks_args: A list of BlockArgs namedtuples of block args.
|
| 414 |
+
"""
|
| 415 |
+
assert isinstance(string_list, list)
|
| 416 |
+
blocks_args = []
|
| 417 |
+
for block_string in string_list:
|
| 418 |
+
blocks_args.append(BlockDecoder._decode_block_string(block_string))
|
| 419 |
+
return blocks_args
|
| 420 |
+
|
| 421 |
+
@staticmethod
|
| 422 |
+
def encode(blocks_args):
|
| 423 |
+
"""Encode a list of BlockArgs to a list of strings.
|
| 424 |
+
Args:
|
| 425 |
+
blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
|
| 426 |
+
Returns:
|
| 427 |
+
block_strings: A list of strings, each string is a notation of block.
|
| 428 |
+
"""
|
| 429 |
+
block_strings = []
|
| 430 |
+
for block in blocks_args:
|
| 431 |
+
block_strings.append(BlockDecoder._encode_block_string(block))
|
| 432 |
+
return block_strings
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def efficientnet_params(model_name):
|
| 436 |
+
"""Map EfficientNet model name to parameter coefficients.
|
| 437 |
+
Args:
|
| 438 |
+
model_name (str): Model name to be queried.
|
| 439 |
+
Returns:
|
| 440 |
+
params_dict[model_name]: A (width,depth,res,dropout) tuple.
|
| 441 |
+
"""
|
| 442 |
+
params_dict = {
|
| 443 |
+
# Coefficients: width,depth,res,dropout
|
| 444 |
+
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
| 445 |
+
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
| 446 |
+
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
| 447 |
+
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
| 448 |
+
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
| 449 |
+
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
| 450 |
+
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
| 451 |
+
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
| 452 |
+
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
|
| 453 |
+
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
|
| 454 |
+
}
|
| 455 |
+
return params_dict[model_name]
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
|
| 459 |
+
dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=False):
|
| 460 |
+
"""Create BlockArgs and GlobalParams for efficientnet model.
|
| 461 |
+
Args:
|
| 462 |
+
width_coefficient (float)
|
| 463 |
+
depth_coefficient (float)
|
| 464 |
+
image_size (int)
|
| 465 |
+
dropout_rate (float)
|
| 466 |
+
drop_connect_rate (float)
|
| 467 |
+
num_classes (int)
|
| 468 |
+
Meaning as the name suggests.
|
| 469 |
+
Returns:
|
| 470 |
+
blocks_args, global_params.
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
# Blocks args for the whole model(efficientnet-b0 by default)
|
| 474 |
+
# It will be modified in the construction of EfficientNet Class according to model
|
| 475 |
+
blocks_args = [
|
| 476 |
+
'r1_k3_s11_e1_i32_o16_se0.25',
|
| 477 |
+
'r2_k3_s22_e6_i16_o24_se0.25',
|
| 478 |
+
'r2_k5_s22_e6_i24_o40_se0.25',
|
| 479 |
+
'r3_k3_s22_e6_i40_o80_se0.25',
|
| 480 |
+
'r3_k5_s11_e6_i80_o112_se0.25',
|
| 481 |
+
'r4_k5_s22_e6_i112_o192_se0.25',
|
| 482 |
+
'r1_k3_s11_e6_i192_o320_se0.25',
|
| 483 |
+
]
|
| 484 |
+
blocks_args = BlockDecoder.decode(blocks_args)
|
| 485 |
+
|
| 486 |
+
global_params = GlobalParams(
|
| 487 |
+
width_coefficient=width_coefficient,
|
| 488 |
+
depth_coefficient=depth_coefficient,
|
| 489 |
+
image_size=image_size,
|
| 490 |
+
dropout_rate=dropout_rate,
|
| 491 |
+
|
| 492 |
+
num_classes=num_classes,
|
| 493 |
+
batch_norm_momentum=0.99,
|
| 494 |
+
batch_norm_epsilon=1e-3,
|
| 495 |
+
drop_connect_rate=drop_connect_rate,
|
| 496 |
+
depth_divisor=8,
|
| 497 |
+
min_depth=None,
|
| 498 |
+
include_top=include_top,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
return blocks_args, global_params
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def get_model_params(model_name, override_params):
|
| 505 |
+
"""Get the block args and global params for a given model name.
|
| 506 |
+
Args:
|
| 507 |
+
model_name (str): Model's name.
|
| 508 |
+
override_params (dict): A dict to modify global_params.
|
| 509 |
+
Returns:
|
| 510 |
+
blocks_args, global_params
|
| 511 |
+
"""
|
| 512 |
+
if model_name.startswith('efficientnet'):
|
| 513 |
+
w, d, s, p = efficientnet_params(model_name)
|
| 514 |
+
# note: all models have drop connect rate = 0.2
|
| 515 |
+
blocks_args, global_params = efficientnet(
|
| 516 |
+
width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
|
| 517 |
+
else:
|
| 518 |
+
raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
|
| 519 |
+
if override_params:
|
| 520 |
+
# ValueError will be raised here if override_params has fields not included in global_params.
|
| 521 |
+
global_params = global_params._replace(**override_params)
|
| 522 |
+
return blocks_args, global_params
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
# train with Standard methods
|
| 526 |
+
# check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks)
|
| 527 |
+
url_map = {
|
| 528 |
+
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
|
| 529 |
+
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
|
| 530 |
+
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
|
| 531 |
+
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
|
| 532 |
+
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
|
| 533 |
+
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
|
| 534 |
+
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
|
| 535 |
+
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
# train with Adversarial Examples(AdvProp)
|
| 539 |
+
# check more details in paper(Adversarial Examples Improve Image Recognition)
|
| 540 |
+
url_map_advprop = {
|
| 541 |
+
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
|
| 542 |
+
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
|
| 543 |
+
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
|
| 544 |
+
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
|
| 545 |
+
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
|
| 546 |
+
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
|
| 547 |
+
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
|
| 548 |
+
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
|
| 549 |
+
'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
# TODO: add the petrained weights url map of 'efficientnet-l2'
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True):
|
| 556 |
+
"""Loads pretrained weights from weights path or download using url.
|
| 557 |
+
Args:
|
| 558 |
+
model (Module): The whole model of efficientnet.
|
| 559 |
+
model_name (str): Model name of efficientnet.
|
| 560 |
+
weights_path (None or str):
|
| 561 |
+
str: path to pretrained weights file on the local disk.
|
| 562 |
+
None: use pretrained weights downloaded from the Internet.
|
| 563 |
+
load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
|
| 564 |
+
advprop (bool): Whether to load pretrained weights
|
| 565 |
+
trained with advprop (valid when weights_path is None).
|
| 566 |
+
"""
|
| 567 |
+
if isinstance(weights_path, str):
|
| 568 |
+
state_dict = torch.load(weights_path)
|
| 569 |
+
else:
|
| 570 |
+
# AutoAugment or Advprop (different preprocessing)
|
| 571 |
+
url_map_ = url_map_advprop if advprop else url_map
|
| 572 |
+
state_dict = model_zoo.load_url(url_map_[model_name])
|
| 573 |
+
|
| 574 |
+
if load_fc:
|
| 575 |
+
ret = model.load_state_dict(state_dict, strict=False)
|
| 576 |
+
# assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
|
| 577 |
+
else:
|
| 578 |
+
state_dict.pop('_fc.weight')
|
| 579 |
+
state_dict.pop('_fc.bias')
|
| 580 |
+
ret = model.load_state_dict(state_dict, strict=False)
|
| 581 |
+
# assert set(ret.missing_keys) == set(
|
| 582 |
+
# ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
|
| 583 |
+
# assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
|
| 584 |
+
|
| 585 |
+
if verbose:
|
| 586 |
+
print('Loaded pretrained weights for {}'.format(model_name))
|
harmonizer/src/model/enhancer.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.transforms.functional as tf
|
| 5 |
+
|
| 6 |
+
from .filter import Filter
|
| 7 |
+
from .backbone import EfficientBackboneCommon
|
| 8 |
+
from .module import CascadeArgumentRegressor, FilterPerformer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Enhancer(nn.Module):
|
| 12 |
+
def __init__(self):
|
| 13 |
+
super(Enhancer, self).__init__()
|
| 14 |
+
|
| 15 |
+
self.input_size = (256, 256)
|
| 16 |
+
self.filter_types = [
|
| 17 |
+
Filter.BRIGHTNESS,
|
| 18 |
+
Filter.CONTRAST,
|
| 19 |
+
Filter.SATURATION,
|
| 20 |
+
Filter.HIGHLIGHT,
|
| 21 |
+
Filter.SHADOW,
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
self.backbone = EfficientBackboneCommon.from_name('efficientnet-b0')
|
| 25 |
+
self.regressor = CascadeArgumentRegressor(1280, 160, 1, len(self.filter_types))
|
| 26 |
+
self.performer = FilterPerformer(self.filter_types)
|
| 27 |
+
|
| 28 |
+
def predict_arguments(self, x, mask):
|
| 29 |
+
x = F.interpolate(x, self.input_size, mode='bilinear', align_corners=False)
|
| 30 |
+
enc2x, enc4x, enc8x, enc16x, enc32x = self.backbone(x)
|
| 31 |
+
arguments = self.regressor(enc32x)
|
| 32 |
+
|
| 33 |
+
return arguments
|
| 34 |
+
|
| 35 |
+
def restore_image(self, x, mask, arguments):
|
| 36 |
+
assert len(arguments) == len(self.filter_types)
|
| 37 |
+
|
| 38 |
+
arguments = [torch.clamp(arg, -1, 1).view(-1, 1, 1, 1) for arg in arguments]
|
| 39 |
+
return self.performer.restore(x, mask, arguments)
|
| 40 |
+
|
harmonizer/src/model/filter.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import kornia
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BrightnessFilter(nn.Module):
|
| 11 |
+
def __init__(self):
|
| 12 |
+
super(BrightnessFilter, self).__init__()
|
| 13 |
+
self.epsilon = 1e-6
|
| 14 |
+
|
| 15 |
+
def forward(self, image, x):
|
| 16 |
+
"""
|
| 17 |
+
Arguments:
|
| 18 |
+
image (tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
|
| 19 |
+
x (tensor [n, 1, 1, 1]): brightness argument with values between [-1, 1]
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
# convert image from RGB to HSV
|
| 23 |
+
image = kornia.color.rgb_to_hsv(image)
|
| 24 |
+
h = image[:,0:1,:,:]
|
| 25 |
+
s = image[:,1:2,:,:]
|
| 26 |
+
v = image[:,2:3,:,:]
|
| 27 |
+
|
| 28 |
+
# calculate alpha
|
| 29 |
+
amask = (x >= 0).float()
|
| 30 |
+
alpha = (1 / ((1 - x) + self.epsilon)) * amask + (x + 1) * (1 - amask)
|
| 31 |
+
|
| 32 |
+
# adjust the V channel
|
| 33 |
+
v = v * alpha
|
| 34 |
+
|
| 35 |
+
# convert image from HSV to RGB
|
| 36 |
+
image = torch.cat((h, s, v), dim=1)
|
| 37 |
+
image = kornia.color.hsv_to_rgb(image)
|
| 38 |
+
|
| 39 |
+
# clip pixel values to [0, 1]
|
| 40 |
+
image = torch.clamp(image, 0.0, 1.0)
|
| 41 |
+
|
| 42 |
+
return image
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ContrastFilter(nn.Module):
|
| 46 |
+
def __init__(self):
|
| 47 |
+
super(ContrastFilter, self).__init__()
|
| 48 |
+
|
| 49 |
+
def forward(self, image, x):
|
| 50 |
+
"""
|
| 51 |
+
Arguments:
|
| 52 |
+
image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
|
| 53 |
+
x (tensor [n, 1, 1, 1]): contrast argument with values between [-1, 1]
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
# calculate the mean of the image as the threshold
|
| 57 |
+
threshold = torch.mean(image, dim=(1, 2, 3), keepdim=True)
|
| 58 |
+
|
| 59 |
+
# pre-process x if it is a positive value
|
| 60 |
+
mask = (x.detach() > 0).float()
|
| 61 |
+
x_ = 255 / (256 - torch.floor(x * 255)) - 1
|
| 62 |
+
x_ = x * (1 - mask) + x_ * mask
|
| 63 |
+
|
| 64 |
+
# modify the contrast of the image
|
| 65 |
+
image = image + (image - threshold) * x_
|
| 66 |
+
|
| 67 |
+
# clip pixel values to [0, 1]
|
| 68 |
+
image = torch.clamp(image, 0.0, 1.0)
|
| 69 |
+
|
| 70 |
+
return image
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class SaturationFilter(nn.Module):
|
| 74 |
+
def __init__(self):
|
| 75 |
+
super(SaturationFilter, self).__init__()
|
| 76 |
+
|
| 77 |
+
self.epsilon = 1e-6
|
| 78 |
+
|
| 79 |
+
def forward(self, image, x):
|
| 80 |
+
"""
|
| 81 |
+
Arguments:
|
| 82 |
+
image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
|
| 83 |
+
x (tensor [n, 1, 1, 1]): saturation argument with values between [-1, 1]
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
# calculate the basic properties of the image
|
| 87 |
+
cmin = torch.min(image, dim=1, keepdim=True)[0]
|
| 88 |
+
cmax = torch.max(image, dim=1, keepdim=True)[0]
|
| 89 |
+
var = cmax - cmin
|
| 90 |
+
ran = cmax + cmin
|
| 91 |
+
mean = ran / 2
|
| 92 |
+
|
| 93 |
+
is_positive = (x.detach() >= 0).float()
|
| 94 |
+
|
| 95 |
+
# calculate s
|
| 96 |
+
m = (mean < 0.5).float()
|
| 97 |
+
s = (var / (ran + self.epsilon)) * m + (var / (2 - ran + self.epsilon)) * (1 - m)
|
| 98 |
+
|
| 99 |
+
# if x is positive
|
| 100 |
+
m = ((x + s) > 1).float()
|
| 101 |
+
a_pos = s * m + (1 - x) * (1 - m)
|
| 102 |
+
a_pos = 1 / (a_pos + self.epsilon) - 1
|
| 103 |
+
|
| 104 |
+
# if x is negtive
|
| 105 |
+
a_neg = 1 + x
|
| 106 |
+
|
| 107 |
+
a = a_pos * is_positive + a_neg * (1 - is_positive)
|
| 108 |
+
image = image * is_positive + mean * (1 - is_positive) + (image - mean) * a
|
| 109 |
+
|
| 110 |
+
# clip pixel values to [0, 1]
|
| 111 |
+
image = torch.clamp(image, 0.0, 1.0)
|
| 112 |
+
|
| 113 |
+
return image
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class TemperatureFilter(nn.Module):
|
| 117 |
+
def __init__(self):
|
| 118 |
+
super(TemperatureFilter, self).__init__()
|
| 119 |
+
|
| 120 |
+
self.epsilon = 1e-6
|
| 121 |
+
|
| 122 |
+
def forward(self, image, x):
|
| 123 |
+
"""
|
| 124 |
+
Arguments:
|
| 125 |
+
image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
|
| 126 |
+
x (tensor [n, 1, 1, 1]): color temperature argument with values between [-1, 1]
|
| 127 |
+
"""
|
| 128 |
+
# split the R/G/B channels
|
| 129 |
+
R, G, B = image[:, 0:1, ...], image[:, 1:2, ...], image[:, 2:3, ...]
|
| 130 |
+
|
| 131 |
+
# calculate the mean of each channel
|
| 132 |
+
meanR = torch.mean(R, dim=(2, 3), keepdim=True)
|
| 133 |
+
meanG = torch.mean(G, dim=(2, 3), keepdim=True)
|
| 134 |
+
meanB = torch.mean(B, dim=(2, 3), keepdim=True)
|
| 135 |
+
|
| 136 |
+
# calculate correction factors
|
| 137 |
+
gray = (meanR + meanG + meanB) / 3
|
| 138 |
+
coefR = gray / (meanR + self.epsilon)
|
| 139 |
+
coefG = gray / (meanG + self.epsilon)
|
| 140 |
+
coefB = gray / (meanB + self.epsilon)
|
| 141 |
+
aR = 1 - coefR
|
| 142 |
+
aG = 1 - coefG
|
| 143 |
+
aB = 1 - coefB
|
| 144 |
+
|
| 145 |
+
# adjust temperature
|
| 146 |
+
is_positive = (x.detach() > 0).float()
|
| 147 |
+
is_negative = (x.detach() < 0).float()
|
| 148 |
+
is_zero = (x.detach() == 0).float()
|
| 149 |
+
|
| 150 |
+
meanR_ = meanR + x * torch.sign(x) * is_negative
|
| 151 |
+
meanG_ = meanG + x * torch.sign(x) * 0.5 * (1 - is_zero)
|
| 152 |
+
meanB_ = meanB + x * torch.sign(x) * is_positive
|
| 153 |
+
gray_ = (meanR_ + meanG_ + meanB_) / 3
|
| 154 |
+
|
| 155 |
+
coefR_ = gray_ / (meanR_ + self.epsilon) + aR
|
| 156 |
+
coefG_ = gray_ / (meanG_ + self.epsilon) + aG
|
| 157 |
+
coefB_ = gray_ / (meanB_ + self.epsilon) + aB
|
| 158 |
+
|
| 159 |
+
R_ = coefR_ * R
|
| 160 |
+
G_ = coefG_ * G
|
| 161 |
+
B_ = coefB_ * B
|
| 162 |
+
|
| 163 |
+
# the RGB image with the adjusted brightness
|
| 164 |
+
image = torch.cat((R_, G_, B_), dim=1)
|
| 165 |
+
|
| 166 |
+
# clip pixel values to [0, 1]
|
| 167 |
+
image = torch.clamp(image, 0.0, 1.0)
|
| 168 |
+
|
| 169 |
+
return image
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class HighlightFilter(nn.Module):
|
| 173 |
+
def __init__(self):
|
| 174 |
+
super(HighlightFilter, self).__init__()
|
| 175 |
+
|
| 176 |
+
def forward(self, image, x):
|
| 177 |
+
"""
|
| 178 |
+
Arguments:
|
| 179 |
+
image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
|
| 180 |
+
x (tensor [n, 1, 1, 1]): highlight argument with values between [-1, 1]
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
x = x + 1
|
| 184 |
+
|
| 185 |
+
image = kornia.enhance.invert(image, image.detach() * 0 + 1)
|
| 186 |
+
image = torch.clamp(torch.pow(image + 1e-9, x), 0.0, 1.0)
|
| 187 |
+
image = kornia.enhance.invert(image, image.detach() * 0 + 1)
|
| 188 |
+
|
| 189 |
+
# clip pixel values to [0, 1]
|
| 190 |
+
image = torch.clamp(image, 0.0, 1.0)
|
| 191 |
+
|
| 192 |
+
return image
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class ShadowFilter(nn.Module):
|
| 196 |
+
def __init__(self):
|
| 197 |
+
super(ShadowFilter, self).__init__()
|
| 198 |
+
|
| 199 |
+
def forward(self, image, x):
|
| 200 |
+
"""
|
| 201 |
+
Arguments:
|
| 202 |
+
image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
|
| 203 |
+
x (tensor [n, 1, 1, 1]): shadow argument with values between [-1, 1]
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
x = -x + 1
|
| 207 |
+
image = torch.clamp(torch.pow(image + 1e-9, x), 0.0, 1.0)
|
| 208 |
+
|
| 209 |
+
# clip pixel values to [0, 1]
|
| 210 |
+
image = torch.clamp(image, 0.0, 1.0)
|
| 211 |
+
|
| 212 |
+
return image
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class Filter(Enum):
|
| 216 |
+
BRIGHTNESS = 1
|
| 217 |
+
CONTRAST = 2
|
| 218 |
+
SATURATION = 3
|
| 219 |
+
TEMPERATURE = 4
|
| 220 |
+
HIGHLIGHT = 5
|
| 221 |
+
SHADOW = 6
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
FILTER_MODULES = {
|
| 225 |
+
Filter.BRIGHTNESS: BrightnessFilter,
|
| 226 |
+
Filter.CONTRAST: ContrastFilter,
|
| 227 |
+
Filter.SATURATION: SaturationFilter,
|
| 228 |
+
Filter.TEMPERATURE: TemperatureFilter,
|
| 229 |
+
Filter.HIGHLIGHT: HighlightFilter,
|
| 230 |
+
Filter.SHADOW: ShadowFilter,
|
| 231 |
+
}
|
harmonizer/src/model/harmonizer.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.transforms.functional as tf
|
| 5 |
+
|
| 6 |
+
from .filter import Filter
|
| 7 |
+
from .backbone import EfficientBackbone
|
| 8 |
+
from .module import CascadeArgumentRegressor, FilterPerformer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Harmonizer(nn.Module):
|
| 12 |
+
def __init__(self):
|
| 13 |
+
super(Harmonizer, self).__init__()
|
| 14 |
+
|
| 15 |
+
self.input_size = (256, 256)
|
| 16 |
+
self.filter_types = [
|
| 17 |
+
Filter.TEMPERATURE,
|
| 18 |
+
Filter.BRIGHTNESS,
|
| 19 |
+
Filter.CONTRAST,
|
| 20 |
+
Filter.SATURATION,
|
| 21 |
+
Filter.HIGHLIGHT,
|
| 22 |
+
Filter.SHADOW,
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
self.backbone = EfficientBackbone.from_name('efficientnet-b0')
|
| 26 |
+
self.regressor = CascadeArgumentRegressor(1280, 160, 1, len(self.filter_types))
|
| 27 |
+
self.performer = FilterPerformer(self.filter_types)
|
| 28 |
+
|
| 29 |
+
def predict_arguments(self, comp, mask):
|
| 30 |
+
comp = F.interpolate(comp, self.input_size, mode='bilinear', align_corners=False)
|
| 31 |
+
mask = F.interpolate(mask, self.input_size, mode='bilinear', align_corners=False)
|
| 32 |
+
|
| 33 |
+
fg = torch.cat((comp, mask), dim=1)
|
| 34 |
+
bg = torch.cat((comp, (1 - mask)), dim=1)
|
| 35 |
+
|
| 36 |
+
enc2x, enc4x, enc8x, enc16x, enc32x = self.backbone(fg, bg)
|
| 37 |
+
arguments = self.regressor(enc32x)
|
| 38 |
+
return arguments
|
| 39 |
+
|
| 40 |
+
def restore_image(self, comp, mask, arguments):
|
| 41 |
+
assert len(arguments) == len(self.filter_types)
|
| 42 |
+
|
| 43 |
+
arguments = [torch.clamp(arg, -1, 1).view(-1, 1, 1, 1) for arg in arguments]
|
| 44 |
+
return self.performer.restore(comp, mask, arguments)
|
harmonizer/src/model/module.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
from enum import Enum
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from .filter import Filter, FILTER_MODULES
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CascadeArgumentRegressor(nn.Module):
|
| 13 |
+
def __init__(self, in_channels, base_channels, out_channels, head_num):
|
| 14 |
+
super(CascadeArgumentRegressor, self).__init__()
|
| 15 |
+
self.in_channels = in_channels
|
| 16 |
+
self.base_channels = base_channels
|
| 17 |
+
self.out_channels = out_channels
|
| 18 |
+
self.head_num = head_num
|
| 19 |
+
|
| 20 |
+
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 21 |
+
|
| 22 |
+
self.f = nn.Linear(self.in_channels, 160)
|
| 23 |
+
self.g = nn.Linear(self.in_channels, self.base_channels)
|
| 24 |
+
|
| 25 |
+
self.headers = nn.ModuleList()
|
| 26 |
+
for i in range(0, self.head_num):
|
| 27 |
+
self.headers.append(
|
| 28 |
+
nn.ModuleList([
|
| 29 |
+
nn.Linear(160 + self.base_channels, self.base_channels),
|
| 30 |
+
nn.Linear(self.base_channels, self.out_channels),
|
| 31 |
+
])
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
x = self.pool(x)
|
| 36 |
+
n, c, _, _ = x.shape
|
| 37 |
+
x = x.view(n, c)
|
| 38 |
+
|
| 39 |
+
f = self.f(x)
|
| 40 |
+
g = self.g(x)
|
| 41 |
+
|
| 42 |
+
pred_args = []
|
| 43 |
+
for i in range(0, self.head_num):
|
| 44 |
+
g = self.headers[i][0](torch.cat((f, g), dim=1))
|
| 45 |
+
pred_args.append(self.headers[i][1](g))
|
| 46 |
+
|
| 47 |
+
return pred_args
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class FilterPerformer(nn.Module):
|
| 51 |
+
def __init__(self, filter_types):
|
| 52 |
+
super(FilterPerformer, self).__init__()
|
| 53 |
+
|
| 54 |
+
self.filters = [FILTER_MODULES[filter_type]() for filter_type in filter_types]
|
| 55 |
+
|
| 56 |
+
def forward(self):
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
def restore(self, x, mask, arguments):
|
| 60 |
+
assert len(self.filters) == len(arguments)
|
| 61 |
+
|
| 62 |
+
outputs = []
|
| 63 |
+
_image = x
|
| 64 |
+
for filter, arg in zip(self.filters, arguments):
|
| 65 |
+
_image = filter(_image, arg)
|
| 66 |
+
outputs.append(_image * mask + x * (1 - mask))
|
| 67 |
+
|
| 68 |
+
return outputs
|
| 69 |
+
|
| 70 |
+
def adjust(self, image, mask, arguments):
|
| 71 |
+
assert len(self.filters) == len(arguments)
|
| 72 |
+
|
| 73 |
+
outputs = []
|
| 74 |
+
_image = image
|
| 75 |
+
for filter, arg in zip(reversed(self.filters), reversed(arguments)):
|
| 76 |
+
_image = filter(_image, arg)
|
| 77 |
+
outputs.append(_image * mask + image * (1 - mask))
|
| 78 |
+
|
| 79 |
+
return outputs
|
| 80 |
+
|
harmonizer/src/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tqdm
|
| 2 |
+
numpy
|
| 3 |
+
Pillow
|
| 4 |
+
argparse
|
| 5 |
+
scikit-image == 0.19.2
|
| 6 |
+
kornia
|
harmonizer/src/train/README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Quick Start - Training Harmonizer
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
1. Download the iHarmony4 dataset and put it in the folder `./harmonizer/dataset/`
|
| 5 |
+
2. Pre-process the iHarmony4 dataset for training.
|
| 6 |
+
We provide the processed Hday2night subset as an example at [this link](https://drive.google.com/drive/folders/1HtrmUlFsT1yIfJ2JkGWwAwFDlv8StD6e?usp=sharing).
|
| 7 |
+
You should convert other subsets to the same format for training.
|
| 8 |
+
Otherwise, you need to implement new dataset loaders in the file `./harmonizer/data.py` to load datasets with other formats.
|
| 9 |
+
3. Run the training script by:
|
| 10 |
+
```
|
| 11 |
+
cd ./harmonizer
|
| 12 |
+
python -m script.train
|
| 13 |
+
```
|
| 14 |
+
You can config the training arguments in the script.
|
harmonizer/src/train/harmonizer/__init__.py
ADDED
|
File without changes
|
harmonizer/src/train/harmonizer/criterion.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
import torchtask
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def add_parser_arguments(parser):
|
| 9 |
+
torchtask.criterion_template.add_parser_arguments(parser)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def harmonizer_loss():
|
| 14 |
+
return HarmonizerLoss
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class AbsoluteLoss(nn.Module):
|
| 18 |
+
def __init__(self, epsilon=1e-6):
|
| 19 |
+
super(AbsoluteLoss, self).__init__()
|
| 20 |
+
self.epsilon = epsilon
|
| 21 |
+
|
| 22 |
+
def forward(self, pred, gt):
|
| 23 |
+
loss = torch.sqrt((pred - gt) ** 2 + self.epsilon)
|
| 24 |
+
return loss
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class HarmonizerLoss(torchtask.criterion_template.TaskCriterion):
|
| 28 |
+
def __init__(self, args):
|
| 29 |
+
super(HarmonizerLoss, self).__init__(args)
|
| 30 |
+
|
| 31 |
+
self.l1 = AbsoluteLoss()
|
| 32 |
+
self.l2 = nn.MSELoss(reduction='none')
|
| 33 |
+
|
| 34 |
+
def forward(self, pred, gt, inp):
|
| 35 |
+
pred_outputs, = pred
|
| 36 |
+
x, mask = inp
|
| 37 |
+
|
| 38 |
+
assert len(pred_outputs) == len(gt)
|
| 39 |
+
|
| 40 |
+
image_losses = []
|
| 41 |
+
for pred_, gt_ in zip(pred_outputs, gt):
|
| 42 |
+
l1_loss = torch.sum(self.l1(pred_, gt_) * mask, dim=(1, 2, 3)) / (torch.sum(mask, dim=(1, 2, 3)) + 1e-6)
|
| 43 |
+
l2_loss = torch.sum(self.l2(pred_, gt_) * mask, dim=(1, 2, 3)) / (torch.sum(mask, dim=(1, 2, 3)) + 1e-6) * 10
|
| 44 |
+
loss = (l1_loss + l2_loss)
|
| 45 |
+
image_losses.append(loss)
|
| 46 |
+
|
| 47 |
+
return image_losses
|
harmonizer/src/train/harmonizer/data.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
|
| 9 |
+
import torchtask
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def add_parser_arguments(parser):
|
| 13 |
+
torchtask.data_template.add_parser_arguments(parser)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def harmonizer_iharmony4():
|
| 17 |
+
return HarmonizerIHarmony4
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def original_iharmony4():
|
| 21 |
+
return OriginalIHarmony4
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def resize(img, size):
|
| 25 |
+
interp = cv2.INTER_LINEAR
|
| 26 |
+
|
| 27 |
+
return Image.fromarray(
|
| 28 |
+
cv2.resize(np.array(img).astype('uint8'), size, interpolation=interp))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
im_train_transform = transforms.Compose([
|
| 32 |
+
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.03),
|
| 33 |
+
transforms.ToTensor(),
|
| 34 |
+
])
|
| 35 |
+
|
| 36 |
+
im_val_transform = transforms.Compose([
|
| 37 |
+
transforms.ToTensor(),
|
| 38 |
+
])
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class HarmonizerIHarmony4(torchtask.data_template.TaskDataset):
|
| 42 |
+
def __init__(self, args, is_train):
|
| 43 |
+
super(HarmonizerIHarmony4, self).__init__(args, is_train)
|
| 44 |
+
|
| 45 |
+
self.im_dir = os.path.join(self.root_dir, 'image')
|
| 46 |
+
self.mask_dir = os.path.join(self.root_dir, 'mask')
|
| 47 |
+
|
| 48 |
+
if not os.path.exists(self.mask_dir):
|
| 49 |
+
self.mask_dir = os.path.join(self.root_dir, 'matte')
|
| 50 |
+
|
| 51 |
+
self.sample_list = [_ for _ in os.listdir(self.im_dir)]
|
| 52 |
+
self.idxs = [_ for _ in range(0, len(self.sample_list))]
|
| 53 |
+
|
| 54 |
+
self.im_size = self.args.im_size
|
| 55 |
+
|
| 56 |
+
self.rotation = True if self.is_train else False
|
| 57 |
+
self.fliplr = True if self.is_train else False
|
| 58 |
+
|
| 59 |
+
def __getitem__(self, idx):
|
| 60 |
+
image_path = os.path.join(self.im_dir, self.sample_list[idx])
|
| 61 |
+
mask_path = os.path.join(self.mask_dir, self.sample_list[idx])
|
| 62 |
+
|
| 63 |
+
image = self.im_loader.load(image_path)
|
| 64 |
+
mask = self.im_loader.load(mask_path)
|
| 65 |
+
|
| 66 |
+
width, height = image.size
|
| 67 |
+
|
| 68 |
+
# resize to self.im_size
|
| 69 |
+
image = resize(image, (self.im_size, self.im_size))
|
| 70 |
+
mask = resize(mask, (self.im_size, self.im_size))
|
| 71 |
+
|
| 72 |
+
# convert to np array and scale to [0, 1]
|
| 73 |
+
image = np.array(image).astype('float32') / 255.0
|
| 74 |
+
mask = np.array(mask).astype('float32') / 255.0
|
| 75 |
+
|
| 76 |
+
# check image shape
|
| 77 |
+
if len(mask.shape) == 3:
|
| 78 |
+
mask = mask[:, :, -1]
|
| 79 |
+
|
| 80 |
+
if len(image.shape) == 2:
|
| 81 |
+
image = image[:, :, None]
|
| 82 |
+
if image.shape[2] == 1:
|
| 83 |
+
image = np.repeat(image, 3, axis=2)
|
| 84 |
+
elif image.shape[2] == 4:
|
| 85 |
+
image = image[:, :, 0:3]
|
| 86 |
+
|
| 87 |
+
# random rotate
|
| 88 |
+
rerotation = 0
|
| 89 |
+
if self.rotation and random.randint(0, 1) == 0:
|
| 90 |
+
rotate_num = random.randint(1, 3)
|
| 91 |
+
rerotation = 4 - rotate_num
|
| 92 |
+
image = np.rot90(image, k=rotate_num).copy()
|
| 93 |
+
mask = np.rot90(mask, k=rotate_num).copy()
|
| 94 |
+
|
| 95 |
+
# random flip
|
| 96 |
+
if self.fliplr and (random.randint(0, 1) == 0):
|
| 97 |
+
image = np.fliplr(image).copy()
|
| 98 |
+
mask = np.fliplr(mask).copy()
|
| 99 |
+
|
| 100 |
+
image = Image.fromarray((image * 255.0).astype('uint8'))
|
| 101 |
+
if self.is_train:
|
| 102 |
+
image = im_train_transform(image)
|
| 103 |
+
else:
|
| 104 |
+
image = im_val_transform(image)
|
| 105 |
+
|
| 106 |
+
mask = mask[None, :, :]
|
| 107 |
+
adjusted = image.numpy() * -1
|
| 108 |
+
|
| 109 |
+
return (adjusted, mask), (image, )
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class OriginalIHarmony4(torchtask.data_template.TaskDataset):
|
| 113 |
+
def __init__(self, args, is_train):
|
| 114 |
+
super(OriginalIHarmony4, self).__init__(args, is_train)
|
| 115 |
+
|
| 116 |
+
self.adjusted_dir = os.path.join(self.root_dir, 'comp')
|
| 117 |
+
self.mask_dir = os.path.join(self.root_dir, 'mask')
|
| 118 |
+
self.im_dir = os.path.join(self.root_dir, 'image')
|
| 119 |
+
|
| 120 |
+
self.sample_list = [_ for _ in os.listdir(self.adjusted_dir)]
|
| 121 |
+
self.idxs = [_ for _ in range(0, len(self.sample_list))]
|
| 122 |
+
|
| 123 |
+
self.im_size = self.args.im_size
|
| 124 |
+
|
| 125 |
+
self.rotation = True if self.is_train else False
|
| 126 |
+
self.fliplr = True if self.is_train else False
|
| 127 |
+
|
| 128 |
+
def __getitem__(self, idx):
|
| 129 |
+
sname = self.sample_list[idx]
|
| 130 |
+
adjusted_path = os.path.join(self.adjusted_dir, sname)
|
| 131 |
+
image_path = os.path.join(self.im_dir, sname)
|
| 132 |
+
mask_path = os.path.join(self.mask_dir, sname)
|
| 133 |
+
|
| 134 |
+
if not os.path.exists(image_path):
|
| 135 |
+
prefix = '_'.join(sname.split('_')[:-1])
|
| 136 |
+
image_path = os.path.join(self.im_dir, '{0}.jpg'.format(prefix))
|
| 137 |
+
mask_path = os.path.join(self.mask_dir, '{0}.jpg'.format(prefix))
|
| 138 |
+
|
| 139 |
+
adjusted = self.im_loader.load(adjusted_path)
|
| 140 |
+
image = self.im_loader.load(image_path)
|
| 141 |
+
mask = self.im_loader.load(mask_path)
|
| 142 |
+
|
| 143 |
+
width, height = image.size
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# resize to self.im_size
|
| 147 |
+
adjusted = resize(adjusted, (self.im_size, self.im_size))
|
| 148 |
+
image = resize(image, (self.im_size, self.im_size))
|
| 149 |
+
mask = resize(mask, (self.im_size, self.im_size))
|
| 150 |
+
|
| 151 |
+
# convert to np array and scale to [0, 1]
|
| 152 |
+
adjusted = np.array(adjusted).astype('float32') / 255.0
|
| 153 |
+
image = np.array(image).astype('float32') / 255.0
|
| 154 |
+
mask = np.array(mask).astype('float32') / 255.0
|
| 155 |
+
|
| 156 |
+
# check image shape
|
| 157 |
+
if len(mask.shape) == 3:
|
| 158 |
+
mask = mask[:, :, -1]
|
| 159 |
+
|
| 160 |
+
if len(image.shape) == 2:
|
| 161 |
+
image = image[:, :, None]
|
| 162 |
+
if image.shape[2] == 1:
|
| 163 |
+
image = np.repeat(image, 3, axis=2)
|
| 164 |
+
elif image.shape[2] == 4:
|
| 165 |
+
image = image[:, :, 0:3]
|
| 166 |
+
|
| 167 |
+
if len(adjusted.shape) == 2:
|
| 168 |
+
adjusted = adjusted[:, :, None]
|
| 169 |
+
if adjusted.shape[2] == 1:
|
| 170 |
+
adjusted = np.repeat(adjusted, 3, axis=2)
|
| 171 |
+
elif adjusted.shape[2] == 4:
|
| 172 |
+
adjusted = adjusted[:, :, 0:3]
|
| 173 |
+
|
| 174 |
+
# random rotate
|
| 175 |
+
rerotation = 0
|
| 176 |
+
if self.rotation and random.randint(0, 1) == 0:
|
| 177 |
+
rotate_num = random.randint(1, 3)
|
| 178 |
+
rerotation = 4 - rotate_num
|
| 179 |
+
adjusted = np.rot90(adjusted, k=rotate_num).copy()
|
| 180 |
+
image = np.rot90(image, k=rotate_num).copy()
|
| 181 |
+
mask = np.rot90(mask, k=rotate_num).copy()
|
| 182 |
+
|
| 183 |
+
# random flip
|
| 184 |
+
if self.fliplr and (random.randint(0, 1) == 0):
|
| 185 |
+
adjusted = np.fliplr(adjusted).copy()
|
| 186 |
+
image = np.fliplr(image).copy()
|
| 187 |
+
mask = np.fliplr(mask).copy()
|
| 188 |
+
|
| 189 |
+
adjusted = Image.fromarray((adjusted * 255.0).astype('uint8'))
|
| 190 |
+
image = Image.fromarray((image * 255.0).astype('uint8'))
|
| 191 |
+
|
| 192 |
+
# NOTE: do not add random color adjustement here
|
| 193 |
+
adjusted = im_val_transform(adjusted)
|
| 194 |
+
image = im_val_transform(image)
|
| 195 |
+
|
| 196 |
+
mask = mask[None, :, :]
|
| 197 |
+
|
| 198 |
+
return (adjusted, mask), (image, )
|
harmonizer/src/train/harmonizer/func.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
import skimage
|
| 5 |
+
|
| 6 |
+
import torchtask
|
| 7 |
+
|
| 8 |
+
def task_func():
|
| 9 |
+
return HarmonizationFunc
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class HarmonizationFunc(torchtask.func_template.TaskFunc):
|
| 13 |
+
def __init__(self, args):
|
| 14 |
+
super(HarmonizationFunc, self).__init__(args)
|
| 15 |
+
|
| 16 |
+
def metrics(self, pred_image, gt_image, mask, meters, id_str=''):
|
| 17 |
+
n, c, h, w = pred_image.shape
|
| 18 |
+
|
| 19 |
+
assert n == 1
|
| 20 |
+
|
| 21 |
+
total_pixels = h * w
|
| 22 |
+
fg_pixels = int(torch.sum(mask, dim=(2, 3))[0][0].cpu().numpy())
|
| 23 |
+
|
| 24 |
+
pred_image = torch.clamp(pred_image * 255, 0, 255)
|
| 25 |
+
gt_image = torch.clamp(gt_image * 255, 0, 255)
|
| 26 |
+
|
| 27 |
+
pred_image = pred_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
| 28 |
+
gt_image = gt_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
| 29 |
+
mask = mask[0].permute(1, 2, 0).detach().cpu().numpy()
|
| 30 |
+
|
| 31 |
+
batch_mse = skimage.metrics.mean_squared_error(pred_image, gt_image)
|
| 32 |
+
meters.update('{0}_{1}_mse'.format(id_str, self.METRIC_STR), batch_mse)
|
| 33 |
+
|
| 34 |
+
batch_fmse = skimage.metrics.mean_squared_error(pred_image * mask, gt_image * mask) * total_pixels / fg_pixels
|
| 35 |
+
meters.update('{0}_{1}_fmse'.format(id_str, self.METRIC_STR), batch_fmse)
|
| 36 |
+
|
| 37 |
+
batch_psnr = skimage.metrics.peak_signal_noise_ratio(pred_image, gt_image, data_range=pred_image.max() - pred_image.min())
|
| 38 |
+
meters.update('{0}_{1}_psnr'.format(id_str, self.METRIC_STR), batch_psnr)
|
| 39 |
+
|
| 40 |
+
batch_ssim = skimage.metrics.structural_similarity(pred_image, gt_image, multichannel=True)
|
| 41 |
+
meters.update('{0}_{1}_ssim'.format(id_str, self.METRIC_STR), batch_ssim)
|
harmonizer/src/train/harmonizer/model.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
import torchtask
|
| 5 |
+
|
| 6 |
+
from module import harmonizer as _harmonizer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def add_parser_arguments(parser):
|
| 10 |
+
torchtask.model_template.add_parser_arguments(parser)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def harmonizer():
|
| 14 |
+
return Harmonizer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Harmonizer(torchtask.model_template.TaskModel):
|
| 18 |
+
def __init__(self, args):
|
| 19 |
+
super(Harmonizer, self).__init__(args)
|
| 20 |
+
|
| 21 |
+
self.model = _harmonizer.Harmonizer()
|
| 22 |
+
self.param_groups = [
|
| 23 |
+
{'params': filter(lambda p:p.requires_grad, self.model.backbone.parameters()), 'lr': self.args.lr},
|
| 24 |
+
{'params': filter(lambda p:p.requires_grad, self.model.regressor.parameters()), 'lr': self.args.lr},
|
| 25 |
+
{'params': filter(lambda p:p.requires_grad, self.model.performer.parameters()), 'lr': self.args.lr},
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
def forward(self, inp):
|
| 29 |
+
resulter, debugger = {}, {}
|
| 30 |
+
x, mask = inp
|
| 31 |
+
pred = self.model(x, mask)
|
| 32 |
+
resulter['outputs'] = pred
|
| 33 |
+
return resulter, debugger
|
| 34 |
+
|
| 35 |
+
def restore(self, x, mask, arguments):
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
return self.model.restore_image(x, mask, arguments)
|
| 38 |
+
|
| 39 |
+
def adjust(self, x, mask, arguments):
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
return self.model.adjust_image(x, mask, arguments)
|
harmonizer/src/train/harmonizer/module/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .harmonizer import Harmonizer
|
harmonizer/src/train/harmonizer/module/backbone/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .efficientnet import EfficientBackbone, EfficientBackboneCommon
|
harmonizer/src/train/harmonizer/module/backbone/efficientnet/__init__.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This EfficientNet implementation comes from:
|
| 3 |
+
Author: lukemelas (github username)
|
| 4 |
+
Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from .model import EfficientNet
|
| 11 |
+
from .utils import round_filters, get_same_padding_conv2d
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# for EfficientNet
|
| 15 |
+
class EfficientBackbone(EfficientNet):
|
| 16 |
+
def __init__(self, blocks_args=None, global_params=None):
|
| 17 |
+
super(EfficientBackbone, self).__init__(blocks_args, global_params)
|
| 18 |
+
|
| 19 |
+
self.enc_channels = [16, 24, 40, 112, 1280]
|
| 20 |
+
|
| 21 |
+
# ------------------------------------------------------------
|
| 22 |
+
# delete the useless layers
|
| 23 |
+
# ------------------------------------------------------------
|
| 24 |
+
del self._conv_stem
|
| 25 |
+
del self._bn0
|
| 26 |
+
# ------------------------------------------------------------
|
| 27 |
+
|
| 28 |
+
# ------------------------------------------------------------
|
| 29 |
+
# parameters for the input layers
|
| 30 |
+
# ------------------------------------------------------------
|
| 31 |
+
bn_mom = 1 - self._global_params.batch_norm_momentum
|
| 32 |
+
bn_eps = self._global_params.batch_norm_epsilon
|
| 33 |
+
|
| 34 |
+
in_channels = 4
|
| 35 |
+
out_channels = round_filters(32, self._global_params)
|
| 36 |
+
out_channels = int(out_channels / 2)
|
| 37 |
+
# ------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
# ------------------------------------------------------------
|
| 40 |
+
# define the input layers
|
| 41 |
+
# ------------------------------------------------------------
|
| 42 |
+
image_size = global_params.image_size
|
| 43 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 44 |
+
self._conv_fg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
| 45 |
+
self._bn_fg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 46 |
+
|
| 47 |
+
self._conv_bg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
| 48 |
+
self._bn_bg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 49 |
+
# ------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def forward(self, xfg, xbg):
|
| 52 |
+
xfg = self._swish(self._bn_fg(self._conv_fg(xfg)))
|
| 53 |
+
xbg = self._swish(self._bn_bg(self._conv_bg(xbg)))
|
| 54 |
+
|
| 55 |
+
x = torch.cat((xfg, xbg), dim=1)
|
| 56 |
+
|
| 57 |
+
block_outputs = []
|
| 58 |
+
for idx, block in enumerate(self._blocks):
|
| 59 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
| 60 |
+
drop_connect_rate *= float(idx) / len(self._blocks)
|
| 61 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
| 62 |
+
block_outputs.append(x)
|
| 63 |
+
|
| 64 |
+
# Head
|
| 65 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
| 66 |
+
|
| 67 |
+
return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# for EfficientNet
|
| 71 |
+
class EfficientBackboneCommon(EfficientNet):
|
| 72 |
+
def __init__(self, blocks_args=None, global_params=None):
|
| 73 |
+
super(EfficientBackboneCommon, self).__init__(blocks_args, global_params)
|
| 74 |
+
|
| 75 |
+
self.enc_channels = [16, 24, 40, 112, 1280]
|
| 76 |
+
|
| 77 |
+
# ------------------------------------------------------------
|
| 78 |
+
# delete the useless layers
|
| 79 |
+
# ------------------------------------------------------------
|
| 80 |
+
del self._conv_stem
|
| 81 |
+
del self._bn0
|
| 82 |
+
# ------------------------------------------------------------
|
| 83 |
+
|
| 84 |
+
# ------------------------------------------------------------
|
| 85 |
+
# parameters for the input layers
|
| 86 |
+
# ------------------------------------------------------------
|
| 87 |
+
bn_mom = 1 - self._global_params.batch_norm_momentum
|
| 88 |
+
bn_eps = self._global_params.batch_norm_epsilon
|
| 89 |
+
|
| 90 |
+
in_channels = 3
|
| 91 |
+
out_channels = round_filters(32, self._global_params)
|
| 92 |
+
# ------------------------------------------------------------
|
| 93 |
+
|
| 94 |
+
# ------------------------------------------------------------
|
| 95 |
+
# define the input layers
|
| 96 |
+
# ------------------------------------------------------------
|
| 97 |
+
image_size = global_params.image_size
|
| 98 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 99 |
+
self._conv = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
| 100 |
+
self._bn = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 101 |
+
# ------------------------------------------------------------
|
| 102 |
+
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
x = self._swish(self._bn(self._conv(x)))
|
| 105 |
+
|
| 106 |
+
block_outputs = []
|
| 107 |
+
for idx, block in enumerate(self._blocks):
|
| 108 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
| 109 |
+
drop_connect_rate *= float(idx) / len(self._blocks)
|
| 110 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
| 111 |
+
block_outputs.append(x)
|
| 112 |
+
|
| 113 |
+
# Head
|
| 114 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
| 115 |
+
|
| 116 |
+
return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x
|
harmonizer/src/train/harmonizer/module/backbone/efficientnet/model.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""model.py - Model and module class for EfficientNet.
|
| 2 |
+
They are built to mirror those in the official TensorFlow implementation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# Author: lukemelas (github username)
|
| 6 |
+
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
|
| 7 |
+
# With adjustments and added comments by workingcoder (github username).
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from .utils import (
|
| 13 |
+
round_filters,
|
| 14 |
+
round_repeats,
|
| 15 |
+
drop_connect,
|
| 16 |
+
get_same_padding_conv2d,
|
| 17 |
+
get_model_params,
|
| 18 |
+
efficientnet_params,
|
| 19 |
+
load_pretrained_weights,
|
| 20 |
+
Swish,
|
| 21 |
+
MemoryEfficientSwish,
|
| 22 |
+
calculate_output_image_size
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
VALID_MODELS = (
|
| 27 |
+
'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
|
| 28 |
+
'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
|
| 29 |
+
'efficientnet-b8',
|
| 30 |
+
|
| 31 |
+
# Support the construction of 'efficientnet-l2' without pretrained weights
|
| 32 |
+
'efficientnet-l2'
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MBConvBlock(nn.Module):
|
| 37 |
+
"""Mobile Inverted Residual Bottleneck Block.
|
| 38 |
+
Args:
|
| 39 |
+
block_args (namedtuple): BlockArgs, defined in utils.py.
|
| 40 |
+
global_params (namedtuple): GlobalParam, defined in utils.py.
|
| 41 |
+
image_size (tuple or list): [image_height, image_width].
|
| 42 |
+
References:
|
| 43 |
+
[1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
|
| 44 |
+
[2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
|
| 45 |
+
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, block_args, global_params, image_size=None):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self._block_args = block_args
|
| 51 |
+
self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
|
| 52 |
+
self._bn_eps = global_params.batch_norm_epsilon
|
| 53 |
+
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
|
| 54 |
+
self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
|
| 55 |
+
|
| 56 |
+
# Expansion phase (Inverted Bottleneck)
|
| 57 |
+
inp = self._block_args.input_filters # number of input channels
|
| 58 |
+
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
|
| 59 |
+
if self._block_args.expand_ratio != 1:
|
| 60 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 61 |
+
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
|
| 62 |
+
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 63 |
+
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
|
| 64 |
+
|
| 65 |
+
# Depthwise convolution phase
|
| 66 |
+
k = self._block_args.kernel_size
|
| 67 |
+
s = self._block_args.stride
|
| 68 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 69 |
+
self._depthwise_conv = Conv2d(
|
| 70 |
+
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
|
| 71 |
+
kernel_size=k, stride=s, bias=False)
|
| 72 |
+
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 73 |
+
image_size = calculate_output_image_size(image_size, s)
|
| 74 |
+
|
| 75 |
+
# Squeeze and Excitation layer, if desired
|
| 76 |
+
if self.has_se:
|
| 77 |
+
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
|
| 78 |
+
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
|
| 79 |
+
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
|
| 80 |
+
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
|
| 81 |
+
|
| 82 |
+
# Pointwise convolution phase
|
| 83 |
+
final_oup = self._block_args.output_filters
|
| 84 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 85 |
+
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
|
| 86 |
+
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
|
| 87 |
+
self._swish = MemoryEfficientSwish()
|
| 88 |
+
|
| 89 |
+
def forward(self, inputs, drop_connect_rate=None):
|
| 90 |
+
"""MBConvBlock's forward function.
|
| 91 |
+
Args:
|
| 92 |
+
inputs (tensor): Input tensor.
|
| 93 |
+
drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
|
| 94 |
+
Returns:
|
| 95 |
+
Output of this block after processing.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
# Expansion and Depthwise Convolution
|
| 99 |
+
x = inputs
|
| 100 |
+
if self._block_args.expand_ratio != 1:
|
| 101 |
+
x = self._expand_conv(inputs)
|
| 102 |
+
x = self._bn0(x)
|
| 103 |
+
x = self._swish(x)
|
| 104 |
+
|
| 105 |
+
x = self._depthwise_conv(x)
|
| 106 |
+
x = self._bn1(x)
|
| 107 |
+
x = self._swish(x)
|
| 108 |
+
|
| 109 |
+
# Squeeze and Excitation
|
| 110 |
+
if self.has_se:
|
| 111 |
+
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
| 112 |
+
x_squeezed = self._se_reduce(x_squeezed)
|
| 113 |
+
x_squeezed = self._swish(x_squeezed)
|
| 114 |
+
x_squeezed = self._se_expand(x_squeezed)
|
| 115 |
+
x = torch.sigmoid(x_squeezed) * x
|
| 116 |
+
|
| 117 |
+
# Pointwise Convolution
|
| 118 |
+
x = self._project_conv(x)
|
| 119 |
+
x = self._bn2(x)
|
| 120 |
+
|
| 121 |
+
# Skip connection and drop connect
|
| 122 |
+
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
|
| 123 |
+
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
|
| 124 |
+
# The combination of skip connection and drop connect brings about stochastic depth.
|
| 125 |
+
if drop_connect_rate:
|
| 126 |
+
x = drop_connect(x, p=drop_connect_rate, training=self.training)
|
| 127 |
+
x = x + inputs # skip connection
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
def set_swish(self, memory_efficient=True):
|
| 131 |
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
| 132 |
+
Args:
|
| 133 |
+
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
| 134 |
+
"""
|
| 135 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class EfficientNet(nn.Module):
|
| 139 |
+
"""EfficientNet model.
|
| 140 |
+
Most easily loaded with the .from_name or .from_pretrained methods.
|
| 141 |
+
Args:
|
| 142 |
+
blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
|
| 143 |
+
global_params (namedtuple): A set of GlobalParams shared between blocks.
|
| 144 |
+
References:
|
| 145 |
+
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)
|
| 146 |
+
Example:
|
| 147 |
+
>>> import torch
|
| 148 |
+
>>> from efficientnet.model import EfficientNet
|
| 149 |
+
>>> inputs = torch.rand(1, 3, 224, 224)
|
| 150 |
+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
|
| 151 |
+
>>> model.eval()
|
| 152 |
+
>>> outputs = model(inputs)
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def __init__(self, blocks_args=None, global_params=None):
|
| 156 |
+
super().__init__()
|
| 157 |
+
assert isinstance(blocks_args, list), 'blocks_args should be a list'
|
| 158 |
+
assert len(blocks_args) > 0, 'block args must be greater than 0'
|
| 159 |
+
self._global_params = global_params
|
| 160 |
+
self._blocks_args = blocks_args
|
| 161 |
+
|
| 162 |
+
# Batch norm parameters
|
| 163 |
+
bn_mom = 1 - self._global_params.batch_norm_momentum
|
| 164 |
+
bn_eps = self._global_params.batch_norm_epsilon
|
| 165 |
+
|
| 166 |
+
# Get stem static or dynamic convolution depending on image size
|
| 167 |
+
image_size = global_params.image_size
|
| 168 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 169 |
+
|
| 170 |
+
# Stem
|
| 171 |
+
in_channels = 3 # rgb
|
| 172 |
+
out_channels = round_filters(32, self._global_params) # number of output channels
|
| 173 |
+
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
| 174 |
+
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 175 |
+
image_size = calculate_output_image_size(image_size, 2)
|
| 176 |
+
|
| 177 |
+
# Build blocks
|
| 178 |
+
self._blocks = nn.ModuleList([])
|
| 179 |
+
for block_args in self._blocks_args:
|
| 180 |
+
|
| 181 |
+
# Update block input and output filters based on depth multiplier.
|
| 182 |
+
block_args = block_args._replace(
|
| 183 |
+
input_filters=round_filters(block_args.input_filters, self._global_params),
|
| 184 |
+
output_filters=round_filters(block_args.output_filters, self._global_params),
|
| 185 |
+
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# The first block needs to take care of stride and filter size increase.
|
| 189 |
+
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
|
| 190 |
+
image_size = calculate_output_image_size(image_size, block_args.stride)
|
| 191 |
+
if block_args.num_repeat > 1: # modify block_args to keep same output size
|
| 192 |
+
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
|
| 193 |
+
for _ in range(block_args.num_repeat - 1):
|
| 194 |
+
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
|
| 195 |
+
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
|
| 196 |
+
|
| 197 |
+
# Head
|
| 198 |
+
in_channels = block_args.output_filters # output of final block
|
| 199 |
+
out_channels = round_filters(1280, self._global_params)
|
| 200 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
| 201 |
+
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
| 202 |
+
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
| 203 |
+
|
| 204 |
+
# Final linear layer
|
| 205 |
+
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
|
| 206 |
+
if self._global_params.include_top:
|
| 207 |
+
self._dropout = nn.Dropout(self._global_params.dropout_rate)
|
| 208 |
+
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
|
| 209 |
+
|
| 210 |
+
# set activation to memory efficient swish by default
|
| 211 |
+
self._swish = MemoryEfficientSwish()
|
| 212 |
+
|
| 213 |
+
def set_swish(self, memory_efficient=True):
|
| 214 |
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
| 215 |
+
Args:
|
| 216 |
+
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
| 217 |
+
"""
|
| 218 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
| 219 |
+
for block in self._blocks:
|
| 220 |
+
block.set_swish(memory_efficient)
|
| 221 |
+
|
| 222 |
+
def extract_endpoints(self, inputs):
|
| 223 |
+
"""Use convolution layer to extract features
|
| 224 |
+
from reduction levels i in [1, 2, 3, 4, 5].
|
| 225 |
+
Args:
|
| 226 |
+
inputs (tensor): Input tensor.
|
| 227 |
+
Returns:
|
| 228 |
+
Dictionary of last intermediate features
|
| 229 |
+
with reduction levels i in [1, 2, 3, 4, 5].
|
| 230 |
+
Example:
|
| 231 |
+
>>> import torch
|
| 232 |
+
>>> from efficientnet.model import EfficientNet
|
| 233 |
+
>>> inputs = torch.rand(1, 3, 224, 224)
|
| 234 |
+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
|
| 235 |
+
>>> endpoints = model.extract_endpoints(inputs)
|
| 236 |
+
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
|
| 237 |
+
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
|
| 238 |
+
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
|
| 239 |
+
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
|
| 240 |
+
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
|
| 241 |
+
>>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
|
| 242 |
+
"""
|
| 243 |
+
endpoints = dict()
|
| 244 |
+
|
| 245 |
+
# Stem
|
| 246 |
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
| 247 |
+
prev_x = x
|
| 248 |
+
|
| 249 |
+
# Blocks
|
| 250 |
+
for idx, block in enumerate(self._blocks):
|
| 251 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
| 252 |
+
if drop_connect_rate:
|
| 253 |
+
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
|
| 254 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
| 255 |
+
if prev_x.size(2) > x.size(2):
|
| 256 |
+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
|
| 257 |
+
elif idx == len(self._blocks) - 1:
|
| 258 |
+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
|
| 259 |
+
prev_x = x
|
| 260 |
+
|
| 261 |
+
# Head
|
| 262 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
| 263 |
+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
|
| 264 |
+
|
| 265 |
+
return endpoints
|
| 266 |
+
|
| 267 |
+
def extract_features(self, inputs):
|
| 268 |
+
"""use convolution layer to extract feature .
|
| 269 |
+
Args:
|
| 270 |
+
inputs (tensor): Input tensor.
|
| 271 |
+
Returns:
|
| 272 |
+
Output of the final convolution
|
| 273 |
+
layer in the efficientnet model.
|
| 274 |
+
"""
|
| 275 |
+
# Stem
|
| 276 |
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
| 277 |
+
|
| 278 |
+
# Blocks
|
| 279 |
+
for idx, block in enumerate(self._blocks):
|
| 280 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
| 281 |
+
if drop_connect_rate:
|
| 282 |
+
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
|
| 283 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
| 284 |
+
|
| 285 |
+
# Head
|
| 286 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
| 287 |
+
|
| 288 |
+
return x
|
| 289 |
+
|
| 290 |
+
def forward(self, inputs):
|
| 291 |
+
"""EfficientNet's forward function.
|
| 292 |
+
Calls extract_features to extract features, applies final linear layer, and returns logits.
|
| 293 |
+
Args:
|
| 294 |
+
inputs (tensor): Input tensor.
|
| 295 |
+
Returns:
|
| 296 |
+
Output of this model after processing.
|
| 297 |
+
"""
|
| 298 |
+
# Convolution layers
|
| 299 |
+
x = self.extract_features(inputs)
|
| 300 |
+
# Pooling and final linear layer
|
| 301 |
+
x = self._avg_pooling(x)
|
| 302 |
+
if self._global_params.include_top:
|
| 303 |
+
x = x.flatten(start_dim=1)
|
| 304 |
+
x = self._dropout(x)
|
| 305 |
+
x = self._fc(x)
|
| 306 |
+
return x
|
| 307 |
+
|
| 308 |
+
@classmethod
|
| 309 |
+
def from_name(cls, model_name, in_channels=3, **override_params):
|
| 310 |
+
"""Create an efficientnet model according to name.
|
| 311 |
+
Args:
|
| 312 |
+
model_name (str): Name for efficientnet.
|
| 313 |
+
in_channels (int): Input data's channel number.
|
| 314 |
+
override_params (other key word params):
|
| 315 |
+
Params to override model's global_params.
|
| 316 |
+
Optional key:
|
| 317 |
+
'width_coefficient', 'depth_coefficient',
|
| 318 |
+
'image_size', 'dropout_rate',
|
| 319 |
+
'num_classes', 'batch_norm_momentum',
|
| 320 |
+
'batch_norm_epsilon', 'drop_connect_rate',
|
| 321 |
+
'depth_divisor', 'min_depth'
|
| 322 |
+
Returns:
|
| 323 |
+
An efficientnet model.
|
| 324 |
+
"""
|
| 325 |
+
cls._check_model_name_is_valid(model_name)
|
| 326 |
+
blocks_args, global_params = get_model_params(model_name, override_params)
|
| 327 |
+
model = cls(blocks_args, global_params)
|
| 328 |
+
model._change_in_channels(in_channels)
|
| 329 |
+
return model
|
| 330 |
+
|
| 331 |
+
@classmethod
|
| 332 |
+
def from_pretrained(cls, model_name, weights_path=None, advprop=False,
|
| 333 |
+
in_channels=3, num_classes=1000, **override_params):
|
| 334 |
+
"""Create an efficientnet model according to name.
|
| 335 |
+
Args:
|
| 336 |
+
model_name (str): Name for efficientnet.
|
| 337 |
+
weights_path (None or str):
|
| 338 |
+
str: path to pretrained weights file on the local disk.
|
| 339 |
+
None: use pretrained weights downloaded from the Internet.
|
| 340 |
+
advprop (bool):
|
| 341 |
+
Whether to load pretrained weights
|
| 342 |
+
trained with advprop (valid when weights_path is None).
|
| 343 |
+
in_channels (int): Input data's channel number.
|
| 344 |
+
num_classes (int):
|
| 345 |
+
Number of categories for classification.
|
| 346 |
+
It controls the output size for final linear layer.
|
| 347 |
+
override_params (other key word params):
|
| 348 |
+
Params to override model's global_params.
|
| 349 |
+
Optional key:
|
| 350 |
+
'width_coefficient', 'depth_coefficient',
|
| 351 |
+
'image_size', 'dropout_rate',
|
| 352 |
+
'batch_norm_momentum',
|
| 353 |
+
'batch_norm_epsilon', 'drop_connect_rate',
|
| 354 |
+
'depth_divisor', 'min_depth'
|
| 355 |
+
Returns:
|
| 356 |
+
A pretrained efficientnet model.
|
| 357 |
+
"""
|
| 358 |
+
model = cls.from_name(model_name, num_classes=num_classes, **override_params)
|
| 359 |
+
load_pretrained_weights(model, model_name, weights_path=weights_path,
|
| 360 |
+
load_fc=(num_classes == 1000), advprop=advprop)
|
| 361 |
+
model._change_in_channels(in_channels)
|
| 362 |
+
return model
|
| 363 |
+
|
| 364 |
+
@classmethod
|
| 365 |
+
def get_image_size(cls, model_name):
|
| 366 |
+
"""Get the input image size for a given efficientnet model.
|
| 367 |
+
Args:
|
| 368 |
+
model_name (str): Name for efficientnet.
|
| 369 |
+
Returns:
|
| 370 |
+
Input image size (resolution).
|
| 371 |
+
"""
|
| 372 |
+
cls._check_model_name_is_valid(model_name)
|
| 373 |
+
_, _, res, _ = efficientnet_params(model_name)
|
| 374 |
+
return res
|
| 375 |
+
|
| 376 |
+
@classmethod
|
| 377 |
+
def _check_model_name_is_valid(cls, model_name):
|
| 378 |
+
"""Validates model name.
|
| 379 |
+
Args:
|
| 380 |
+
model_name (str): Name for efficientnet.
|
| 381 |
+
Returns:
|
| 382 |
+
bool: Is a valid name or not.
|
| 383 |
+
"""
|
| 384 |
+
if model_name not in VALID_MODELS:
|
| 385 |
+
raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
|
| 386 |
+
|
| 387 |
+
def _change_in_channels(self, in_channels):
|
| 388 |
+
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
|
| 389 |
+
Args:
|
| 390 |
+
in_channels (int): Input data's channel number.
|
| 391 |
+
"""
|
| 392 |
+
if in_channels != 3:
|
| 393 |
+
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
|
| 394 |
+
out_channels = round_filters(32, self._global_params)
|
| 395 |
+
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
harmonizer/src/train/harmonizer/module/backbone/efficientnet/utils.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""utils.py - Helper functions for building the model and for loading model parameters.
|
| 2 |
+
These helper functions are built to mirror those in the official TensorFlow implementation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# Author: lukemelas (github username)
|
| 6 |
+
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
|
| 7 |
+
# With adjustments and added comments by workingcoder (github username).
|
| 8 |
+
|
| 9 |
+
import re
|
| 10 |
+
import math
|
| 11 |
+
import collections
|
| 12 |
+
from functools import partial
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
from torch.nn import functional as F
|
| 16 |
+
from torch.utils import model_zoo
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
################################################################################
|
| 20 |
+
# Help functions for model architecture
|
| 21 |
+
################################################################################
|
| 22 |
+
|
| 23 |
+
# GlobalParams and BlockArgs: Two namedtuples
|
| 24 |
+
# Swish and MemoryEfficientSwish: Two implementations of the method
|
| 25 |
+
# round_filters and round_repeats:
|
| 26 |
+
# Functions to calculate params for scaling model width and depth ! ! !
|
| 27 |
+
# get_width_and_height_from_size and calculate_output_image_size
|
| 28 |
+
# drop_connect: A structural design
|
| 29 |
+
# get_same_padding_conv2d:
|
| 30 |
+
# Conv2dDynamicSamePadding
|
| 31 |
+
# Conv2dStaticSamePadding
|
| 32 |
+
# get_same_padding_maxPool2d:
|
| 33 |
+
# MaxPool2dDynamicSamePadding
|
| 34 |
+
# MaxPool2dStaticSamePadding
|
| 35 |
+
# It's an additional function, not used in EfficientNet,
|
| 36 |
+
# but can be used in other model (such as EfficientDet).
|
| 37 |
+
|
| 38 |
+
# Parameters for the entire model (stem, all blocks, and head)
|
| 39 |
+
GlobalParams = collections.namedtuple('GlobalParams', [
|
| 40 |
+
'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
|
| 41 |
+
'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
|
| 42 |
+
'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])
|
| 43 |
+
|
| 44 |
+
# Parameters for an individual model block
|
| 45 |
+
BlockArgs = collections.namedtuple('BlockArgs', [
|
| 46 |
+
'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
|
| 47 |
+
'input_filters', 'output_filters', 'se_ratio', 'id_skip'])
|
| 48 |
+
|
| 49 |
+
# Set GlobalParams and BlockArgs's defaults
|
| 50 |
+
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
|
| 51 |
+
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
|
| 52 |
+
|
| 53 |
+
# Swish activation function
|
| 54 |
+
if hasattr(nn, 'SiLU'):
|
| 55 |
+
Swish = nn.SiLU
|
| 56 |
+
else:
|
| 57 |
+
# For compatibility with old PyTorch versions
|
| 58 |
+
class Swish(nn.Module):
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
return x * torch.sigmoid(x)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# A memory-efficient implementation of Swish function
|
| 64 |
+
class SwishImplementation(torch.autograd.Function):
|
| 65 |
+
@staticmethod
|
| 66 |
+
def forward(ctx, i):
|
| 67 |
+
result = i * torch.sigmoid(i)
|
| 68 |
+
ctx.save_for_backward(i)
|
| 69 |
+
return result
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def backward(ctx, grad_output):
|
| 73 |
+
i = ctx.saved_tensors[0]
|
| 74 |
+
sigmoid_i = torch.sigmoid(i)
|
| 75 |
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class MemoryEfficientSwish(nn.Module):
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
return SwishImplementation.apply(x)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def round_filters(filters, global_params):
|
| 84 |
+
"""Calculate and round number of filters based on width multiplier.
|
| 85 |
+
Use width_coefficient, depth_divisor and min_depth of global_params.
|
| 86 |
+
Args:
|
| 87 |
+
filters (int): Filters number to be calculated.
|
| 88 |
+
global_params (namedtuple): Global params of the model.
|
| 89 |
+
Returns:
|
| 90 |
+
new_filters: New filters number after calculating.
|
| 91 |
+
"""
|
| 92 |
+
multiplier = global_params.width_coefficient
|
| 93 |
+
if not multiplier:
|
| 94 |
+
return filters
|
| 95 |
+
# TODO: modify the params names.
|
| 96 |
+
# maybe the names (width_divisor,min_width)
|
| 97 |
+
# are more suitable than (depth_divisor,min_depth).
|
| 98 |
+
divisor = global_params.depth_divisor
|
| 99 |
+
min_depth = global_params.min_depth
|
| 100 |
+
filters *= multiplier
|
| 101 |
+
min_depth = min_depth or divisor # pay attention to this line when using min_depth
|
| 102 |
+
# follow the formula transferred from official TensorFlow implementation
|
| 103 |
+
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
|
| 104 |
+
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
| 105 |
+
new_filters += divisor
|
| 106 |
+
return int(new_filters)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def round_repeats(repeats, global_params):
|
| 110 |
+
"""Calculate module's repeat number of a block based on depth multiplier.
|
| 111 |
+
Use depth_coefficient of global_params.
|
| 112 |
+
Args:
|
| 113 |
+
repeats (int): num_repeat to be calculated.
|
| 114 |
+
global_params (namedtuple): Global params of the model.
|
| 115 |
+
Returns:
|
| 116 |
+
new repeat: New repeat number after calculating.
|
| 117 |
+
"""
|
| 118 |
+
multiplier = global_params.depth_coefficient
|
| 119 |
+
if not multiplier:
|
| 120 |
+
return repeats
|
| 121 |
+
# follow the formula transferred from official TensorFlow implementation
|
| 122 |
+
return int(math.ceil(multiplier * repeats))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def drop_connect(inputs, p, training):
|
| 126 |
+
"""Drop connect.
|
| 127 |
+
Args:
|
| 128 |
+
input (tensor: BCWH): Input of this structure.
|
| 129 |
+
p (float: 0.0~1.0): Probability of drop connection.
|
| 130 |
+
training (bool): The running mode.
|
| 131 |
+
Returns:
|
| 132 |
+
output: Output after drop connection.
|
| 133 |
+
"""
|
| 134 |
+
assert 0 <= p <= 1, 'p must be in range of [0,1]'
|
| 135 |
+
|
| 136 |
+
if not training:
|
| 137 |
+
return inputs
|
| 138 |
+
|
| 139 |
+
batch_size = inputs.shape[0]
|
| 140 |
+
keep_prob = 1 - p
|
| 141 |
+
|
| 142 |
+
# generate binary_tensor mask according to probability (p for 0, 1-p for 1)
|
| 143 |
+
random_tensor = keep_prob
|
| 144 |
+
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
|
| 145 |
+
binary_tensor = torch.floor(random_tensor)
|
| 146 |
+
|
| 147 |
+
output = inputs / keep_prob * binary_tensor
|
| 148 |
+
return output
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def get_width_and_height_from_size(x):
|
| 152 |
+
"""Obtain height and width from x.
|
| 153 |
+
Args:
|
| 154 |
+
x (int, tuple or list): Data size.
|
| 155 |
+
Returns:
|
| 156 |
+
size: A tuple or list (H,W).
|
| 157 |
+
"""
|
| 158 |
+
if isinstance(x, int):
|
| 159 |
+
return x, x
|
| 160 |
+
if isinstance(x, list) or isinstance(x, tuple):
|
| 161 |
+
return x
|
| 162 |
+
else:
|
| 163 |
+
raise TypeError()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def calculate_output_image_size(input_image_size, stride):
|
| 167 |
+
"""Calculates the output image size when using Conv2dSamePadding with a stride.
|
| 168 |
+
Necessary for static padding. Thanks to mannatsingh for pointing this out.
|
| 169 |
+
Args:
|
| 170 |
+
input_image_size (int, tuple or list): Size of input image.
|
| 171 |
+
stride (int, tuple or list): Conv2d operation's stride.
|
| 172 |
+
Returns:
|
| 173 |
+
output_image_size: A list [H,W].
|
| 174 |
+
"""
|
| 175 |
+
if input_image_size is None:
|
| 176 |
+
return None
|
| 177 |
+
image_height, image_width = get_width_and_height_from_size(input_image_size)
|
| 178 |
+
stride = stride if isinstance(stride, int) else stride[0]
|
| 179 |
+
image_height = int(math.ceil(image_height / stride))
|
| 180 |
+
image_width = int(math.ceil(image_width / stride))
|
| 181 |
+
return [image_height, image_width]
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# Note:
|
| 185 |
+
# The following 'SamePadding' functions make output size equal ceil(input size/stride).
|
| 186 |
+
# Only when stride equals 1, can the output size be the same as input size.
|
| 187 |
+
# Don't be confused by their function names ! ! !
|
| 188 |
+
|
| 189 |
+
def get_same_padding_conv2d(image_size=None):
|
| 190 |
+
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
| 191 |
+
Static padding is necessary for ONNX exporting of models.
|
| 192 |
+
Args:
|
| 193 |
+
image_size (int or tuple): Size of the image.
|
| 194 |
+
Returns:
|
| 195 |
+
Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
|
| 196 |
+
"""
|
| 197 |
+
if image_size is None:
|
| 198 |
+
return Conv2dDynamicSamePadding
|
| 199 |
+
else:
|
| 200 |
+
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class Conv2dDynamicSamePadding(nn.Conv2d):
|
| 204 |
+
"""2D Convolutions like TensorFlow, for a dynamic image size.
|
| 205 |
+
The padding is operated in forward function by calculating dynamically.
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
# Tips for 'SAME' mode padding.
|
| 209 |
+
# Given the following:
|
| 210 |
+
# i: width or height
|
| 211 |
+
# s: stride
|
| 212 |
+
# k: kernel size
|
| 213 |
+
# d: dilation
|
| 214 |
+
# p: padding
|
| 215 |
+
# Output after Conv2d:
|
| 216 |
+
# o = floor((i+p-((k-1)*d+1))/s+1)
|
| 217 |
+
# If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
|
| 218 |
+
# => p = (i-1)*s+((k-1)*d+1)-i
|
| 219 |
+
|
| 220 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
|
| 221 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
| 222 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
| 223 |
+
|
| 224 |
+
def forward(self, x):
|
| 225 |
+
ih, iw = x.size()[-2:]
|
| 226 |
+
kh, kw = self.weight.size()[-2:]
|
| 227 |
+
sh, sw = self.stride
|
| 228 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
|
| 229 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
| 230 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
| 231 |
+
if pad_h > 0 or pad_w > 0:
|
| 232 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
| 233 |
+
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class Conv2dStaticSamePadding(nn.Conv2d):
|
| 237 |
+
"""2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
|
| 238 |
+
The padding mudule is calculated in construction function, then used in forward.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
# With the same calculation as Conv2dDynamicSamePadding
|
| 242 |
+
|
| 243 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
|
| 244 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
|
| 245 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
| 246 |
+
|
| 247 |
+
# Calculate padding based on image size and save it
|
| 248 |
+
assert image_size is not None
|
| 249 |
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
| 250 |
+
kh, kw = self.weight.size()[-2:]
|
| 251 |
+
sh, sw = self.stride
|
| 252 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
| 253 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
| 254 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
| 255 |
+
if pad_h > 0 or pad_w > 0:
|
| 256 |
+
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2,
|
| 257 |
+
pad_h // 2, pad_h - pad_h // 2))
|
| 258 |
+
else:
|
| 259 |
+
self.static_padding = nn.Identity()
|
| 260 |
+
|
| 261 |
+
def forward(self, x):
|
| 262 |
+
x = self.static_padding(x)
|
| 263 |
+
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
| 264 |
+
return x
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def get_same_padding_maxPool2d(image_size=None):
|
| 268 |
+
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
| 269 |
+
Static padding is necessary for ONNX exporting of models.
|
| 270 |
+
Args:
|
| 271 |
+
image_size (int or tuple): Size of the image.
|
| 272 |
+
Returns:
|
| 273 |
+
MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
|
| 274 |
+
"""
|
| 275 |
+
if image_size is None:
|
| 276 |
+
return MaxPool2dDynamicSamePadding
|
| 277 |
+
else:
|
| 278 |
+
return partial(MaxPool2dStaticSamePadding, image_size=image_size)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
|
| 282 |
+
"""2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
|
| 283 |
+
The padding is operated in forward function by calculating dynamically.
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False):
|
| 287 |
+
super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
|
| 288 |
+
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
|
| 289 |
+
self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
|
| 290 |
+
self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
|
| 291 |
+
|
| 292 |
+
def forward(self, x):
|
| 293 |
+
ih, iw = x.size()[-2:]
|
| 294 |
+
kh, kw = self.kernel_size
|
| 295 |
+
sh, sw = self.stride
|
| 296 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
| 297 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
| 298 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
| 299 |
+
if pad_h > 0 or pad_w > 0:
|
| 300 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
| 301 |
+
return F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
|
| 302 |
+
self.dilation, self.ceil_mode, self.return_indices)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class MaxPool2dStaticSamePadding(nn.MaxPool2d):
|
| 306 |
+
"""2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
|
| 307 |
+
The padding mudule is calculated in construction function, then used in forward.
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
def __init__(self, kernel_size, stride, image_size=None, **kwargs):
|
| 311 |
+
super().__init__(kernel_size, stride, **kwargs)
|
| 312 |
+
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
|
| 313 |
+
self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
|
| 314 |
+
self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
|
| 315 |
+
|
| 316 |
+
# Calculate padding based on image size and save it
|
| 317 |
+
assert image_size is not None
|
| 318 |
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
| 319 |
+
kh, kw = self.kernel_size
|
| 320 |
+
sh, sw = self.stride
|
| 321 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
| 322 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
| 323 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
| 324 |
+
if pad_h > 0 or pad_w > 0:
|
| 325 |
+
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
|
| 326 |
+
else:
|
| 327 |
+
self.static_padding = nn.Identity()
|
| 328 |
+
|
| 329 |
+
def forward(self, x):
|
| 330 |
+
x = self.static_padding(x)
|
| 331 |
+
x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
|
| 332 |
+
self.dilation, self.ceil_mode, self.return_indices)
|
| 333 |
+
return x
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
################################################################################
|
| 337 |
+
# Helper functions for loading model params
|
| 338 |
+
################################################################################
|
| 339 |
+
|
| 340 |
+
# BlockDecoder: A Class for encoding and decoding BlockArgs
|
| 341 |
+
# efficientnet_params: A function to query compound coefficient
|
| 342 |
+
# get_model_params and efficientnet:
|
| 343 |
+
# Functions to get BlockArgs and GlobalParams for efficientnet
|
| 344 |
+
# url_map and url_map_advprop: Dicts of url_map for pretrained weights
|
| 345 |
+
# load_pretrained_weights: A function to load pretrained weights
|
| 346 |
+
|
| 347 |
+
class BlockDecoder(object):
|
| 348 |
+
"""Block Decoder for readability,
|
| 349 |
+
straight from the official TensorFlow repository.
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
@staticmethod
|
| 353 |
+
def _decode_block_string(block_string):
|
| 354 |
+
"""Get a block through a string notation of arguments.
|
| 355 |
+
Args:
|
| 356 |
+
block_string (str): A string notation of arguments.
|
| 357 |
+
Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
|
| 358 |
+
Returns:
|
| 359 |
+
BlockArgs: The namedtuple defined at the top of this file.
|
| 360 |
+
"""
|
| 361 |
+
assert isinstance(block_string, str)
|
| 362 |
+
|
| 363 |
+
ops = block_string.split('_')
|
| 364 |
+
options = {}
|
| 365 |
+
for op in ops:
|
| 366 |
+
splits = re.split(r'(\d.*)', op)
|
| 367 |
+
if len(splits) >= 2:
|
| 368 |
+
key, value = splits[:2]
|
| 369 |
+
options[key] = value
|
| 370 |
+
|
| 371 |
+
# Check stride
|
| 372 |
+
assert (('s' in options and len(options['s']) == 1) or
|
| 373 |
+
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
|
| 374 |
+
|
| 375 |
+
return BlockArgs(
|
| 376 |
+
num_repeat=int(options['r']),
|
| 377 |
+
kernel_size=int(options['k']),
|
| 378 |
+
stride=[int(options['s'][0])],
|
| 379 |
+
expand_ratio=int(options['e']),
|
| 380 |
+
input_filters=int(options['i']),
|
| 381 |
+
output_filters=int(options['o']),
|
| 382 |
+
se_ratio=float(options['se']) if 'se' in options else None,
|
| 383 |
+
id_skip=('noskip' not in block_string))
|
| 384 |
+
|
| 385 |
+
@staticmethod
|
| 386 |
+
def _encode_block_string(block):
|
| 387 |
+
"""Encode a block to a string.
|
| 388 |
+
Args:
|
| 389 |
+
block (namedtuple): A BlockArgs type argument.
|
| 390 |
+
Returns:
|
| 391 |
+
block_string: A String form of BlockArgs.
|
| 392 |
+
"""
|
| 393 |
+
args = [
|
| 394 |
+
'r%d' % block.num_repeat,
|
| 395 |
+
'k%d' % block.kernel_size,
|
| 396 |
+
's%d%d' % (block.strides[0], block.strides[1]),
|
| 397 |
+
'e%s' % block.expand_ratio,
|
| 398 |
+
'i%d' % block.input_filters,
|
| 399 |
+
'o%d' % block.output_filters
|
| 400 |
+
]
|
| 401 |
+
if 0 < block.se_ratio <= 1:
|
| 402 |
+
args.append('se%s' % block.se_ratio)
|
| 403 |
+
if block.id_skip is False:
|
| 404 |
+
args.append('noskip')
|
| 405 |
+
return '_'.join(args)
|
| 406 |
+
|
| 407 |
+
@staticmethod
|
| 408 |
+
def decode(string_list):
|
| 409 |
+
"""Decode a list of string notations to specify blocks inside the network.
|
| 410 |
+
Args:
|
| 411 |
+
string_list (list[str]): A list of strings, each string is a notation of block.
|
| 412 |
+
Returns:
|
| 413 |
+
blocks_args: A list of BlockArgs namedtuples of block args.
|
| 414 |
+
"""
|
| 415 |
+
assert isinstance(string_list, list)
|
| 416 |
+
blocks_args = []
|
| 417 |
+
for block_string in string_list:
|
| 418 |
+
blocks_args.append(BlockDecoder._decode_block_string(block_string))
|
| 419 |
+
return blocks_args
|
| 420 |
+
|
| 421 |
+
@staticmethod
|
| 422 |
+
def encode(blocks_args):
|
| 423 |
+
"""Encode a list of BlockArgs to a list of strings.
|
| 424 |
+
Args:
|
| 425 |
+
blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
|
| 426 |
+
Returns:
|
| 427 |
+
block_strings: A list of strings, each string is a notation of block.
|
| 428 |
+
"""
|
| 429 |
+
block_strings = []
|
| 430 |
+
for block in blocks_args:
|
| 431 |
+
block_strings.append(BlockDecoder._encode_block_string(block))
|
| 432 |
+
return block_strings
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def efficientnet_params(model_name):
|
| 436 |
+
"""Map EfficientNet model name to parameter coefficients.
|
| 437 |
+
Args:
|
| 438 |
+
model_name (str): Model name to be queried.
|
| 439 |
+
Returns:
|
| 440 |
+
params_dict[model_name]: A (width,depth,res,dropout) tuple.
|
| 441 |
+
"""
|
| 442 |
+
params_dict = {
|
| 443 |
+
# Coefficients: width,depth,res,dropout
|
| 444 |
+
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
| 445 |
+
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
| 446 |
+
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
| 447 |
+
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
| 448 |
+
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
| 449 |
+
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
| 450 |
+
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
| 451 |
+
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
| 452 |
+
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
|
| 453 |
+
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
|
| 454 |
+
}
|
| 455 |
+
return params_dict[model_name]
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
|
| 459 |
+
dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=False):
|
| 460 |
+
"""Create BlockArgs and GlobalParams for efficientnet model.
|
| 461 |
+
Args:
|
| 462 |
+
width_coefficient (float)
|
| 463 |
+
depth_coefficient (float)
|
| 464 |
+
image_size (int)
|
| 465 |
+
dropout_rate (float)
|
| 466 |
+
drop_connect_rate (float)
|
| 467 |
+
num_classes (int)
|
| 468 |
+
Meaning as the name suggests.
|
| 469 |
+
Returns:
|
| 470 |
+
blocks_args, global_params.
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
# Blocks args for the whole model(efficientnet-b0 by default)
|
| 474 |
+
# It will be modified in the construction of EfficientNet Class according to model
|
| 475 |
+
blocks_args = [
|
| 476 |
+
'r1_k3_s11_e1_i32_o16_se0.25',
|
| 477 |
+
'r2_k3_s22_e6_i16_o24_se0.25',
|
| 478 |
+
'r2_k5_s22_e6_i24_o40_se0.25',
|
| 479 |
+
'r3_k3_s22_e6_i40_o80_se0.25',
|
| 480 |
+
'r3_k5_s11_e6_i80_o112_se0.25',
|
| 481 |
+
'r4_k5_s22_e6_i112_o192_se0.25',
|
| 482 |
+
'r1_k3_s11_e6_i192_o320_se0.25',
|
| 483 |
+
]
|
| 484 |
+
blocks_args = BlockDecoder.decode(blocks_args)
|
| 485 |
+
|
| 486 |
+
global_params = GlobalParams(
|
| 487 |
+
width_coefficient=width_coefficient,
|
| 488 |
+
depth_coefficient=depth_coefficient,
|
| 489 |
+
image_size=image_size,
|
| 490 |
+
dropout_rate=dropout_rate,
|
| 491 |
+
|
| 492 |
+
num_classes=num_classes,
|
| 493 |
+
batch_norm_momentum=0.99,
|
| 494 |
+
batch_norm_epsilon=1e-3,
|
| 495 |
+
drop_connect_rate=drop_connect_rate,
|
| 496 |
+
depth_divisor=8,
|
| 497 |
+
min_depth=None,
|
| 498 |
+
include_top=include_top,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
return blocks_args, global_params
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def get_model_params(model_name, override_params):
|
| 505 |
+
"""Get the block args and global params for a given model name.
|
| 506 |
+
Args:
|
| 507 |
+
model_name (str): Model's name.
|
| 508 |
+
override_params (dict): A dict to modify global_params.
|
| 509 |
+
Returns:
|
| 510 |
+
blocks_args, global_params
|
| 511 |
+
"""
|
| 512 |
+
if model_name.startswith('efficientnet'):
|
| 513 |
+
w, d, s, p = efficientnet_params(model_name)
|
| 514 |
+
# note: all models have drop connect rate = 0.2
|
| 515 |
+
blocks_args, global_params = efficientnet(
|
| 516 |
+
width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
|
| 517 |
+
else:
|
| 518 |
+
raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
|
| 519 |
+
if override_params:
|
| 520 |
+
# ValueError will be raised here if override_params has fields not included in global_params.
|
| 521 |
+
global_params = global_params._replace(**override_params)
|
| 522 |
+
return blocks_args, global_params
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
# train with Standard methods
|
| 526 |
+
# check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks)
|
| 527 |
+
url_map = {
|
| 528 |
+
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
|
| 529 |
+
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
|
| 530 |
+
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
|
| 531 |
+
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
|
| 532 |
+
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
|
| 533 |
+
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
|
| 534 |
+
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
|
| 535 |
+
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
# train with Adversarial Examples(AdvProp)
|
| 539 |
+
# check more details in paper(Adversarial Examples Improve Image Recognition)
|
| 540 |
+
url_map_advprop = {
|
| 541 |
+
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
|
| 542 |
+
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
|
| 543 |
+
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
|
| 544 |
+
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
|
| 545 |
+
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
|
| 546 |
+
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
|
| 547 |
+
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
|
| 548 |
+
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
|
| 549 |
+
'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
# TODO: add the petrained weights url map of 'efficientnet-l2'
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True):
|
| 556 |
+
"""Loads pretrained weights from weights path or download using url.
|
| 557 |
+
Args:
|
| 558 |
+
model (Module): The whole model of efficientnet.
|
| 559 |
+
model_name (str): Model name of efficientnet.
|
| 560 |
+
weights_path (None or str):
|
| 561 |
+
str: path to pretrained weights file on the local disk.
|
| 562 |
+
None: use pretrained weights downloaded from the Internet.
|
| 563 |
+
load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
|
| 564 |
+
advprop (bool): Whether to load pretrained weights
|
| 565 |
+
trained with advprop (valid when weights_path is None).
|
| 566 |
+
"""
|
| 567 |
+
if isinstance(weights_path, str):
|
| 568 |
+
state_dict = torch.load(weights_path)
|
| 569 |
+
else:
|
| 570 |
+
# AutoAugment or Advprop (different preprocessing)
|
| 571 |
+
url_map_ = url_map_advprop if advprop else url_map
|
| 572 |
+
state_dict = model_zoo.load_url(url_map_[model_name])
|
| 573 |
+
|
| 574 |
+
if load_fc:
|
| 575 |
+
ret = model.load_state_dict(state_dict, strict=False)
|
| 576 |
+
# assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
|
| 577 |
+
else:
|
| 578 |
+
state_dict.pop('_fc.weight')
|
| 579 |
+
state_dict.pop('_fc.bias')
|
| 580 |
+
ret = model.load_state_dict(state_dict, strict=False)
|
| 581 |
+
# assert set(ret.missing_keys) == set(
|
| 582 |
+
# ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
|
| 583 |
+
# assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
|
| 584 |
+
|
| 585 |
+
if verbose:
|
| 586 |
+
print('Loaded pretrained weights for {}'.format(model_name))
|
harmonizer/src/train/harmonizer/module/filter.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import kornia
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BrightnessFilter(nn.Module):
|
| 11 |
+
def __init__(self):
|
| 12 |
+
super(BrightnessFilter, self).__init__()
|
| 13 |
+
self.epsilon = 1e-6
|
| 14 |
+
|
| 15 |
+
def forward(self, image, x):
|
| 16 |
+
"""
|
| 17 |
+
Arguments:
|
| 18 |
+
image (tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
|
| 19 |
+
x (tensor [n, 1, 1, 1]): brightness argument with values between [-1, 1]
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
# convert image from RGB to HSV
|
| 23 |
+
image = kornia.color.rgb_to_hsv(image)
|
| 24 |
+
h = image[:,0:1,:,:]
|
| 25 |
+
s = image[:,1:2,:,:]
|
| 26 |
+
v = image[:,2:3,:,:]
|
| 27 |
+
|
| 28 |
+
# calculate alpha
|
| 29 |
+
amask = (x >= 0).float()
|
| 30 |
+
alpha = (1 / ((1 - x) + self.epsilon)) * amask + (x + 1) * (1 - amask)
|
| 31 |
+
|
| 32 |
+
# adjust the V channel
|
| 33 |
+
v = v * alpha
|
| 34 |
+
|
| 35 |
+
# convert image from HSV to RGB
|
| 36 |
+
image = torch.cat((h, s, v), dim=1)
|
| 37 |
+
image = kornia.color.hsv_to_rgb(image)
|
| 38 |
+
|
| 39 |
+
# clip pixel values to [0, 1]
|
| 40 |
+
image = torch.clamp(image, 0.0, 1.0)
|
| 41 |
+
|
| 42 |
+
return image
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ContrastFilter(nn.Module):
|
| 46 |
+
def __init__(self):
|
| 47 |
+
super(ContrastFilter, self).__init__()
|
| 48 |
+
|
| 49 |
+
def forward(self, image, x):
|
| 50 |
+
"""
|
| 51 |
+
Arguments:
|
| 52 |
+
image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
|
| 53 |
+
x (tensor [n, 1, 1, 1]): contrast argument with values between [-1, 1]
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
# calculate the mean of the image as the threshold
|
| 57 |
+
threshold = torch.mean(image, dim=(1, 2, 3), keepdim=True)
|
| 58 |
+
|
| 59 |
+
# pre-process x if it is a positive value
|
| 60 |
+
mask = (x.detach() > 0).float()
|
| 61 |
+
x_ = 255 / (256 - torch.floor(x * 255)) - 1
|
| 62 |
+
x_ = x * (1 - mask) + x_ * mask
|
| 63 |
+
|
| 64 |
+
# modify the contrast of the image
|
| 65 |
+
image = image + (image - threshold) * x_
|
| 66 |
+
|
| 67 |
+
# clip pixel values to [0, 1]
|
| 68 |
+
image = torch.clamp(image, 0.0, 1.0)
|
| 69 |
+
|
| 70 |
+
return image
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class SaturationFilter(nn.Module):
|
| 74 |
+
def __init__(self):
|
| 75 |
+
super(SaturationFilter, self).__init__()
|
| 76 |
+
|
| 77 |
+
self.epsilon = 1e-6
|
| 78 |
+
|
| 79 |
+
def forward(self, image, x):
|
| 80 |
+
"""
|
| 81 |
+
Arguments:
|
| 82 |
+
image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
|
| 83 |
+
x (tensor [n, 1, 1, 1]): saturation argument with values between [-1, 1]
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
# calculate the basic properties of the image
|
| 87 |
+
cmin = torch.min(image, dim=1, keepdim=True)[0]
|
| 88 |
+
cmax = torch.max(image, dim=1, keepdim=True)[0]
|
| 89 |
+
var = cmax - cmin
|
| 90 |
+
ran = cmax + cmin
|
| 91 |
+
mean = ran / 2
|
| 92 |
+
|
| 93 |
+
is_positive = (x.detach() >= 0).float()
|
| 94 |
+
|
| 95 |
+
# calculate s
|
| 96 |
+
m = (mean < 0.5).float()
|
| 97 |
+
s = (var / (ran + self.epsilon)) * m + (var / (2 - ran + self.epsilon)) * (1 - m)
|
| 98 |
+
|
| 99 |
+
# if x is positive
|
| 100 |
+
m = ((x + s) > 1).float()
|
| 101 |
+
a_pos = s * m + (1 - x) * (1 - m)
|
| 102 |
+
a_pos = 1 / (a_pos + self.epsilon) - 1
|
| 103 |
+
|
| 104 |
+
# if x is negtive
|
| 105 |
+
a_neg = 1 + x
|
| 106 |
+
|
| 107 |
+
a = a_pos * is_positive + a_neg * (1 - is_positive)
|
| 108 |
+
image = image * is_positive + mean * (1 - is_positive) + (image - mean) * a
|
| 109 |
+
|
| 110 |
+
# clip pixel values to [0, 1]
|
| 111 |
+
image = torch.clamp(image, 0.0, 1.0)
|
| 112 |
+
|
| 113 |
+
return image
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class TemperatureFilter(nn.Module):
|
| 117 |
+
def __init__(self):
|
| 118 |
+
super(TemperatureFilter, self).__init__()
|
| 119 |
+
|
| 120 |
+
self.epsilon = 1e-6
|
| 121 |
+
|
| 122 |
+
def forward(self, image, x):
|
| 123 |
+
"""
|
| 124 |
+
Arguments:
|
| 125 |
+
image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
|
| 126 |
+
x (tensor [n, 1, 1, 1]): color temperature argument with values between [-1, 1]
|
| 127 |
+
"""
|
| 128 |
+
# split the R/G/B channels
|
| 129 |
+
R, G, B = image[:, 0:1, ...], image[:, 1:2, ...], image[:, 2:3, ...]
|
| 130 |
+
|
| 131 |
+
# calculate the mean of each channel
|
| 132 |
+
meanR = torch.mean(R, dim=(2, 3), keepdim=True)
|
| 133 |
+
meanG = torch.mean(G, dim=(2, 3), keepdim=True)
|
| 134 |
+
meanB = torch.mean(B, dim=(2, 3), keepdim=True)
|
| 135 |
+
|
| 136 |
+
# calculate correction factors
|
| 137 |
+
gray = (meanR + meanG + meanB) / 3
|
| 138 |
+
coefR = gray / (meanR + self.epsilon)
|
| 139 |
+
coefG = gray / (meanG + self.epsilon)
|
| 140 |
+
coefB = gray / (meanB + self.epsilon)
|
| 141 |
+
aR = 1 - coefR
|
| 142 |
+
aG = 1 - coefG
|
| 143 |
+
aB = 1 - coefB
|
| 144 |
+
|
| 145 |
+
# adjust temperature
|
| 146 |
+
is_positive = (x.detach() > 0).float()
|
| 147 |
+
is_negative = (x.detach() < 0).float()
|
| 148 |
+
is_zero = (x.detach() == 0).float()
|
| 149 |
+
|
| 150 |
+
meanR_ = meanR + x * torch.sign(x) * is_negative
|
| 151 |
+
meanG_ = meanG + x * torch.sign(x) * 0.5 * (1 - is_zero)
|
| 152 |
+
meanB_ = meanB + x * torch.sign(x) * is_positive
|
| 153 |
+
gray_ = (meanR_ + meanG_ + meanB_) / 3
|
| 154 |
+
|
| 155 |
+
coefR_ = gray_ / (meanR_ + self.epsilon) + aR
|
| 156 |
+
coefG_ = gray_ / (meanG_ + self.epsilon) + aG
|
| 157 |
+
coefB_ = gray_ / (meanB_ + self.epsilon) + aB
|
| 158 |
+
|
| 159 |
+
R_ = coefR_ * R
|
| 160 |
+
G_ = coefG_ * G
|
| 161 |
+
B_ = coefB_ * B
|
| 162 |
+
|
| 163 |
+
# the RGB image with the adjusted brightness
|
| 164 |
+
image = torch.cat((R_, G_, B_), dim=1)
|
| 165 |
+
|
| 166 |
+
# clip pixel values to [0, 1]
|
| 167 |
+
image = torch.clamp(image, 0.0, 1.0)
|
| 168 |
+
|
| 169 |
+
return image
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class HighlightFilter(nn.Module):
|
| 173 |
+
def __init__(self):
|
| 174 |
+
super(HighlightFilter, self).__init__()
|
| 175 |
+
|
| 176 |
+
def forward(self, image, x):
|
| 177 |
+
"""
|
| 178 |
+
Arguments:
|
| 179 |
+
image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
|
| 180 |
+
x (tensor [n, 1, 1, 1]): highlight argument with values between [-1, 1]
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
x = x + 1
|
| 184 |
+
|
| 185 |
+
image = kornia.enhance.invert(image, image.detach() * 0 + 1)
|
| 186 |
+
image = torch.clamp(torch.pow(image + 1e-9, x), 0.0, 1.0)
|
| 187 |
+
image = kornia.enhance.invert(image, image.detach() * 0 + 1)
|
| 188 |
+
|
| 189 |
+
# clip pixel values to [0, 1]
|
| 190 |
+
image = torch.clamp(image, 0.0, 1.0)
|
| 191 |
+
|
| 192 |
+
return image
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class ShadowFilter(nn.Module):
|
| 196 |
+
def __init__(self):
|
| 197 |
+
super(ShadowFilter, self).__init__()
|
| 198 |
+
|
| 199 |
+
def forward(self, image, x):
|
| 200 |
+
"""
|
| 201 |
+
Arguments:
|
| 202 |
+
image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
|
| 203 |
+
x (tensor [n, 1, 1, 1]): shadow argument with values between [-1, 1]
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
x = -x + 1
|
| 207 |
+
image = torch.clamp(torch.pow(image + 1e-9, x), 0.0, 1.0)
|
| 208 |
+
|
| 209 |
+
# clip pixel values to [0, 1]
|
| 210 |
+
image = torch.clamp(image, 0.0, 1.0)
|
| 211 |
+
|
| 212 |
+
return image
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class Filter(Enum):
|
| 216 |
+
BRIGHTNESS = 1
|
| 217 |
+
CONTRAST = 2
|
| 218 |
+
SATURATION = 3
|
| 219 |
+
TEMPERATURE = 4
|
| 220 |
+
HIGHLIGHT = 5
|
| 221 |
+
SHADOW = 6
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
FILTER_MODULES = {
|
| 225 |
+
Filter.BRIGHTNESS: BrightnessFilter,
|
| 226 |
+
Filter.CONTRAST: ContrastFilter,
|
| 227 |
+
Filter.SATURATION: SaturationFilter,
|
| 228 |
+
Filter.TEMPERATURE: TemperatureFilter,
|
| 229 |
+
Filter.HIGHLIGHT: HighlightFilter,
|
| 230 |
+
Filter.SHADOW: ShadowFilter,
|
| 231 |
+
}
|
harmonizer/src/train/harmonizer/module/harmonizer.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.transforms.functional as tf
|
| 5 |
+
|
| 6 |
+
from .filter import Filter
|
| 7 |
+
from .backbone import EfficientBackbone
|
| 8 |
+
from .module import CascadeArgumentRegressor, FilterPerformer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Harmonizer(nn.Module):
|
| 12 |
+
def __init__(self):
|
| 13 |
+
super(Harmonizer, self).__init__()
|
| 14 |
+
|
| 15 |
+
self.input_size = (256, 256)
|
| 16 |
+
self.filter_types = [
|
| 17 |
+
Filter.TEMPERATURE,
|
| 18 |
+
Filter.BRIGHTNESS,
|
| 19 |
+
Filter.CONTRAST,
|
| 20 |
+
Filter.SATURATION,
|
| 21 |
+
Filter.HIGHLIGHT,
|
| 22 |
+
Filter.SHADOW,
|
| 23 |
+
]
|
| 24 |
+
self.filter_argument_ranges = [
|
| 25 |
+
0.3,
|
| 26 |
+
0.5,
|
| 27 |
+
0.5,
|
| 28 |
+
0.6,
|
| 29 |
+
0.4,
|
| 30 |
+
0.4,
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
self.backbone = EfficientBackbone.from_name('efficientnet-b0')
|
| 34 |
+
self.regressor = CascadeArgumentRegressor(1280, 160, 1, len(self.filter_types))
|
| 35 |
+
self.performer = FilterPerformer(self.filter_types)
|
| 36 |
+
|
| 37 |
+
for m in self.modules():
|
| 38 |
+
if isinstance(m, nn.Conv2d):
|
| 39 |
+
self._init_conv(m)
|
| 40 |
+
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
|
| 41 |
+
self._init_norm(m)
|
| 42 |
+
|
| 43 |
+
self.backbone = EfficientBackbone.from_pretrained('efficientnet-b0')
|
| 44 |
+
|
| 45 |
+
def forward(self, comp, mask):
|
| 46 |
+
arguments = self.predict_arguments(comp, mask)
|
| 47 |
+
pred = self.restore_image(comp, mask, arguments)
|
| 48 |
+
return pred
|
| 49 |
+
|
| 50 |
+
def predict_arguments(self, comp, mask):
|
| 51 |
+
comp = F.interpolate(comp, self.input_size, mode='bilinear', align_corners=False)
|
| 52 |
+
mask = F.interpolate(mask, self.input_size, mode='bilinear', align_corners=False)
|
| 53 |
+
|
| 54 |
+
fg = torch.cat((comp, mask), dim=1)
|
| 55 |
+
bg = torch.cat((comp, (1 - mask)), dim=1)
|
| 56 |
+
|
| 57 |
+
enc2x, enc4x, enc8x, enc16x, enc32x = self.backbone(fg, bg)
|
| 58 |
+
arguments = self.regressor(enc32x)
|
| 59 |
+
return arguments
|
| 60 |
+
|
| 61 |
+
def restore_image(self, comp, mask, arguments):
|
| 62 |
+
assert len(arguments) == len(self.filter_types)
|
| 63 |
+
|
| 64 |
+
arguments = [torch.clamp(arg, -1, 1).view(-1, 1, 1, 1) for arg in arguments]
|
| 65 |
+
return self.performer.restore(comp, mask, arguments)
|
| 66 |
+
|
| 67 |
+
def adjust_image(self, image, mask, arguments):
|
| 68 |
+
assert len(arguments) == len(self.filter_types)
|
| 69 |
+
|
| 70 |
+
arguments = [(torch.clamp(arg, -1, 1) * r).view(-1, 1, 1, 1) \
|
| 71 |
+
for arg, r in zip(arguments, self.filter_argument_ranges)]
|
| 72 |
+
return self.performer.adjust(image, mask, arguments)
|
| 73 |
+
|
| 74 |
+
def _init_conv(self, conv):
|
| 75 |
+
nn.init.kaiming_uniform_(
|
| 76 |
+
conv.weight, a=0, mode='fan_in', nonlinearity='relu')
|
| 77 |
+
if conv.bias is not None:
|
| 78 |
+
nn.init.constant_(conv.bias, 0)
|
| 79 |
+
|
| 80 |
+
def _init_norm(self, bn):
|
| 81 |
+
if bn.weight is not None:
|
| 82 |
+
nn.init.constant_(bn.weight, 1)
|
| 83 |
+
nn.init.constant_(bn.bias, 0)
|
harmonizer/src/train/harmonizer/module/module.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
from enum import Enum
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from .filter import Filter, FILTER_MODULES
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CascadeArgumentRegressor(nn.Module):
|
| 13 |
+
def __init__(self, in_channels, base_channels, out_channels, head_num):
|
| 14 |
+
super(CascadeArgumentRegressor, self).__init__()
|
| 15 |
+
self.in_channels = in_channels
|
| 16 |
+
self.base_channels = base_channels
|
| 17 |
+
self.out_channels = out_channels
|
| 18 |
+
self.head_num = head_num
|
| 19 |
+
|
| 20 |
+
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 21 |
+
|
| 22 |
+
self.f = nn.Linear(self.in_channels, 160)
|
| 23 |
+
self.g = nn.Linear(self.in_channels, self.base_channels)
|
| 24 |
+
|
| 25 |
+
self.headers = nn.ModuleList()
|
| 26 |
+
for i in range(0, self.head_num):
|
| 27 |
+
self.headers.append(
|
| 28 |
+
nn.ModuleList([
|
| 29 |
+
nn.Linear(160 + self.base_channels, self.base_channels),
|
| 30 |
+
nn.Linear(self.base_channels, self.out_channels),
|
| 31 |
+
])
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
x = self.pool(x)
|
| 36 |
+
n, c, _, _ = x.shape
|
| 37 |
+
x = x.view(n, c)
|
| 38 |
+
|
| 39 |
+
f = self.f(x)
|
| 40 |
+
g = self.g(x)
|
| 41 |
+
|
| 42 |
+
pred_args = []
|
| 43 |
+
for i in range(0, self.head_num):
|
| 44 |
+
g = self.headers[i][0](torch.cat((f, g), dim=1))
|
| 45 |
+
pred_args.append(self.headers[i][1](g))
|
| 46 |
+
|
| 47 |
+
return pred_args
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class FilterPerformer(nn.Module):
|
| 51 |
+
def __init__(self, filter_types):
|
| 52 |
+
super(FilterPerformer, self).__init__()
|
| 53 |
+
|
| 54 |
+
self.filters = [FILTER_MODULES[filter_type]() for filter_type in filter_types]
|
| 55 |
+
|
| 56 |
+
def forward(self):
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
def restore(self, x, mask, arguments):
|
| 60 |
+
assert len(self.filters) == len(arguments)
|
| 61 |
+
|
| 62 |
+
outputs = []
|
| 63 |
+
_image = x
|
| 64 |
+
for filter, arg in zip(self.filters, arguments):
|
| 65 |
+
_image = filter(_image, arg)
|
| 66 |
+
outputs.append(_image * mask + x * (1 - mask))
|
| 67 |
+
|
| 68 |
+
return outputs
|
| 69 |
+
|
| 70 |
+
def adjust(self, image, mask, arguments):
|
| 71 |
+
assert len(self.filters) == len(arguments)
|
| 72 |
+
|
| 73 |
+
outputs = []
|
| 74 |
+
_image = image
|
| 75 |
+
for filter, arg in zip(reversed(self.filters), reversed(arguments)):
|
| 76 |
+
_image = filter(_image, arg)
|
| 77 |
+
outputs.append(_image * mask + image * (1 - mask))
|
| 78 |
+
|
| 79 |
+
return outputs
|
| 80 |
+
|
harmonizer/src/train/harmonizer/proxy.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchtask
|
| 2 |
+
|
| 3 |
+
import func, data, model, criterion, trainer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def add_parser_arguments(parser):
|
| 7 |
+
torchtask.proxy_template.add_parser_arguments(parser)
|
| 8 |
+
|
| 9 |
+
data.add_parser_arguments(parser)
|
| 10 |
+
model.add_parser_arguments(parser)
|
| 11 |
+
criterion.add_parser_arguments(parser)
|
| 12 |
+
trainer.add_parser_arguments(parser)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class HarmonizerProxy(torchtask.proxy_template.TaskProxy):
|
| 16 |
+
|
| 17 |
+
NAME = 'harmonizer'
|
| 18 |
+
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
super(HarmonizerProxy, self).__init__(args, func, data, model, criterion, trainer)
|
harmonizer/src/train/harmonizer/script/train.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import collections
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
sys.path.append('..')
|
| 7 |
+
import torchtask
|
| 8 |
+
|
| 9 |
+
import proxy
|
| 10 |
+
|
| 11 |
+
config = collections.OrderedDict(
|
| 12 |
+
[
|
| 13 |
+
('exp_id', os.path.basename(__file__).split(".")[0]),
|
| 14 |
+
|
| 15 |
+
('trainer', 'harmonizer_trainer'),
|
| 16 |
+
|
| 17 |
+
# arguments - Task Proxy
|
| 18 |
+
('short_ep', False),
|
| 19 |
+
|
| 20 |
+
# arguments - exp
|
| 21 |
+
('resume', ''),
|
| 22 |
+
('validation', False),
|
| 23 |
+
|
| 24 |
+
('out_path', 'result'),
|
| 25 |
+
|
| 26 |
+
('visualize', False),
|
| 27 |
+
('debug', False),
|
| 28 |
+
|
| 29 |
+
('val_freq', 1),
|
| 30 |
+
('log_freq', 100),
|
| 31 |
+
('visual_freq', 100),
|
| 32 |
+
('checkpoint_freq', 1),
|
| 33 |
+
|
| 34 |
+
# arguments - dataset / dataloader
|
| 35 |
+
('im_size', 256),
|
| 36 |
+
('num_workers', 4),
|
| 37 |
+
('ignore_additional', False),
|
| 38 |
+
|
| 39 |
+
('trainset', {
|
| 40 |
+
'harmonizer_iharmony4': [
|
| 41 |
+
'./dataset/iHarmony4/HAdobe5k/train',
|
| 42 |
+
'./dataset/iHarmony4/HCOCO/train',
|
| 43 |
+
'./dataset/iHarmony4/Hday2night/train',
|
| 44 |
+
'./dataset/iHarmony4/HFlickr/train',
|
| 45 |
+
]
|
| 46 |
+
}),
|
| 47 |
+
('additionalset', {
|
| 48 |
+
'original_iharmony4': [
|
| 49 |
+
'./dataset/iHarmony4/HAdobe5k/train',
|
| 50 |
+
'./dataset/iHarmony4/HCOCO/train',
|
| 51 |
+
'./dataset/iHarmony4/Hday2night/train',
|
| 52 |
+
'./dataset/iHarmony4/HFlickr/train',
|
| 53 |
+
],
|
| 54 |
+
}),
|
| 55 |
+
('valset', {
|
| 56 |
+
'original_iharmony4': [
|
| 57 |
+
'./dataset/iHarmony4/HAdobe5k/test',
|
| 58 |
+
'./dataset/iHarmony4/HCOCO/test',
|
| 59 |
+
'./dataset/iHarmony4/Hday2night/test',
|
| 60 |
+
'./dataset/iHarmony4/HFlickr/test',
|
| 61 |
+
]
|
| 62 |
+
}),
|
| 63 |
+
|
| 64 |
+
# arguments - task specific components
|
| 65 |
+
('models', {'model': 'harmonizer'}),
|
| 66 |
+
('optimizers', {'model': 'adam'}),
|
| 67 |
+
('lrers', {'model': 'multisteplr'}),
|
| 68 |
+
('criterions', {'model': 'harmonizer_loss'}),
|
| 69 |
+
|
| 70 |
+
# arguments - task specific optimizer / lr scheduler
|
| 71 |
+
('lr', 0.0003),
|
| 72 |
+
|
| 73 |
+
('milestones', [25, 50]),
|
| 74 |
+
('gamma', 0.1),
|
| 75 |
+
|
| 76 |
+
# arguments - training details
|
| 77 |
+
('epochs', 60),
|
| 78 |
+
('batch_size', 16),
|
| 79 |
+
('additional_batch_size', 8),
|
| 80 |
+
]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == '__main__':
|
| 85 |
+
torchtask.run_script(config, proxy, proxy.HarmonizerProxy)
|
harmonizer/src/train/harmonizer/trainer.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.autograd import Variable
|
| 8 |
+
|
| 9 |
+
import torchtask
|
| 10 |
+
from torchtask.utils import logger, cmd, tool
|
| 11 |
+
from torchtask.nn import func
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def add_parser_arguments(parser):
|
| 15 |
+
torchtask.trainer_template.add_parser_arguments(parser)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def harmonizer_trainer(args, model_dict, optimizer_dict, lrer_dict, criterion_dict, task_func):
|
| 19 |
+
model_funcs = [model_dict['model']]
|
| 20 |
+
optimizer_funcs = [optimizer_dict['model']]
|
| 21 |
+
lrer_funcs = [lrer_dict['model']]
|
| 22 |
+
criterion_funcs = [criterion_dict['model']]
|
| 23 |
+
|
| 24 |
+
algorithm = HarmonizerTrainer(args)
|
| 25 |
+
algorithm.build(model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func)
|
| 26 |
+
return algorithm
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class HarmonizerTrainer(torchtask.trainer_template.TaskTrainer):
|
| 30 |
+
def __init__(self, args):
|
| 31 |
+
super(HarmonizerTrainer, self).__init__(args)
|
| 32 |
+
|
| 33 |
+
self.model = None
|
| 34 |
+
self.optimizer = None
|
| 35 |
+
self.lrer = None
|
| 36 |
+
self.criterion = None
|
| 37 |
+
|
| 38 |
+
def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func):
|
| 39 |
+
self.task_func = task_func
|
| 40 |
+
|
| 41 |
+
self.model = func.create_model(model_funcs[0], 'model', args=self.args)
|
| 42 |
+
self.models = {'model': self.model}
|
| 43 |
+
|
| 44 |
+
self.optimizer = optimizer_funcs[0](self.model.module.param_groups)
|
| 45 |
+
self.optimizers = {'optimizer': self.optimizer}
|
| 46 |
+
|
| 47 |
+
self.lrer = lrer_funcs[0](self.optimizer)
|
| 48 |
+
self.lrers = {'lrer': self.lrer}
|
| 49 |
+
|
| 50 |
+
self.criterion = criterion_funcs[0](self.args)
|
| 51 |
+
self.criterions = {'criterion': self.criterion}
|
| 52 |
+
|
| 53 |
+
def _train(self, data_loader, epoch):
|
| 54 |
+
self.meters.reset()
|
| 55 |
+
|
| 56 |
+
lbs = self.args.labeled_batch_size
|
| 57 |
+
|
| 58 |
+
self.model.train()
|
| 59 |
+
|
| 60 |
+
timer = time.time()
|
| 61 |
+
for idx, (inp, gt) in enumerate(data_loader):
|
| 62 |
+
# pre-process input tensor and ground truth tensor
|
| 63 |
+
inp, gt = self._batch_prehandle(inp, gt, True)
|
| 64 |
+
x, mask = inp
|
| 65 |
+
|
| 66 |
+
# forword the model
|
| 67 |
+
self.optimizer.zero_grad()
|
| 68 |
+
resulter, debugger = self.model(inp)
|
| 69 |
+
|
| 70 |
+
pred_outputs = tool.dict_value(resulter, 'outputs')
|
| 71 |
+
|
| 72 |
+
# calculate loss for the fine labeled data
|
| 73 |
+
l_pred_outputs = func.split_tensor_tuple(pred_outputs, 0, lbs)
|
| 74 |
+
l_pred = (l_pred_outputs, )
|
| 75 |
+
|
| 76 |
+
l_gt = func.split_tensor_tuple(gt, 0, lbs)
|
| 77 |
+
l_inp = func.split_tensor_tuple(inp, 0, lbs)
|
| 78 |
+
|
| 79 |
+
l_image_losses = self.criterion(l_pred, l_gt, l_inp)
|
| 80 |
+
|
| 81 |
+
# if self.args.dynamic_loss:
|
| 82 |
+
sum_losses = l_image_losses[0].detach()
|
| 83 |
+
for i in range(1, len(l_image_losses)):
|
| 84 |
+
sum_losses = sum_losses + \
|
| 85 |
+
(l_image_losses[i].detach() - l_image_losses[i-1].detach()) * ((l_image_losses[i].detach() - l_image_losses[i-1].detach()) > 0).float()
|
| 86 |
+
sum_losses = sum_losses + 1e-9
|
| 87 |
+
sum_losses = sum_losses.detach()
|
| 88 |
+
|
| 89 |
+
scaled_l_image_losses = [torch.mean(l_image_losses[0] / sum_losses)]
|
| 90 |
+
self.meters.update('fine_filter_0_loss', torch.mean(l_image_losses[0] / sum_losses).item())
|
| 91 |
+
|
| 92 |
+
for i in range(1, len(l_image_losses)):
|
| 93 |
+
loss = (l_image_losses[i] - l_image_losses[i-1].detach()) / sum_losses
|
| 94 |
+
loss = loss * (loss > 0).float()
|
| 95 |
+
loss = torch.mean(loss)
|
| 96 |
+
scaled_l_image_losses.append(loss)
|
| 97 |
+
self.meters.update('fine_filter_{0}_loss'.format(i), loss.item())
|
| 98 |
+
|
| 99 |
+
# calculate loss for the coarse labeled data
|
| 100 |
+
if not self.args.ignore_additional:
|
| 101 |
+
u_pred_outputs = func.split_tensor_tuple(pred_outputs, lbs, self.args.batch_size)
|
| 102 |
+
u_pred_outputs = (u_pred_outputs[-1], )
|
| 103 |
+
u_pred = (u_pred_outputs, )
|
| 104 |
+
|
| 105 |
+
u_gt = func.split_tensor_tuple(gt, lbs, self.args.batch_size)
|
| 106 |
+
u_gt = (u_gt[-1], )
|
| 107 |
+
|
| 108 |
+
u_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size)
|
| 109 |
+
|
| 110 |
+
u_image_losses = self.criterion(u_pred, u_gt, u_inp)
|
| 111 |
+
|
| 112 |
+
u_image_loss = torch.mean(u_image_losses[0]) * 10
|
| 113 |
+
|
| 114 |
+
self.meters.update('coarse_filter_loss', u_image_loss.item())
|
| 115 |
+
else:
|
| 116 |
+
self.meters.update('coarse_filter_loss', torch.mean(torch.zeros(1)).item())
|
| 117 |
+
|
| 118 |
+
# calculate the sum of all losses
|
| 119 |
+
loss = 0
|
| 120 |
+
for l_image_loss in scaled_l_image_losses:
|
| 121 |
+
loss = loss + l_image_loss
|
| 122 |
+
loss = loss + u_image_loss
|
| 123 |
+
|
| 124 |
+
# backward and update
|
| 125 |
+
loss.backward()
|
| 126 |
+
self.optimizer.step()
|
| 127 |
+
|
| 128 |
+
# logging
|
| 129 |
+
self.meters.update('batch_time', time.time() - timer)
|
| 130 |
+
if idx % self.args.log_freq == 0:
|
| 131 |
+
logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}'.format(epoch+1, idx, len(data_loader), meters=self.meters))
|
| 132 |
+
logger.log_info('\tfine-filter-0-loss: {meters[fine_filter_0_loss]:.6f}'.format(meters=self.meters))
|
| 133 |
+
logger.log_info('\tfine-filter-1-loss: {meters[fine_filter_1_loss]:.6f}'.format(meters=self.meters))
|
| 134 |
+
logger.log_info('\tfine-filter-2-loss: {meters[fine_filter_2_loss]:.6f}'.format(meters=self.meters))
|
| 135 |
+
logger.log_info('\tfine-filter-3-loss: {meters[fine_filter_3_loss]:.6f}'.format(meters=self.meters))
|
| 136 |
+
logger.log_info('\tfine-filter-4-loss: {meters[fine_filter_4_loss]:.6f}'.format(meters=self.meters))
|
| 137 |
+
logger.log_info('\tfine-filter-5-loss: {meters[fine_filter_5_loss]:.6f}'.format(meters=self.meters))
|
| 138 |
+
logger.log_info('\tcoarse-filter-loss: {meters[coarse_filter_loss]:.6f}'.format(meters=self.meters))
|
| 139 |
+
|
| 140 |
+
# visualization
|
| 141 |
+
if self.args.visualize and idx % self.args.visual_freq == 0:
|
| 142 |
+
self._visualization(
|
| 143 |
+
epoch, idx, True,
|
| 144 |
+
func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
|
| 145 |
+
func.split_tensor_tuple(pred_outputs, 0, 1, reduce_dim=True),
|
| 146 |
+
func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))
|
| 147 |
+
|
| 148 |
+
# update iteration-based lrers
|
| 149 |
+
if not self.args.is_epoch_lrer:
|
| 150 |
+
self.lrer.step()
|
| 151 |
+
|
| 152 |
+
timer = time.time()
|
| 153 |
+
|
| 154 |
+
# update epoch-based lrers
|
| 155 |
+
if self.args.is_epoch_lrer:
|
| 156 |
+
self.lrer.step()
|
| 157 |
+
|
| 158 |
+
def _validate(self, data_loader, epoch):
|
| 159 |
+
self.meters.reset()
|
| 160 |
+
|
| 161 |
+
self.model.eval()
|
| 162 |
+
|
| 163 |
+
timer = time.time()
|
| 164 |
+
for idx, (inp, gt) in enumerate(data_loader):
|
| 165 |
+
inp, gt = self._batch_prehandle(inp, gt, False)
|
| 166 |
+
x, mask = inp
|
| 167 |
+
|
| 168 |
+
resulter, debugger = self.model(inp)
|
| 169 |
+
|
| 170 |
+
pred_outputs = tool.dict_value(resulter, 'outputs')
|
| 171 |
+
|
| 172 |
+
pred = (pred_outputs[-1], )
|
| 173 |
+
gt = (gt[-1], )
|
| 174 |
+
|
| 175 |
+
# calculate loss for the fine labeled data
|
| 176 |
+
losses = self.criterion.forward(pred, gt, inp)
|
| 177 |
+
loss = 0
|
| 178 |
+
for _loss in losses:
|
| 179 |
+
loss = loss + _loss
|
| 180 |
+
loss = loss / len(losses)
|
| 181 |
+
|
| 182 |
+
self.meters.update('loss', loss.item())
|
| 183 |
+
|
| 184 |
+
self.task_func.metrics(pred_outputs[-1].detach(), gt[-1], mask, self.meters, id_str='IH')
|
| 185 |
+
|
| 186 |
+
self.meters.update('batch_time', time.time() - timer)
|
| 187 |
+
if idx % self.args.log_freq == 0:
|
| 188 |
+
logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
|
| 189 |
+
'loss: {meters[loss]:.6f}\n'
|
| 190 |
+
.format(epoch+1, idx, len(data_loader), meters=self.meters))
|
| 191 |
+
|
| 192 |
+
if self.args.visualize:
|
| 193 |
+
self._visualization(
|
| 194 |
+
epoch, idx, False,
|
| 195 |
+
func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
|
| 196 |
+
func.split_tensor_tuple((pred_outputs[-1], ), 0, 1, reduce_dim=True),
|
| 197 |
+
func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))
|
| 198 |
+
|
| 199 |
+
timer = time.time()
|
| 200 |
+
|
| 201 |
+
metrics_info = {'IH': ''}
|
| 202 |
+
for key in sorted(list(self.meters.keys())):
|
| 203 |
+
if self.task_func.METRIC_STR in key:
|
| 204 |
+
for id_str in metrics_info.keys():
|
| 205 |
+
if key.startswith(id_str):
|
| 206 |
+
metrics_info[id_str] += '{0}: {1:.6}\t'.format(key, self.meters[key])
|
| 207 |
+
|
| 208 |
+
logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format(metrics_info['IH'].replace('_', '-')))
|
| 209 |
+
|
| 210 |
+
def _visualization(self, epoch, idx, is_train, inp, pred, gt):
|
| 211 |
+
visualize_path = self.args.visual_train_path if is_train else self.args.visual_val_path
|
| 212 |
+
out_path = os.path.join(visualize_path, '{0}_{1}'.format(epoch, idx))
|
| 213 |
+
|
| 214 |
+
x, mask = inp
|
| 215 |
+
|
| 216 |
+
x = (np.transpose(x.cpu().numpy(), (1, 2, 0)))
|
| 217 |
+
Image.fromarray((x * 255).astype('uint8')).save(out_path + '_1_0_x.jpg')
|
| 218 |
+
|
| 219 |
+
mask = mask[0].data.cpu().numpy()
|
| 220 |
+
Image.fromarray((mask * 255).astype('uint8'), mode='L').save(out_path + '_2_0_mask.jpg')
|
| 221 |
+
|
| 222 |
+
for idx, (pred_, gt_) in enumerate(zip(pred, gt)):
|
| 223 |
+
pred_ = (np.transpose(pred_.detach().cpu().numpy(), (1, 2, 0)))
|
| 224 |
+
Image.fromarray((pred_ * 255).astype('uint8')).save(out_path + '_1_{0}_pred_filter.jpg'.format(idx+1))
|
| 225 |
+
|
| 226 |
+
if torch.mean(gt_) != -999:
|
| 227 |
+
gt_ = (np.transpose(gt_.cpu().numpy(), (1, 2, 0)))
|
| 228 |
+
Image.fromarray((gt_ * 255).astype('uint8')).save(out_path + '_2_{0}_gt_filter.jpg'.format(idx+1))
|
| 229 |
+
|
| 230 |
+
def _save_checkpoint(self, epoch):
|
| 231 |
+
state = {
|
| 232 |
+
'epoch': epoch,
|
| 233 |
+
'model': self.model.state_dict(),
|
| 234 |
+
'optimizer': self.optimizer.state_dict(),
|
| 235 |
+
'lrer': self.lrer.state_dict(),
|
| 236 |
+
}
|
| 237 |
+
checkpoint = os.path.join(self.args.checkpoint_path, 'checkpoint_{0}.ckpt'.format(epoch))
|
| 238 |
+
|
| 239 |
+
torch.save(state, checkpoint)
|
| 240 |
+
|
| 241 |
+
def _load_checkpoint(self):
|
| 242 |
+
checkpoint = torch.load(self.args.resume)
|
| 243 |
+
self.model.load_state_dict(checkpoint['model'])
|
| 244 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
| 245 |
+
self.lrer.load_state_dict(checkpoint['lrer'])
|
| 246 |
+
return checkpoint['epoch']
|
| 247 |
+
|
| 248 |
+
def _batch_prehandle(self, inp, gt, is_train):
|
| 249 |
+
lbs = self.args.labeled_batch_size
|
| 250 |
+
ubs = self.args.additional_batch_size
|
| 251 |
+
|
| 252 |
+
# convert all input and ground truth to Variables
|
| 253 |
+
inp_var = []
|
| 254 |
+
for i in inp:
|
| 255 |
+
inp_var.append(Variable(i).cuda())
|
| 256 |
+
inp = tuple(inp_var)
|
| 257 |
+
|
| 258 |
+
gt_var = []
|
| 259 |
+
for g in gt:
|
| 260 |
+
gt_var.append(Variable(g).cuda())
|
| 261 |
+
gt = tuple(gt_var)
|
| 262 |
+
|
| 263 |
+
filter_num = len(self.model.module.model.filter_types)
|
| 264 |
+
|
| 265 |
+
if is_train:
|
| 266 |
+
# ----------------------------------------------------------------
|
| 267 |
+
# for fine labeled data, we generate the adjusted input
|
| 268 |
+
# ----------------------------------------------------------------
|
| 269 |
+
l_inp = func.split_tensor_tuple(inp, 0, lbs)
|
| 270 |
+
l_gt = func.split_tensor_tuple(gt, 0, lbs)
|
| 271 |
+
|
| 272 |
+
_, l_mask = l_inp
|
| 273 |
+
l_gt_image, = l_gt
|
| 274 |
+
|
| 275 |
+
n = l_gt_image.shape[0]
|
| 276 |
+
l_rand_arguments = [self._rand_adjustment_values(n) for _ in range(0, filter_num)]
|
| 277 |
+
|
| 278 |
+
l_x = self.model.module.adjust(l_gt_image, l_mask, l_rand_arguments)
|
| 279 |
+
|
| 280 |
+
l_inp = (l_x[-1], l_mask)
|
| 281 |
+
l_gt = []
|
| 282 |
+
for _ in reversed(l_x[:-1]):
|
| 283 |
+
l_gt.append(_)
|
| 284 |
+
l_gt.append(l_gt_image)
|
| 285 |
+
|
| 286 |
+
if not self.args.ignore_additional:
|
| 287 |
+
# ----------------------------------------------------------------
|
| 288 |
+
# for coarse labeled data, we use the existising adjusted input
|
| 289 |
+
# ----------------------------------------------------------------
|
| 290 |
+
u_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size)
|
| 291 |
+
u_gt = func.split_tensor_tuple(gt, lbs, self.args.batch_size)
|
| 292 |
+
|
| 293 |
+
u_gt_image, = u_gt
|
| 294 |
+
none_value = torch.ones(ubs).view(ubs, 1).cuda() * -999
|
| 295 |
+
none_im = u_gt_image.cuda() * 0 - 999
|
| 296 |
+
|
| 297 |
+
u_gt = [none_im for _ in range(0, filter_num)]
|
| 298 |
+
u_gt[-1] = u_gt_image
|
| 299 |
+
|
| 300 |
+
inp = func.combine_tensor_tuple(l_inp, u_inp, 0)
|
| 301 |
+
gt = func.combine_tensor_tuple(l_gt, u_gt, 0)
|
| 302 |
+
|
| 303 |
+
else:
|
| 304 |
+
inp = l_inp
|
| 305 |
+
gt = l_gt
|
| 306 |
+
|
| 307 |
+
else:
|
| 308 |
+
gt_image, = gt
|
| 309 |
+
|
| 310 |
+
none_value = torch.ones(1).view(1, 1).cuda() * -999
|
| 311 |
+
none_im = gt_image.cuda() * 0 - 999
|
| 312 |
+
|
| 313 |
+
gt = [none_im for _ in range(0, filter_num)]
|
| 314 |
+
gt[-1] = gt_image
|
| 315 |
+
|
| 316 |
+
return inp, gt
|
| 317 |
+
|
| 318 |
+
def _rand_adjustment_values(self, n):
|
| 319 |
+
x = torch.FloatTensor(np.random.uniform(-1, 1, n))
|
| 320 |
+
x = x.view(n, 1).cuda()
|
| 321 |
+
return x
|
| 322 |
+
|
harmonizer/src/train/torchtask/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from .utils import *
|
| 4 |
+
|
| 5 |
+
from .nn import *
|
| 6 |
+
|
| 7 |
+
from .template import *
|
| 8 |
+
|
| 9 |
+
from .runner import run_script
|
harmonizer/src/train/torchtask/nn/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .lrer import VALID_LRER
|
| 2 |
+
from .optimizer import VALID_OPTIMIZER
|
| 3 |
+
from .module import SynchronizedBatchNorm2d
|
harmonizer/src/train/torchtask/nn/data.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import itertools
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from torch.utils.data.sampler import Sampler
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
""" This file implements dataset wrappers and batch samplers for TorchTask.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class _TorchTaskDatasetWrapper(Dataset):
|
| 14 |
+
""" This is the superclass of TorchTask dataset wrapper.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
super(_TorchTaskDatasetWrapper, self).__init__()
|
| 19 |
+
|
| 20 |
+
self.labeled_idxs = [] # index of the labeled data
|
| 21 |
+
self.additional_idxs = [] # index of the additional data
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SplitUnlabeledWrapper(_TorchTaskDatasetWrapper):
|
| 25 |
+
""" Split the fully labeled dataset into a labeled subset and an
|
| 26 |
+
additional dataset based on a given sublabeled prefix list.
|
| 27 |
+
|
| 28 |
+
For a fully labeled dataset, a common operation is to remove the labels
|
| 29 |
+
of some samples and treat them as the additional samples.
|
| 30 |
+
|
| 31 |
+
This dataset wrapper implements the dataset-split operation by using
|
| 32 |
+
the given sublabeled prefix list. Samples whose prefix in the list
|
| 33 |
+
are treated as the labeled samples, while others samples are treated as
|
| 34 |
+
the additional samples.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, dataset, sublabeled_prefix, ignore_additional=False):
|
| 38 |
+
super(SplitUnlabeledWrapper, self).__init__()
|
| 39 |
+
|
| 40 |
+
self.dataset = dataset
|
| 41 |
+
self.sublabeled_prefix = sublabeled_prefix
|
| 42 |
+
self.ignore_additional = ignore_additional
|
| 43 |
+
|
| 44 |
+
self._split_labeled()
|
| 45 |
+
|
| 46 |
+
def __len__(self):
|
| 47 |
+
return self.dataset.__len__()
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, idx):
|
| 50 |
+
return self.dataset.__getitem__(idx)
|
| 51 |
+
|
| 52 |
+
def _split_labeled(self):
|
| 53 |
+
labeled_list, additional_list = [], []
|
| 54 |
+
for img in self.dataset.sample_list:
|
| 55 |
+
is_labeled = False
|
| 56 |
+
for pdx, prefix in enumerate(self.sublabeled_prefix):
|
| 57 |
+
if img.startswith(prefix):
|
| 58 |
+
labeled_list.append(img)
|
| 59 |
+
is_labeled = True
|
| 60 |
+
break
|
| 61 |
+
|
| 62 |
+
if not is_labeled:
|
| 63 |
+
additional_list.append(img)
|
| 64 |
+
|
| 65 |
+
labeled_size, additional_size = len(labeled_list), len(additional_list)
|
| 66 |
+
assert labeled_size + additional_size == len(self.dataset.sample_list)
|
| 67 |
+
|
| 68 |
+
if self.ignore_additional:
|
| 69 |
+
self.dataset.sample_list = labeled_list
|
| 70 |
+
self.dataset.idxs = [_ for _ in range(0, len(self.dataset.sample_list))]
|
| 71 |
+
self.labeled_idxs = self.dataset.idxs
|
| 72 |
+
self.additional_idxs = []
|
| 73 |
+
else:
|
| 74 |
+
self.dataset.sample_list = labeled_list + additional_list
|
| 75 |
+
self.dataset.idxs = [_ for _ in range(0, len(self.dataset.sample_list))]
|
| 76 |
+
self.labeled_idxs = [_ for _ in range(0, labeled_size)]
|
| 77 |
+
self.additional_idxs = [_ + labeled_size for _ in range(0, additional_size)]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class JointDatasetsWrapper(_TorchTaskDatasetWrapper):
|
| 81 |
+
""" Combine several datasets (can be labeled or additional) into one dataset.
|
| 82 |
+
|
| 83 |
+
This dataset wrapper will combine multiple given dataset into one big dataset.
|
| 84 |
+
The new dataset consists of a labeled subset and an additional subset.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, labeled_datasets, additional_datasets, ignore_additional=False):
|
| 88 |
+
super(JointDatasetsWrapper, self).__init__()
|
| 89 |
+
|
| 90 |
+
self.labeled_datasets = labeled_datasets
|
| 91 |
+
self.additional_datasets = additional_datasets
|
| 92 |
+
self.ignore_additional = ignore_additional
|
| 93 |
+
|
| 94 |
+
self.labeled_datasets_size = [len(d) for d in self.labeled_datasets]
|
| 95 |
+
self.additional_datasets_size = [len(d) for d in self.additional_datasets]
|
| 96 |
+
|
| 97 |
+
self.labeled_size = np.sum(np.asarray(self.labeled_datasets_size))
|
| 98 |
+
self.labeled_idxs = [_ for _ in range(0, self.labeled_size)]
|
| 99 |
+
|
| 100 |
+
self.additional_size = 0
|
| 101 |
+
if not self.ignore_additional:
|
| 102 |
+
self.additional_size = np.sum(np.asarray(self.additional_datasets_size))
|
| 103 |
+
self.additional_idxs = [self.labeled_size + _ for _ in range(0, self.additional_size)]
|
| 104 |
+
|
| 105 |
+
def __len__(self):
|
| 106 |
+
return int(self.labeled_size + self.additional_size)
|
| 107 |
+
|
| 108 |
+
def __getitem__(self, idx):
|
| 109 |
+
assert 0 <= idx < self.__len__()
|
| 110 |
+
|
| 111 |
+
if idx >= self.labeled_size:
|
| 112 |
+
idx -= self.labeled_size
|
| 113 |
+
datasets = self.additional_datasets
|
| 114 |
+
datasets_size = self.additional_datasets_size
|
| 115 |
+
else:
|
| 116 |
+
datasets = self.labeled_datasets
|
| 117 |
+
datasets_size = self.labeled_datasets_size
|
| 118 |
+
|
| 119 |
+
accumulated_idxs = 0
|
| 120 |
+
for ddx, dsize in enumerate(datasets_size):
|
| 121 |
+
accumulated_idxs += dsize
|
| 122 |
+
if idx < accumulated_idxs:
|
| 123 |
+
return datasets[ddx].__getitem__(idx - (accumulated_idxs - dsize))
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class TwoStreamBatchSampler(Sampler):
|
| 127 |
+
""" This two stream batch sampler is used to read data from '_TorchTaskDatasetWrapper'.
|
| 128 |
+
|
| 129 |
+
It iterates two sets of indices simultaneously to read mini-batch for TorchTask.
|
| 130 |
+
There are two sets of indices:
|
| 131 |
+
labeled_idxs, additional_idxs
|
| 132 |
+
An 'epoch' is defined by going through the longer indices once.
|
| 133 |
+
In each 'epoch', the shorter indices are iterated through as many times as needed.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __init__(self, labeled_idxs, additional_idxs, labeled_batch_size, additional_batch_size, short_ep=False):
|
| 137 |
+
self.labeled_idxs = labeled_idxs
|
| 138 |
+
self.additional_idxs = additional_idxs
|
| 139 |
+
self.labeled_batch_size = labeled_batch_size
|
| 140 |
+
self.additional_batch_size = additional_batch_size
|
| 141 |
+
|
| 142 |
+
assert len(self.labeled_idxs) >= self.labeled_batch_size > 0
|
| 143 |
+
assert len(self.additional_idxs) >= self.additional_batch_size > 0
|
| 144 |
+
|
| 145 |
+
self.additional_batchs = len(self.additional_idxs) // self.additional_batch_size
|
| 146 |
+
self.labeled_batchs = len(self.labeled_idxs) // self.labeled_batch_size
|
| 147 |
+
|
| 148 |
+
self.short_ep = short_ep
|
| 149 |
+
|
| 150 |
+
def __iter__(self):
|
| 151 |
+
if not self.short_ep:
|
| 152 |
+
if self.additional_batchs >= self.labeled_batchs:
|
| 153 |
+
additional_iter = self.iterate_once(self.additional_idxs)
|
| 154 |
+
labeled_iter = self.iterate_eternally(self.labeled_idxs)
|
| 155 |
+
else:
|
| 156 |
+
additional_iter = self.iterate_eternally(self.additional_idxs)
|
| 157 |
+
labeled_iter = self.iterate_once(self.labeled_idxs)
|
| 158 |
+
else:
|
| 159 |
+
if self.additional_batchs >= self.labeled_batchs:
|
| 160 |
+
additional_iter = self.iterate_eternally(self.additional_idxs)
|
| 161 |
+
labeled_iter = self.iterate_once(self.labeled_idxs)
|
| 162 |
+
else:
|
| 163 |
+
additional_iter = self.iterate_once(self.additional_idxs)
|
| 164 |
+
labeled_iter = self.iterate_eternally(self.labeled_idxs)
|
| 165 |
+
|
| 166 |
+
return (labeled_batch + additional_batch
|
| 167 |
+
for (labeled_batch, additional_batch) in zip(
|
| 168 |
+
self.grouper(labeled_iter, self.labeled_batch_size),
|
| 169 |
+
self.grouper(additional_iter, self.additional_batch_size)))
|
| 170 |
+
|
| 171 |
+
def __len__(self):
|
| 172 |
+
if self.short_ep:
|
| 173 |
+
return min(self.additional_batchs, self.labeled_batchs)
|
| 174 |
+
else:
|
| 175 |
+
return max(self.additional_batchs, self.labeled_batchs)
|
| 176 |
+
|
| 177 |
+
def iterate_once(self, iterable):
|
| 178 |
+
return np.random.permutation(iterable)
|
| 179 |
+
|
| 180 |
+
def iterate_eternally(self, indices):
|
| 181 |
+
def infinite_shuffles():
|
| 182 |
+
while True:
|
| 183 |
+
yield np.random.permutation(indices)
|
| 184 |
+
|
| 185 |
+
return itertools.chain.from_iterable(infinite_shuffles())
|
| 186 |
+
|
| 187 |
+
def grouper(self, iterable, n):
|
| 188 |
+
# e.g., grouper('ABCDEFG', 3) --> ABC DEF"
|
| 189 |
+
args = [iter(iterable)] * n
|
| 190 |
+
return zip(*args)
|
harmonizer/src/train/torchtask/nn/func.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from torchtask.utils import logger
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
""" This file provides tool functions for deep learning.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def sigmoid_rampup(current, rampup_length):
|
| 13 |
+
""" Exponential rampup from https://arxiv.org/abs/1610.02242 .
|
| 14 |
+
"""
|
| 15 |
+
if rampup_length == 0:
|
| 16 |
+
return 1.0
|
| 17 |
+
else:
|
| 18 |
+
current = np.clip(current, 0.0, rampup_length)
|
| 19 |
+
phase = 1.0 - current / rampup_length
|
| 20 |
+
return float(np.exp(-5.0 * phase * phase))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def split_tensor_tuple(ttuple, start, end, reduce_dim=False):
|
| 25 |
+
""" Slice each tensor in the input tuple by channel-dim.
|
| 26 |
+
|
| 27 |
+
Arguments:
|
| 28 |
+
ttuple (tuple): tuple of a torch.Tensor
|
| 29 |
+
start (int): start index of slicing
|
| 30 |
+
end (int): end index of slicing
|
| 31 |
+
reduce_dim (bool): whether reduce the channel-dim when end - start == 1
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
tuple: a sliced tensor tuple
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
result = []
|
| 38 |
+
|
| 39 |
+
if reduce_dim:
|
| 40 |
+
assert end - start == 1
|
| 41 |
+
|
| 42 |
+
for t in ttuple:
|
| 43 |
+
if end - start == 1 and reduce_dim:
|
| 44 |
+
result.append(t[start, ...])
|
| 45 |
+
else:
|
| 46 |
+
result.append(t[start:end, ...])
|
| 47 |
+
|
| 48 |
+
return tuple(result)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def combine_tensor_tuple(ttuple1, ttuple2, dim):
|
| 52 |
+
result = []
|
| 53 |
+
|
| 54 |
+
assert len(ttuple1) == len(ttuple2)
|
| 55 |
+
|
| 56 |
+
for t1, t2 in zip(ttuple1, ttuple2):
|
| 57 |
+
result.append(torch.cat((t1, t2), dim=dim))
|
| 58 |
+
|
| 59 |
+
return tuple(result)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def create_model(mclass, mname, **kwargs):
|
| 63 |
+
""" Create a nn.Module and setup it on multiple GPUs.
|
| 64 |
+
"""
|
| 65 |
+
model = mclass(**kwargs)
|
| 66 |
+
model = torch.nn.DataParallel(model)
|
| 67 |
+
model = model.cuda()
|
| 68 |
+
|
| 69 |
+
logger.log_info(' ' + '=' * 76 + '\n {0} parameters \n{1}'.format(mname, model_str(model)))
|
| 70 |
+
return model
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def model_str(module):
|
| 74 |
+
""" Output model structure and parameters number as strings.
|
| 75 |
+
"""
|
| 76 |
+
row_format = ' {name:<40} {shape:>20} = {total_size:>12,d}'
|
| 77 |
+
lines = [' ' + '-' * 76,]
|
| 78 |
+
|
| 79 |
+
params = list(module.named_parameters())
|
| 80 |
+
for name, param in params:
|
| 81 |
+
lines.append(row_format.format(name=name,
|
| 82 |
+
shape=' * '.join(str(p) for p in param.size()), total_size=param.numel()))
|
| 83 |
+
|
| 84 |
+
lines.append(' ' + '-' * 76)
|
| 85 |
+
lines.append(row_format.format(name='all parameters', shape='sum of above',
|
| 86 |
+
total_size=sum(int(param.numel()) for name, param in params)))
|
| 87 |
+
lines.append(' ' + '=' * 76)
|
| 88 |
+
lines.append('')
|
| 89 |
+
|
| 90 |
+
return '\n'.join(lines)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def pytorch_support(required_version='1.0.0', info_str=''):
|
| 94 |
+
if torch.__version__ < required_version:
|
| 95 |
+
logger.log_err('{0} required PyTorch >= {1}\n'
|
| 96 |
+
'However, current PyTorch == {2}\n'
|
| 97 |
+
.format(info_str, required_version, torch.__version__))
|
| 98 |
+
else:
|
| 99 |
+
return True
|
harmonizer/src/train/torchtask/nn/lrer.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
|
| 6 |
+
from torchtask.utils import cmd, logger
|
| 7 |
+
from torchtask.nn.func import pytorch_support
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
""" This file wraps the learning rate schedulers used in the script.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
EPOCH_LRERS = ['steplr', 'multisteplr', 'exponentiallr', 'cosineannealinglr']
|
| 15 |
+
ITER_LRERS = ['polynomiallr']
|
| 16 |
+
VALID_LRER = EPOCH_LRERS + ITER_LRERS
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def add_parser_arguments(parser):
|
| 20 |
+
""" Add the arguments related to the learning rate (LR) schedulers.
|
| 21 |
+
|
| 22 |
+
This 'add_parser_arguments' function will be called every time.
|
| 23 |
+
Please do not use the argument's name that are already defined in is function.
|
| 24 |
+
The default value '-1' means that the default value corresponding to
|
| 25 |
+
different LR schedulers will be used.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
parser.add_argument('--last-epoch', type=int, default=-1, metavar='',
|
| 29 |
+
help='lr scheduler - the index of last epoch required by [all]')
|
| 30 |
+
|
| 31 |
+
parser.add_argument('--step-size', type=int, default=-1, metavar='',
|
| 32 |
+
help='lr scheduler - period (epoch) of learning rate decay required by [steplr]')
|
| 33 |
+
parser.add_argument('--milestones', type=cmd.str2intlist, default=[], metavar='',
|
| 34 |
+
help='lr scheduler - increased list of epoch indices required by [multisteplr]')
|
| 35 |
+
parser.add_argument('--gamma', type=float, default=-1, metavar='',
|
| 36 |
+
help='lr scheduler - multiplicative factor of learning rate decay required by [steplr, multisteplr, exponentiallr]')
|
| 37 |
+
|
| 38 |
+
parser.add_argument('--T-max', type=int, default=-1, metavar='',
|
| 39 |
+
help='lr scheduler - maximum number of epochs required by [cosineannealinglr]')
|
| 40 |
+
parser.add_argument('--eta-min', type=float, default=-1, metavar='',
|
| 41 |
+
help='lr scheduler - minimum learning rate required by [cosineannealinglr]')
|
| 42 |
+
|
| 43 |
+
parser.add_argument('--power', type=float, default=-1, metavar='',
|
| 44 |
+
help='lr scheduler - power factor of learning rate decay required by [polynomiallr]')
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------
|
| 48 |
+
# Wrapper of Learning Rate Scheduler
|
| 49 |
+
# ---------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def steplr(args):
|
| 52 |
+
""" Wrapper of torch.optim.lr_scheduler.StepLR (PyTorch >= 1.0.0).
|
| 53 |
+
|
| 54 |
+
Sets the learning rate of each parameter group to the initial lr decayed by gamma every
|
| 55 |
+
step_size epochs. When last_epoch=-1, sets initial lr as lr.
|
| 56 |
+
"""
|
| 57 |
+
args.step_size = args.epochs if args.step_size == -1 else args.step_size
|
| 58 |
+
args.gamma = 0.1 if args.gamma == -1 else args.gamma
|
| 59 |
+
args.last_epoch = -1 if args.last_epoch == -1 else args.last_epoch
|
| 60 |
+
|
| 61 |
+
def steplr_wrapper(optimizer):
|
| 62 |
+
pytorch_support(required_version='1.0.0', info_str='LRScheduler - StepLR')
|
| 63 |
+
return optim.lr_scheduler.StepLR(
|
| 64 |
+
optimizer, step_size=args.step_size, gamma=args.gamma, last_epoch=args.last_epoch)
|
| 65 |
+
|
| 66 |
+
return steplr_wrapper
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def multisteplr(args):
|
| 70 |
+
""" Wrapper of torch.optim.lr_scheduler.MultiStepLR (PyTorch >= 1.0.0).
|
| 71 |
+
|
| 72 |
+
Set the learning rate of each parameter group to the initial lr decayed by gamma once the
|
| 73 |
+
number of epoch reaches one of the milestones. When last_epoch=-1, sets initial lr as lr.
|
| 74 |
+
"""
|
| 75 |
+
args.milestones = [i for i in range(1, args.epochs)] if args.milestones == [] else args.milestones
|
| 76 |
+
args.gamma = 0.1 if args.gamma == -1 else args.gamma
|
| 77 |
+
args.last_epoch = -1 if args.last_epoch == -1 else args.last_epoch
|
| 78 |
+
|
| 79 |
+
def multisteplr_wrapper(optimizer):
|
| 80 |
+
pytorch_support(required_version='1.0.0', info_str='LRScheduler - MultiStepLR')
|
| 81 |
+
return optim.lr_scheduler.MultiStepLR(
|
| 82 |
+
optimizer, milestones=args.milestones, gamma=args.gamma, last_epoch=args.last_epoch)
|
| 83 |
+
|
| 84 |
+
return multisteplr_wrapper
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def exponentiallr(args):
|
| 88 |
+
""" Wrapper of torch.optim.lr_scheduler.ExponentialLR (PyTorch >= 1.0.0).
|
| 89 |
+
|
| 90 |
+
Set the learning rate of each parameter group to the initial lr decayed by gamma every epoch.
|
| 91 |
+
When last_epoch=-1, sets initial lr as lr.
|
| 92 |
+
"""
|
| 93 |
+
args.gamma = 0.1 if args.gamma == -1 else args.gamma
|
| 94 |
+
args.last_epoch = -1 if args.last_epoch == -1 else args.last_epoch
|
| 95 |
+
|
| 96 |
+
def exponentiallr_wrapper(optimizer):
|
| 97 |
+
pytorch_support(required_version='1.0.0', info_str='LRScheduler - ExponentialLR')
|
| 98 |
+
return optim.lr_scheduler.ExponentialLR(
|
| 99 |
+
optimizer, gamma=args.gamma, last_epoch=args.last_epoch)
|
| 100 |
+
|
| 101 |
+
return exponentiallr_wrapper
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def cosineannealinglr(args):
|
| 105 |
+
""" Wrapper of torch.optim.lr_schduler.CosineAnnealingLR (PyTorch >= 1.0.0).
|
| 106 |
+
|
| 107 |
+
Set the learning rate of each parameter group using a cosine annealing schedule.
|
| 108 |
+
When last_epoch=-1, sets initial lr as lr.
|
| 109 |
+
"""
|
| 110 |
+
args.T_max = args.epochs if args.T_max == -1 else args.T_max
|
| 111 |
+
args.eta_min = 0 if args.eta_min == -1 else args.eta_min
|
| 112 |
+
args.last_epoch = -1 if args.last_epoch == -1 else args.last_epoch
|
| 113 |
+
|
| 114 |
+
def cosineannealinglr_wrapper(optimizer):
|
| 115 |
+
pytorch_support(required_version='1.0.0', info_str='LRScheduler - CosineAnnealingLR')
|
| 116 |
+
return optim.lr_scheduler.CosineAnnealingLR(
|
| 117 |
+
optimizer, T_max=args.T_max, eta_min=args.eta_min, last_epoch=args.last_epoch)
|
| 118 |
+
|
| 119 |
+
return cosineannealinglr_wrapper
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def polynomiallr(args):
|
| 123 |
+
""" Wrapper of torchtask.nn.lrer.PolynomialLR (PyTorch >= 1.0.0).
|
| 124 |
+
|
| 125 |
+
Set the learning rate of each parmeter group to the initial lr decayed by power every
|
| 126 |
+
iteration. When last_epoch=-1, sets initial lr as lr.
|
| 127 |
+
"""
|
| 128 |
+
args.power = 0.9 if args.power == -1 else args.power
|
| 129 |
+
args.last_epoch = -1 if args.last_epoch == -1 else args.last_epoch
|
| 130 |
+
|
| 131 |
+
def polynomiallr_wrapper(optimizer):
|
| 132 |
+
pytorch_support(required_version='1.0.0', info_str='LRScheduler - PolynomialLR')
|
| 133 |
+
return PolynomialLR(optimizer, epochs=args.epochs, iters_per_epoch=args.iters_per_epoch,
|
| 134 |
+
power=args.power, last_epoch=args.last_epoch)
|
| 135 |
+
|
| 136 |
+
return polynomiallr_wrapper
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ---------------------------------------------------------------------
|
| 140 |
+
# Implementation of Learning Rate Scheduler
|
| 141 |
+
# ---------------------------------------------------------------------
|
| 142 |
+
|
| 143 |
+
class PolynomialLR(torch.optim.lr_scheduler._LRScheduler):
|
| 144 |
+
""" Polynomial decay learning rate scheduler.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
def __init__(self, optimizer, epochs, iters_per_epoch, power=0.9, last_epoch=-1):
|
| 148 |
+
self.epochs = epochs
|
| 149 |
+
self.iters_per_epoch = iters_per_epoch
|
| 150 |
+
self.max_iters = self.epochs * self.iters_per_epoch
|
| 151 |
+
self.cur_iter = 0
|
| 152 |
+
self.power = power
|
| 153 |
+
self.is_warn = False
|
| 154 |
+
super(PolynomialLR, self).__init__(optimizer, last_epoch)
|
| 155 |
+
|
| 156 |
+
def get_lr(self):
|
| 157 |
+
return [base_lr * ((1 - float(self.cur_iter) / self.max_iters) ** self.power)
|
| 158 |
+
for base_lr in self.base_lrs]
|
| 159 |
+
|
| 160 |
+
def step(self, epoch=None):
|
| 161 |
+
if epoch is not None and epoch != 0:
|
| 162 |
+
# update lr after each epoch if epoch is given
|
| 163 |
+
# after each epoch, set epoch += 1 and call this function
|
| 164 |
+
if not self.is_warn:
|
| 165 |
+
logger.log_warn('PolynomialLR is designed for updating learning rate after each iteration.\n'
|
| 166 |
+
'However, it will be updated after each epoch now, please be careful.\n')
|
| 167 |
+
self.is_warn = True
|
| 168 |
+
|
| 169 |
+
self.last_epoch = epoch
|
| 170 |
+
assert self.last_epoch <= self.epochs
|
| 171 |
+
self.cur_iter = self.last_epoch * self.iters_per_epoch
|
| 172 |
+
|
| 173 |
+
elif epoch is None:
|
| 174 |
+
# update lr after each iteration if epoch is None
|
| 175 |
+
self.cur_iter += 1
|
| 176 |
+
self.last_epoch = math.floor(self.cur_iter / self.iters_per_epoch)
|
| 177 |
+
|
| 178 |
+
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
| 179 |
+
param_group['lr'] = lr
|
harmonizer/src/train/torchtask/nn/module/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .third_party import SynchronizedBatchNorm2d, patch_replication_callback
|
| 2 |
+
from .gaussian_blur import GaussianBlurLayer
|
| 3 |
+
from .gaussian_noise import GaussianNoiseLayer
|
harmonizer/src/train/torchtask/nn/module/gaussian_blur.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import scipy.ndimage
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from torchtask.utils import logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class GaussianBlurLayer(nn.Module):
|
| 12 |
+
""" Add Gaussian Blur to a 4D tensor
|
| 13 |
+
|
| 14 |
+
This layer takes a 4D tensor of {N, C, H, W} as input.
|
| 15 |
+
The Gaussian blur will be performed in given channel number (C) splitly.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, channels, kernel_size):
|
| 19 |
+
"""
|
| 20 |
+
Arguments:
|
| 21 |
+
channels (int): Channel for input tensor
|
| 22 |
+
kernel_size (int): Size of the kernel used in blurring
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
super(GaussianBlurLayer, self).__init__()
|
| 26 |
+
self.channels = channels
|
| 27 |
+
self.kernel_size = kernel_size
|
| 28 |
+
assert self.kernel_size % 2 != 0
|
| 29 |
+
|
| 30 |
+
self.op = nn.Sequential(
|
| 31 |
+
nn.ReflectionPad2d(math.floor(self.kernel_size / 2)),
|
| 32 |
+
nn.Conv2d(channels, channels, self.kernel_size,
|
| 33 |
+
stride=1, padding=0, bias=None, groups=channels)
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
self._init_kernel()
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
"""
|
| 40 |
+
Arguments:
|
| 41 |
+
x (torch.Tensor): input 4D tensor
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
torch.Tensor: Blurred version of the input
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
if not len(list(x.shape)) == 4:
|
| 48 |
+
logger.log_err('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
|
| 49 |
+
elif not x.shape[1] == self.channels:
|
| 50 |
+
logger.log_err('In \'GaussianBlurLayer\', the required channel ({0}) is'
|
| 51 |
+
'not the same as input ({1})\n'.format(self.channels, x.shape[1]))
|
| 52 |
+
|
| 53 |
+
return self.op(x)
|
| 54 |
+
|
| 55 |
+
def _init_kernel(self):
|
| 56 |
+
sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
|
| 57 |
+
|
| 58 |
+
n = np.zeros((self.kernel_size, self.kernel_size))
|
| 59 |
+
i = math.floor(self.kernel_size / 2)
|
| 60 |
+
n[i, i] = 1
|
| 61 |
+
kernel = scipy.ndimage.gaussian_filter(n, sigma)
|
| 62 |
+
|
| 63 |
+
for name, param in self.named_parameters():
|
| 64 |
+
param.data.copy_(torch.from_numpy(kernel))
|
harmonizer/src/train/torchtask/nn/module/gaussian_noise.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class GaussianNoiseLayer(nn.Module):
|
| 8 |
+
""" Add Gaussian noise to a 4D tensor
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, std):
|
| 12 |
+
super(GaussianNoiseLayer, self).__init__()
|
| 13 |
+
self.std = std
|
| 14 |
+
self.noise = torch.zeros(0)
|
| 15 |
+
self.enable = False if self.std is None else True
|
| 16 |
+
|
| 17 |
+
def forward(self, inp):
|
| 18 |
+
if not self.enable:
|
| 19 |
+
return inp
|
| 20 |
+
|
| 21 |
+
if self.noise.shape != inp.shape:
|
| 22 |
+
self.noise = torch.zeros(inp.shape).cuda()
|
| 23 |
+
self.noise.data.normal_(0, std=random.uniform(0, self.std))
|
| 24 |
+
|
| 25 |
+
imax = inp.max(dim=3, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
|
| 26 |
+
imin = inp.min(dim=3, keepdim=True)[0].min(dim=2, keepdim=True)[0].min(dim=1, keepdim=True)[0]
|
| 27 |
+
|
| 28 |
+
# normalize to [0, 1]
|
| 29 |
+
inp.sub_(imin).div_(imax - imin + 1e-9)
|
| 30 |
+
# add noise
|
| 31 |
+
inp.add_(self.noise)
|
| 32 |
+
# clip to [0, 1]
|
| 33 |
+
upper_bound = (inp > 1.0).float()
|
| 34 |
+
lower_bound = (inp < 0.0).float()
|
| 35 |
+
inp.mul_(1 - upper_bound).add_(upper_bound)
|
| 36 |
+
inp.mul_(1 - lower_bound)
|
| 37 |
+
# de-normalize
|
| 38 |
+
inp.mul_(imax - imin + 1e-9).add_(imin)
|
| 39 |
+
|
| 40 |
+
return inp
|
harmonizer/src/train/torchtask/nn/module/third_party/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sync_batchnorm import SynchronizedBatchNorm2d, patch_replication_callback
|
harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : __init__.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : maojiayuan@gmail.com
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
| 12 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/batchnorm.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : batchnorm.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : maojiayuan@gmail.com
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import collections
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 17 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
| 18 |
+
|
| 19 |
+
from .comm import SyncMaster
|
| 20 |
+
|
| 21 |
+
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _sum_ft(tensor):
|
| 25 |
+
"""sum over the first and last dimention"""
|
| 26 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _unsqueeze_ft(tensor):
|
| 30 |
+
"""add new dementions at the front and the tail"""
|
| 31 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
| 35 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
| 39 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
| 40 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
| 41 |
+
|
| 42 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
| 43 |
+
|
| 44 |
+
self._is_parallel = False
|
| 45 |
+
self._parallel_id = None
|
| 46 |
+
self._slave_pipe = None
|
| 47 |
+
|
| 48 |
+
def forward(self, input):
|
| 49 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
| 50 |
+
if not (self._is_parallel and self.training):
|
| 51 |
+
return F.batch_norm(
|
| 52 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
| 53 |
+
self.training, self.momentum, self.eps)
|
| 54 |
+
|
| 55 |
+
# Resize the input to (B, C, -1).
|
| 56 |
+
input_shape = input.size()
|
| 57 |
+
input = input.view(input.size(0), self.num_features, -1)
|
| 58 |
+
|
| 59 |
+
# Compute the sum and square-sum.
|
| 60 |
+
sum_size = input.size(0) * input.size(2)
|
| 61 |
+
input_sum = _sum_ft(input)
|
| 62 |
+
input_ssum = _sum_ft(input ** 2)
|
| 63 |
+
|
| 64 |
+
# Reduce-and-broadcast the statistics.
|
| 65 |
+
if self._parallel_id == 0:
|
| 66 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
| 67 |
+
else:
|
| 68 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
| 69 |
+
|
| 70 |
+
# Compute the output.
|
| 71 |
+
if self.affine:
|
| 72 |
+
# MJY:: Fuse the multiplication for speed.
|
| 73 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
| 74 |
+
else:
|
| 75 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
| 76 |
+
|
| 77 |
+
# Reshape it.
|
| 78 |
+
return output.view(input_shape)
|
| 79 |
+
|
| 80 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
| 81 |
+
self._is_parallel = True
|
| 82 |
+
self._parallel_id = copy_id
|
| 83 |
+
|
| 84 |
+
# parallel_id == 0 means master device.
|
| 85 |
+
if self._parallel_id == 0:
|
| 86 |
+
ctx.sync_master = self._sync_master
|
| 87 |
+
else:
|
| 88 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
| 89 |
+
|
| 90 |
+
def _data_parallel_master(self, intermediates):
|
| 91 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
| 92 |
+
|
| 93 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
| 94 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
| 95 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
| 96 |
+
|
| 97 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
| 98 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
| 99 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
| 100 |
+
|
| 101 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
| 102 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
| 103 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
| 104 |
+
|
| 105 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
| 106 |
+
|
| 107 |
+
outputs = []
|
| 108 |
+
for i, rec in enumerate(intermediates):
|
| 109 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
|
| 110 |
+
|
| 111 |
+
return outputs
|
| 112 |
+
|
| 113 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
| 114 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
| 115 |
+
also maintains the moving average on the master device."""
|
| 116 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
| 117 |
+
mean = sum_ / size
|
| 118 |
+
sumvar = ssum - sum_ * mean
|
| 119 |
+
unbias_var = sumvar / (size - 1)
|
| 120 |
+
bias_var = sumvar / size
|
| 121 |
+
|
| 122 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
| 123 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
| 124 |
+
|
| 125 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
| 129 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
| 130 |
+
mini-batch.
|
| 131 |
+
.. math::
|
| 132 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| 133 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
| 134 |
+
standard-deviation are reduced across all devices during training.
|
| 135 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
| 136 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
| 137 |
+
the statistics only on that device, which accelerated the computation and
|
| 138 |
+
is also easy to implement, but the statistics might be inaccurate.
|
| 139 |
+
Instead, in this synchronized version, the statistics will be computed
|
| 140 |
+
over all training samples distributed on multiple devices.
|
| 141 |
+
|
| 142 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| 143 |
+
as the built-in PyTorch implementation.
|
| 144 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 145 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
| 146 |
+
of size C (where C is the input size).
|
| 147 |
+
During training, this layer keeps a running estimate of its computed mean
|
| 148 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
| 149 |
+
During evaluation, this running mean/variance is used for normalization.
|
| 150 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
| 151 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
| 152 |
+
Args:
|
| 153 |
+
num_features: num_features from an expected input of size
|
| 154 |
+
`batch_size x num_features [x width]`
|
| 155 |
+
eps: a value added to the denominator for numerical stability.
|
| 156 |
+
Default: 1e-5
|
| 157 |
+
momentum: the value used for the running_mean and running_var
|
| 158 |
+
computation. Default: 0.1
|
| 159 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
| 160 |
+
affine parameters. Default: ``True``
|
| 161 |
+
Shape:
|
| 162 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
| 163 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
| 164 |
+
Examples:
|
| 165 |
+
>>> # With Learnable Parameters
|
| 166 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
| 167 |
+
>>> # Without Learnable Parameters
|
| 168 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
| 169 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
| 170 |
+
>>> output = m(input)
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def _check_input_dim(self, input):
|
| 174 |
+
if input.dim() != 2 and input.dim() != 3:
|
| 175 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
| 176 |
+
.format(input.dim()))
|
| 177 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
| 181 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
| 182 |
+
of 3d inputs
|
| 183 |
+
.. math::
|
| 184 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| 185 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
| 186 |
+
standard-deviation are reduced across all devices during training.
|
| 187 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
| 188 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
| 189 |
+
the statistics only on that device, which accelerated the computation and
|
| 190 |
+
is also easy to implement, but the statistics might be inaccurate.
|
| 191 |
+
Instead, in this synchronized version, the statistics will be computed
|
| 192 |
+
over all training samples distributed on multiple devices.
|
| 193 |
+
|
| 194 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| 195 |
+
as the built-in PyTorch implementation.
|
| 196 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 197 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
| 198 |
+
of size C (where C is the input size).
|
| 199 |
+
During training, this layer keeps a running estimate of its computed mean
|
| 200 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
| 201 |
+
During evaluation, this running mean/variance is used for normalization.
|
| 202 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
| 203 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
| 204 |
+
Args:
|
| 205 |
+
num_features: num_features from an expected input of
|
| 206 |
+
size batch_size x num_features x height x width
|
| 207 |
+
eps: a value added to the denominator for numerical stability.
|
| 208 |
+
Default: 1e-5
|
| 209 |
+
momentum: the value used for the running_mean and running_var
|
| 210 |
+
computation. Default: 0.1
|
| 211 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
| 212 |
+
affine parameters. Default: ``True``
|
| 213 |
+
Shape:
|
| 214 |
+
- Input: :math:`(N, C, H, W)`
|
| 215 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
| 216 |
+
Examples:
|
| 217 |
+
>>> # With Learnable Parameters
|
| 218 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
| 219 |
+
>>> # Without Learnable Parameters
|
| 220 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
| 221 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
| 222 |
+
>>> output = m(input)
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
def _check_input_dim(self, input):
|
| 226 |
+
if input.dim() != 4:
|
| 227 |
+
raise ValueError('expected 4D input (got {}D input)'
|
| 228 |
+
.format(input.dim()))
|
| 229 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
| 233 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
| 234 |
+
of 4d inputs
|
| 235 |
+
.. math::
|
| 236 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| 237 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
| 238 |
+
standard-deviation are reduced across all devices during training.
|
| 239 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
| 240 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
| 241 |
+
the statistics only on that device, which accelerated the computation and
|
| 242 |
+
is also easy to implement, but the statistics might be inaccurate.
|
| 243 |
+
Instead, in this synchronized version, the statistics will be computed
|
| 244 |
+
over all training samples distributed on multiple devices.
|
| 245 |
+
|
| 246 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| 247 |
+
as the built-in PyTorch implementation.
|
| 248 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 249 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
| 250 |
+
of size C (where C is the input size).
|
| 251 |
+
During training, this layer keeps a running estimate of its computed mean
|
| 252 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
| 253 |
+
During evaluation, this running mean/variance is used for normalization.
|
| 254 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
| 255 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
| 256 |
+
or Spatio-temporal BatchNorm
|
| 257 |
+
Args:
|
| 258 |
+
num_features: num_features from an expected input of
|
| 259 |
+
size batch_size x num_features x depth x height x width
|
| 260 |
+
eps: a value added to the denominator for numerical stability.
|
| 261 |
+
Default: 1e-5
|
| 262 |
+
momentum: the value used for the running_mean and running_var
|
| 263 |
+
computation. Default: 0.1
|
| 264 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
| 265 |
+
affine parameters. Default: ``True``
|
| 266 |
+
Shape:
|
| 267 |
+
- Input: :math:`(N, C, D, H, W)`
|
| 268 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
| 269 |
+
Examples:
|
| 270 |
+
>>> # With Learnable Parameters
|
| 271 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
| 272 |
+
>>> # Without Learnable Parameters
|
| 273 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
| 274 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
| 275 |
+
>>> output = m(input)
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
def _check_input_dim(self, input):
|
| 279 |
+
if input.dim() != 5:
|
| 280 |
+
raise ValueError('expected 5D input (got {}D input)'
|
| 281 |
+
.format(input.dim()))
|
| 282 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/comm.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : comm.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : maojiayuan@gmail.com
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import queue
|
| 12 |
+
import collections
|
| 13 |
+
import threading
|
| 14 |
+
|
| 15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FutureResult(object):
|
| 19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self._result = None
|
| 23 |
+
self._lock = threading.Lock()
|
| 24 |
+
self._cond = threading.Condition(self._lock)
|
| 25 |
+
|
| 26 |
+
def put(self, result):
|
| 27 |
+
with self._lock:
|
| 28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
| 29 |
+
self._result = result
|
| 30 |
+
self._cond.notify()
|
| 31 |
+
|
| 32 |
+
def get(self):
|
| 33 |
+
with self._lock:
|
| 34 |
+
if self._result is None:
|
| 35 |
+
self._cond.wait()
|
| 36 |
+
|
| 37 |
+
res = self._result
|
| 38 |
+
self._result = None
|
| 39 |
+
return res
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
| 43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SlavePipe(_SlavePipeBase):
|
| 47 |
+
"""Pipe for master-slave communication."""
|
| 48 |
+
|
| 49 |
+
def run_slave(self, msg):
|
| 50 |
+
self.queue.put((self.identifier, msg))
|
| 51 |
+
ret = self.result.get()
|
| 52 |
+
self.queue.put(True)
|
| 53 |
+
return ret
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SyncMaster(object):
|
| 57 |
+
"""An abstract `SyncMaster` object.
|
| 58 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
| 59 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
| 60 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
| 61 |
+
and passed to a registered callback.
|
| 62 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
| 63 |
+
back to each slave devices.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, master_callback):
|
| 67 |
+
"""
|
| 68 |
+
Args:
|
| 69 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
| 70 |
+
"""
|
| 71 |
+
self._master_callback = master_callback
|
| 72 |
+
self._queue = queue.Queue()
|
| 73 |
+
self._registry = collections.OrderedDict()
|
| 74 |
+
self._activated = False
|
| 75 |
+
|
| 76 |
+
def __getstate__(self):
|
| 77 |
+
return {'master_callback': self._master_callback}
|
| 78 |
+
|
| 79 |
+
def __setstate__(self, state):
|
| 80 |
+
self.__init__(state['master_callback'])
|
| 81 |
+
|
| 82 |
+
def register_slave(self, identifier):
|
| 83 |
+
"""
|
| 84 |
+
Register an slave device.
|
| 85 |
+
Args:
|
| 86 |
+
identifier: an identifier, usually is the device id.
|
| 87 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
| 88 |
+
"""
|
| 89 |
+
if self._activated:
|
| 90 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
| 91 |
+
self._activated = False
|
| 92 |
+
self._registry.clear()
|
| 93 |
+
future = FutureResult()
|
| 94 |
+
self._registry[identifier] = _MasterRegistry(future)
|
| 95 |
+
return SlavePipe(identifier, self._queue, future)
|
| 96 |
+
|
| 97 |
+
def run_master(self, master_msg):
|
| 98 |
+
"""
|
| 99 |
+
Main entry for the master device in each forward pass.
|
| 100 |
+
The messages were first collected from each devices (including the master device), and then
|
| 101 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
| 102 |
+
(including the master device).
|
| 103 |
+
Args:
|
| 104 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
| 105 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
| 106 |
+
Returns: the message to be sent back to the master device.
|
| 107 |
+
"""
|
| 108 |
+
self._activated = True
|
| 109 |
+
|
| 110 |
+
intermediates = [(0, master_msg)]
|
| 111 |
+
for i in range(self.nr_slaves):
|
| 112 |
+
intermediates.append(self._queue.get())
|
| 113 |
+
|
| 114 |
+
results = self._master_callback(intermediates)
|
| 115 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
| 116 |
+
|
| 117 |
+
for i, res in results:
|
| 118 |
+
if i == 0:
|
| 119 |
+
continue
|
| 120 |
+
self._registry[i].result.put(res)
|
| 121 |
+
|
| 122 |
+
for i in range(self.nr_slaves):
|
| 123 |
+
assert self._queue.get() is True
|
| 124 |
+
|
| 125 |
+
return results[0][1]
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def nr_slaves(self):
|
| 129 |
+
return len(self._registry)
|
harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/replicate.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : replicate.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : maojiayuan@gmail.com
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import functools
|
| 12 |
+
|
| 13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
'CallbackContext',
|
| 17 |
+
'execute_replication_callbacks',
|
| 18 |
+
'DataParallelWithCallback',
|
| 19 |
+
'patch_replication_callback'
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CallbackContext(object):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def execute_replication_callbacks(modules):
|
| 28 |
+
"""
|
| 29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
| 30 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
| 31 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
| 32 |
+
(shared among multiple copies of this module on different devices).
|
| 33 |
+
Through this context, different copies can share some information.
|
| 34 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
| 35 |
+
of any slave copies.
|
| 36 |
+
"""
|
| 37 |
+
master_copy = modules[0]
|
| 38 |
+
nr_modules = len(list(master_copy.modules()))
|
| 39 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
| 40 |
+
|
| 41 |
+
for i, module in enumerate(modules):
|
| 42 |
+
for j, m in enumerate(module.modules()):
|
| 43 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
| 44 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class DataParallelWithCallback(DataParallel):
|
| 48 |
+
"""
|
| 49 |
+
Data Parallel with a replication callback.
|
| 50 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
| 51 |
+
original `replicate` function.
|
| 52 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
| 53 |
+
Examples:
|
| 54 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 55 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
| 56 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def replicate(self, module, device_ids):
|
| 60 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
| 61 |
+
execute_replication_callbacks(modules)
|
| 62 |
+
return modules
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def patch_replication_callback(data_parallel):
|
| 66 |
+
"""
|
| 67 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
| 68 |
+
Useful when you have customized `DataParallel` implementation.
|
| 69 |
+
Examples:
|
| 70 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 71 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
| 72 |
+
> patch_replication_callback(sync_bn)
|
| 73 |
+
# this is equivalent to
|
| 74 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 75 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
assert isinstance(data_parallel, DataParallel)
|
| 79 |
+
|
| 80 |
+
old_replicate = data_parallel.replicate
|
| 81 |
+
|
| 82 |
+
@functools.wraps(old_replicate)
|
| 83 |
+
def new_replicate(module, device_ids):
|
| 84 |
+
modules = old_replicate(module, device_ids)
|
| 85 |
+
execute_replication_callbacks(modules)
|
| 86 |
+
return modules
|
| 87 |
+
|
| 88 |
+
data_parallel.replicate = new_replicate
|
harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/unittest.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : unittest.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : maojiayuan@gmail.com
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import unittest
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from torch.autograd import Variable
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def as_numpy(v):
|
| 18 |
+
if isinstance(v, Variable):
|
| 19 |
+
v = v.data
|
| 20 |
+
return v.cpu().numpy()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TorchTestCase(unittest.TestCase):
|
| 24 |
+
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
|
| 25 |
+
npa, npb = as_numpy(a), as_numpy(b)
|
| 26 |
+
self.assertTrue(
|
| 27 |
+
np.allclose(npa, npb, atol=atol),
|
| 28 |
+
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
|
| 29 |
+
)
|
harmonizer/src/train/torchtask/nn/optimizer.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from torch.optim.optimizer import Optimizer
|
| 7 |
+
|
| 8 |
+
from torchtask.utils import cmd
|
| 9 |
+
from torchtask.nn.func import pytorch_support
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
""" This file wraps the optimizers used in the script.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
VALID_OPTIMIZER = ['sgd', 'rmsprop', 'adam', 'wdadam']
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def add_parser_arguments(parser):
|
| 20 |
+
""" Add the arguments related to the optimizer.
|
| 21 |
+
|
| 22 |
+
This 'add_parser_arguments' function will be called every time.
|
| 23 |
+
Please do not use the argument's name that are already defined in is function.
|
| 24 |
+
The default value '-1' means that the default value corresponding to
|
| 25 |
+
different LR schedulers will be used.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
parser.add_argument('--lr', type=float, default=-1, metavar='',
|
| 29 |
+
help='optimizer - learning rate (required by [all])')
|
| 30 |
+
|
| 31 |
+
parser.add_argument('--dampening', type=float, default=-1, metavar='',
|
| 32 |
+
help='optimizer - dampening for momentum (required by [sgd])')
|
| 33 |
+
parser.add_argument('--nesterov', type=cmd.str2bool, default=False, metavar='',
|
| 34 |
+
help='optimizer - enables Nesterov momentum if True (required by [sgd])')
|
| 35 |
+
parser.add_argument('--weight-decay', type=float, default=-1, metavar='',
|
| 36 |
+
help='optimizer - weight decay (L2 penalty) (required by [sgd, rmsprop, adam, wdadam])')
|
| 37 |
+
parser.add_argument('--momentum', type=float, default=-1, metavar='',
|
| 38 |
+
help='optimizer - momentum factor (required by [sgd, rmsprop])')
|
| 39 |
+
parser.add_argument('--alpha', type=float, default=-1, metavar='',
|
| 40 |
+
help='smoothing constant (required by [rmsprop])')
|
| 41 |
+
parser.add_argument('--centered', type=cmd.str2bool, default=False, metavar='',
|
| 42 |
+
help='if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance ( required by [rmsprop])')
|
| 43 |
+
parser.add_argument('--eps', type=float, default=-1, metavar='',
|
| 44 |
+
help='optimizer - term added to the denominator to improve numerical stability (required by [rmsprop, adam, wdadam])')
|
| 45 |
+
parser.add_argument('--beta1', type=float, default=-1, metavar='',
|
| 46 |
+
help='optimizer - coefficients used for computing running averages of gradient and its square (required by [adam, wdadam])')
|
| 47 |
+
parser.add_argument('--beta2', type=float, default=-1, metavar='',
|
| 48 |
+
help='optimizer - coefficients used for computing running averages of gradient and its square (required by [adam, wdadam])')
|
| 49 |
+
parser.add_argument('--amsgrad', type=cmd.str2bool, default=False, metavar='',
|
| 50 |
+
help='optimizer - use the AMSGrad variant if True (required by [wdadam])')
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------
|
| 54 |
+
# Wrapper of Optimizer
|
| 55 |
+
# ---------------------------------------------------------------------
|
| 56 |
+
|
| 57 |
+
def sgd(args):
|
| 58 |
+
""" Wrapper of torch.optim.SGD (PyTorch >= 1.0.0).
|
| 59 |
+
|
| 60 |
+
Implements stochastic gradient descent (optionally with momentum).
|
| 61 |
+
"""
|
| 62 |
+
args.lr = 0.01 if args.lr == -1 else args.lr
|
| 63 |
+
args.weight_decay = 0 if args.weight_decay == -1 else args.weight_decay
|
| 64 |
+
args.momentum = 0 if args.momentum == -1 else args.momentum
|
| 65 |
+
args.dampening = 0 if args.dampening == -1 else args.dampening
|
| 66 |
+
args.nesterov = False if args.nesterov == False else args.nesterov
|
| 67 |
+
|
| 68 |
+
def sgd_wrapper(param_groups):
|
| 69 |
+
pytorch_support(required_version='1.0.0', info_str='Optimizer - SGD')
|
| 70 |
+
return optim.SGD(
|
| 71 |
+
param_groups,
|
| 72 |
+
lr=args.lr, momentum=args.momentum, dampening=args.dampening,
|
| 73 |
+
weight_decay=args.weight_decay, nesterov=args.nesterov)
|
| 74 |
+
|
| 75 |
+
return sgd_wrapper
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def rmsprop(args):
|
| 79 |
+
""" Wrapper of torch.optim.RMSprop (PyTorch >= 1.0.0).
|
| 80 |
+
|
| 81 |
+
Implements RMSprop algorithm.
|
| 82 |
+
Proposed by G. Hinton in his course.
|
| 83 |
+
The centered version first appears in Generating Sequences With Recurrent Neural Networks.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
args.lr = 0.01 if args.lr == -1 else args.lr
|
| 87 |
+
args.alpha = 0.99 if args.alpha == -1 else args.alpha
|
| 88 |
+
args.eps = 1e-08 if args.eps == -1 else args.eps
|
| 89 |
+
args.weight_decay = 0 if args.weight_decay == -1 else args.weight_decay
|
| 90 |
+
args.momentum = 0 if args.momentum == -1 else args.momentum
|
| 91 |
+
args.centered = False if args.centered == False else args.centered
|
| 92 |
+
|
| 93 |
+
def rmsprop_wrapper(param_groups):
|
| 94 |
+
pytorch_support(required_version='1.0.0', info_str='Optimizer - RMSprop')
|
| 95 |
+
return optim.RMSprop(
|
| 96 |
+
param_groups,
|
| 97 |
+
lr=args.lr, alpha=args.alpha, eps=args.eps, weight_decay=args.weight_decay,
|
| 98 |
+
momentum=args.momentum, centered=args.centered)
|
| 99 |
+
|
| 100 |
+
return rmsprop_wrapper
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def adam(args):
|
| 104 |
+
""" Wrapper of torch.optim.Adam (PyTorch >= 1.0.0).
|
| 105 |
+
|
| 106 |
+
Implements Adam algorithm.
|
| 107 |
+
It has been proposed in 'Adam: A Method for Stochastic Optimization'.
|
| 108 |
+
"""
|
| 109 |
+
args.lr = 0.001 if args.lr == -1 else args.lr
|
| 110 |
+
args.beta1 = 0.9 if args.beta1 == -1 else args.beta1
|
| 111 |
+
args.beta2 = 0.999 if args.beta2 == -1 else args.beta2
|
| 112 |
+
args.eps = 1e-08 if args.eps == -1 else args.eps
|
| 113 |
+
args.weight_decay = 0.0 if args.weight_decay == -1 else args.weight_decay
|
| 114 |
+
|
| 115 |
+
def adam_wrapper(param_groups):
|
| 116 |
+
pytorch_support(required_version='1.0.0', info_str='Optimizer - Adam')
|
| 117 |
+
return optim.Adam(
|
| 118 |
+
param_groups,
|
| 119 |
+
lr=args.lr, betas=(args.beta1, args.beta2), eps=args.eps,
|
| 120 |
+
weight_decay=args.weight_decay)
|
| 121 |
+
|
| 122 |
+
return adam_wrapper
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def wdadam(args):
|
| 126 |
+
""" Wrapper of torchtask.nn.optimizer.WDAdam (PyTorch >= 1.0.0).
|
| 127 |
+
|
| 128 |
+
Implements Adam algorithm with weight decay and AMSGrad.
|
| 129 |
+
"""
|
| 130 |
+
args.lr = 0.001 if args.lr == -1 else args.lr
|
| 131 |
+
args.beta1 = 0.9 if args.beta1 == -1 else args.beta1
|
| 132 |
+
args.beta2 = 0.999 if args.beta2 == -1 else args.beta2
|
| 133 |
+
args.eps = 1e-08 if args.eps == -1 else args.eps
|
| 134 |
+
args.weight_decay = 0.0 if args.weight_decay == -1 else args.weight_decay
|
| 135 |
+
args.amsgrad = False if args.amsgrad == False else args.amsgrad
|
| 136 |
+
|
| 137 |
+
def wdadam_wrapper(param_groups):
|
| 138 |
+
pytorch_support(required_version='1.0.0', info_str='Optimizer - WDAdam')
|
| 139 |
+
return WDAdam(
|
| 140 |
+
param_groups,
|
| 141 |
+
lr=args.lr, betas=(args.beta1, args.beta2), eps=args.eps,
|
| 142 |
+
weight_decay=args.weight_decay, amsgrad=args.amsgrad)
|
| 143 |
+
|
| 144 |
+
return wdadam_wrapper
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ---------------------------------------------------------------------
|
| 148 |
+
# Implementation of Optimizer
|
| 149 |
+
# ---------------------------------------------------------------------
|
| 150 |
+
|
| 151 |
+
class WDAdam(Optimizer):
|
| 152 |
+
""" Implements Adam algorithm with weight decay and AMSGrad.
|
| 153 |
+
|
| 154 |
+
It has been proposed in `Adam: A Method for Stochastic Optimization`.
|
| 155 |
+
|
| 156 |
+
Arguments:
|
| 157 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
| 158 |
+
parameter groups
|
| 159 |
+
lr (float, optional): learning rate (default: 1e-3)
|
| 160 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
| 161 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
| 162 |
+
eps (float, optional): term added to the denominator to improve
|
| 163 |
+
numerical stability (default: 1e-8)
|
| 164 |
+
weight_decay (float, optional): weight decay using the method from
|
| 165 |
+
the paper `Fixing Weight Decay Regularization in Adam` (default: 0)
|
| 166 |
+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
| 167 |
+
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
|
| 171 |
+
if not 0.0 <= lr:
|
| 172 |
+
raise ValueError("Invalid learning rate: {0}".format(lr))
|
| 173 |
+
if not 0.0 <= eps:
|
| 174 |
+
raise ValueError("Invalid epsilon value: {0}".format(eps))
|
| 175 |
+
if not 0.0 <= betas[0] < 1.0:
|
| 176 |
+
raise ValueError("Invalid beta parameter at index 0: {0}".format(betas[0]))
|
| 177 |
+
if not 0.0 <= betas[1] < 1.0:
|
| 178 |
+
raise ValueError("Invalid beta parameter at index 1: {0}".format(betas[1]))
|
| 179 |
+
|
| 180 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay / lr, amsgrad=amsgrad)
|
| 181 |
+
super(WDAdam, self).__init__(params, defaults)
|
| 182 |
+
|
| 183 |
+
def __setstate__(self, state):
|
| 184 |
+
super(WDAdam, self).__setstate__(state)
|
| 185 |
+
for group in self.param_groups:
|
| 186 |
+
group.setdefault('amsgrad', False)
|
| 187 |
+
|
| 188 |
+
def step(self, closure=None):
|
| 189 |
+
""" Performs a single optimization step.
|
| 190 |
+
|
| 191 |
+
Arguments:
|
| 192 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 193 |
+
and returns the loss.
|
| 194 |
+
"""
|
| 195 |
+
loss = None
|
| 196 |
+
if closure is not None:
|
| 197 |
+
loss = closure()
|
| 198 |
+
|
| 199 |
+
for group in self.param_groups:
|
| 200 |
+
for p in group['params']:
|
| 201 |
+
if p.grad is None:
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
grad = p.grad.data
|
| 205 |
+
if grad.is_sparse:
|
| 206 |
+
raise RuntimeError('Adam does not support sparse gradients')
|
| 207 |
+
amsgrad = group['amsgrad']
|
| 208 |
+
|
| 209 |
+
# State initialization
|
| 210 |
+
state = self.state[p]
|
| 211 |
+
if len(state) == 0:
|
| 212 |
+
state['step'] = 0
|
| 213 |
+
# Exponential moving average of gradient values
|
| 214 |
+
state['exp_avg'] = torch.zeros_like(p.data)
|
| 215 |
+
# Exponential moving average of squared gradient values
|
| 216 |
+
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
| 217 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
| 218 |
+
if amsgrad:
|
| 219 |
+
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
|
| 220 |
+
|
| 221 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
| 222 |
+
if amsgrad:
|
| 223 |
+
max_exp_avg_sq = state['max_exp_avg_sq']
|
| 224 |
+
beta1, beta2 = group['betas']
|
| 225 |
+
state['step'] += 1
|
| 226 |
+
|
| 227 |
+
# Decay the first and second moment running average coefficient
|
| 228 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
| 229 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
| 230 |
+
if amsgrad:
|
| 231 |
+
# Maintains the maximum of all 2nd moment running avg. till now
|
| 232 |
+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
| 233 |
+
# Use the max. for normalizing running avg. of gradient
|
| 234 |
+
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
|
| 235 |
+
else:
|
| 236 |
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
| 237 |
+
|
| 238 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
| 239 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
| 240 |
+
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
| 241 |
+
|
| 242 |
+
if group['weight_decay'] != 0:
|
| 243 |
+
p.data.add_(-group['weight_decay'] * group['lr'], p.data)
|
| 244 |
+
|
| 245 |
+
p.data.addcdiv_(-step_size, exp_avg, denom)
|
| 246 |
+
|
| 247 |
+
return loss
|
harmonizer/src/train/torchtask/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
scipy
|
| 3 |
+
Pillow
|
| 4 |
+
pyyaml
|
| 5 |
+
opencv-python
|
harmonizer/src/train/torchtask/runner.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
from torchtask.utils import cmd
|
| 6 |
+
from torchtask.nn import optimizer, lrer
|
| 7 |
+
from torchtask.nn.func import pytorch_support
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_parser():
|
| 11 |
+
parser = argparse.ArgumentParser(description='TorchTask Script Parser')
|
| 12 |
+
|
| 13 |
+
optimizer.add_parser_arguments(parser)
|
| 14 |
+
lrer.add_parser_arguments(parser)
|
| 15 |
+
|
| 16 |
+
return parser
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def run_script(config, proxy_file, proxy_class):
|
| 20 |
+
# TorchTask requires PyTorch >= 1.0.0
|
| 21 |
+
pytorch_support(required_version='1.0.0', info_str='TorchTask')
|
| 22 |
+
|
| 23 |
+
# help information
|
| 24 |
+
if len(sys.argv) > 1 and sys.argv[1] in ['help', '--help', 'h', '-h']:
|
| 25 |
+
config['h'] = True
|
| 26 |
+
|
| 27 |
+
# create parser and parse args from config
|
| 28 |
+
parser = create_parser()
|
| 29 |
+
proxy_file.add_parser_arguments(parser)
|
| 30 |
+
args = cmd.parse_args(parser, config)
|
| 31 |
+
|
| 32 |
+
task_proxy = proxy_class(args)
|
| 33 |
+
task_proxy.run()
|
harmonizer/src/train/torchtask/template/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import func as func_template
|
| 2 |
+
from . import data as data_template
|
| 3 |
+
from . import model as model_template
|
| 4 |
+
from . import criterion as criterion_template
|
| 5 |
+
from . import proxy as proxy_template
|
| 6 |
+
from . import trainer as trainer_template
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
'func_template',
|
| 11 |
+
'data_template',
|
| 12 |
+
'model_template',
|
| 13 |
+
'criterion_template',
|
| 14 |
+
'proxy_template',
|
| 15 |
+
'trainer_template',
|
| 16 |
+
]
|