gabgrenier commited on
Commit
060b41f
·
1 Parent(s): ead9f2a

added harmonizer

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. csai.py +34 -1
  2. harmonizer/.gitignore +101 -0
  3. harmonizer/README.md +101 -0
  4. harmonizer/pretrained/README.md +3 -0
  5. harmonizer/src/__init__.py +0 -0
  6. harmonizer/src/model/__init__.py +2 -0
  7. harmonizer/src/model/backbone/__init__.py +1 -0
  8. harmonizer/src/model/backbone/efficientnet/__init__.py +116 -0
  9. harmonizer/src/model/backbone/efficientnet/model.py +395 -0
  10. harmonizer/src/model/backbone/efficientnet/utils.py +586 -0
  11. harmonizer/src/model/enhancer.py +40 -0
  12. harmonizer/src/model/filter.py +231 -0
  13. harmonizer/src/model/harmonizer.py +44 -0
  14. harmonizer/src/model/module.py +80 -0
  15. harmonizer/src/requirements.txt +6 -0
  16. harmonizer/src/train/README.md +14 -0
  17. harmonizer/src/train/harmonizer/__init__.py +0 -0
  18. harmonizer/src/train/harmonizer/criterion.py +47 -0
  19. harmonizer/src/train/harmonizer/data.py +198 -0
  20. harmonizer/src/train/harmonizer/func.py +41 -0
  21. harmonizer/src/train/harmonizer/model.py +41 -0
  22. harmonizer/src/train/harmonizer/module/__init__.py +1 -0
  23. harmonizer/src/train/harmonizer/module/backbone/__init__.py +1 -0
  24. harmonizer/src/train/harmonizer/module/backbone/efficientnet/__init__.py +116 -0
  25. harmonizer/src/train/harmonizer/module/backbone/efficientnet/model.py +395 -0
  26. harmonizer/src/train/harmonizer/module/backbone/efficientnet/utils.py +586 -0
  27. harmonizer/src/train/harmonizer/module/filter.py +231 -0
  28. harmonizer/src/train/harmonizer/module/harmonizer.py +83 -0
  29. harmonizer/src/train/harmonizer/module/module.py +80 -0
  30. harmonizer/src/train/harmonizer/proxy.py +20 -0
  31. harmonizer/src/train/harmonizer/script/train.py +85 -0
  32. harmonizer/src/train/harmonizer/trainer.py +322 -0
  33. harmonizer/src/train/torchtask/__init__.py +9 -0
  34. harmonizer/src/train/torchtask/nn/__init__.py +3 -0
  35. harmonizer/src/train/torchtask/nn/data.py +190 -0
  36. harmonizer/src/train/torchtask/nn/func.py +99 -0
  37. harmonizer/src/train/torchtask/nn/lrer.py +179 -0
  38. harmonizer/src/train/torchtask/nn/module/__init__.py +3 -0
  39. harmonizer/src/train/torchtask/nn/module/gaussian_blur.py +64 -0
  40. harmonizer/src/train/torchtask/nn/module/gaussian_noise.py +40 -0
  41. harmonizer/src/train/torchtask/nn/module/third_party/__init__.py +1 -0
  42. harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/__init__.py +12 -0
  43. harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/batchnorm.py +282 -0
  44. harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/comm.py +129 -0
  45. harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/replicate.py +88 -0
  46. harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/unittest.py +29 -0
  47. harmonizer/src/train/torchtask/nn/optimizer.py +247 -0
  48. harmonizer/src/train/torchtask/requirements.txt +5 -0
  49. harmonizer/src/train/torchtask/runner.py +33 -0
  50. harmonizer/src/train/torchtask/template/__init__.py +16 -0
csai.py CHANGED
@@ -67,9 +67,42 @@ def process(fg, bg):
67
 
68
  # Use the final_mask_img when pasting
69
  bg.paste(fg, (0, 0), final_mask_img)
 
 
 
 
 
 
 
 
70
 
71
- return bg
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  def rvm(fg):
75
  model = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50"
 
67
 
68
  # Use the final_mask_img when pasting
69
  bg.paste(fg, (0, 0), final_mask_img)
70
+
71
+ # now run the harmonizer to make sure the foreground matches the background
72
+ harmonized = harmonizer(bg, final_mask_img)
73
+
74
+ if harmonized != None:
75
+ return harmonized
76
+ else:
77
+ return bg
78
 
 
79
 
80
+ def harmonizer(comp, mask):
81
+ try:
82
+ import torchvision.transforms.functional as tf
83
+ from harmonizer.src import model
84
+ harmonizer = model.Harmonizer()
85
+ harmonizer = harmonizer.cuda()
86
+ harmonizer.load_state_dict(torch.load("harmonizer/pretrained/harmonizer.pth"), strict=True)
87
+ harmonizer.eval()
88
+
89
+ comp = tf.to_tensor(comp)[None, ...]
90
+ mask = tf.to_tensor(mask)[None, ...]
91
+ comp = comp.cuda()
92
+ mask = mask.cuda()
93
+
94
+ with torch.no_grad():
95
+ arguments = harmonizer.predict_arguments(comp, mask)
96
+ harmonized = harmonizer.restore_image(comp, mask, arguments)[-1]
97
+
98
+ harmonized = np.transpose(harmonized[0].cpu().numpy(), (1, 2, 0)) * 255
99
+ harmonized = Image.fromarray(harmonized.astype(np.uint8))
100
+
101
+ return harmonized
102
+
103
+ except:
104
+ return None
105
+
106
 
107
  def rvm(fg):
108
  model = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50"
harmonizer/.gitignore ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Temporary directories and files
2
+ *.ckpt
3
+ *.pth
4
+ *.zip
5
+ *.tar
6
+ result/
7
+ dataset/
8
+
9
+ # Byte-compiled / optimized / DLL files
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+
14
+ # C extensions
15
+ *.so
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ env/
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *,cover
54
+ .hypothesis/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # IPython Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # pyenv
81
+ .python-version
82
+
83
+ # celery beat schedule file
84
+ celerybeat-schedule
85
+
86
+ # dotenv
87
+ .env
88
+
89
+ # virtualenv
90
+ venv/
91
+ ENV/
92
+
93
+ # Spyder project settings
94
+ .spyderproject
95
+
96
+ # Rope project settings
97
+ .ropeproject
98
+
99
+
100
+ # Project files
101
+ .vscode
harmonizer/README.md ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h2 align="center">Harmonizer: High-Resolution Image/Video Harmonization</h2>
2
+
3
+ <p align="center"><i>Harmonizer: Learning to Perform White-Box
4
+ Image and Video Harmonization (ECCV 2022)</i></p>
5
+
6
+ <p align="center">
7
+ <a href="https://arxiv.org/abs/2207.01322">Paper</a> |
8
+ <a href="#demo">Demo</a> |
9
+ <a href="#code">Code</a> |
10
+ <a href="#license">License</a> |
11
+ <a href="#citation">Citation</a> |
12
+ <a href="#contact">Contact</a>
13
+ </p>
14
+
15
+ <p align="center">
16
+ <a href="https://youtu.be/kKKK3D1f_Mc">Harmonizer Result Video</a> |
17
+ <a href="https://youtu.be/NS8f-eJY9cc">Enhancer Result Video</a>
18
+ </p>
19
+
20
+ <div align="center"><b>Harmonizer</b> is a <b>lightweight (20MB)</b> model enabled image/video harmonization up to <b>8K</b> resolution.</div>
21
+ <div align="center">With GPUs, Harmonizer has <b>real-time</b> performance at <b>Full HD</b> resolution.</div>
22
+ <img src="doc/gif/harmonizer.gif" width="100%">
23
+
24
+ <div align="center"><b>Enhancer</b> is a model applied the Harmonizer architecture for image/video color enhancement.</div>
25
+ <img src="doc/gif/enhancer.gif" width="100%">
26
+
27
+ ---
28
+
29
+ ## Demo
30
+
31
+ In our demos, the <b>Harmonizer</b> model is trained on the *iHarmony4* dataset, while the <b>Enhancer</b> model is trained on the *FiveK + HDRPlus* datasets.
32
+
33
+ ### Online Demo
34
+ Try our online demos for fun without code!
35
+
36
+ | Image Harmonization | Image Enhancement |
37
+ | :---: | :---: |
38
+ | [Online Demo](https://zhke.io/?harmonizer_demo) | [Online Demo](https://zhke.io/?enhancer_demo) |
39
+
40
+ <img src="doc/gif/online_demo.gif" width="100%">
41
+
42
+ ### Offline Demo
43
+ We provide offline demos for image/video harmonization/enhancement.
44
+
45
+ | Image Harmonization | Video Harmonization | Image Enhancement | Video Enhancement |
46
+ | :---: | :---: | :---: | :---: |
47
+ | [Offline Demo](demo/image_harmonization) | [Offline Demo](demo/video_harmonization) | [Offline Demo](demo/image_enhancement) | [Offline Demo](demo/video_enhancement) |
48
+
49
+
50
+ ## Code
51
+
52
+ ### Training
53
+
54
+ The training code is released in the folder `./src/train`.
55
+ Refer to [README.md](src/train/README.md) for more details about training.
56
+
57
+
58
+ ### Validation
59
+
60
+ We provide PyTorch validation code to reproduce the iHarmony4 results reported in our [paper](https://arxiv.org/abs/2207.01322), please:
61
+
62
+ 1. Download the Harmonizer model pre-trained on the iHarmony4 dataset from [this link](https://drive.google.com/file/d/15XGPQHBppaYGnhsP9l7iOGZudXNw1WbA/view?usp=sharing) and put it in the folder `./pretrained`.
63
+
64
+ 2. Download the four subsets of iHarmony4 from [this repository](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4) and put them in the folder `./dataset/harmonization/iHarmony4`.
65
+
66
+ 3. Install python requirements. In the root path of this repository, run:
67
+ ```
68
+ pip install -r src/requirements.txt
69
+ ```
70
+
71
+ 5. For validation, in the root path of this repository, run:
72
+ ```
73
+ python -m src.val_harmonizer \
74
+ --pretrained ./pretrained/harmonizer \
75
+ --datasets HCOCO HFlickr HAdobe5k Hday2night \
76
+ --metric-size 256
77
+ ```
78
+ - You can change `--datasets` to validate a specific subset.
79
+ - You can remove `--metric-size` to calculate the metrics without resizing the outputs.
80
+ - The metric values may slightly different from our [paper](https://arxiv.org/abs/2207.01322) due to the dependency versions.
81
+
82
+ ## License
83
+ This project is released under the [Creative Commons Attribution NonCommercial ShareAlike 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) license.
84
+
85
+
86
+ ## Citation
87
+ If this work helps your research, please consider to cite:
88
+
89
+ ```bibtex
90
+ @InProceedings{Harmonizer,
91
+ author = {Zhanghan Ke and Chunyi Sun and Lei Zhu and Ke Xu and Rynson W.H. Lau},
92
+ title = {Harmonizer: Learning to Perform White-Box Image and Video Harmonization},
93
+ booktitle = {European Conference on Computer Vision (ECCV)},
94
+ year = {2022},
95
+ }
96
+ ```
97
+
98
+
99
+ ## Contact
100
+ This repository is maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)).
101
+ For questions, please contact `kezhanghan@outlook.com`.
harmonizer/pretrained/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ## Harmonizer - Pre-Trained Models
2
+ This folder is used to save the official pre-trained models of Harmonizer/Enhancer.
3
+ You can download them from [this link](https://drive.google.com/drive/folders/1k7TCcwETeF5SYoD2Ic211UQyV1lwIBHY?usp=sharing).
harmonizer/src/__init__.py ADDED
File without changes
harmonizer/src/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .harmonizer import Harmonizer
2
+ from .enhancer import Enhancer
harmonizer/src/model/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .efficientnet import EfficientBackbone, EfficientBackboneCommon
harmonizer/src/model/backbone/efficientnet/__init__.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This EfficientNet implementation comes from:
3
+ Author: lukemelas (github username)
4
+ Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from .model import EfficientNet
11
+ from .utils import round_filters, get_same_padding_conv2d
12
+
13
+
14
+ # for EfficientNet
15
+ class EfficientBackbone(EfficientNet):
16
+ def __init__(self, blocks_args=None, global_params=None):
17
+ super(EfficientBackbone, self).__init__(blocks_args, global_params)
18
+
19
+ self.enc_channels = [16, 24, 40, 112, 1280]
20
+
21
+ # ------------------------------------------------------------
22
+ # delete the useless layers
23
+ # ------------------------------------------------------------
24
+ del self._conv_stem
25
+ del self._bn0
26
+ # ------------------------------------------------------------
27
+
28
+ # ------------------------------------------------------------
29
+ # parameters for the input layers
30
+ # ------------------------------------------------------------
31
+ bn_mom = 1 - self._global_params.batch_norm_momentum
32
+ bn_eps = self._global_params.batch_norm_epsilon
33
+
34
+ in_channels = 4
35
+ out_channels = round_filters(32, self._global_params)
36
+ out_channels = int(out_channels / 2)
37
+ # ------------------------------------------------------------
38
+
39
+ # ------------------------------------------------------------
40
+ # define the input layers
41
+ # ------------------------------------------------------------
42
+ image_size = global_params.image_size
43
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
44
+ self._conv_fg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
45
+ self._bn_fg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
46
+
47
+ self._conv_bg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
48
+ self._bn_bg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
49
+ # ------------------------------------------------------------
50
+
51
+ def forward(self, xfg, xbg):
52
+ xfg = self._swish(self._bn_fg(self._conv_fg(xfg)))
53
+ xbg = self._swish(self._bn_bg(self._conv_bg(xbg)))
54
+
55
+ x = torch.cat((xfg, xbg), dim=1)
56
+
57
+ block_outputs = []
58
+ for idx, block in enumerate(self._blocks):
59
+ drop_connect_rate = self._global_params.drop_connect_rate
60
+ drop_connect_rate *= float(idx) / len(self._blocks)
61
+ x = block(x, drop_connect_rate=drop_connect_rate)
62
+ block_outputs.append(x)
63
+
64
+ # Head
65
+ x = self._swish(self._bn1(self._conv_head(x)))
66
+
67
+ return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x
68
+
69
+
70
+ # for EfficientNet
71
+ class EfficientBackboneCommon(EfficientNet):
72
+ def __init__(self, blocks_args=None, global_params=None):
73
+ super(EfficientBackboneCommon, self).__init__(blocks_args, global_params)
74
+
75
+ self.enc_channels = [16, 24, 40, 112, 1280]
76
+
77
+ # ------------------------------------------------------------
78
+ # delete the useless layers
79
+ # ------------------------------------------------------------
80
+ del self._conv_stem
81
+ del self._bn0
82
+ # ------------------------------------------------------------
83
+
84
+ # ------------------------------------------------------------
85
+ # parameters for the input layers
86
+ # ------------------------------------------------------------
87
+ bn_mom = 1 - self._global_params.batch_norm_momentum
88
+ bn_eps = self._global_params.batch_norm_epsilon
89
+
90
+ in_channels = 3
91
+ out_channels = round_filters(32, self._global_params)
92
+ # ------------------------------------------------------------
93
+
94
+ # ------------------------------------------------------------
95
+ # define the input layers
96
+ # ------------------------------------------------------------
97
+ image_size = global_params.image_size
98
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
99
+ self._conv = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
100
+ self._bn = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
101
+ # ------------------------------------------------------------
102
+
103
+ def forward(self, x):
104
+ x = self._swish(self._bn(self._conv(x)))
105
+
106
+ block_outputs = []
107
+ for idx, block in enumerate(self._blocks):
108
+ drop_connect_rate = self._global_params.drop_connect_rate
109
+ drop_connect_rate *= float(idx) / len(self._blocks)
110
+ x = block(x, drop_connect_rate=drop_connect_rate)
111
+ block_outputs.append(x)
112
+
113
+ # Head
114
+ x = self._swish(self._bn1(self._conv_head(x)))
115
+
116
+ return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x
harmonizer/src/model/backbone/efficientnet/model.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """model.py - Model and module class for EfficientNet.
2
+ They are built to mirror those in the official TensorFlow implementation.
3
+ """
4
+
5
+ # Author: lukemelas (github username)
6
+ # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7
+ # With adjustments and added comments by workingcoder (github username).
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+ from .utils import (
13
+ round_filters,
14
+ round_repeats,
15
+ drop_connect,
16
+ get_same_padding_conv2d,
17
+ get_model_params,
18
+ efficientnet_params,
19
+ load_pretrained_weights,
20
+ Swish,
21
+ MemoryEfficientSwish,
22
+ calculate_output_image_size
23
+ )
24
+
25
+
26
+ VALID_MODELS = (
27
+ 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
28
+ 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
29
+ 'efficientnet-b8',
30
+
31
+ # Support the construction of 'efficientnet-l2' without pretrained weights
32
+ 'efficientnet-l2'
33
+ )
34
+
35
+
36
+ class MBConvBlock(nn.Module):
37
+ """Mobile Inverted Residual Bottleneck Block.
38
+ Args:
39
+ block_args (namedtuple): BlockArgs, defined in utils.py.
40
+ global_params (namedtuple): GlobalParam, defined in utils.py.
41
+ image_size (tuple or list): [image_height, image_width].
42
+ References:
43
+ [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
44
+ [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
45
+ [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
46
+ """
47
+
48
+ def __init__(self, block_args, global_params, image_size=None):
49
+ super().__init__()
50
+ self._block_args = block_args
51
+ self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
52
+ self._bn_eps = global_params.batch_norm_epsilon
53
+ self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
54
+ self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
55
+
56
+ # Expansion phase (Inverted Bottleneck)
57
+ inp = self._block_args.input_filters # number of input channels
58
+ oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
59
+ if self._block_args.expand_ratio != 1:
60
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
61
+ self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
62
+ self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
63
+ # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
64
+
65
+ # Depthwise convolution phase
66
+ k = self._block_args.kernel_size
67
+ s = self._block_args.stride
68
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
69
+ self._depthwise_conv = Conv2d(
70
+ in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
71
+ kernel_size=k, stride=s, bias=False)
72
+ self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
73
+ image_size = calculate_output_image_size(image_size, s)
74
+
75
+ # Squeeze and Excitation layer, if desired
76
+ if self.has_se:
77
+ Conv2d = get_same_padding_conv2d(image_size=(1, 1))
78
+ num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
79
+ self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
80
+ self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
81
+
82
+ # Pointwise convolution phase
83
+ final_oup = self._block_args.output_filters
84
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
85
+ self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
86
+ self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
87
+ self._swish = MemoryEfficientSwish()
88
+
89
+ def forward(self, inputs, drop_connect_rate=None):
90
+ """MBConvBlock's forward function.
91
+ Args:
92
+ inputs (tensor): Input tensor.
93
+ drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
94
+ Returns:
95
+ Output of this block after processing.
96
+ """
97
+
98
+ # Expansion and Depthwise Convolution
99
+ x = inputs
100
+ if self._block_args.expand_ratio != 1:
101
+ x = self._expand_conv(inputs)
102
+ x = self._bn0(x)
103
+ x = self._swish(x)
104
+
105
+ x = self._depthwise_conv(x)
106
+ x = self._bn1(x)
107
+ x = self._swish(x)
108
+
109
+ # Squeeze and Excitation
110
+ if self.has_se:
111
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
112
+ x_squeezed = self._se_reduce(x_squeezed)
113
+ x_squeezed = self._swish(x_squeezed)
114
+ x_squeezed = self._se_expand(x_squeezed)
115
+ x = torch.sigmoid(x_squeezed) * x
116
+
117
+ # Pointwise Convolution
118
+ x = self._project_conv(x)
119
+ x = self._bn2(x)
120
+
121
+ # Skip connection and drop connect
122
+ input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
123
+ if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
124
+ # The combination of skip connection and drop connect brings about stochastic depth.
125
+ if drop_connect_rate:
126
+ x = drop_connect(x, p=drop_connect_rate, training=self.training)
127
+ x = x + inputs # skip connection
128
+ return x
129
+
130
+ def set_swish(self, memory_efficient=True):
131
+ """Sets swish function as memory efficient (for training) or standard (for export).
132
+ Args:
133
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
134
+ """
135
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
136
+
137
+
138
+ class EfficientNet(nn.Module):
139
+ """EfficientNet model.
140
+ Most easily loaded with the .from_name or .from_pretrained methods.
141
+ Args:
142
+ blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
143
+ global_params (namedtuple): A set of GlobalParams shared between blocks.
144
+ References:
145
+ [1] https://arxiv.org/abs/1905.11946 (EfficientNet)
146
+ Example:
147
+ >>> import torch
148
+ >>> from efficientnet.model import EfficientNet
149
+ >>> inputs = torch.rand(1, 3, 224, 224)
150
+ >>> model = EfficientNet.from_pretrained('efficientnet-b0')
151
+ >>> model.eval()
152
+ >>> outputs = model(inputs)
153
+ """
154
+
155
+ def __init__(self, blocks_args=None, global_params=None):
156
+ super().__init__()
157
+ assert isinstance(blocks_args, list), 'blocks_args should be a list'
158
+ assert len(blocks_args) > 0, 'block args must be greater than 0'
159
+ self._global_params = global_params
160
+ self._blocks_args = blocks_args
161
+
162
+ # Batch norm parameters
163
+ bn_mom = 1 - self._global_params.batch_norm_momentum
164
+ bn_eps = self._global_params.batch_norm_epsilon
165
+
166
+ # Get stem static or dynamic convolution depending on image size
167
+ image_size = global_params.image_size
168
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
169
+
170
+ # Stem
171
+ in_channels = 3 # rgb
172
+ out_channels = round_filters(32, self._global_params) # number of output channels
173
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
174
+ self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
175
+ image_size = calculate_output_image_size(image_size, 2)
176
+
177
+ # Build blocks
178
+ self._blocks = nn.ModuleList([])
179
+ for block_args in self._blocks_args:
180
+
181
+ # Update block input and output filters based on depth multiplier.
182
+ block_args = block_args._replace(
183
+ input_filters=round_filters(block_args.input_filters, self._global_params),
184
+ output_filters=round_filters(block_args.output_filters, self._global_params),
185
+ num_repeat=round_repeats(block_args.num_repeat, self._global_params)
186
+ )
187
+
188
+ # The first block needs to take care of stride and filter size increase.
189
+ self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
190
+ image_size = calculate_output_image_size(image_size, block_args.stride)
191
+ if block_args.num_repeat > 1: # modify block_args to keep same output size
192
+ block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
193
+ for _ in range(block_args.num_repeat - 1):
194
+ self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
195
+ # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
196
+
197
+ # Head
198
+ in_channels = block_args.output_filters # output of final block
199
+ out_channels = round_filters(1280, self._global_params)
200
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
201
+ self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
202
+ self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
203
+
204
+ # Final linear layer
205
+ self._avg_pooling = nn.AdaptiveAvgPool2d(1)
206
+ if self._global_params.include_top:
207
+ self._dropout = nn.Dropout(self._global_params.dropout_rate)
208
+ self._fc = nn.Linear(out_channels, self._global_params.num_classes)
209
+
210
+ # set activation to memory efficient swish by default
211
+ self._swish = MemoryEfficientSwish()
212
+
213
+ def set_swish(self, memory_efficient=True):
214
+ """Sets swish function as memory efficient (for training) or standard (for export).
215
+ Args:
216
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
217
+ """
218
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
219
+ for block in self._blocks:
220
+ block.set_swish(memory_efficient)
221
+
222
+ def extract_endpoints(self, inputs):
223
+ """Use convolution layer to extract features
224
+ from reduction levels i in [1, 2, 3, 4, 5].
225
+ Args:
226
+ inputs (tensor): Input tensor.
227
+ Returns:
228
+ Dictionary of last intermediate features
229
+ with reduction levels i in [1, 2, 3, 4, 5].
230
+ Example:
231
+ >>> import torch
232
+ >>> from efficientnet.model import EfficientNet
233
+ >>> inputs = torch.rand(1, 3, 224, 224)
234
+ >>> model = EfficientNet.from_pretrained('efficientnet-b0')
235
+ >>> endpoints = model.extract_endpoints(inputs)
236
+ >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
237
+ >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
238
+ >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
239
+ >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
240
+ >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
241
+ >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
242
+ """
243
+ endpoints = dict()
244
+
245
+ # Stem
246
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
247
+ prev_x = x
248
+
249
+ # Blocks
250
+ for idx, block in enumerate(self._blocks):
251
+ drop_connect_rate = self._global_params.drop_connect_rate
252
+ if drop_connect_rate:
253
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
254
+ x = block(x, drop_connect_rate=drop_connect_rate)
255
+ if prev_x.size(2) > x.size(2):
256
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
257
+ elif idx == len(self._blocks) - 1:
258
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
259
+ prev_x = x
260
+
261
+ # Head
262
+ x = self._swish(self._bn1(self._conv_head(x)))
263
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
264
+
265
+ return endpoints
266
+
267
+ def extract_features(self, inputs):
268
+ """use convolution layer to extract feature .
269
+ Args:
270
+ inputs (tensor): Input tensor.
271
+ Returns:
272
+ Output of the final convolution
273
+ layer in the efficientnet model.
274
+ """
275
+ # Stem
276
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
277
+
278
+ # Blocks
279
+ for idx, block in enumerate(self._blocks):
280
+ drop_connect_rate = self._global_params.drop_connect_rate
281
+ if drop_connect_rate:
282
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
283
+ x = block(x, drop_connect_rate=drop_connect_rate)
284
+
285
+ # Head
286
+ x = self._swish(self._bn1(self._conv_head(x)))
287
+
288
+ return x
289
+
290
+ def forward(self, inputs):
291
+ """EfficientNet's forward function.
292
+ Calls extract_features to extract features, applies final linear layer, and returns logits.
293
+ Args:
294
+ inputs (tensor): Input tensor.
295
+ Returns:
296
+ Output of this model after processing.
297
+ """
298
+ # Convolution layers
299
+ x = self.extract_features(inputs)
300
+ # Pooling and final linear layer
301
+ x = self._avg_pooling(x)
302
+ if self._global_params.include_top:
303
+ x = x.flatten(start_dim=1)
304
+ x = self._dropout(x)
305
+ x = self._fc(x)
306
+ return x
307
+
308
+ @classmethod
309
+ def from_name(cls, model_name, in_channels=3, **override_params):
310
+ """Create an efficientnet model according to name.
311
+ Args:
312
+ model_name (str): Name for efficientnet.
313
+ in_channels (int): Input data's channel number.
314
+ override_params (other key word params):
315
+ Params to override model's global_params.
316
+ Optional key:
317
+ 'width_coefficient', 'depth_coefficient',
318
+ 'image_size', 'dropout_rate',
319
+ 'num_classes', 'batch_norm_momentum',
320
+ 'batch_norm_epsilon', 'drop_connect_rate',
321
+ 'depth_divisor', 'min_depth'
322
+ Returns:
323
+ An efficientnet model.
324
+ """
325
+ cls._check_model_name_is_valid(model_name)
326
+ blocks_args, global_params = get_model_params(model_name, override_params)
327
+ model = cls(blocks_args, global_params)
328
+ model._change_in_channels(in_channels)
329
+ return model
330
+
331
+ @classmethod
332
+ def from_pretrained(cls, model_name, weights_path=None, advprop=False,
333
+ in_channels=3, num_classes=1000, **override_params):
334
+ """Create an efficientnet model according to name.
335
+ Args:
336
+ model_name (str): Name for efficientnet.
337
+ weights_path (None or str):
338
+ str: path to pretrained weights file on the local disk.
339
+ None: use pretrained weights downloaded from the Internet.
340
+ advprop (bool):
341
+ Whether to load pretrained weights
342
+ trained with advprop (valid when weights_path is None).
343
+ in_channels (int): Input data's channel number.
344
+ num_classes (int):
345
+ Number of categories for classification.
346
+ It controls the output size for final linear layer.
347
+ override_params (other key word params):
348
+ Params to override model's global_params.
349
+ Optional key:
350
+ 'width_coefficient', 'depth_coefficient',
351
+ 'image_size', 'dropout_rate',
352
+ 'batch_norm_momentum',
353
+ 'batch_norm_epsilon', 'drop_connect_rate',
354
+ 'depth_divisor', 'min_depth'
355
+ Returns:
356
+ A pretrained efficientnet model.
357
+ """
358
+ model = cls.from_name(model_name, num_classes=num_classes, **override_params)
359
+ load_pretrained_weights(model, model_name, weights_path=weights_path,
360
+ load_fc=(num_classes == 1000), advprop=advprop)
361
+ model._change_in_channels(in_channels)
362
+ return model
363
+
364
+ @classmethod
365
+ def get_image_size(cls, model_name):
366
+ """Get the input image size for a given efficientnet model.
367
+ Args:
368
+ model_name (str): Name for efficientnet.
369
+ Returns:
370
+ Input image size (resolution).
371
+ """
372
+ cls._check_model_name_is_valid(model_name)
373
+ _, _, res, _ = efficientnet_params(model_name)
374
+ return res
375
+
376
+ @classmethod
377
+ def _check_model_name_is_valid(cls, model_name):
378
+ """Validates model name.
379
+ Args:
380
+ model_name (str): Name for efficientnet.
381
+ Returns:
382
+ bool: Is a valid name or not.
383
+ """
384
+ if model_name not in VALID_MODELS:
385
+ raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
386
+
387
+ def _change_in_channels(self, in_channels):
388
+ """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
389
+ Args:
390
+ in_channels (int): Input data's channel number.
391
+ """
392
+ if in_channels != 3:
393
+ Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
394
+ out_channels = round_filters(32, self._global_params)
395
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
harmonizer/src/model/backbone/efficientnet/utils.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """utils.py - Helper functions for building the model and for loading model parameters.
2
+ These helper functions are built to mirror those in the official TensorFlow implementation.
3
+ """
4
+
5
+ # Author: lukemelas (github username)
6
+ # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7
+ # With adjustments and added comments by workingcoder (github username).
8
+
9
+ import re
10
+ import math
11
+ import collections
12
+ from functools import partial
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.utils import model_zoo
17
+
18
+
19
+ ################################################################################
20
+ # Help functions for model architecture
21
+ ################################################################################
22
+
23
+ # GlobalParams and BlockArgs: Two namedtuples
24
+ # Swish and MemoryEfficientSwish: Two implementations of the method
25
+ # round_filters and round_repeats:
26
+ # Functions to calculate params for scaling model width and depth ! ! !
27
+ # get_width_and_height_from_size and calculate_output_image_size
28
+ # drop_connect: A structural design
29
+ # get_same_padding_conv2d:
30
+ # Conv2dDynamicSamePadding
31
+ # Conv2dStaticSamePadding
32
+ # get_same_padding_maxPool2d:
33
+ # MaxPool2dDynamicSamePadding
34
+ # MaxPool2dStaticSamePadding
35
+ # It's an additional function, not used in EfficientNet,
36
+ # but can be used in other model (such as EfficientDet).
37
+
38
+ # Parameters for the entire model (stem, all blocks, and head)
39
+ GlobalParams = collections.namedtuple('GlobalParams', [
40
+ 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
41
+ 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
42
+ 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])
43
+
44
+ # Parameters for an individual model block
45
+ BlockArgs = collections.namedtuple('BlockArgs', [
46
+ 'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
47
+ 'input_filters', 'output_filters', 'se_ratio', 'id_skip'])
48
+
49
+ # Set GlobalParams and BlockArgs's defaults
50
+ GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
51
+ BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
52
+
53
+ # Swish activation function
54
+ if hasattr(nn, 'SiLU'):
55
+ Swish = nn.SiLU
56
+ else:
57
+ # For compatibility with old PyTorch versions
58
+ class Swish(nn.Module):
59
+ def forward(self, x):
60
+ return x * torch.sigmoid(x)
61
+
62
+
63
+ # A memory-efficient implementation of Swish function
64
+ class SwishImplementation(torch.autograd.Function):
65
+ @staticmethod
66
+ def forward(ctx, i):
67
+ result = i * torch.sigmoid(i)
68
+ ctx.save_for_backward(i)
69
+ return result
70
+
71
+ @staticmethod
72
+ def backward(ctx, grad_output):
73
+ i = ctx.saved_tensors[0]
74
+ sigmoid_i = torch.sigmoid(i)
75
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
76
+
77
+
78
+ class MemoryEfficientSwish(nn.Module):
79
+ def forward(self, x):
80
+ return SwishImplementation.apply(x)
81
+
82
+
83
+ def round_filters(filters, global_params):
84
+ """Calculate and round number of filters based on width multiplier.
85
+ Use width_coefficient, depth_divisor and min_depth of global_params.
86
+ Args:
87
+ filters (int): Filters number to be calculated.
88
+ global_params (namedtuple): Global params of the model.
89
+ Returns:
90
+ new_filters: New filters number after calculating.
91
+ """
92
+ multiplier = global_params.width_coefficient
93
+ if not multiplier:
94
+ return filters
95
+ # TODO: modify the params names.
96
+ # maybe the names (width_divisor,min_width)
97
+ # are more suitable than (depth_divisor,min_depth).
98
+ divisor = global_params.depth_divisor
99
+ min_depth = global_params.min_depth
100
+ filters *= multiplier
101
+ min_depth = min_depth or divisor # pay attention to this line when using min_depth
102
+ # follow the formula transferred from official TensorFlow implementation
103
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
104
+ if new_filters < 0.9 * filters: # prevent rounding by more than 10%
105
+ new_filters += divisor
106
+ return int(new_filters)
107
+
108
+
109
+ def round_repeats(repeats, global_params):
110
+ """Calculate module's repeat number of a block based on depth multiplier.
111
+ Use depth_coefficient of global_params.
112
+ Args:
113
+ repeats (int): num_repeat to be calculated.
114
+ global_params (namedtuple): Global params of the model.
115
+ Returns:
116
+ new repeat: New repeat number after calculating.
117
+ """
118
+ multiplier = global_params.depth_coefficient
119
+ if not multiplier:
120
+ return repeats
121
+ # follow the formula transferred from official TensorFlow implementation
122
+ return int(math.ceil(multiplier * repeats))
123
+
124
+
125
+ def drop_connect(inputs, p, training):
126
+ """Drop connect.
127
+ Args:
128
+ input (tensor: BCWH): Input of this structure.
129
+ p (float: 0.0~1.0): Probability of drop connection.
130
+ training (bool): The running mode.
131
+ Returns:
132
+ output: Output after drop connection.
133
+ """
134
+ assert 0 <= p <= 1, 'p must be in range of [0,1]'
135
+
136
+ if not training:
137
+ return inputs
138
+
139
+ batch_size = inputs.shape[0]
140
+ keep_prob = 1 - p
141
+
142
+ # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
143
+ random_tensor = keep_prob
144
+ random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
145
+ binary_tensor = torch.floor(random_tensor)
146
+
147
+ output = inputs / keep_prob * binary_tensor
148
+ return output
149
+
150
+
151
+ def get_width_and_height_from_size(x):
152
+ """Obtain height and width from x.
153
+ Args:
154
+ x (int, tuple or list): Data size.
155
+ Returns:
156
+ size: A tuple or list (H,W).
157
+ """
158
+ if isinstance(x, int):
159
+ return x, x
160
+ if isinstance(x, list) or isinstance(x, tuple):
161
+ return x
162
+ else:
163
+ raise TypeError()
164
+
165
+
166
+ def calculate_output_image_size(input_image_size, stride):
167
+ """Calculates the output image size when using Conv2dSamePadding with a stride.
168
+ Necessary for static padding. Thanks to mannatsingh for pointing this out.
169
+ Args:
170
+ input_image_size (int, tuple or list): Size of input image.
171
+ stride (int, tuple or list): Conv2d operation's stride.
172
+ Returns:
173
+ output_image_size: A list [H,W].
174
+ """
175
+ if input_image_size is None:
176
+ return None
177
+ image_height, image_width = get_width_and_height_from_size(input_image_size)
178
+ stride = stride if isinstance(stride, int) else stride[0]
179
+ image_height = int(math.ceil(image_height / stride))
180
+ image_width = int(math.ceil(image_width / stride))
181
+ return [image_height, image_width]
182
+
183
+
184
+ # Note:
185
+ # The following 'SamePadding' functions make output size equal ceil(input size/stride).
186
+ # Only when stride equals 1, can the output size be the same as input size.
187
+ # Don't be confused by their function names ! ! !
188
+
189
+ def get_same_padding_conv2d(image_size=None):
190
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
191
+ Static padding is necessary for ONNX exporting of models.
192
+ Args:
193
+ image_size (int or tuple): Size of the image.
194
+ Returns:
195
+ Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
196
+ """
197
+ if image_size is None:
198
+ return Conv2dDynamicSamePadding
199
+ else:
200
+ return partial(Conv2dStaticSamePadding, image_size=image_size)
201
+
202
+
203
+ class Conv2dDynamicSamePadding(nn.Conv2d):
204
+ """2D Convolutions like TensorFlow, for a dynamic image size.
205
+ The padding is operated in forward function by calculating dynamically.
206
+ """
207
+
208
+ # Tips for 'SAME' mode padding.
209
+ # Given the following:
210
+ # i: width or height
211
+ # s: stride
212
+ # k: kernel size
213
+ # d: dilation
214
+ # p: padding
215
+ # Output after Conv2d:
216
+ # o = floor((i+p-((k-1)*d+1))/s+1)
217
+ # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
218
+ # => p = (i-1)*s+((k-1)*d+1)-i
219
+
220
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
221
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
222
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
223
+
224
+ def forward(self, x):
225
+ ih, iw = x.size()[-2:]
226
+ kh, kw = self.weight.size()[-2:]
227
+ sh, sw = self.stride
228
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
229
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
230
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
231
+ if pad_h > 0 or pad_w > 0:
232
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
233
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
234
+
235
+
236
+ class Conv2dStaticSamePadding(nn.Conv2d):
237
+ """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
238
+ The padding mudule is calculated in construction function, then used in forward.
239
+ """
240
+
241
+ # With the same calculation as Conv2dDynamicSamePadding
242
+
243
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
244
+ super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
245
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
246
+
247
+ # Calculate padding based on image size and save it
248
+ assert image_size is not None
249
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
250
+ kh, kw = self.weight.size()[-2:]
251
+ sh, sw = self.stride
252
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
253
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
254
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
255
+ if pad_h > 0 or pad_w > 0:
256
+ self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2,
257
+ pad_h // 2, pad_h - pad_h // 2))
258
+ else:
259
+ self.static_padding = nn.Identity()
260
+
261
+ def forward(self, x):
262
+ x = self.static_padding(x)
263
+ x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
264
+ return x
265
+
266
+
267
+ def get_same_padding_maxPool2d(image_size=None):
268
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
269
+ Static padding is necessary for ONNX exporting of models.
270
+ Args:
271
+ image_size (int or tuple): Size of the image.
272
+ Returns:
273
+ MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
274
+ """
275
+ if image_size is None:
276
+ return MaxPool2dDynamicSamePadding
277
+ else:
278
+ return partial(MaxPool2dStaticSamePadding, image_size=image_size)
279
+
280
+
281
+ class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
282
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
283
+ The padding is operated in forward function by calculating dynamically.
284
+ """
285
+
286
+ def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False):
287
+ super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
288
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
289
+ self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
290
+ self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
291
+
292
+ def forward(self, x):
293
+ ih, iw = x.size()[-2:]
294
+ kh, kw = self.kernel_size
295
+ sh, sw = self.stride
296
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
297
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
298
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
299
+ if pad_h > 0 or pad_w > 0:
300
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
301
+ return F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
302
+ self.dilation, self.ceil_mode, self.return_indices)
303
+
304
+
305
+ class MaxPool2dStaticSamePadding(nn.MaxPool2d):
306
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
307
+ The padding mudule is calculated in construction function, then used in forward.
308
+ """
309
+
310
+ def __init__(self, kernel_size, stride, image_size=None, **kwargs):
311
+ super().__init__(kernel_size, stride, **kwargs)
312
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
313
+ self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
314
+ self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
315
+
316
+ # Calculate padding based on image size and save it
317
+ assert image_size is not None
318
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
319
+ kh, kw = self.kernel_size
320
+ sh, sw = self.stride
321
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
322
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
323
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
324
+ if pad_h > 0 or pad_w > 0:
325
+ self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
326
+ else:
327
+ self.static_padding = nn.Identity()
328
+
329
+ def forward(self, x):
330
+ x = self.static_padding(x)
331
+ x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
332
+ self.dilation, self.ceil_mode, self.return_indices)
333
+ return x
334
+
335
+
336
+ ################################################################################
337
+ # Helper functions for loading model params
338
+ ################################################################################
339
+
340
+ # BlockDecoder: A Class for encoding and decoding BlockArgs
341
+ # efficientnet_params: A function to query compound coefficient
342
+ # get_model_params and efficientnet:
343
+ # Functions to get BlockArgs and GlobalParams for efficientnet
344
+ # url_map and url_map_advprop: Dicts of url_map for pretrained weights
345
+ # load_pretrained_weights: A function to load pretrained weights
346
+
347
+ class BlockDecoder(object):
348
+ """Block Decoder for readability,
349
+ straight from the official TensorFlow repository.
350
+ """
351
+
352
+ @staticmethod
353
+ def _decode_block_string(block_string):
354
+ """Get a block through a string notation of arguments.
355
+ Args:
356
+ block_string (str): A string notation of arguments.
357
+ Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
358
+ Returns:
359
+ BlockArgs: The namedtuple defined at the top of this file.
360
+ """
361
+ assert isinstance(block_string, str)
362
+
363
+ ops = block_string.split('_')
364
+ options = {}
365
+ for op in ops:
366
+ splits = re.split(r'(\d.*)', op)
367
+ if len(splits) >= 2:
368
+ key, value = splits[:2]
369
+ options[key] = value
370
+
371
+ # Check stride
372
+ assert (('s' in options and len(options['s']) == 1) or
373
+ (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
374
+
375
+ return BlockArgs(
376
+ num_repeat=int(options['r']),
377
+ kernel_size=int(options['k']),
378
+ stride=[int(options['s'][0])],
379
+ expand_ratio=int(options['e']),
380
+ input_filters=int(options['i']),
381
+ output_filters=int(options['o']),
382
+ se_ratio=float(options['se']) if 'se' in options else None,
383
+ id_skip=('noskip' not in block_string))
384
+
385
+ @staticmethod
386
+ def _encode_block_string(block):
387
+ """Encode a block to a string.
388
+ Args:
389
+ block (namedtuple): A BlockArgs type argument.
390
+ Returns:
391
+ block_string: A String form of BlockArgs.
392
+ """
393
+ args = [
394
+ 'r%d' % block.num_repeat,
395
+ 'k%d' % block.kernel_size,
396
+ 's%d%d' % (block.strides[0], block.strides[1]),
397
+ 'e%s' % block.expand_ratio,
398
+ 'i%d' % block.input_filters,
399
+ 'o%d' % block.output_filters
400
+ ]
401
+ if 0 < block.se_ratio <= 1:
402
+ args.append('se%s' % block.se_ratio)
403
+ if block.id_skip is False:
404
+ args.append('noskip')
405
+ return '_'.join(args)
406
+
407
+ @staticmethod
408
+ def decode(string_list):
409
+ """Decode a list of string notations to specify blocks inside the network.
410
+ Args:
411
+ string_list (list[str]): A list of strings, each string is a notation of block.
412
+ Returns:
413
+ blocks_args: A list of BlockArgs namedtuples of block args.
414
+ """
415
+ assert isinstance(string_list, list)
416
+ blocks_args = []
417
+ for block_string in string_list:
418
+ blocks_args.append(BlockDecoder._decode_block_string(block_string))
419
+ return blocks_args
420
+
421
+ @staticmethod
422
+ def encode(blocks_args):
423
+ """Encode a list of BlockArgs to a list of strings.
424
+ Args:
425
+ blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
426
+ Returns:
427
+ block_strings: A list of strings, each string is a notation of block.
428
+ """
429
+ block_strings = []
430
+ for block in blocks_args:
431
+ block_strings.append(BlockDecoder._encode_block_string(block))
432
+ return block_strings
433
+
434
+
435
+ def efficientnet_params(model_name):
436
+ """Map EfficientNet model name to parameter coefficients.
437
+ Args:
438
+ model_name (str): Model name to be queried.
439
+ Returns:
440
+ params_dict[model_name]: A (width,depth,res,dropout) tuple.
441
+ """
442
+ params_dict = {
443
+ # Coefficients: width,depth,res,dropout
444
+ 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
445
+ 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
446
+ 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
447
+ 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
448
+ 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
449
+ 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
450
+ 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
451
+ 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
452
+ 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
453
+ 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
454
+ }
455
+ return params_dict[model_name]
456
+
457
+
458
+ def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
459
+ dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=False):
460
+ """Create BlockArgs and GlobalParams for efficientnet model.
461
+ Args:
462
+ width_coefficient (float)
463
+ depth_coefficient (float)
464
+ image_size (int)
465
+ dropout_rate (float)
466
+ drop_connect_rate (float)
467
+ num_classes (int)
468
+ Meaning as the name suggests.
469
+ Returns:
470
+ blocks_args, global_params.
471
+ """
472
+
473
+ # Blocks args for the whole model(efficientnet-b0 by default)
474
+ # It will be modified in the construction of EfficientNet Class according to model
475
+ blocks_args = [
476
+ 'r1_k3_s11_e1_i32_o16_se0.25',
477
+ 'r2_k3_s22_e6_i16_o24_se0.25',
478
+ 'r2_k5_s22_e6_i24_o40_se0.25',
479
+ 'r3_k3_s22_e6_i40_o80_se0.25',
480
+ 'r3_k5_s11_e6_i80_o112_se0.25',
481
+ 'r4_k5_s22_e6_i112_o192_se0.25',
482
+ 'r1_k3_s11_e6_i192_o320_se0.25',
483
+ ]
484
+ blocks_args = BlockDecoder.decode(blocks_args)
485
+
486
+ global_params = GlobalParams(
487
+ width_coefficient=width_coefficient,
488
+ depth_coefficient=depth_coefficient,
489
+ image_size=image_size,
490
+ dropout_rate=dropout_rate,
491
+
492
+ num_classes=num_classes,
493
+ batch_norm_momentum=0.99,
494
+ batch_norm_epsilon=1e-3,
495
+ drop_connect_rate=drop_connect_rate,
496
+ depth_divisor=8,
497
+ min_depth=None,
498
+ include_top=include_top,
499
+ )
500
+
501
+ return blocks_args, global_params
502
+
503
+
504
+ def get_model_params(model_name, override_params):
505
+ """Get the block args and global params for a given model name.
506
+ Args:
507
+ model_name (str): Model's name.
508
+ override_params (dict): A dict to modify global_params.
509
+ Returns:
510
+ blocks_args, global_params
511
+ """
512
+ if model_name.startswith('efficientnet'):
513
+ w, d, s, p = efficientnet_params(model_name)
514
+ # note: all models have drop connect rate = 0.2
515
+ blocks_args, global_params = efficientnet(
516
+ width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
517
+ else:
518
+ raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
519
+ if override_params:
520
+ # ValueError will be raised here if override_params has fields not included in global_params.
521
+ global_params = global_params._replace(**override_params)
522
+ return blocks_args, global_params
523
+
524
+
525
+ # train with Standard methods
526
+ # check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks)
527
+ url_map = {
528
+ 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
529
+ 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
530
+ 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
531
+ 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
532
+ 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
533
+ 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
534
+ 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
535
+ 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
536
+ }
537
+
538
+ # train with Adversarial Examples(AdvProp)
539
+ # check more details in paper(Adversarial Examples Improve Image Recognition)
540
+ url_map_advprop = {
541
+ 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
542
+ 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
543
+ 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
544
+ 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
545
+ 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
546
+ 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
547
+ 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
548
+ 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
549
+ 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
550
+ }
551
+
552
+ # TODO: add the petrained weights url map of 'efficientnet-l2'
553
+
554
+
555
+ def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True):
556
+ """Loads pretrained weights from weights path or download using url.
557
+ Args:
558
+ model (Module): The whole model of efficientnet.
559
+ model_name (str): Model name of efficientnet.
560
+ weights_path (None or str):
561
+ str: path to pretrained weights file on the local disk.
562
+ None: use pretrained weights downloaded from the Internet.
563
+ load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
564
+ advprop (bool): Whether to load pretrained weights
565
+ trained with advprop (valid when weights_path is None).
566
+ """
567
+ if isinstance(weights_path, str):
568
+ state_dict = torch.load(weights_path)
569
+ else:
570
+ # AutoAugment or Advprop (different preprocessing)
571
+ url_map_ = url_map_advprop if advprop else url_map
572
+ state_dict = model_zoo.load_url(url_map_[model_name])
573
+
574
+ if load_fc:
575
+ ret = model.load_state_dict(state_dict, strict=False)
576
+ # assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
577
+ else:
578
+ state_dict.pop('_fc.weight')
579
+ state_dict.pop('_fc.bias')
580
+ ret = model.load_state_dict(state_dict, strict=False)
581
+ # assert set(ret.missing_keys) == set(
582
+ # ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
583
+ # assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
584
+
585
+ if verbose:
586
+ print('Loaded pretrained weights for {}'.format(model_name))
harmonizer/src/model/enhancer.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms.functional as tf
5
+
6
+ from .filter import Filter
7
+ from .backbone import EfficientBackboneCommon
8
+ from .module import CascadeArgumentRegressor, FilterPerformer
9
+
10
+
11
+ class Enhancer(nn.Module):
12
+ def __init__(self):
13
+ super(Enhancer, self).__init__()
14
+
15
+ self.input_size = (256, 256)
16
+ self.filter_types = [
17
+ Filter.BRIGHTNESS,
18
+ Filter.CONTRAST,
19
+ Filter.SATURATION,
20
+ Filter.HIGHLIGHT,
21
+ Filter.SHADOW,
22
+ ]
23
+
24
+ self.backbone = EfficientBackboneCommon.from_name('efficientnet-b0')
25
+ self.regressor = CascadeArgumentRegressor(1280, 160, 1, len(self.filter_types))
26
+ self.performer = FilterPerformer(self.filter_types)
27
+
28
+ def predict_arguments(self, x, mask):
29
+ x = F.interpolate(x, self.input_size, mode='bilinear', align_corners=False)
30
+ enc2x, enc4x, enc8x, enc16x, enc32x = self.backbone(x)
31
+ arguments = self.regressor(enc32x)
32
+
33
+ return arguments
34
+
35
+ def restore_image(self, x, mask, arguments):
36
+ assert len(arguments) == len(self.filter_types)
37
+
38
+ arguments = [torch.clamp(arg, -1, 1).view(-1, 1, 1, 1) for arg in arguments]
39
+ return self.performer.restore(x, mask, arguments)
40
+
harmonizer/src/model/filter.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from enum import Enum
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import kornia
8
+
9
+
10
+ class BrightnessFilter(nn.Module):
11
+ def __init__(self):
12
+ super(BrightnessFilter, self).__init__()
13
+ self.epsilon = 1e-6
14
+
15
+ def forward(self, image, x):
16
+ """
17
+ Arguments:
18
+ image (tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
19
+ x (tensor [n, 1, 1, 1]): brightness argument with values between [-1, 1]
20
+ """
21
+
22
+ # convert image from RGB to HSV
23
+ image = kornia.color.rgb_to_hsv(image)
24
+ h = image[:,0:1,:,:]
25
+ s = image[:,1:2,:,:]
26
+ v = image[:,2:3,:,:]
27
+
28
+ # calculate alpha
29
+ amask = (x >= 0).float()
30
+ alpha = (1 / ((1 - x) + self.epsilon)) * amask + (x + 1) * (1 - amask)
31
+
32
+ # adjust the V channel
33
+ v = v * alpha
34
+
35
+ # convert image from HSV to RGB
36
+ image = torch.cat((h, s, v), dim=1)
37
+ image = kornia.color.hsv_to_rgb(image)
38
+
39
+ # clip pixel values to [0, 1]
40
+ image = torch.clamp(image, 0.0, 1.0)
41
+
42
+ return image
43
+
44
+
45
+ class ContrastFilter(nn.Module):
46
+ def __init__(self):
47
+ super(ContrastFilter, self).__init__()
48
+
49
+ def forward(self, image, x):
50
+ """
51
+ Arguments:
52
+ image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
53
+ x (tensor [n, 1, 1, 1]): contrast argument with values between [-1, 1]
54
+ """
55
+
56
+ # calculate the mean of the image as the threshold
57
+ threshold = torch.mean(image, dim=(1, 2, 3), keepdim=True)
58
+
59
+ # pre-process x if it is a positive value
60
+ mask = (x.detach() > 0).float()
61
+ x_ = 255 / (256 - torch.floor(x * 255)) - 1
62
+ x_ = x * (1 - mask) + x_ * mask
63
+
64
+ # modify the contrast of the image
65
+ image = image + (image - threshold) * x_
66
+
67
+ # clip pixel values to [0, 1]
68
+ image = torch.clamp(image, 0.0, 1.0)
69
+
70
+ return image
71
+
72
+
73
+ class SaturationFilter(nn.Module):
74
+ def __init__(self):
75
+ super(SaturationFilter, self).__init__()
76
+
77
+ self.epsilon = 1e-6
78
+
79
+ def forward(self, image, x):
80
+ """
81
+ Arguments:
82
+ image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
83
+ x (tensor [n, 1, 1, 1]): saturation argument with values between [-1, 1]
84
+ """
85
+
86
+ # calculate the basic properties of the image
87
+ cmin = torch.min(image, dim=1, keepdim=True)[0]
88
+ cmax = torch.max(image, dim=1, keepdim=True)[0]
89
+ var = cmax - cmin
90
+ ran = cmax + cmin
91
+ mean = ran / 2
92
+
93
+ is_positive = (x.detach() >= 0).float()
94
+
95
+ # calculate s
96
+ m = (mean < 0.5).float()
97
+ s = (var / (ran + self.epsilon)) * m + (var / (2 - ran + self.epsilon)) * (1 - m)
98
+
99
+ # if x is positive
100
+ m = ((x + s) > 1).float()
101
+ a_pos = s * m + (1 - x) * (1 - m)
102
+ a_pos = 1 / (a_pos + self.epsilon) - 1
103
+
104
+ # if x is negtive
105
+ a_neg = 1 + x
106
+
107
+ a = a_pos * is_positive + a_neg * (1 - is_positive)
108
+ image = image * is_positive + mean * (1 - is_positive) + (image - mean) * a
109
+
110
+ # clip pixel values to [0, 1]
111
+ image = torch.clamp(image, 0.0, 1.0)
112
+
113
+ return image
114
+
115
+
116
+ class TemperatureFilter(nn.Module):
117
+ def __init__(self):
118
+ super(TemperatureFilter, self).__init__()
119
+
120
+ self.epsilon = 1e-6
121
+
122
+ def forward(self, image, x):
123
+ """
124
+ Arguments:
125
+ image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
126
+ x (tensor [n, 1, 1, 1]): color temperature argument with values between [-1, 1]
127
+ """
128
+ # split the R/G/B channels
129
+ R, G, B = image[:, 0:1, ...], image[:, 1:2, ...], image[:, 2:3, ...]
130
+
131
+ # calculate the mean of each channel
132
+ meanR = torch.mean(R, dim=(2, 3), keepdim=True)
133
+ meanG = torch.mean(G, dim=(2, 3), keepdim=True)
134
+ meanB = torch.mean(B, dim=(2, 3), keepdim=True)
135
+
136
+ # calculate correction factors
137
+ gray = (meanR + meanG + meanB) / 3
138
+ coefR = gray / (meanR + self.epsilon)
139
+ coefG = gray / (meanG + self.epsilon)
140
+ coefB = gray / (meanB + self.epsilon)
141
+ aR = 1 - coefR
142
+ aG = 1 - coefG
143
+ aB = 1 - coefB
144
+
145
+ # adjust temperature
146
+ is_positive = (x.detach() > 0).float()
147
+ is_negative = (x.detach() < 0).float()
148
+ is_zero = (x.detach() == 0).float()
149
+
150
+ meanR_ = meanR + x * torch.sign(x) * is_negative
151
+ meanG_ = meanG + x * torch.sign(x) * 0.5 * (1 - is_zero)
152
+ meanB_ = meanB + x * torch.sign(x) * is_positive
153
+ gray_ = (meanR_ + meanG_ + meanB_) / 3
154
+
155
+ coefR_ = gray_ / (meanR_ + self.epsilon) + aR
156
+ coefG_ = gray_ / (meanG_ + self.epsilon) + aG
157
+ coefB_ = gray_ / (meanB_ + self.epsilon) + aB
158
+
159
+ R_ = coefR_ * R
160
+ G_ = coefG_ * G
161
+ B_ = coefB_ * B
162
+
163
+ # the RGB image with the adjusted brightness
164
+ image = torch.cat((R_, G_, B_), dim=1)
165
+
166
+ # clip pixel values to [0, 1]
167
+ image = torch.clamp(image, 0.0, 1.0)
168
+
169
+ return image
170
+
171
+
172
+ class HighlightFilter(nn.Module):
173
+ def __init__(self):
174
+ super(HighlightFilter, self).__init__()
175
+
176
+ def forward(self, image, x):
177
+ """
178
+ Arguments:
179
+ image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
180
+ x (tensor [n, 1, 1, 1]): highlight argument with values between [-1, 1]
181
+ """
182
+
183
+ x = x + 1
184
+
185
+ image = kornia.enhance.invert(image, image.detach() * 0 + 1)
186
+ image = torch.clamp(torch.pow(image + 1e-9, x), 0.0, 1.0)
187
+ image = kornia.enhance.invert(image, image.detach() * 0 + 1)
188
+
189
+ # clip pixel values to [0, 1]
190
+ image = torch.clamp(image, 0.0, 1.0)
191
+
192
+ return image
193
+
194
+
195
+ class ShadowFilter(nn.Module):
196
+ def __init__(self):
197
+ super(ShadowFilter, self).__init__()
198
+
199
+ def forward(self, image, x):
200
+ """
201
+ Arguments:
202
+ image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
203
+ x (tensor [n, 1, 1, 1]): shadow argument with values between [-1, 1]
204
+ """
205
+
206
+ x = -x + 1
207
+ image = torch.clamp(torch.pow(image + 1e-9, x), 0.0, 1.0)
208
+
209
+ # clip pixel values to [0, 1]
210
+ image = torch.clamp(image, 0.0, 1.0)
211
+
212
+ return image
213
+
214
+
215
+ class Filter(Enum):
216
+ BRIGHTNESS = 1
217
+ CONTRAST = 2
218
+ SATURATION = 3
219
+ TEMPERATURE = 4
220
+ HIGHLIGHT = 5
221
+ SHADOW = 6
222
+
223
+
224
+ FILTER_MODULES = {
225
+ Filter.BRIGHTNESS: BrightnessFilter,
226
+ Filter.CONTRAST: ContrastFilter,
227
+ Filter.SATURATION: SaturationFilter,
228
+ Filter.TEMPERATURE: TemperatureFilter,
229
+ Filter.HIGHLIGHT: HighlightFilter,
230
+ Filter.SHADOW: ShadowFilter,
231
+ }
harmonizer/src/model/harmonizer.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms.functional as tf
5
+
6
+ from .filter import Filter
7
+ from .backbone import EfficientBackbone
8
+ from .module import CascadeArgumentRegressor, FilterPerformer
9
+
10
+
11
+ class Harmonizer(nn.Module):
12
+ def __init__(self):
13
+ super(Harmonizer, self).__init__()
14
+
15
+ self.input_size = (256, 256)
16
+ self.filter_types = [
17
+ Filter.TEMPERATURE,
18
+ Filter.BRIGHTNESS,
19
+ Filter.CONTRAST,
20
+ Filter.SATURATION,
21
+ Filter.HIGHLIGHT,
22
+ Filter.SHADOW,
23
+ ]
24
+
25
+ self.backbone = EfficientBackbone.from_name('efficientnet-b0')
26
+ self.regressor = CascadeArgumentRegressor(1280, 160, 1, len(self.filter_types))
27
+ self.performer = FilterPerformer(self.filter_types)
28
+
29
+ def predict_arguments(self, comp, mask):
30
+ comp = F.interpolate(comp, self.input_size, mode='bilinear', align_corners=False)
31
+ mask = F.interpolate(mask, self.input_size, mode='bilinear', align_corners=False)
32
+
33
+ fg = torch.cat((comp, mask), dim=1)
34
+ bg = torch.cat((comp, (1 - mask)), dim=1)
35
+
36
+ enc2x, enc4x, enc8x, enc16x, enc32x = self.backbone(fg, bg)
37
+ arguments = self.regressor(enc32x)
38
+ return arguments
39
+
40
+ def restore_image(self, comp, mask, arguments):
41
+ assert len(arguments) == len(self.filter_types)
42
+
43
+ arguments = [torch.clamp(arg, -1, 1).view(-1, 1, 1, 1) for arg in arguments]
44
+ return self.performer.restore(comp, mask, arguments)
harmonizer/src/model/module.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ from enum import Enum
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ from .filter import Filter, FILTER_MODULES
10
+
11
+
12
+ class CascadeArgumentRegressor(nn.Module):
13
+ def __init__(self, in_channels, base_channels, out_channels, head_num):
14
+ super(CascadeArgumentRegressor, self).__init__()
15
+ self.in_channels = in_channels
16
+ self.base_channels = base_channels
17
+ self.out_channels = out_channels
18
+ self.head_num = head_num
19
+
20
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
21
+
22
+ self.f = nn.Linear(self.in_channels, 160)
23
+ self.g = nn.Linear(self.in_channels, self.base_channels)
24
+
25
+ self.headers = nn.ModuleList()
26
+ for i in range(0, self.head_num):
27
+ self.headers.append(
28
+ nn.ModuleList([
29
+ nn.Linear(160 + self.base_channels, self.base_channels),
30
+ nn.Linear(self.base_channels, self.out_channels),
31
+ ])
32
+ )
33
+
34
+ def forward(self, x):
35
+ x = self.pool(x)
36
+ n, c, _, _ = x.shape
37
+ x = x.view(n, c)
38
+
39
+ f = self.f(x)
40
+ g = self.g(x)
41
+
42
+ pred_args = []
43
+ for i in range(0, self.head_num):
44
+ g = self.headers[i][0](torch.cat((f, g), dim=1))
45
+ pred_args.append(self.headers[i][1](g))
46
+
47
+ return pred_args
48
+
49
+
50
+ class FilterPerformer(nn.Module):
51
+ def __init__(self, filter_types):
52
+ super(FilterPerformer, self).__init__()
53
+
54
+ self.filters = [FILTER_MODULES[filter_type]() for filter_type in filter_types]
55
+
56
+ def forward(self):
57
+ pass
58
+
59
+ def restore(self, x, mask, arguments):
60
+ assert len(self.filters) == len(arguments)
61
+
62
+ outputs = []
63
+ _image = x
64
+ for filter, arg in zip(self.filters, arguments):
65
+ _image = filter(_image, arg)
66
+ outputs.append(_image * mask + x * (1 - mask))
67
+
68
+ return outputs
69
+
70
+ def adjust(self, image, mask, arguments):
71
+ assert len(self.filters) == len(arguments)
72
+
73
+ outputs = []
74
+ _image = image
75
+ for filter, arg in zip(reversed(self.filters), reversed(arguments)):
76
+ _image = filter(_image, arg)
77
+ outputs.append(_image * mask + image * (1 - mask))
78
+
79
+ return outputs
80
+
harmonizer/src/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ tqdm
2
+ numpy
3
+ Pillow
4
+ argparse
5
+ scikit-image == 0.19.2
6
+ kornia
harmonizer/src/train/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Quick Start - Training Harmonizer
2
+
3
+
4
+ 1. Download the iHarmony4 dataset and put it in the folder `./harmonizer/dataset/`
5
+ 2. Pre-process the iHarmony4 dataset for training.
6
+ We provide the processed Hday2night subset as an example at [this link](https://drive.google.com/drive/folders/1HtrmUlFsT1yIfJ2JkGWwAwFDlv8StD6e?usp=sharing).
7
+ You should convert other subsets to the same format for training.
8
+ Otherwise, you need to implement new dataset loaders in the file `./harmonizer/data.py` to load datasets with other formats.
9
+ 3. Run the training script by:
10
+ ```
11
+ cd ./harmonizer
12
+ python -m script.train
13
+ ```
14
+ You can config the training arguments in the script.
harmonizer/src/train/harmonizer/__init__.py ADDED
File without changes
harmonizer/src/train/harmonizer/criterion.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ import torchtask
6
+
7
+
8
+ def add_parser_arguments(parser):
9
+ torchtask.criterion_template.add_parser_arguments(parser)
10
+
11
+
12
+
13
+ def harmonizer_loss():
14
+ return HarmonizerLoss
15
+
16
+
17
+ class AbsoluteLoss(nn.Module):
18
+ def __init__(self, epsilon=1e-6):
19
+ super(AbsoluteLoss, self).__init__()
20
+ self.epsilon = epsilon
21
+
22
+ def forward(self, pred, gt):
23
+ loss = torch.sqrt((pred - gt) ** 2 + self.epsilon)
24
+ return loss
25
+
26
+
27
+ class HarmonizerLoss(torchtask.criterion_template.TaskCriterion):
28
+ def __init__(self, args):
29
+ super(HarmonizerLoss, self).__init__(args)
30
+
31
+ self.l1 = AbsoluteLoss()
32
+ self.l2 = nn.MSELoss(reduction='none')
33
+
34
+ def forward(self, pred, gt, inp):
35
+ pred_outputs, = pred
36
+ x, mask = inp
37
+
38
+ assert len(pred_outputs) == len(gt)
39
+
40
+ image_losses = []
41
+ for pred_, gt_ in zip(pred_outputs, gt):
42
+ l1_loss = torch.sum(self.l1(pred_, gt_) * mask, dim=(1, 2, 3)) / (torch.sum(mask, dim=(1, 2, 3)) + 1e-6)
43
+ l2_loss = torch.sum(self.l2(pred_, gt_) * mask, dim=(1, 2, 3)) / (torch.sum(mask, dim=(1, 2, 3)) + 1e-6) * 10
44
+ loss = (l1_loss + l2_loss)
45
+ image_losses.append(loss)
46
+
47
+ return image_losses
harmonizer/src/train/harmonizer/data.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import random
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ from torchvision import transforms
8
+
9
+ import torchtask
10
+
11
+
12
+ def add_parser_arguments(parser):
13
+ torchtask.data_template.add_parser_arguments(parser)
14
+
15
+
16
+ def harmonizer_iharmony4():
17
+ return HarmonizerIHarmony4
18
+
19
+
20
+ def original_iharmony4():
21
+ return OriginalIHarmony4
22
+
23
+
24
+ def resize(img, size):
25
+ interp = cv2.INTER_LINEAR
26
+
27
+ return Image.fromarray(
28
+ cv2.resize(np.array(img).astype('uint8'), size, interpolation=interp))
29
+
30
+
31
+ im_train_transform = transforms.Compose([
32
+ transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.03),
33
+ transforms.ToTensor(),
34
+ ])
35
+
36
+ im_val_transform = transforms.Compose([
37
+ transforms.ToTensor(),
38
+ ])
39
+
40
+
41
+ class HarmonizerIHarmony4(torchtask.data_template.TaskDataset):
42
+ def __init__(self, args, is_train):
43
+ super(HarmonizerIHarmony4, self).__init__(args, is_train)
44
+
45
+ self.im_dir = os.path.join(self.root_dir, 'image')
46
+ self.mask_dir = os.path.join(self.root_dir, 'mask')
47
+
48
+ if not os.path.exists(self.mask_dir):
49
+ self.mask_dir = os.path.join(self.root_dir, 'matte')
50
+
51
+ self.sample_list = [_ for _ in os.listdir(self.im_dir)]
52
+ self.idxs = [_ for _ in range(0, len(self.sample_list))]
53
+
54
+ self.im_size = self.args.im_size
55
+
56
+ self.rotation = True if self.is_train else False
57
+ self.fliplr = True if self.is_train else False
58
+
59
+ def __getitem__(self, idx):
60
+ image_path = os.path.join(self.im_dir, self.sample_list[idx])
61
+ mask_path = os.path.join(self.mask_dir, self.sample_list[idx])
62
+
63
+ image = self.im_loader.load(image_path)
64
+ mask = self.im_loader.load(mask_path)
65
+
66
+ width, height = image.size
67
+
68
+ # resize to self.im_size
69
+ image = resize(image, (self.im_size, self.im_size))
70
+ mask = resize(mask, (self.im_size, self.im_size))
71
+
72
+ # convert to np array and scale to [0, 1]
73
+ image = np.array(image).astype('float32') / 255.0
74
+ mask = np.array(mask).astype('float32') / 255.0
75
+
76
+ # check image shape
77
+ if len(mask.shape) == 3:
78
+ mask = mask[:, :, -1]
79
+
80
+ if len(image.shape) == 2:
81
+ image = image[:, :, None]
82
+ if image.shape[2] == 1:
83
+ image = np.repeat(image, 3, axis=2)
84
+ elif image.shape[2] == 4:
85
+ image = image[:, :, 0:3]
86
+
87
+ # random rotate
88
+ rerotation = 0
89
+ if self.rotation and random.randint(0, 1) == 0:
90
+ rotate_num = random.randint(1, 3)
91
+ rerotation = 4 - rotate_num
92
+ image = np.rot90(image, k=rotate_num).copy()
93
+ mask = np.rot90(mask, k=rotate_num).copy()
94
+
95
+ # random flip
96
+ if self.fliplr and (random.randint(0, 1) == 0):
97
+ image = np.fliplr(image).copy()
98
+ mask = np.fliplr(mask).copy()
99
+
100
+ image = Image.fromarray((image * 255.0).astype('uint8'))
101
+ if self.is_train:
102
+ image = im_train_transform(image)
103
+ else:
104
+ image = im_val_transform(image)
105
+
106
+ mask = mask[None, :, :]
107
+ adjusted = image.numpy() * -1
108
+
109
+ return (adjusted, mask), (image, )
110
+
111
+
112
+ class OriginalIHarmony4(torchtask.data_template.TaskDataset):
113
+ def __init__(self, args, is_train):
114
+ super(OriginalIHarmony4, self).__init__(args, is_train)
115
+
116
+ self.adjusted_dir = os.path.join(self.root_dir, 'comp')
117
+ self.mask_dir = os.path.join(self.root_dir, 'mask')
118
+ self.im_dir = os.path.join(self.root_dir, 'image')
119
+
120
+ self.sample_list = [_ for _ in os.listdir(self.adjusted_dir)]
121
+ self.idxs = [_ for _ in range(0, len(self.sample_list))]
122
+
123
+ self.im_size = self.args.im_size
124
+
125
+ self.rotation = True if self.is_train else False
126
+ self.fliplr = True if self.is_train else False
127
+
128
+ def __getitem__(self, idx):
129
+ sname = self.sample_list[idx]
130
+ adjusted_path = os.path.join(self.adjusted_dir, sname)
131
+ image_path = os.path.join(self.im_dir, sname)
132
+ mask_path = os.path.join(self.mask_dir, sname)
133
+
134
+ if not os.path.exists(image_path):
135
+ prefix = '_'.join(sname.split('_')[:-1])
136
+ image_path = os.path.join(self.im_dir, '{0}.jpg'.format(prefix))
137
+ mask_path = os.path.join(self.mask_dir, '{0}.jpg'.format(prefix))
138
+
139
+ adjusted = self.im_loader.load(adjusted_path)
140
+ image = self.im_loader.load(image_path)
141
+ mask = self.im_loader.load(mask_path)
142
+
143
+ width, height = image.size
144
+
145
+
146
+ # resize to self.im_size
147
+ adjusted = resize(adjusted, (self.im_size, self.im_size))
148
+ image = resize(image, (self.im_size, self.im_size))
149
+ mask = resize(mask, (self.im_size, self.im_size))
150
+
151
+ # convert to np array and scale to [0, 1]
152
+ adjusted = np.array(adjusted).astype('float32') / 255.0
153
+ image = np.array(image).astype('float32') / 255.0
154
+ mask = np.array(mask).astype('float32') / 255.0
155
+
156
+ # check image shape
157
+ if len(mask.shape) == 3:
158
+ mask = mask[:, :, -1]
159
+
160
+ if len(image.shape) == 2:
161
+ image = image[:, :, None]
162
+ if image.shape[2] == 1:
163
+ image = np.repeat(image, 3, axis=2)
164
+ elif image.shape[2] == 4:
165
+ image = image[:, :, 0:3]
166
+
167
+ if len(adjusted.shape) == 2:
168
+ adjusted = adjusted[:, :, None]
169
+ if adjusted.shape[2] == 1:
170
+ adjusted = np.repeat(adjusted, 3, axis=2)
171
+ elif adjusted.shape[2] == 4:
172
+ adjusted = adjusted[:, :, 0:3]
173
+
174
+ # random rotate
175
+ rerotation = 0
176
+ if self.rotation and random.randint(0, 1) == 0:
177
+ rotate_num = random.randint(1, 3)
178
+ rerotation = 4 - rotate_num
179
+ adjusted = np.rot90(adjusted, k=rotate_num).copy()
180
+ image = np.rot90(image, k=rotate_num).copy()
181
+ mask = np.rot90(mask, k=rotate_num).copy()
182
+
183
+ # random flip
184
+ if self.fliplr and (random.randint(0, 1) == 0):
185
+ adjusted = np.fliplr(adjusted).copy()
186
+ image = np.fliplr(image).copy()
187
+ mask = np.fliplr(mask).copy()
188
+
189
+ adjusted = Image.fromarray((adjusted * 255.0).astype('uint8'))
190
+ image = Image.fromarray((image * 255.0).astype('uint8'))
191
+
192
+ # NOTE: do not add random color adjustement here
193
+ adjusted = im_val_transform(adjusted)
194
+ image = im_val_transform(image)
195
+
196
+ mask = mask[None, :, :]
197
+
198
+ return (adjusted, mask), (image, )
harmonizer/src/train/harmonizer/func.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ import skimage
5
+
6
+ import torchtask
7
+
8
+ def task_func():
9
+ return HarmonizationFunc
10
+
11
+
12
+ class HarmonizationFunc(torchtask.func_template.TaskFunc):
13
+ def __init__(self, args):
14
+ super(HarmonizationFunc, self).__init__(args)
15
+
16
+ def metrics(self, pred_image, gt_image, mask, meters, id_str=''):
17
+ n, c, h, w = pred_image.shape
18
+
19
+ assert n == 1
20
+
21
+ total_pixels = h * w
22
+ fg_pixels = int(torch.sum(mask, dim=(2, 3))[0][0].cpu().numpy())
23
+
24
+ pred_image = torch.clamp(pred_image * 255, 0, 255)
25
+ gt_image = torch.clamp(gt_image * 255, 0, 255)
26
+
27
+ pred_image = pred_image[0].permute(1, 2, 0).detach().cpu().numpy()
28
+ gt_image = gt_image[0].permute(1, 2, 0).detach().cpu().numpy()
29
+ mask = mask[0].permute(1, 2, 0).detach().cpu().numpy()
30
+
31
+ batch_mse = skimage.metrics.mean_squared_error(pred_image, gt_image)
32
+ meters.update('{0}_{1}_mse'.format(id_str, self.METRIC_STR), batch_mse)
33
+
34
+ batch_fmse = skimage.metrics.mean_squared_error(pred_image * mask, gt_image * mask) * total_pixels / fg_pixels
35
+ meters.update('{0}_{1}_fmse'.format(id_str, self.METRIC_STR), batch_fmse)
36
+
37
+ batch_psnr = skimage.metrics.peak_signal_noise_ratio(pred_image, gt_image, data_range=pred_image.max() - pred_image.min())
38
+ meters.update('{0}_{1}_psnr'.format(id_str, self.METRIC_STR), batch_psnr)
39
+
40
+ batch_ssim = skimage.metrics.structural_similarity(pred_image, gt_image, multichannel=True)
41
+ meters.update('{0}_{1}_ssim'.format(id_str, self.METRIC_STR), batch_ssim)
harmonizer/src/train/harmonizer/model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import torchtask
5
+
6
+ from module import harmonizer as _harmonizer
7
+
8
+
9
+ def add_parser_arguments(parser):
10
+ torchtask.model_template.add_parser_arguments(parser)
11
+
12
+
13
+ def harmonizer():
14
+ return Harmonizer
15
+
16
+
17
+ class Harmonizer(torchtask.model_template.TaskModel):
18
+ def __init__(self, args):
19
+ super(Harmonizer, self).__init__(args)
20
+
21
+ self.model = _harmonizer.Harmonizer()
22
+ self.param_groups = [
23
+ {'params': filter(lambda p:p.requires_grad, self.model.backbone.parameters()), 'lr': self.args.lr},
24
+ {'params': filter(lambda p:p.requires_grad, self.model.regressor.parameters()), 'lr': self.args.lr},
25
+ {'params': filter(lambda p:p.requires_grad, self.model.performer.parameters()), 'lr': self.args.lr},
26
+ ]
27
+
28
+ def forward(self, inp):
29
+ resulter, debugger = {}, {}
30
+ x, mask = inp
31
+ pred = self.model(x, mask)
32
+ resulter['outputs'] = pred
33
+ return resulter, debugger
34
+
35
+ def restore(self, x, mask, arguments):
36
+ with torch.no_grad():
37
+ return self.model.restore_image(x, mask, arguments)
38
+
39
+ def adjust(self, x, mask, arguments):
40
+ with torch.no_grad():
41
+ return self.model.adjust_image(x, mask, arguments)
harmonizer/src/train/harmonizer/module/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .harmonizer import Harmonizer
harmonizer/src/train/harmonizer/module/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .efficientnet import EfficientBackbone, EfficientBackboneCommon
harmonizer/src/train/harmonizer/module/backbone/efficientnet/__init__.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This EfficientNet implementation comes from:
3
+ Author: lukemelas (github username)
4
+ Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from .model import EfficientNet
11
+ from .utils import round_filters, get_same_padding_conv2d
12
+
13
+
14
+ # for EfficientNet
15
+ class EfficientBackbone(EfficientNet):
16
+ def __init__(self, blocks_args=None, global_params=None):
17
+ super(EfficientBackbone, self).__init__(blocks_args, global_params)
18
+
19
+ self.enc_channels = [16, 24, 40, 112, 1280]
20
+
21
+ # ------------------------------------------------------------
22
+ # delete the useless layers
23
+ # ------------------------------------------------------------
24
+ del self._conv_stem
25
+ del self._bn0
26
+ # ------------------------------------------------------------
27
+
28
+ # ------------------------------------------------------------
29
+ # parameters for the input layers
30
+ # ------------------------------------------------------------
31
+ bn_mom = 1 - self._global_params.batch_norm_momentum
32
+ bn_eps = self._global_params.batch_norm_epsilon
33
+
34
+ in_channels = 4
35
+ out_channels = round_filters(32, self._global_params)
36
+ out_channels = int(out_channels / 2)
37
+ # ------------------------------------------------------------
38
+
39
+ # ------------------------------------------------------------
40
+ # define the input layers
41
+ # ------------------------------------------------------------
42
+ image_size = global_params.image_size
43
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
44
+ self._conv_fg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
45
+ self._bn_fg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
46
+
47
+ self._conv_bg = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
48
+ self._bn_bg = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
49
+ # ------------------------------------------------------------
50
+
51
+ def forward(self, xfg, xbg):
52
+ xfg = self._swish(self._bn_fg(self._conv_fg(xfg)))
53
+ xbg = self._swish(self._bn_bg(self._conv_bg(xbg)))
54
+
55
+ x = torch.cat((xfg, xbg), dim=1)
56
+
57
+ block_outputs = []
58
+ for idx, block in enumerate(self._blocks):
59
+ drop_connect_rate = self._global_params.drop_connect_rate
60
+ drop_connect_rate *= float(idx) / len(self._blocks)
61
+ x = block(x, drop_connect_rate=drop_connect_rate)
62
+ block_outputs.append(x)
63
+
64
+ # Head
65
+ x = self._swish(self._bn1(self._conv_head(x)))
66
+
67
+ return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x
68
+
69
+
70
+ # for EfficientNet
71
+ class EfficientBackboneCommon(EfficientNet):
72
+ def __init__(self, blocks_args=None, global_params=None):
73
+ super(EfficientBackboneCommon, self).__init__(blocks_args, global_params)
74
+
75
+ self.enc_channels = [16, 24, 40, 112, 1280]
76
+
77
+ # ------------------------------------------------------------
78
+ # delete the useless layers
79
+ # ------------------------------------------------------------
80
+ del self._conv_stem
81
+ del self._bn0
82
+ # ------------------------------------------------------------
83
+
84
+ # ------------------------------------------------------------
85
+ # parameters for the input layers
86
+ # ------------------------------------------------------------
87
+ bn_mom = 1 - self._global_params.batch_norm_momentum
88
+ bn_eps = self._global_params.batch_norm_epsilon
89
+
90
+ in_channels = 3
91
+ out_channels = round_filters(32, self._global_params)
92
+ # ------------------------------------------------------------
93
+
94
+ # ------------------------------------------------------------
95
+ # define the input layers
96
+ # ------------------------------------------------------------
97
+ image_size = global_params.image_size
98
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
99
+ self._conv = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
100
+ self._bn = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
101
+ # ------------------------------------------------------------
102
+
103
+ def forward(self, x):
104
+ x = self._swish(self._bn(self._conv(x)))
105
+
106
+ block_outputs = []
107
+ for idx, block in enumerate(self._blocks):
108
+ drop_connect_rate = self._global_params.drop_connect_rate
109
+ drop_connect_rate *= float(idx) / len(self._blocks)
110
+ x = block(x, drop_connect_rate=drop_connect_rate)
111
+ block_outputs.append(x)
112
+
113
+ # Head
114
+ x = self._swish(self._bn1(self._conv_head(x)))
115
+
116
+ return block_outputs[0], block_outputs[2], block_outputs[4], block_outputs[10], x
harmonizer/src/train/harmonizer/module/backbone/efficientnet/model.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """model.py - Model and module class for EfficientNet.
2
+ They are built to mirror those in the official TensorFlow implementation.
3
+ """
4
+
5
+ # Author: lukemelas (github username)
6
+ # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7
+ # With adjustments and added comments by workingcoder (github username).
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+ from .utils import (
13
+ round_filters,
14
+ round_repeats,
15
+ drop_connect,
16
+ get_same_padding_conv2d,
17
+ get_model_params,
18
+ efficientnet_params,
19
+ load_pretrained_weights,
20
+ Swish,
21
+ MemoryEfficientSwish,
22
+ calculate_output_image_size
23
+ )
24
+
25
+
26
+ VALID_MODELS = (
27
+ 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
28
+ 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
29
+ 'efficientnet-b8',
30
+
31
+ # Support the construction of 'efficientnet-l2' without pretrained weights
32
+ 'efficientnet-l2'
33
+ )
34
+
35
+
36
+ class MBConvBlock(nn.Module):
37
+ """Mobile Inverted Residual Bottleneck Block.
38
+ Args:
39
+ block_args (namedtuple): BlockArgs, defined in utils.py.
40
+ global_params (namedtuple): GlobalParam, defined in utils.py.
41
+ image_size (tuple or list): [image_height, image_width].
42
+ References:
43
+ [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
44
+ [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
45
+ [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
46
+ """
47
+
48
+ def __init__(self, block_args, global_params, image_size=None):
49
+ super().__init__()
50
+ self._block_args = block_args
51
+ self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
52
+ self._bn_eps = global_params.batch_norm_epsilon
53
+ self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
54
+ self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
55
+
56
+ # Expansion phase (Inverted Bottleneck)
57
+ inp = self._block_args.input_filters # number of input channels
58
+ oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
59
+ if self._block_args.expand_ratio != 1:
60
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
61
+ self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
62
+ self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
63
+ # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
64
+
65
+ # Depthwise convolution phase
66
+ k = self._block_args.kernel_size
67
+ s = self._block_args.stride
68
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
69
+ self._depthwise_conv = Conv2d(
70
+ in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
71
+ kernel_size=k, stride=s, bias=False)
72
+ self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
73
+ image_size = calculate_output_image_size(image_size, s)
74
+
75
+ # Squeeze and Excitation layer, if desired
76
+ if self.has_se:
77
+ Conv2d = get_same_padding_conv2d(image_size=(1, 1))
78
+ num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
79
+ self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
80
+ self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
81
+
82
+ # Pointwise convolution phase
83
+ final_oup = self._block_args.output_filters
84
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
85
+ self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
86
+ self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
87
+ self._swish = MemoryEfficientSwish()
88
+
89
+ def forward(self, inputs, drop_connect_rate=None):
90
+ """MBConvBlock's forward function.
91
+ Args:
92
+ inputs (tensor): Input tensor.
93
+ drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
94
+ Returns:
95
+ Output of this block after processing.
96
+ """
97
+
98
+ # Expansion and Depthwise Convolution
99
+ x = inputs
100
+ if self._block_args.expand_ratio != 1:
101
+ x = self._expand_conv(inputs)
102
+ x = self._bn0(x)
103
+ x = self._swish(x)
104
+
105
+ x = self._depthwise_conv(x)
106
+ x = self._bn1(x)
107
+ x = self._swish(x)
108
+
109
+ # Squeeze and Excitation
110
+ if self.has_se:
111
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
112
+ x_squeezed = self._se_reduce(x_squeezed)
113
+ x_squeezed = self._swish(x_squeezed)
114
+ x_squeezed = self._se_expand(x_squeezed)
115
+ x = torch.sigmoid(x_squeezed) * x
116
+
117
+ # Pointwise Convolution
118
+ x = self._project_conv(x)
119
+ x = self._bn2(x)
120
+
121
+ # Skip connection and drop connect
122
+ input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
123
+ if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
124
+ # The combination of skip connection and drop connect brings about stochastic depth.
125
+ if drop_connect_rate:
126
+ x = drop_connect(x, p=drop_connect_rate, training=self.training)
127
+ x = x + inputs # skip connection
128
+ return x
129
+
130
+ def set_swish(self, memory_efficient=True):
131
+ """Sets swish function as memory efficient (for training) or standard (for export).
132
+ Args:
133
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
134
+ """
135
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
136
+
137
+
138
+ class EfficientNet(nn.Module):
139
+ """EfficientNet model.
140
+ Most easily loaded with the .from_name or .from_pretrained methods.
141
+ Args:
142
+ blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
143
+ global_params (namedtuple): A set of GlobalParams shared between blocks.
144
+ References:
145
+ [1] https://arxiv.org/abs/1905.11946 (EfficientNet)
146
+ Example:
147
+ >>> import torch
148
+ >>> from efficientnet.model import EfficientNet
149
+ >>> inputs = torch.rand(1, 3, 224, 224)
150
+ >>> model = EfficientNet.from_pretrained('efficientnet-b0')
151
+ >>> model.eval()
152
+ >>> outputs = model(inputs)
153
+ """
154
+
155
+ def __init__(self, blocks_args=None, global_params=None):
156
+ super().__init__()
157
+ assert isinstance(blocks_args, list), 'blocks_args should be a list'
158
+ assert len(blocks_args) > 0, 'block args must be greater than 0'
159
+ self._global_params = global_params
160
+ self._blocks_args = blocks_args
161
+
162
+ # Batch norm parameters
163
+ bn_mom = 1 - self._global_params.batch_norm_momentum
164
+ bn_eps = self._global_params.batch_norm_epsilon
165
+
166
+ # Get stem static or dynamic convolution depending on image size
167
+ image_size = global_params.image_size
168
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
169
+
170
+ # Stem
171
+ in_channels = 3 # rgb
172
+ out_channels = round_filters(32, self._global_params) # number of output channels
173
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
174
+ self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
175
+ image_size = calculate_output_image_size(image_size, 2)
176
+
177
+ # Build blocks
178
+ self._blocks = nn.ModuleList([])
179
+ for block_args in self._blocks_args:
180
+
181
+ # Update block input and output filters based on depth multiplier.
182
+ block_args = block_args._replace(
183
+ input_filters=round_filters(block_args.input_filters, self._global_params),
184
+ output_filters=round_filters(block_args.output_filters, self._global_params),
185
+ num_repeat=round_repeats(block_args.num_repeat, self._global_params)
186
+ )
187
+
188
+ # The first block needs to take care of stride and filter size increase.
189
+ self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
190
+ image_size = calculate_output_image_size(image_size, block_args.stride)
191
+ if block_args.num_repeat > 1: # modify block_args to keep same output size
192
+ block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
193
+ for _ in range(block_args.num_repeat - 1):
194
+ self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
195
+ # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
196
+
197
+ # Head
198
+ in_channels = block_args.output_filters # output of final block
199
+ out_channels = round_filters(1280, self._global_params)
200
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
201
+ self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
202
+ self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
203
+
204
+ # Final linear layer
205
+ self._avg_pooling = nn.AdaptiveAvgPool2d(1)
206
+ if self._global_params.include_top:
207
+ self._dropout = nn.Dropout(self._global_params.dropout_rate)
208
+ self._fc = nn.Linear(out_channels, self._global_params.num_classes)
209
+
210
+ # set activation to memory efficient swish by default
211
+ self._swish = MemoryEfficientSwish()
212
+
213
+ def set_swish(self, memory_efficient=True):
214
+ """Sets swish function as memory efficient (for training) or standard (for export).
215
+ Args:
216
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
217
+ """
218
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
219
+ for block in self._blocks:
220
+ block.set_swish(memory_efficient)
221
+
222
+ def extract_endpoints(self, inputs):
223
+ """Use convolution layer to extract features
224
+ from reduction levels i in [1, 2, 3, 4, 5].
225
+ Args:
226
+ inputs (tensor): Input tensor.
227
+ Returns:
228
+ Dictionary of last intermediate features
229
+ with reduction levels i in [1, 2, 3, 4, 5].
230
+ Example:
231
+ >>> import torch
232
+ >>> from efficientnet.model import EfficientNet
233
+ >>> inputs = torch.rand(1, 3, 224, 224)
234
+ >>> model = EfficientNet.from_pretrained('efficientnet-b0')
235
+ >>> endpoints = model.extract_endpoints(inputs)
236
+ >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
237
+ >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
238
+ >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
239
+ >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
240
+ >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
241
+ >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
242
+ """
243
+ endpoints = dict()
244
+
245
+ # Stem
246
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
247
+ prev_x = x
248
+
249
+ # Blocks
250
+ for idx, block in enumerate(self._blocks):
251
+ drop_connect_rate = self._global_params.drop_connect_rate
252
+ if drop_connect_rate:
253
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
254
+ x = block(x, drop_connect_rate=drop_connect_rate)
255
+ if prev_x.size(2) > x.size(2):
256
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
257
+ elif idx == len(self._blocks) - 1:
258
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
259
+ prev_x = x
260
+
261
+ # Head
262
+ x = self._swish(self._bn1(self._conv_head(x)))
263
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
264
+
265
+ return endpoints
266
+
267
+ def extract_features(self, inputs):
268
+ """use convolution layer to extract feature .
269
+ Args:
270
+ inputs (tensor): Input tensor.
271
+ Returns:
272
+ Output of the final convolution
273
+ layer in the efficientnet model.
274
+ """
275
+ # Stem
276
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
277
+
278
+ # Blocks
279
+ for idx, block in enumerate(self._blocks):
280
+ drop_connect_rate = self._global_params.drop_connect_rate
281
+ if drop_connect_rate:
282
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
283
+ x = block(x, drop_connect_rate=drop_connect_rate)
284
+
285
+ # Head
286
+ x = self._swish(self._bn1(self._conv_head(x)))
287
+
288
+ return x
289
+
290
+ def forward(self, inputs):
291
+ """EfficientNet's forward function.
292
+ Calls extract_features to extract features, applies final linear layer, and returns logits.
293
+ Args:
294
+ inputs (tensor): Input tensor.
295
+ Returns:
296
+ Output of this model after processing.
297
+ """
298
+ # Convolution layers
299
+ x = self.extract_features(inputs)
300
+ # Pooling and final linear layer
301
+ x = self._avg_pooling(x)
302
+ if self._global_params.include_top:
303
+ x = x.flatten(start_dim=1)
304
+ x = self._dropout(x)
305
+ x = self._fc(x)
306
+ return x
307
+
308
+ @classmethod
309
+ def from_name(cls, model_name, in_channels=3, **override_params):
310
+ """Create an efficientnet model according to name.
311
+ Args:
312
+ model_name (str): Name for efficientnet.
313
+ in_channels (int): Input data's channel number.
314
+ override_params (other key word params):
315
+ Params to override model's global_params.
316
+ Optional key:
317
+ 'width_coefficient', 'depth_coefficient',
318
+ 'image_size', 'dropout_rate',
319
+ 'num_classes', 'batch_norm_momentum',
320
+ 'batch_norm_epsilon', 'drop_connect_rate',
321
+ 'depth_divisor', 'min_depth'
322
+ Returns:
323
+ An efficientnet model.
324
+ """
325
+ cls._check_model_name_is_valid(model_name)
326
+ blocks_args, global_params = get_model_params(model_name, override_params)
327
+ model = cls(blocks_args, global_params)
328
+ model._change_in_channels(in_channels)
329
+ return model
330
+
331
+ @classmethod
332
+ def from_pretrained(cls, model_name, weights_path=None, advprop=False,
333
+ in_channels=3, num_classes=1000, **override_params):
334
+ """Create an efficientnet model according to name.
335
+ Args:
336
+ model_name (str): Name for efficientnet.
337
+ weights_path (None or str):
338
+ str: path to pretrained weights file on the local disk.
339
+ None: use pretrained weights downloaded from the Internet.
340
+ advprop (bool):
341
+ Whether to load pretrained weights
342
+ trained with advprop (valid when weights_path is None).
343
+ in_channels (int): Input data's channel number.
344
+ num_classes (int):
345
+ Number of categories for classification.
346
+ It controls the output size for final linear layer.
347
+ override_params (other key word params):
348
+ Params to override model's global_params.
349
+ Optional key:
350
+ 'width_coefficient', 'depth_coefficient',
351
+ 'image_size', 'dropout_rate',
352
+ 'batch_norm_momentum',
353
+ 'batch_norm_epsilon', 'drop_connect_rate',
354
+ 'depth_divisor', 'min_depth'
355
+ Returns:
356
+ A pretrained efficientnet model.
357
+ """
358
+ model = cls.from_name(model_name, num_classes=num_classes, **override_params)
359
+ load_pretrained_weights(model, model_name, weights_path=weights_path,
360
+ load_fc=(num_classes == 1000), advprop=advprop)
361
+ model._change_in_channels(in_channels)
362
+ return model
363
+
364
+ @classmethod
365
+ def get_image_size(cls, model_name):
366
+ """Get the input image size for a given efficientnet model.
367
+ Args:
368
+ model_name (str): Name for efficientnet.
369
+ Returns:
370
+ Input image size (resolution).
371
+ """
372
+ cls._check_model_name_is_valid(model_name)
373
+ _, _, res, _ = efficientnet_params(model_name)
374
+ return res
375
+
376
+ @classmethod
377
+ def _check_model_name_is_valid(cls, model_name):
378
+ """Validates model name.
379
+ Args:
380
+ model_name (str): Name for efficientnet.
381
+ Returns:
382
+ bool: Is a valid name or not.
383
+ """
384
+ if model_name not in VALID_MODELS:
385
+ raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
386
+
387
+ def _change_in_channels(self, in_channels):
388
+ """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
389
+ Args:
390
+ in_channels (int): Input data's channel number.
391
+ """
392
+ if in_channels != 3:
393
+ Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
394
+ out_channels = round_filters(32, self._global_params)
395
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
harmonizer/src/train/harmonizer/module/backbone/efficientnet/utils.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """utils.py - Helper functions for building the model and for loading model parameters.
2
+ These helper functions are built to mirror those in the official TensorFlow implementation.
3
+ """
4
+
5
+ # Author: lukemelas (github username)
6
+ # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7
+ # With adjustments and added comments by workingcoder (github username).
8
+
9
+ import re
10
+ import math
11
+ import collections
12
+ from functools import partial
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.utils import model_zoo
17
+
18
+
19
+ ################################################################################
20
+ # Help functions for model architecture
21
+ ################################################################################
22
+
23
+ # GlobalParams and BlockArgs: Two namedtuples
24
+ # Swish and MemoryEfficientSwish: Two implementations of the method
25
+ # round_filters and round_repeats:
26
+ # Functions to calculate params for scaling model width and depth ! ! !
27
+ # get_width_and_height_from_size and calculate_output_image_size
28
+ # drop_connect: A structural design
29
+ # get_same_padding_conv2d:
30
+ # Conv2dDynamicSamePadding
31
+ # Conv2dStaticSamePadding
32
+ # get_same_padding_maxPool2d:
33
+ # MaxPool2dDynamicSamePadding
34
+ # MaxPool2dStaticSamePadding
35
+ # It's an additional function, not used in EfficientNet,
36
+ # but can be used in other model (such as EfficientDet).
37
+
38
+ # Parameters for the entire model (stem, all blocks, and head)
39
+ GlobalParams = collections.namedtuple('GlobalParams', [
40
+ 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
41
+ 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
42
+ 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])
43
+
44
+ # Parameters for an individual model block
45
+ BlockArgs = collections.namedtuple('BlockArgs', [
46
+ 'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
47
+ 'input_filters', 'output_filters', 'se_ratio', 'id_skip'])
48
+
49
+ # Set GlobalParams and BlockArgs's defaults
50
+ GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
51
+ BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
52
+
53
+ # Swish activation function
54
+ if hasattr(nn, 'SiLU'):
55
+ Swish = nn.SiLU
56
+ else:
57
+ # For compatibility with old PyTorch versions
58
+ class Swish(nn.Module):
59
+ def forward(self, x):
60
+ return x * torch.sigmoid(x)
61
+
62
+
63
+ # A memory-efficient implementation of Swish function
64
+ class SwishImplementation(torch.autograd.Function):
65
+ @staticmethod
66
+ def forward(ctx, i):
67
+ result = i * torch.sigmoid(i)
68
+ ctx.save_for_backward(i)
69
+ return result
70
+
71
+ @staticmethod
72
+ def backward(ctx, grad_output):
73
+ i = ctx.saved_tensors[0]
74
+ sigmoid_i = torch.sigmoid(i)
75
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
76
+
77
+
78
+ class MemoryEfficientSwish(nn.Module):
79
+ def forward(self, x):
80
+ return SwishImplementation.apply(x)
81
+
82
+
83
+ def round_filters(filters, global_params):
84
+ """Calculate and round number of filters based on width multiplier.
85
+ Use width_coefficient, depth_divisor and min_depth of global_params.
86
+ Args:
87
+ filters (int): Filters number to be calculated.
88
+ global_params (namedtuple): Global params of the model.
89
+ Returns:
90
+ new_filters: New filters number after calculating.
91
+ """
92
+ multiplier = global_params.width_coefficient
93
+ if not multiplier:
94
+ return filters
95
+ # TODO: modify the params names.
96
+ # maybe the names (width_divisor,min_width)
97
+ # are more suitable than (depth_divisor,min_depth).
98
+ divisor = global_params.depth_divisor
99
+ min_depth = global_params.min_depth
100
+ filters *= multiplier
101
+ min_depth = min_depth or divisor # pay attention to this line when using min_depth
102
+ # follow the formula transferred from official TensorFlow implementation
103
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
104
+ if new_filters < 0.9 * filters: # prevent rounding by more than 10%
105
+ new_filters += divisor
106
+ return int(new_filters)
107
+
108
+
109
+ def round_repeats(repeats, global_params):
110
+ """Calculate module's repeat number of a block based on depth multiplier.
111
+ Use depth_coefficient of global_params.
112
+ Args:
113
+ repeats (int): num_repeat to be calculated.
114
+ global_params (namedtuple): Global params of the model.
115
+ Returns:
116
+ new repeat: New repeat number after calculating.
117
+ """
118
+ multiplier = global_params.depth_coefficient
119
+ if not multiplier:
120
+ return repeats
121
+ # follow the formula transferred from official TensorFlow implementation
122
+ return int(math.ceil(multiplier * repeats))
123
+
124
+
125
+ def drop_connect(inputs, p, training):
126
+ """Drop connect.
127
+ Args:
128
+ input (tensor: BCWH): Input of this structure.
129
+ p (float: 0.0~1.0): Probability of drop connection.
130
+ training (bool): The running mode.
131
+ Returns:
132
+ output: Output after drop connection.
133
+ """
134
+ assert 0 <= p <= 1, 'p must be in range of [0,1]'
135
+
136
+ if not training:
137
+ return inputs
138
+
139
+ batch_size = inputs.shape[0]
140
+ keep_prob = 1 - p
141
+
142
+ # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
143
+ random_tensor = keep_prob
144
+ random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
145
+ binary_tensor = torch.floor(random_tensor)
146
+
147
+ output = inputs / keep_prob * binary_tensor
148
+ return output
149
+
150
+
151
+ def get_width_and_height_from_size(x):
152
+ """Obtain height and width from x.
153
+ Args:
154
+ x (int, tuple or list): Data size.
155
+ Returns:
156
+ size: A tuple or list (H,W).
157
+ """
158
+ if isinstance(x, int):
159
+ return x, x
160
+ if isinstance(x, list) or isinstance(x, tuple):
161
+ return x
162
+ else:
163
+ raise TypeError()
164
+
165
+
166
+ def calculate_output_image_size(input_image_size, stride):
167
+ """Calculates the output image size when using Conv2dSamePadding with a stride.
168
+ Necessary for static padding. Thanks to mannatsingh for pointing this out.
169
+ Args:
170
+ input_image_size (int, tuple or list): Size of input image.
171
+ stride (int, tuple or list): Conv2d operation's stride.
172
+ Returns:
173
+ output_image_size: A list [H,W].
174
+ """
175
+ if input_image_size is None:
176
+ return None
177
+ image_height, image_width = get_width_and_height_from_size(input_image_size)
178
+ stride = stride if isinstance(stride, int) else stride[0]
179
+ image_height = int(math.ceil(image_height / stride))
180
+ image_width = int(math.ceil(image_width / stride))
181
+ return [image_height, image_width]
182
+
183
+
184
+ # Note:
185
+ # The following 'SamePadding' functions make output size equal ceil(input size/stride).
186
+ # Only when stride equals 1, can the output size be the same as input size.
187
+ # Don't be confused by their function names ! ! !
188
+
189
+ def get_same_padding_conv2d(image_size=None):
190
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
191
+ Static padding is necessary for ONNX exporting of models.
192
+ Args:
193
+ image_size (int or tuple): Size of the image.
194
+ Returns:
195
+ Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
196
+ """
197
+ if image_size is None:
198
+ return Conv2dDynamicSamePadding
199
+ else:
200
+ return partial(Conv2dStaticSamePadding, image_size=image_size)
201
+
202
+
203
+ class Conv2dDynamicSamePadding(nn.Conv2d):
204
+ """2D Convolutions like TensorFlow, for a dynamic image size.
205
+ The padding is operated in forward function by calculating dynamically.
206
+ """
207
+
208
+ # Tips for 'SAME' mode padding.
209
+ # Given the following:
210
+ # i: width or height
211
+ # s: stride
212
+ # k: kernel size
213
+ # d: dilation
214
+ # p: padding
215
+ # Output after Conv2d:
216
+ # o = floor((i+p-((k-1)*d+1))/s+1)
217
+ # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
218
+ # => p = (i-1)*s+((k-1)*d+1)-i
219
+
220
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
221
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
222
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
223
+
224
+ def forward(self, x):
225
+ ih, iw = x.size()[-2:]
226
+ kh, kw = self.weight.size()[-2:]
227
+ sh, sw = self.stride
228
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
229
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
230
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
231
+ if pad_h > 0 or pad_w > 0:
232
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
233
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
234
+
235
+
236
+ class Conv2dStaticSamePadding(nn.Conv2d):
237
+ """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
238
+ The padding mudule is calculated in construction function, then used in forward.
239
+ """
240
+
241
+ # With the same calculation as Conv2dDynamicSamePadding
242
+
243
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
244
+ super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
245
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
246
+
247
+ # Calculate padding based on image size and save it
248
+ assert image_size is not None
249
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
250
+ kh, kw = self.weight.size()[-2:]
251
+ sh, sw = self.stride
252
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
253
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
254
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
255
+ if pad_h > 0 or pad_w > 0:
256
+ self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2,
257
+ pad_h // 2, pad_h - pad_h // 2))
258
+ else:
259
+ self.static_padding = nn.Identity()
260
+
261
+ def forward(self, x):
262
+ x = self.static_padding(x)
263
+ x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
264
+ return x
265
+
266
+
267
+ def get_same_padding_maxPool2d(image_size=None):
268
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
269
+ Static padding is necessary for ONNX exporting of models.
270
+ Args:
271
+ image_size (int or tuple): Size of the image.
272
+ Returns:
273
+ MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
274
+ """
275
+ if image_size is None:
276
+ return MaxPool2dDynamicSamePadding
277
+ else:
278
+ return partial(MaxPool2dStaticSamePadding, image_size=image_size)
279
+
280
+
281
+ class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
282
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
283
+ The padding is operated in forward function by calculating dynamically.
284
+ """
285
+
286
+ def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False):
287
+ super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
288
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
289
+ self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
290
+ self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
291
+
292
+ def forward(self, x):
293
+ ih, iw = x.size()[-2:]
294
+ kh, kw = self.kernel_size
295
+ sh, sw = self.stride
296
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
297
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
298
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
299
+ if pad_h > 0 or pad_w > 0:
300
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
301
+ return F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
302
+ self.dilation, self.ceil_mode, self.return_indices)
303
+
304
+
305
+ class MaxPool2dStaticSamePadding(nn.MaxPool2d):
306
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
307
+ The padding mudule is calculated in construction function, then used in forward.
308
+ """
309
+
310
+ def __init__(self, kernel_size, stride, image_size=None, **kwargs):
311
+ super().__init__(kernel_size, stride, **kwargs)
312
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
313
+ self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
314
+ self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
315
+
316
+ # Calculate padding based on image size and save it
317
+ assert image_size is not None
318
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
319
+ kh, kw = self.kernel_size
320
+ sh, sw = self.stride
321
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
322
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
323
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
324
+ if pad_h > 0 or pad_w > 0:
325
+ self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
326
+ else:
327
+ self.static_padding = nn.Identity()
328
+
329
+ def forward(self, x):
330
+ x = self.static_padding(x)
331
+ x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
332
+ self.dilation, self.ceil_mode, self.return_indices)
333
+ return x
334
+
335
+
336
+ ################################################################################
337
+ # Helper functions for loading model params
338
+ ################################################################################
339
+
340
+ # BlockDecoder: A Class for encoding and decoding BlockArgs
341
+ # efficientnet_params: A function to query compound coefficient
342
+ # get_model_params and efficientnet:
343
+ # Functions to get BlockArgs and GlobalParams for efficientnet
344
+ # url_map and url_map_advprop: Dicts of url_map for pretrained weights
345
+ # load_pretrained_weights: A function to load pretrained weights
346
+
347
+ class BlockDecoder(object):
348
+ """Block Decoder for readability,
349
+ straight from the official TensorFlow repository.
350
+ """
351
+
352
+ @staticmethod
353
+ def _decode_block_string(block_string):
354
+ """Get a block through a string notation of arguments.
355
+ Args:
356
+ block_string (str): A string notation of arguments.
357
+ Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
358
+ Returns:
359
+ BlockArgs: The namedtuple defined at the top of this file.
360
+ """
361
+ assert isinstance(block_string, str)
362
+
363
+ ops = block_string.split('_')
364
+ options = {}
365
+ for op in ops:
366
+ splits = re.split(r'(\d.*)', op)
367
+ if len(splits) >= 2:
368
+ key, value = splits[:2]
369
+ options[key] = value
370
+
371
+ # Check stride
372
+ assert (('s' in options and len(options['s']) == 1) or
373
+ (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
374
+
375
+ return BlockArgs(
376
+ num_repeat=int(options['r']),
377
+ kernel_size=int(options['k']),
378
+ stride=[int(options['s'][0])],
379
+ expand_ratio=int(options['e']),
380
+ input_filters=int(options['i']),
381
+ output_filters=int(options['o']),
382
+ se_ratio=float(options['se']) if 'se' in options else None,
383
+ id_skip=('noskip' not in block_string))
384
+
385
+ @staticmethod
386
+ def _encode_block_string(block):
387
+ """Encode a block to a string.
388
+ Args:
389
+ block (namedtuple): A BlockArgs type argument.
390
+ Returns:
391
+ block_string: A String form of BlockArgs.
392
+ """
393
+ args = [
394
+ 'r%d' % block.num_repeat,
395
+ 'k%d' % block.kernel_size,
396
+ 's%d%d' % (block.strides[0], block.strides[1]),
397
+ 'e%s' % block.expand_ratio,
398
+ 'i%d' % block.input_filters,
399
+ 'o%d' % block.output_filters
400
+ ]
401
+ if 0 < block.se_ratio <= 1:
402
+ args.append('se%s' % block.se_ratio)
403
+ if block.id_skip is False:
404
+ args.append('noskip')
405
+ return '_'.join(args)
406
+
407
+ @staticmethod
408
+ def decode(string_list):
409
+ """Decode a list of string notations to specify blocks inside the network.
410
+ Args:
411
+ string_list (list[str]): A list of strings, each string is a notation of block.
412
+ Returns:
413
+ blocks_args: A list of BlockArgs namedtuples of block args.
414
+ """
415
+ assert isinstance(string_list, list)
416
+ blocks_args = []
417
+ for block_string in string_list:
418
+ blocks_args.append(BlockDecoder._decode_block_string(block_string))
419
+ return blocks_args
420
+
421
+ @staticmethod
422
+ def encode(blocks_args):
423
+ """Encode a list of BlockArgs to a list of strings.
424
+ Args:
425
+ blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
426
+ Returns:
427
+ block_strings: A list of strings, each string is a notation of block.
428
+ """
429
+ block_strings = []
430
+ for block in blocks_args:
431
+ block_strings.append(BlockDecoder._encode_block_string(block))
432
+ return block_strings
433
+
434
+
435
+ def efficientnet_params(model_name):
436
+ """Map EfficientNet model name to parameter coefficients.
437
+ Args:
438
+ model_name (str): Model name to be queried.
439
+ Returns:
440
+ params_dict[model_name]: A (width,depth,res,dropout) tuple.
441
+ """
442
+ params_dict = {
443
+ # Coefficients: width,depth,res,dropout
444
+ 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
445
+ 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
446
+ 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
447
+ 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
448
+ 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
449
+ 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
450
+ 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
451
+ 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
452
+ 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
453
+ 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
454
+ }
455
+ return params_dict[model_name]
456
+
457
+
458
+ def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
459
+ dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=False):
460
+ """Create BlockArgs and GlobalParams for efficientnet model.
461
+ Args:
462
+ width_coefficient (float)
463
+ depth_coefficient (float)
464
+ image_size (int)
465
+ dropout_rate (float)
466
+ drop_connect_rate (float)
467
+ num_classes (int)
468
+ Meaning as the name suggests.
469
+ Returns:
470
+ blocks_args, global_params.
471
+ """
472
+
473
+ # Blocks args for the whole model(efficientnet-b0 by default)
474
+ # It will be modified in the construction of EfficientNet Class according to model
475
+ blocks_args = [
476
+ 'r1_k3_s11_e1_i32_o16_se0.25',
477
+ 'r2_k3_s22_e6_i16_o24_se0.25',
478
+ 'r2_k5_s22_e6_i24_o40_se0.25',
479
+ 'r3_k3_s22_e6_i40_o80_se0.25',
480
+ 'r3_k5_s11_e6_i80_o112_se0.25',
481
+ 'r4_k5_s22_e6_i112_o192_se0.25',
482
+ 'r1_k3_s11_e6_i192_o320_se0.25',
483
+ ]
484
+ blocks_args = BlockDecoder.decode(blocks_args)
485
+
486
+ global_params = GlobalParams(
487
+ width_coefficient=width_coefficient,
488
+ depth_coefficient=depth_coefficient,
489
+ image_size=image_size,
490
+ dropout_rate=dropout_rate,
491
+
492
+ num_classes=num_classes,
493
+ batch_norm_momentum=0.99,
494
+ batch_norm_epsilon=1e-3,
495
+ drop_connect_rate=drop_connect_rate,
496
+ depth_divisor=8,
497
+ min_depth=None,
498
+ include_top=include_top,
499
+ )
500
+
501
+ return blocks_args, global_params
502
+
503
+
504
+ def get_model_params(model_name, override_params):
505
+ """Get the block args and global params for a given model name.
506
+ Args:
507
+ model_name (str): Model's name.
508
+ override_params (dict): A dict to modify global_params.
509
+ Returns:
510
+ blocks_args, global_params
511
+ """
512
+ if model_name.startswith('efficientnet'):
513
+ w, d, s, p = efficientnet_params(model_name)
514
+ # note: all models have drop connect rate = 0.2
515
+ blocks_args, global_params = efficientnet(
516
+ width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
517
+ else:
518
+ raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
519
+ if override_params:
520
+ # ValueError will be raised here if override_params has fields not included in global_params.
521
+ global_params = global_params._replace(**override_params)
522
+ return blocks_args, global_params
523
+
524
+
525
+ # train with Standard methods
526
+ # check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks)
527
+ url_map = {
528
+ 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
529
+ 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
530
+ 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
531
+ 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
532
+ 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
533
+ 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
534
+ 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
535
+ 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
536
+ }
537
+
538
+ # train with Adversarial Examples(AdvProp)
539
+ # check more details in paper(Adversarial Examples Improve Image Recognition)
540
+ url_map_advprop = {
541
+ 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
542
+ 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
543
+ 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
544
+ 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
545
+ 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
546
+ 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
547
+ 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
548
+ 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
549
+ 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
550
+ }
551
+
552
+ # TODO: add the petrained weights url map of 'efficientnet-l2'
553
+
554
+
555
+ def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True):
556
+ """Loads pretrained weights from weights path or download using url.
557
+ Args:
558
+ model (Module): The whole model of efficientnet.
559
+ model_name (str): Model name of efficientnet.
560
+ weights_path (None or str):
561
+ str: path to pretrained weights file on the local disk.
562
+ None: use pretrained weights downloaded from the Internet.
563
+ load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
564
+ advprop (bool): Whether to load pretrained weights
565
+ trained with advprop (valid when weights_path is None).
566
+ """
567
+ if isinstance(weights_path, str):
568
+ state_dict = torch.load(weights_path)
569
+ else:
570
+ # AutoAugment or Advprop (different preprocessing)
571
+ url_map_ = url_map_advprop if advprop else url_map
572
+ state_dict = model_zoo.load_url(url_map_[model_name])
573
+
574
+ if load_fc:
575
+ ret = model.load_state_dict(state_dict, strict=False)
576
+ # assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
577
+ else:
578
+ state_dict.pop('_fc.weight')
579
+ state_dict.pop('_fc.bias')
580
+ ret = model.load_state_dict(state_dict, strict=False)
581
+ # assert set(ret.missing_keys) == set(
582
+ # ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
583
+ # assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
584
+
585
+ if verbose:
586
+ print('Loaded pretrained weights for {}'.format(model_name))
harmonizer/src/train/harmonizer/module/filter.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from enum import Enum
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import kornia
8
+
9
+
10
+ class BrightnessFilter(nn.Module):
11
+ def __init__(self):
12
+ super(BrightnessFilter, self).__init__()
13
+ self.epsilon = 1e-6
14
+
15
+ def forward(self, image, x):
16
+ """
17
+ Arguments:
18
+ image (tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
19
+ x (tensor [n, 1, 1, 1]): brightness argument with values between [-1, 1]
20
+ """
21
+
22
+ # convert image from RGB to HSV
23
+ image = kornia.color.rgb_to_hsv(image)
24
+ h = image[:,0:1,:,:]
25
+ s = image[:,1:2,:,:]
26
+ v = image[:,2:3,:,:]
27
+
28
+ # calculate alpha
29
+ amask = (x >= 0).float()
30
+ alpha = (1 / ((1 - x) + self.epsilon)) * amask + (x + 1) * (1 - amask)
31
+
32
+ # adjust the V channel
33
+ v = v * alpha
34
+
35
+ # convert image from HSV to RGB
36
+ image = torch.cat((h, s, v), dim=1)
37
+ image = kornia.color.hsv_to_rgb(image)
38
+
39
+ # clip pixel values to [0, 1]
40
+ image = torch.clamp(image, 0.0, 1.0)
41
+
42
+ return image
43
+
44
+
45
+ class ContrastFilter(nn.Module):
46
+ def __init__(self):
47
+ super(ContrastFilter, self).__init__()
48
+
49
+ def forward(self, image, x):
50
+ """
51
+ Arguments:
52
+ image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
53
+ x (tensor [n, 1, 1, 1]): contrast argument with values between [-1, 1]
54
+ """
55
+
56
+ # calculate the mean of the image as the threshold
57
+ threshold = torch.mean(image, dim=(1, 2, 3), keepdim=True)
58
+
59
+ # pre-process x if it is a positive value
60
+ mask = (x.detach() > 0).float()
61
+ x_ = 255 / (256 - torch.floor(x * 255)) - 1
62
+ x_ = x * (1 - mask) + x_ * mask
63
+
64
+ # modify the contrast of the image
65
+ image = image + (image - threshold) * x_
66
+
67
+ # clip pixel values to [0, 1]
68
+ image = torch.clamp(image, 0.0, 1.0)
69
+
70
+ return image
71
+
72
+
73
+ class SaturationFilter(nn.Module):
74
+ def __init__(self):
75
+ super(SaturationFilter, self).__init__()
76
+
77
+ self.epsilon = 1e-6
78
+
79
+ def forward(self, image, x):
80
+ """
81
+ Arguments:
82
+ image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
83
+ x (tensor [n, 1, 1, 1]): saturation argument with values between [-1, 1]
84
+ """
85
+
86
+ # calculate the basic properties of the image
87
+ cmin = torch.min(image, dim=1, keepdim=True)[0]
88
+ cmax = torch.max(image, dim=1, keepdim=True)[0]
89
+ var = cmax - cmin
90
+ ran = cmax + cmin
91
+ mean = ran / 2
92
+
93
+ is_positive = (x.detach() >= 0).float()
94
+
95
+ # calculate s
96
+ m = (mean < 0.5).float()
97
+ s = (var / (ran + self.epsilon)) * m + (var / (2 - ran + self.epsilon)) * (1 - m)
98
+
99
+ # if x is positive
100
+ m = ((x + s) > 1).float()
101
+ a_pos = s * m + (1 - x) * (1 - m)
102
+ a_pos = 1 / (a_pos + self.epsilon) - 1
103
+
104
+ # if x is negtive
105
+ a_neg = 1 + x
106
+
107
+ a = a_pos * is_positive + a_neg * (1 - is_positive)
108
+ image = image * is_positive + mean * (1 - is_positive) + (image - mean) * a
109
+
110
+ # clip pixel values to [0, 1]
111
+ image = torch.clamp(image, 0.0, 1.0)
112
+
113
+ return image
114
+
115
+
116
+ class TemperatureFilter(nn.Module):
117
+ def __init__(self):
118
+ super(TemperatureFilter, self).__init__()
119
+
120
+ self.epsilon = 1e-6
121
+
122
+ def forward(self, image, x):
123
+ """
124
+ Arguments:
125
+ image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
126
+ x (tensor [n, 1, 1, 1]): color temperature argument with values between [-1, 1]
127
+ """
128
+ # split the R/G/B channels
129
+ R, G, B = image[:, 0:1, ...], image[:, 1:2, ...], image[:, 2:3, ...]
130
+
131
+ # calculate the mean of each channel
132
+ meanR = torch.mean(R, dim=(2, 3), keepdim=True)
133
+ meanG = torch.mean(G, dim=(2, 3), keepdim=True)
134
+ meanB = torch.mean(B, dim=(2, 3), keepdim=True)
135
+
136
+ # calculate correction factors
137
+ gray = (meanR + meanG + meanB) / 3
138
+ coefR = gray / (meanR + self.epsilon)
139
+ coefG = gray / (meanG + self.epsilon)
140
+ coefB = gray / (meanB + self.epsilon)
141
+ aR = 1 - coefR
142
+ aG = 1 - coefG
143
+ aB = 1 - coefB
144
+
145
+ # adjust temperature
146
+ is_positive = (x.detach() > 0).float()
147
+ is_negative = (x.detach() < 0).float()
148
+ is_zero = (x.detach() == 0).float()
149
+
150
+ meanR_ = meanR + x * torch.sign(x) * is_negative
151
+ meanG_ = meanG + x * torch.sign(x) * 0.5 * (1 - is_zero)
152
+ meanB_ = meanB + x * torch.sign(x) * is_positive
153
+ gray_ = (meanR_ + meanG_ + meanB_) / 3
154
+
155
+ coefR_ = gray_ / (meanR_ + self.epsilon) + aR
156
+ coefG_ = gray_ / (meanG_ + self.epsilon) + aG
157
+ coefB_ = gray_ / (meanB_ + self.epsilon) + aB
158
+
159
+ R_ = coefR_ * R
160
+ G_ = coefG_ * G
161
+ B_ = coefB_ * B
162
+
163
+ # the RGB image with the adjusted brightness
164
+ image = torch.cat((R_, G_, B_), dim=1)
165
+
166
+ # clip pixel values to [0, 1]
167
+ image = torch.clamp(image, 0.0, 1.0)
168
+
169
+ return image
170
+
171
+
172
+ class HighlightFilter(nn.Module):
173
+ def __init__(self):
174
+ super(HighlightFilter, self).__init__()
175
+
176
+ def forward(self, image, x):
177
+ """
178
+ Arguments:
179
+ image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
180
+ x (tensor [n, 1, 1, 1]): highlight argument with values between [-1, 1]
181
+ """
182
+
183
+ x = x + 1
184
+
185
+ image = kornia.enhance.invert(image, image.detach() * 0 + 1)
186
+ image = torch.clamp(torch.pow(image + 1e-9, x), 0.0, 1.0)
187
+ image = kornia.enhance.invert(image, image.detach() * 0 + 1)
188
+
189
+ # clip pixel values to [0, 1]
190
+ image = torch.clamp(image, 0.0, 1.0)
191
+
192
+ return image
193
+
194
+
195
+ class ShadowFilter(nn.Module):
196
+ def __init__(self):
197
+ super(ShadowFilter, self).__init__()
198
+
199
+ def forward(self, image, x):
200
+ """
201
+ Arguments:
202
+ image(tensor [n, 3, h, w]): RGB image with pixel values between [0, 1]
203
+ x (tensor [n, 1, 1, 1]): shadow argument with values between [-1, 1]
204
+ """
205
+
206
+ x = -x + 1
207
+ image = torch.clamp(torch.pow(image + 1e-9, x), 0.0, 1.0)
208
+
209
+ # clip pixel values to [0, 1]
210
+ image = torch.clamp(image, 0.0, 1.0)
211
+
212
+ return image
213
+
214
+
215
+ class Filter(Enum):
216
+ BRIGHTNESS = 1
217
+ CONTRAST = 2
218
+ SATURATION = 3
219
+ TEMPERATURE = 4
220
+ HIGHLIGHT = 5
221
+ SHADOW = 6
222
+
223
+
224
+ FILTER_MODULES = {
225
+ Filter.BRIGHTNESS: BrightnessFilter,
226
+ Filter.CONTRAST: ContrastFilter,
227
+ Filter.SATURATION: SaturationFilter,
228
+ Filter.TEMPERATURE: TemperatureFilter,
229
+ Filter.HIGHLIGHT: HighlightFilter,
230
+ Filter.SHADOW: ShadowFilter,
231
+ }
harmonizer/src/train/harmonizer/module/harmonizer.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms.functional as tf
5
+
6
+ from .filter import Filter
7
+ from .backbone import EfficientBackbone
8
+ from .module import CascadeArgumentRegressor, FilterPerformer
9
+
10
+
11
+ class Harmonizer(nn.Module):
12
+ def __init__(self):
13
+ super(Harmonizer, self).__init__()
14
+
15
+ self.input_size = (256, 256)
16
+ self.filter_types = [
17
+ Filter.TEMPERATURE,
18
+ Filter.BRIGHTNESS,
19
+ Filter.CONTRAST,
20
+ Filter.SATURATION,
21
+ Filter.HIGHLIGHT,
22
+ Filter.SHADOW,
23
+ ]
24
+ self.filter_argument_ranges = [
25
+ 0.3,
26
+ 0.5,
27
+ 0.5,
28
+ 0.6,
29
+ 0.4,
30
+ 0.4,
31
+ ]
32
+
33
+ self.backbone = EfficientBackbone.from_name('efficientnet-b0')
34
+ self.regressor = CascadeArgumentRegressor(1280, 160, 1, len(self.filter_types))
35
+ self.performer = FilterPerformer(self.filter_types)
36
+
37
+ for m in self.modules():
38
+ if isinstance(m, nn.Conv2d):
39
+ self._init_conv(m)
40
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
41
+ self._init_norm(m)
42
+
43
+ self.backbone = EfficientBackbone.from_pretrained('efficientnet-b0')
44
+
45
+ def forward(self, comp, mask):
46
+ arguments = self.predict_arguments(comp, mask)
47
+ pred = self.restore_image(comp, mask, arguments)
48
+ return pred
49
+
50
+ def predict_arguments(self, comp, mask):
51
+ comp = F.interpolate(comp, self.input_size, mode='bilinear', align_corners=False)
52
+ mask = F.interpolate(mask, self.input_size, mode='bilinear', align_corners=False)
53
+
54
+ fg = torch.cat((comp, mask), dim=1)
55
+ bg = torch.cat((comp, (1 - mask)), dim=1)
56
+
57
+ enc2x, enc4x, enc8x, enc16x, enc32x = self.backbone(fg, bg)
58
+ arguments = self.regressor(enc32x)
59
+ return arguments
60
+
61
+ def restore_image(self, comp, mask, arguments):
62
+ assert len(arguments) == len(self.filter_types)
63
+
64
+ arguments = [torch.clamp(arg, -1, 1).view(-1, 1, 1, 1) for arg in arguments]
65
+ return self.performer.restore(comp, mask, arguments)
66
+
67
+ def adjust_image(self, image, mask, arguments):
68
+ assert len(arguments) == len(self.filter_types)
69
+
70
+ arguments = [(torch.clamp(arg, -1, 1) * r).view(-1, 1, 1, 1) \
71
+ for arg, r in zip(arguments, self.filter_argument_ranges)]
72
+ return self.performer.adjust(image, mask, arguments)
73
+
74
+ def _init_conv(self, conv):
75
+ nn.init.kaiming_uniform_(
76
+ conv.weight, a=0, mode='fan_in', nonlinearity='relu')
77
+ if conv.bias is not None:
78
+ nn.init.constant_(conv.bias, 0)
79
+
80
+ def _init_norm(self, bn):
81
+ if bn.weight is not None:
82
+ nn.init.constant_(bn.weight, 1)
83
+ nn.init.constant_(bn.bias, 0)
harmonizer/src/train/harmonizer/module/module.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ from enum import Enum
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ from .filter import Filter, FILTER_MODULES
10
+
11
+
12
+ class CascadeArgumentRegressor(nn.Module):
13
+ def __init__(self, in_channels, base_channels, out_channels, head_num):
14
+ super(CascadeArgumentRegressor, self).__init__()
15
+ self.in_channels = in_channels
16
+ self.base_channels = base_channels
17
+ self.out_channels = out_channels
18
+ self.head_num = head_num
19
+
20
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
21
+
22
+ self.f = nn.Linear(self.in_channels, 160)
23
+ self.g = nn.Linear(self.in_channels, self.base_channels)
24
+
25
+ self.headers = nn.ModuleList()
26
+ for i in range(0, self.head_num):
27
+ self.headers.append(
28
+ nn.ModuleList([
29
+ nn.Linear(160 + self.base_channels, self.base_channels),
30
+ nn.Linear(self.base_channels, self.out_channels),
31
+ ])
32
+ )
33
+
34
+ def forward(self, x):
35
+ x = self.pool(x)
36
+ n, c, _, _ = x.shape
37
+ x = x.view(n, c)
38
+
39
+ f = self.f(x)
40
+ g = self.g(x)
41
+
42
+ pred_args = []
43
+ for i in range(0, self.head_num):
44
+ g = self.headers[i][0](torch.cat((f, g), dim=1))
45
+ pred_args.append(self.headers[i][1](g))
46
+
47
+ return pred_args
48
+
49
+
50
+ class FilterPerformer(nn.Module):
51
+ def __init__(self, filter_types):
52
+ super(FilterPerformer, self).__init__()
53
+
54
+ self.filters = [FILTER_MODULES[filter_type]() for filter_type in filter_types]
55
+
56
+ def forward(self):
57
+ pass
58
+
59
+ def restore(self, x, mask, arguments):
60
+ assert len(self.filters) == len(arguments)
61
+
62
+ outputs = []
63
+ _image = x
64
+ for filter, arg in zip(self.filters, arguments):
65
+ _image = filter(_image, arg)
66
+ outputs.append(_image * mask + x * (1 - mask))
67
+
68
+ return outputs
69
+
70
+ def adjust(self, image, mask, arguments):
71
+ assert len(self.filters) == len(arguments)
72
+
73
+ outputs = []
74
+ _image = image
75
+ for filter, arg in zip(reversed(self.filters), reversed(arguments)):
76
+ _image = filter(_image, arg)
77
+ outputs.append(_image * mask + image * (1 - mask))
78
+
79
+ return outputs
80
+
harmonizer/src/train/harmonizer/proxy.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchtask
2
+
3
+ import func, data, model, criterion, trainer
4
+
5
+
6
+ def add_parser_arguments(parser):
7
+ torchtask.proxy_template.add_parser_arguments(parser)
8
+
9
+ data.add_parser_arguments(parser)
10
+ model.add_parser_arguments(parser)
11
+ criterion.add_parser_arguments(parser)
12
+ trainer.add_parser_arguments(parser)
13
+
14
+
15
+ class HarmonizerProxy(torchtask.proxy_template.TaskProxy):
16
+
17
+ NAME = 'harmonizer'
18
+
19
+ def __init__(self, args):
20
+ super(HarmonizerProxy, self).__init__(args, func, data, model, criterion, trainer)
harmonizer/src/train/harmonizer/script/train.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import collections
4
+
5
+
6
+ sys.path.append('..')
7
+ import torchtask
8
+
9
+ import proxy
10
+
11
+ config = collections.OrderedDict(
12
+ [
13
+ ('exp_id', os.path.basename(__file__).split(".")[0]),
14
+
15
+ ('trainer', 'harmonizer_trainer'),
16
+
17
+ # arguments - Task Proxy
18
+ ('short_ep', False),
19
+
20
+ # arguments - exp
21
+ ('resume', ''),
22
+ ('validation', False),
23
+
24
+ ('out_path', 'result'),
25
+
26
+ ('visualize', False),
27
+ ('debug', False),
28
+
29
+ ('val_freq', 1),
30
+ ('log_freq', 100),
31
+ ('visual_freq', 100),
32
+ ('checkpoint_freq', 1),
33
+
34
+ # arguments - dataset / dataloader
35
+ ('im_size', 256),
36
+ ('num_workers', 4),
37
+ ('ignore_additional', False),
38
+
39
+ ('trainset', {
40
+ 'harmonizer_iharmony4': [
41
+ './dataset/iHarmony4/HAdobe5k/train',
42
+ './dataset/iHarmony4/HCOCO/train',
43
+ './dataset/iHarmony4/Hday2night/train',
44
+ './dataset/iHarmony4/HFlickr/train',
45
+ ]
46
+ }),
47
+ ('additionalset', {
48
+ 'original_iharmony4': [
49
+ './dataset/iHarmony4/HAdobe5k/train',
50
+ './dataset/iHarmony4/HCOCO/train',
51
+ './dataset/iHarmony4/Hday2night/train',
52
+ './dataset/iHarmony4/HFlickr/train',
53
+ ],
54
+ }),
55
+ ('valset', {
56
+ 'original_iharmony4': [
57
+ './dataset/iHarmony4/HAdobe5k/test',
58
+ './dataset/iHarmony4/HCOCO/test',
59
+ './dataset/iHarmony4/Hday2night/test',
60
+ './dataset/iHarmony4/HFlickr/test',
61
+ ]
62
+ }),
63
+
64
+ # arguments - task specific components
65
+ ('models', {'model': 'harmonizer'}),
66
+ ('optimizers', {'model': 'adam'}),
67
+ ('lrers', {'model': 'multisteplr'}),
68
+ ('criterions', {'model': 'harmonizer_loss'}),
69
+
70
+ # arguments - task specific optimizer / lr scheduler
71
+ ('lr', 0.0003),
72
+
73
+ ('milestones', [25, 50]),
74
+ ('gamma', 0.1),
75
+
76
+ # arguments - training details
77
+ ('epochs', 60),
78
+ ('batch_size', 16),
79
+ ('additional_batch_size', 8),
80
+ ]
81
+ )
82
+
83
+
84
+ if __name__ == '__main__':
85
+ torchtask.run_script(config, proxy, proxy.HarmonizerProxy)
harmonizer/src/train/harmonizer/trainer.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ import torch
7
+ from torch.autograd import Variable
8
+
9
+ import torchtask
10
+ from torchtask.utils import logger, cmd, tool
11
+ from torchtask.nn import func
12
+
13
+
14
+ def add_parser_arguments(parser):
15
+ torchtask.trainer_template.add_parser_arguments(parser)
16
+
17
+
18
+ def harmonizer_trainer(args, model_dict, optimizer_dict, lrer_dict, criterion_dict, task_func):
19
+ model_funcs = [model_dict['model']]
20
+ optimizer_funcs = [optimizer_dict['model']]
21
+ lrer_funcs = [lrer_dict['model']]
22
+ criterion_funcs = [criterion_dict['model']]
23
+
24
+ algorithm = HarmonizerTrainer(args)
25
+ algorithm.build(model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func)
26
+ return algorithm
27
+
28
+
29
+ class HarmonizerTrainer(torchtask.trainer_template.TaskTrainer):
30
+ def __init__(self, args):
31
+ super(HarmonizerTrainer, self).__init__(args)
32
+
33
+ self.model = None
34
+ self.optimizer = None
35
+ self.lrer = None
36
+ self.criterion = None
37
+
38
+ def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func):
39
+ self.task_func = task_func
40
+
41
+ self.model = func.create_model(model_funcs[0], 'model', args=self.args)
42
+ self.models = {'model': self.model}
43
+
44
+ self.optimizer = optimizer_funcs[0](self.model.module.param_groups)
45
+ self.optimizers = {'optimizer': self.optimizer}
46
+
47
+ self.lrer = lrer_funcs[0](self.optimizer)
48
+ self.lrers = {'lrer': self.lrer}
49
+
50
+ self.criterion = criterion_funcs[0](self.args)
51
+ self.criterions = {'criterion': self.criterion}
52
+
53
+ def _train(self, data_loader, epoch):
54
+ self.meters.reset()
55
+
56
+ lbs = self.args.labeled_batch_size
57
+
58
+ self.model.train()
59
+
60
+ timer = time.time()
61
+ for idx, (inp, gt) in enumerate(data_loader):
62
+ # pre-process input tensor and ground truth tensor
63
+ inp, gt = self._batch_prehandle(inp, gt, True)
64
+ x, mask = inp
65
+
66
+ # forword the model
67
+ self.optimizer.zero_grad()
68
+ resulter, debugger = self.model(inp)
69
+
70
+ pred_outputs = tool.dict_value(resulter, 'outputs')
71
+
72
+ # calculate loss for the fine labeled data
73
+ l_pred_outputs = func.split_tensor_tuple(pred_outputs, 0, lbs)
74
+ l_pred = (l_pred_outputs, )
75
+
76
+ l_gt = func.split_tensor_tuple(gt, 0, lbs)
77
+ l_inp = func.split_tensor_tuple(inp, 0, lbs)
78
+
79
+ l_image_losses = self.criterion(l_pred, l_gt, l_inp)
80
+
81
+ # if self.args.dynamic_loss:
82
+ sum_losses = l_image_losses[0].detach()
83
+ for i in range(1, len(l_image_losses)):
84
+ sum_losses = sum_losses + \
85
+ (l_image_losses[i].detach() - l_image_losses[i-1].detach()) * ((l_image_losses[i].detach() - l_image_losses[i-1].detach()) > 0).float()
86
+ sum_losses = sum_losses + 1e-9
87
+ sum_losses = sum_losses.detach()
88
+
89
+ scaled_l_image_losses = [torch.mean(l_image_losses[0] / sum_losses)]
90
+ self.meters.update('fine_filter_0_loss', torch.mean(l_image_losses[0] / sum_losses).item())
91
+
92
+ for i in range(1, len(l_image_losses)):
93
+ loss = (l_image_losses[i] - l_image_losses[i-1].detach()) / sum_losses
94
+ loss = loss * (loss > 0).float()
95
+ loss = torch.mean(loss)
96
+ scaled_l_image_losses.append(loss)
97
+ self.meters.update('fine_filter_{0}_loss'.format(i), loss.item())
98
+
99
+ # calculate loss for the coarse labeled data
100
+ if not self.args.ignore_additional:
101
+ u_pred_outputs = func.split_tensor_tuple(pred_outputs, lbs, self.args.batch_size)
102
+ u_pred_outputs = (u_pred_outputs[-1], )
103
+ u_pred = (u_pred_outputs, )
104
+
105
+ u_gt = func.split_tensor_tuple(gt, lbs, self.args.batch_size)
106
+ u_gt = (u_gt[-1], )
107
+
108
+ u_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size)
109
+
110
+ u_image_losses = self.criterion(u_pred, u_gt, u_inp)
111
+
112
+ u_image_loss = torch.mean(u_image_losses[0]) * 10
113
+
114
+ self.meters.update('coarse_filter_loss', u_image_loss.item())
115
+ else:
116
+ self.meters.update('coarse_filter_loss', torch.mean(torch.zeros(1)).item())
117
+
118
+ # calculate the sum of all losses
119
+ loss = 0
120
+ for l_image_loss in scaled_l_image_losses:
121
+ loss = loss + l_image_loss
122
+ loss = loss + u_image_loss
123
+
124
+ # backward and update
125
+ loss.backward()
126
+ self.optimizer.step()
127
+
128
+ # logging
129
+ self.meters.update('batch_time', time.time() - timer)
130
+ if idx % self.args.log_freq == 0:
131
+ logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}'.format(epoch+1, idx, len(data_loader), meters=self.meters))
132
+ logger.log_info('\tfine-filter-0-loss: {meters[fine_filter_0_loss]:.6f}'.format(meters=self.meters))
133
+ logger.log_info('\tfine-filter-1-loss: {meters[fine_filter_1_loss]:.6f}'.format(meters=self.meters))
134
+ logger.log_info('\tfine-filter-2-loss: {meters[fine_filter_2_loss]:.6f}'.format(meters=self.meters))
135
+ logger.log_info('\tfine-filter-3-loss: {meters[fine_filter_3_loss]:.6f}'.format(meters=self.meters))
136
+ logger.log_info('\tfine-filter-4-loss: {meters[fine_filter_4_loss]:.6f}'.format(meters=self.meters))
137
+ logger.log_info('\tfine-filter-5-loss: {meters[fine_filter_5_loss]:.6f}'.format(meters=self.meters))
138
+ logger.log_info('\tcoarse-filter-loss: {meters[coarse_filter_loss]:.6f}'.format(meters=self.meters))
139
+
140
+ # visualization
141
+ if self.args.visualize and idx % self.args.visual_freq == 0:
142
+ self._visualization(
143
+ epoch, idx, True,
144
+ func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
145
+ func.split_tensor_tuple(pred_outputs, 0, 1, reduce_dim=True),
146
+ func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))
147
+
148
+ # update iteration-based lrers
149
+ if not self.args.is_epoch_lrer:
150
+ self.lrer.step()
151
+
152
+ timer = time.time()
153
+
154
+ # update epoch-based lrers
155
+ if self.args.is_epoch_lrer:
156
+ self.lrer.step()
157
+
158
+ def _validate(self, data_loader, epoch):
159
+ self.meters.reset()
160
+
161
+ self.model.eval()
162
+
163
+ timer = time.time()
164
+ for idx, (inp, gt) in enumerate(data_loader):
165
+ inp, gt = self._batch_prehandle(inp, gt, False)
166
+ x, mask = inp
167
+
168
+ resulter, debugger = self.model(inp)
169
+
170
+ pred_outputs = tool.dict_value(resulter, 'outputs')
171
+
172
+ pred = (pred_outputs[-1], )
173
+ gt = (gt[-1], )
174
+
175
+ # calculate loss for the fine labeled data
176
+ losses = self.criterion.forward(pred, gt, inp)
177
+ loss = 0
178
+ for _loss in losses:
179
+ loss = loss + _loss
180
+ loss = loss / len(losses)
181
+
182
+ self.meters.update('loss', loss.item())
183
+
184
+ self.task_func.metrics(pred_outputs[-1].detach(), gt[-1], mask, self.meters, id_str='IH')
185
+
186
+ self.meters.update('batch_time', time.time() - timer)
187
+ if idx % self.args.log_freq == 0:
188
+ logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
189
+ 'loss: {meters[loss]:.6f}\n'
190
+ .format(epoch+1, idx, len(data_loader), meters=self.meters))
191
+
192
+ if self.args.visualize:
193
+ self._visualization(
194
+ epoch, idx, False,
195
+ func.split_tensor_tuple(inp, 0, 1, reduce_dim=True),
196
+ func.split_tensor_tuple((pred_outputs[-1], ), 0, 1, reduce_dim=True),
197
+ func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))
198
+
199
+ timer = time.time()
200
+
201
+ metrics_info = {'IH': ''}
202
+ for key in sorted(list(self.meters.keys())):
203
+ if self.task_func.METRIC_STR in key:
204
+ for id_str in metrics_info.keys():
205
+ if key.startswith(id_str):
206
+ metrics_info[id_str] += '{0}: {1:.6}\t'.format(key, self.meters[key])
207
+
208
+ logger.log_info('Validation metrics:\n task-metrics\t=>\t{0}\n'.format(metrics_info['IH'].replace('_', '-')))
209
+
210
+ def _visualization(self, epoch, idx, is_train, inp, pred, gt):
211
+ visualize_path = self.args.visual_train_path if is_train else self.args.visual_val_path
212
+ out_path = os.path.join(visualize_path, '{0}_{1}'.format(epoch, idx))
213
+
214
+ x, mask = inp
215
+
216
+ x = (np.transpose(x.cpu().numpy(), (1, 2, 0)))
217
+ Image.fromarray((x * 255).astype('uint8')).save(out_path + '_1_0_x.jpg')
218
+
219
+ mask = mask[0].data.cpu().numpy()
220
+ Image.fromarray((mask * 255).astype('uint8'), mode='L').save(out_path + '_2_0_mask.jpg')
221
+
222
+ for idx, (pred_, gt_) in enumerate(zip(pred, gt)):
223
+ pred_ = (np.transpose(pred_.detach().cpu().numpy(), (1, 2, 0)))
224
+ Image.fromarray((pred_ * 255).astype('uint8')).save(out_path + '_1_{0}_pred_filter.jpg'.format(idx+1))
225
+
226
+ if torch.mean(gt_) != -999:
227
+ gt_ = (np.transpose(gt_.cpu().numpy(), (1, 2, 0)))
228
+ Image.fromarray((gt_ * 255).astype('uint8')).save(out_path + '_2_{0}_gt_filter.jpg'.format(idx+1))
229
+
230
+ def _save_checkpoint(self, epoch):
231
+ state = {
232
+ 'epoch': epoch,
233
+ 'model': self.model.state_dict(),
234
+ 'optimizer': self.optimizer.state_dict(),
235
+ 'lrer': self.lrer.state_dict(),
236
+ }
237
+ checkpoint = os.path.join(self.args.checkpoint_path, 'checkpoint_{0}.ckpt'.format(epoch))
238
+
239
+ torch.save(state, checkpoint)
240
+
241
+ def _load_checkpoint(self):
242
+ checkpoint = torch.load(self.args.resume)
243
+ self.model.load_state_dict(checkpoint['model'])
244
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
245
+ self.lrer.load_state_dict(checkpoint['lrer'])
246
+ return checkpoint['epoch']
247
+
248
+ def _batch_prehandle(self, inp, gt, is_train):
249
+ lbs = self.args.labeled_batch_size
250
+ ubs = self.args.additional_batch_size
251
+
252
+ # convert all input and ground truth to Variables
253
+ inp_var = []
254
+ for i in inp:
255
+ inp_var.append(Variable(i).cuda())
256
+ inp = tuple(inp_var)
257
+
258
+ gt_var = []
259
+ for g in gt:
260
+ gt_var.append(Variable(g).cuda())
261
+ gt = tuple(gt_var)
262
+
263
+ filter_num = len(self.model.module.model.filter_types)
264
+
265
+ if is_train:
266
+ # ----------------------------------------------------------------
267
+ # for fine labeled data, we generate the adjusted input
268
+ # ----------------------------------------------------------------
269
+ l_inp = func.split_tensor_tuple(inp, 0, lbs)
270
+ l_gt = func.split_tensor_tuple(gt, 0, lbs)
271
+
272
+ _, l_mask = l_inp
273
+ l_gt_image, = l_gt
274
+
275
+ n = l_gt_image.shape[0]
276
+ l_rand_arguments = [self._rand_adjustment_values(n) for _ in range(0, filter_num)]
277
+
278
+ l_x = self.model.module.adjust(l_gt_image, l_mask, l_rand_arguments)
279
+
280
+ l_inp = (l_x[-1], l_mask)
281
+ l_gt = []
282
+ for _ in reversed(l_x[:-1]):
283
+ l_gt.append(_)
284
+ l_gt.append(l_gt_image)
285
+
286
+ if not self.args.ignore_additional:
287
+ # ----------------------------------------------------------------
288
+ # for coarse labeled data, we use the existising adjusted input
289
+ # ----------------------------------------------------------------
290
+ u_inp = func.split_tensor_tuple(inp, lbs, self.args.batch_size)
291
+ u_gt = func.split_tensor_tuple(gt, lbs, self.args.batch_size)
292
+
293
+ u_gt_image, = u_gt
294
+ none_value = torch.ones(ubs).view(ubs, 1).cuda() * -999
295
+ none_im = u_gt_image.cuda() * 0 - 999
296
+
297
+ u_gt = [none_im for _ in range(0, filter_num)]
298
+ u_gt[-1] = u_gt_image
299
+
300
+ inp = func.combine_tensor_tuple(l_inp, u_inp, 0)
301
+ gt = func.combine_tensor_tuple(l_gt, u_gt, 0)
302
+
303
+ else:
304
+ inp = l_inp
305
+ gt = l_gt
306
+
307
+ else:
308
+ gt_image, = gt
309
+
310
+ none_value = torch.ones(1).view(1, 1).cuda() * -999
311
+ none_im = gt_image.cuda() * 0 - 999
312
+
313
+ gt = [none_im for _ in range(0, filter_num)]
314
+ gt[-1] = gt_image
315
+
316
+ return inp, gt
317
+
318
+ def _rand_adjustment_values(self, n):
319
+ x = torch.FloatTensor(np.random.uniform(-1, 1, n))
320
+ x = x.view(n, 1).cuda()
321
+ return x
322
+
harmonizer/src/train/torchtask/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from .utils import *
4
+
5
+ from .nn import *
6
+
7
+ from .template import *
8
+
9
+ from .runner import run_script
harmonizer/src/train/torchtask/nn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .lrer import VALID_LRER
2
+ from .optimizer import VALID_OPTIMIZER
3
+ from .module import SynchronizedBatchNorm2d
harmonizer/src/train/torchtask/nn/data.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import itertools
3
+ import numpy as np
4
+
5
+ from torch.utils.data import Dataset
6
+ from torch.utils.data.sampler import Sampler
7
+
8
+
9
+ """ This file implements dataset wrappers and batch samplers for TorchTask.
10
+ """
11
+
12
+
13
+ class _TorchTaskDatasetWrapper(Dataset):
14
+ """ This is the superclass of TorchTask dataset wrapper.
15
+ """
16
+
17
+ def __init__(self):
18
+ super(_TorchTaskDatasetWrapper, self).__init__()
19
+
20
+ self.labeled_idxs = [] # index of the labeled data
21
+ self.additional_idxs = [] # index of the additional data
22
+
23
+
24
+ class SplitUnlabeledWrapper(_TorchTaskDatasetWrapper):
25
+ """ Split the fully labeled dataset into a labeled subset and an
26
+ additional dataset based on a given sublabeled prefix list.
27
+
28
+ For a fully labeled dataset, a common operation is to remove the labels
29
+ of some samples and treat them as the additional samples.
30
+
31
+ This dataset wrapper implements the dataset-split operation by using
32
+ the given sublabeled prefix list. Samples whose prefix in the list
33
+ are treated as the labeled samples, while others samples are treated as
34
+ the additional samples.
35
+ """
36
+
37
+ def __init__(self, dataset, sublabeled_prefix, ignore_additional=False):
38
+ super(SplitUnlabeledWrapper, self).__init__()
39
+
40
+ self.dataset = dataset
41
+ self.sublabeled_prefix = sublabeled_prefix
42
+ self.ignore_additional = ignore_additional
43
+
44
+ self._split_labeled()
45
+
46
+ def __len__(self):
47
+ return self.dataset.__len__()
48
+
49
+ def __getitem__(self, idx):
50
+ return self.dataset.__getitem__(idx)
51
+
52
+ def _split_labeled(self):
53
+ labeled_list, additional_list = [], []
54
+ for img in self.dataset.sample_list:
55
+ is_labeled = False
56
+ for pdx, prefix in enumerate(self.sublabeled_prefix):
57
+ if img.startswith(prefix):
58
+ labeled_list.append(img)
59
+ is_labeled = True
60
+ break
61
+
62
+ if not is_labeled:
63
+ additional_list.append(img)
64
+
65
+ labeled_size, additional_size = len(labeled_list), len(additional_list)
66
+ assert labeled_size + additional_size == len(self.dataset.sample_list)
67
+
68
+ if self.ignore_additional:
69
+ self.dataset.sample_list = labeled_list
70
+ self.dataset.idxs = [_ for _ in range(0, len(self.dataset.sample_list))]
71
+ self.labeled_idxs = self.dataset.idxs
72
+ self.additional_idxs = []
73
+ else:
74
+ self.dataset.sample_list = labeled_list + additional_list
75
+ self.dataset.idxs = [_ for _ in range(0, len(self.dataset.sample_list))]
76
+ self.labeled_idxs = [_ for _ in range(0, labeled_size)]
77
+ self.additional_idxs = [_ + labeled_size for _ in range(0, additional_size)]
78
+
79
+
80
+ class JointDatasetsWrapper(_TorchTaskDatasetWrapper):
81
+ """ Combine several datasets (can be labeled or additional) into one dataset.
82
+
83
+ This dataset wrapper will combine multiple given dataset into one big dataset.
84
+ The new dataset consists of a labeled subset and an additional subset.
85
+ """
86
+
87
+ def __init__(self, labeled_datasets, additional_datasets, ignore_additional=False):
88
+ super(JointDatasetsWrapper, self).__init__()
89
+
90
+ self.labeled_datasets = labeled_datasets
91
+ self.additional_datasets = additional_datasets
92
+ self.ignore_additional = ignore_additional
93
+
94
+ self.labeled_datasets_size = [len(d) for d in self.labeled_datasets]
95
+ self.additional_datasets_size = [len(d) for d in self.additional_datasets]
96
+
97
+ self.labeled_size = np.sum(np.asarray(self.labeled_datasets_size))
98
+ self.labeled_idxs = [_ for _ in range(0, self.labeled_size)]
99
+
100
+ self.additional_size = 0
101
+ if not self.ignore_additional:
102
+ self.additional_size = np.sum(np.asarray(self.additional_datasets_size))
103
+ self.additional_idxs = [self.labeled_size + _ for _ in range(0, self.additional_size)]
104
+
105
+ def __len__(self):
106
+ return int(self.labeled_size + self.additional_size)
107
+
108
+ def __getitem__(self, idx):
109
+ assert 0 <= idx < self.__len__()
110
+
111
+ if idx >= self.labeled_size:
112
+ idx -= self.labeled_size
113
+ datasets = self.additional_datasets
114
+ datasets_size = self.additional_datasets_size
115
+ else:
116
+ datasets = self.labeled_datasets
117
+ datasets_size = self.labeled_datasets_size
118
+
119
+ accumulated_idxs = 0
120
+ for ddx, dsize in enumerate(datasets_size):
121
+ accumulated_idxs += dsize
122
+ if idx < accumulated_idxs:
123
+ return datasets[ddx].__getitem__(idx - (accumulated_idxs - dsize))
124
+
125
+
126
+ class TwoStreamBatchSampler(Sampler):
127
+ """ This two stream batch sampler is used to read data from '_TorchTaskDatasetWrapper'.
128
+
129
+ It iterates two sets of indices simultaneously to read mini-batch for TorchTask.
130
+ There are two sets of indices:
131
+ labeled_idxs, additional_idxs
132
+ An 'epoch' is defined by going through the longer indices once.
133
+ In each 'epoch', the shorter indices are iterated through as many times as needed.
134
+ """
135
+
136
+ def __init__(self, labeled_idxs, additional_idxs, labeled_batch_size, additional_batch_size, short_ep=False):
137
+ self.labeled_idxs = labeled_idxs
138
+ self.additional_idxs = additional_idxs
139
+ self.labeled_batch_size = labeled_batch_size
140
+ self.additional_batch_size = additional_batch_size
141
+
142
+ assert len(self.labeled_idxs) >= self.labeled_batch_size > 0
143
+ assert len(self.additional_idxs) >= self.additional_batch_size > 0
144
+
145
+ self.additional_batchs = len(self.additional_idxs) // self.additional_batch_size
146
+ self.labeled_batchs = len(self.labeled_idxs) // self.labeled_batch_size
147
+
148
+ self.short_ep = short_ep
149
+
150
+ def __iter__(self):
151
+ if not self.short_ep:
152
+ if self.additional_batchs >= self.labeled_batchs:
153
+ additional_iter = self.iterate_once(self.additional_idxs)
154
+ labeled_iter = self.iterate_eternally(self.labeled_idxs)
155
+ else:
156
+ additional_iter = self.iterate_eternally(self.additional_idxs)
157
+ labeled_iter = self.iterate_once(self.labeled_idxs)
158
+ else:
159
+ if self.additional_batchs >= self.labeled_batchs:
160
+ additional_iter = self.iterate_eternally(self.additional_idxs)
161
+ labeled_iter = self.iterate_once(self.labeled_idxs)
162
+ else:
163
+ additional_iter = self.iterate_once(self.additional_idxs)
164
+ labeled_iter = self.iterate_eternally(self.labeled_idxs)
165
+
166
+ return (labeled_batch + additional_batch
167
+ for (labeled_batch, additional_batch) in zip(
168
+ self.grouper(labeled_iter, self.labeled_batch_size),
169
+ self.grouper(additional_iter, self.additional_batch_size)))
170
+
171
+ def __len__(self):
172
+ if self.short_ep:
173
+ return min(self.additional_batchs, self.labeled_batchs)
174
+ else:
175
+ return max(self.additional_batchs, self.labeled_batchs)
176
+
177
+ def iterate_once(self, iterable):
178
+ return np.random.permutation(iterable)
179
+
180
+ def iterate_eternally(self, indices):
181
+ def infinite_shuffles():
182
+ while True:
183
+ yield np.random.permutation(indices)
184
+
185
+ return itertools.chain.from_iterable(infinite_shuffles())
186
+
187
+ def grouper(self, iterable, n):
188
+ # e.g., grouper('ABCDEFG', 3) --> ABC DEF"
189
+ args = [iter(iterable)] * n
190
+ return zip(*args)
harmonizer/src/train/torchtask/nn/func.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+
5
+ from torchtask.utils import logger
6
+
7
+
8
+ """ This file provides tool functions for deep learning.
9
+ """
10
+
11
+
12
+ def sigmoid_rampup(current, rampup_length):
13
+ """ Exponential rampup from https://arxiv.org/abs/1610.02242 .
14
+ """
15
+ if rampup_length == 0:
16
+ return 1.0
17
+ else:
18
+ current = np.clip(current, 0.0, rampup_length)
19
+ phase = 1.0 - current / rampup_length
20
+ return float(np.exp(-5.0 * phase * phase))
21
+
22
+
23
+
24
+ def split_tensor_tuple(ttuple, start, end, reduce_dim=False):
25
+ """ Slice each tensor in the input tuple by channel-dim.
26
+
27
+ Arguments:
28
+ ttuple (tuple): tuple of a torch.Tensor
29
+ start (int): start index of slicing
30
+ end (int): end index of slicing
31
+ reduce_dim (bool): whether reduce the channel-dim when end - start == 1
32
+
33
+ Returns:
34
+ tuple: a sliced tensor tuple
35
+ """
36
+
37
+ result = []
38
+
39
+ if reduce_dim:
40
+ assert end - start == 1
41
+
42
+ for t in ttuple:
43
+ if end - start == 1 and reduce_dim:
44
+ result.append(t[start, ...])
45
+ else:
46
+ result.append(t[start:end, ...])
47
+
48
+ return tuple(result)
49
+
50
+
51
+ def combine_tensor_tuple(ttuple1, ttuple2, dim):
52
+ result = []
53
+
54
+ assert len(ttuple1) == len(ttuple2)
55
+
56
+ for t1, t2 in zip(ttuple1, ttuple2):
57
+ result.append(torch.cat((t1, t2), dim=dim))
58
+
59
+ return tuple(result)
60
+
61
+
62
+ def create_model(mclass, mname, **kwargs):
63
+ """ Create a nn.Module and setup it on multiple GPUs.
64
+ """
65
+ model = mclass(**kwargs)
66
+ model = torch.nn.DataParallel(model)
67
+ model = model.cuda()
68
+
69
+ logger.log_info(' ' + '=' * 76 + '\n {0} parameters \n{1}'.format(mname, model_str(model)))
70
+ return model
71
+
72
+
73
+ def model_str(module):
74
+ """ Output model structure and parameters number as strings.
75
+ """
76
+ row_format = ' {name:<40} {shape:>20} = {total_size:>12,d}'
77
+ lines = [' ' + '-' * 76,]
78
+
79
+ params = list(module.named_parameters())
80
+ for name, param in params:
81
+ lines.append(row_format.format(name=name,
82
+ shape=' * '.join(str(p) for p in param.size()), total_size=param.numel()))
83
+
84
+ lines.append(' ' + '-' * 76)
85
+ lines.append(row_format.format(name='all parameters', shape='sum of above',
86
+ total_size=sum(int(param.numel()) for name, param in params)))
87
+ lines.append(' ' + '=' * 76)
88
+ lines.append('')
89
+
90
+ return '\n'.join(lines)
91
+
92
+
93
+ def pytorch_support(required_version='1.0.0', info_str=''):
94
+ if torch.__version__ < required_version:
95
+ logger.log_err('{0} required PyTorch >= {1}\n'
96
+ 'However, current PyTorch == {2}\n'
97
+ .format(info_str, required_version, torch.__version__))
98
+ else:
99
+ return True
harmonizer/src/train/torchtask/nn/lrer.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.optim as optim
5
+
6
+ from torchtask.utils import cmd, logger
7
+ from torchtask.nn.func import pytorch_support
8
+
9
+
10
+ """ This file wraps the learning rate schedulers used in the script.
11
+ """
12
+
13
+
14
+ EPOCH_LRERS = ['steplr', 'multisteplr', 'exponentiallr', 'cosineannealinglr']
15
+ ITER_LRERS = ['polynomiallr']
16
+ VALID_LRER = EPOCH_LRERS + ITER_LRERS
17
+
18
+
19
+ def add_parser_arguments(parser):
20
+ """ Add the arguments related to the learning rate (LR) schedulers.
21
+
22
+ This 'add_parser_arguments' function will be called every time.
23
+ Please do not use the argument's name that are already defined in is function.
24
+ The default value '-1' means that the default value corresponding to
25
+ different LR schedulers will be used.
26
+ """
27
+
28
+ parser.add_argument('--last-epoch', type=int, default=-1, metavar='',
29
+ help='lr scheduler - the index of last epoch required by [all]')
30
+
31
+ parser.add_argument('--step-size', type=int, default=-1, metavar='',
32
+ help='lr scheduler - period (epoch) of learning rate decay required by [steplr]')
33
+ parser.add_argument('--milestones', type=cmd.str2intlist, default=[], metavar='',
34
+ help='lr scheduler - increased list of epoch indices required by [multisteplr]')
35
+ parser.add_argument('--gamma', type=float, default=-1, metavar='',
36
+ help='lr scheduler - multiplicative factor of learning rate decay required by [steplr, multisteplr, exponentiallr]')
37
+
38
+ parser.add_argument('--T-max', type=int, default=-1, metavar='',
39
+ help='lr scheduler - maximum number of epochs required by [cosineannealinglr]')
40
+ parser.add_argument('--eta-min', type=float, default=-1, metavar='',
41
+ help='lr scheduler - minimum learning rate required by [cosineannealinglr]')
42
+
43
+ parser.add_argument('--power', type=float, default=-1, metavar='',
44
+ help='lr scheduler - power factor of learning rate decay required by [polynomiallr]')
45
+
46
+
47
+ # ---------------------------------------------------------------------
48
+ # Wrapper of Learning Rate Scheduler
49
+ # ---------------------------------------------------------------------
50
+
51
+ def steplr(args):
52
+ """ Wrapper of torch.optim.lr_scheduler.StepLR (PyTorch >= 1.0.0).
53
+
54
+ Sets the learning rate of each parameter group to the initial lr decayed by gamma every
55
+ step_size epochs. When last_epoch=-1, sets initial lr as lr.
56
+ """
57
+ args.step_size = args.epochs if args.step_size == -1 else args.step_size
58
+ args.gamma = 0.1 if args.gamma == -1 else args.gamma
59
+ args.last_epoch = -1 if args.last_epoch == -1 else args.last_epoch
60
+
61
+ def steplr_wrapper(optimizer):
62
+ pytorch_support(required_version='1.0.0', info_str='LRScheduler - StepLR')
63
+ return optim.lr_scheduler.StepLR(
64
+ optimizer, step_size=args.step_size, gamma=args.gamma, last_epoch=args.last_epoch)
65
+
66
+ return steplr_wrapper
67
+
68
+
69
+ def multisteplr(args):
70
+ """ Wrapper of torch.optim.lr_scheduler.MultiStepLR (PyTorch >= 1.0.0).
71
+
72
+ Set the learning rate of each parameter group to the initial lr decayed by gamma once the
73
+ number of epoch reaches one of the milestones. When last_epoch=-1, sets initial lr as lr.
74
+ """
75
+ args.milestones = [i for i in range(1, args.epochs)] if args.milestones == [] else args.milestones
76
+ args.gamma = 0.1 if args.gamma == -1 else args.gamma
77
+ args.last_epoch = -1 if args.last_epoch == -1 else args.last_epoch
78
+
79
+ def multisteplr_wrapper(optimizer):
80
+ pytorch_support(required_version='1.0.0', info_str='LRScheduler - MultiStepLR')
81
+ return optim.lr_scheduler.MultiStepLR(
82
+ optimizer, milestones=args.milestones, gamma=args.gamma, last_epoch=args.last_epoch)
83
+
84
+ return multisteplr_wrapper
85
+
86
+
87
+ def exponentiallr(args):
88
+ """ Wrapper of torch.optim.lr_scheduler.ExponentialLR (PyTorch >= 1.0.0).
89
+
90
+ Set the learning rate of each parameter group to the initial lr decayed by gamma every epoch.
91
+ When last_epoch=-1, sets initial lr as lr.
92
+ """
93
+ args.gamma = 0.1 if args.gamma == -1 else args.gamma
94
+ args.last_epoch = -1 if args.last_epoch == -1 else args.last_epoch
95
+
96
+ def exponentiallr_wrapper(optimizer):
97
+ pytorch_support(required_version='1.0.0', info_str='LRScheduler - ExponentialLR')
98
+ return optim.lr_scheduler.ExponentialLR(
99
+ optimizer, gamma=args.gamma, last_epoch=args.last_epoch)
100
+
101
+ return exponentiallr_wrapper
102
+
103
+
104
+ def cosineannealinglr(args):
105
+ """ Wrapper of torch.optim.lr_schduler.CosineAnnealingLR (PyTorch >= 1.0.0).
106
+
107
+ Set the learning rate of each parameter group using a cosine annealing schedule.
108
+ When last_epoch=-1, sets initial lr as lr.
109
+ """
110
+ args.T_max = args.epochs if args.T_max == -1 else args.T_max
111
+ args.eta_min = 0 if args.eta_min == -1 else args.eta_min
112
+ args.last_epoch = -1 if args.last_epoch == -1 else args.last_epoch
113
+
114
+ def cosineannealinglr_wrapper(optimizer):
115
+ pytorch_support(required_version='1.0.0', info_str='LRScheduler - CosineAnnealingLR')
116
+ return optim.lr_scheduler.CosineAnnealingLR(
117
+ optimizer, T_max=args.T_max, eta_min=args.eta_min, last_epoch=args.last_epoch)
118
+
119
+ return cosineannealinglr_wrapper
120
+
121
+
122
+ def polynomiallr(args):
123
+ """ Wrapper of torchtask.nn.lrer.PolynomialLR (PyTorch >= 1.0.0).
124
+
125
+ Set the learning rate of each parmeter group to the initial lr decayed by power every
126
+ iteration. When last_epoch=-1, sets initial lr as lr.
127
+ """
128
+ args.power = 0.9 if args.power == -1 else args.power
129
+ args.last_epoch = -1 if args.last_epoch == -1 else args.last_epoch
130
+
131
+ def polynomiallr_wrapper(optimizer):
132
+ pytorch_support(required_version='1.0.0', info_str='LRScheduler - PolynomialLR')
133
+ return PolynomialLR(optimizer, epochs=args.epochs, iters_per_epoch=args.iters_per_epoch,
134
+ power=args.power, last_epoch=args.last_epoch)
135
+
136
+ return polynomiallr_wrapper
137
+
138
+
139
+ # ---------------------------------------------------------------------
140
+ # Implementation of Learning Rate Scheduler
141
+ # ---------------------------------------------------------------------
142
+
143
+ class PolynomialLR(torch.optim.lr_scheduler._LRScheduler):
144
+ """ Polynomial decay learning rate scheduler.
145
+ """
146
+
147
+ def __init__(self, optimizer, epochs, iters_per_epoch, power=0.9, last_epoch=-1):
148
+ self.epochs = epochs
149
+ self.iters_per_epoch = iters_per_epoch
150
+ self.max_iters = self.epochs * self.iters_per_epoch
151
+ self.cur_iter = 0
152
+ self.power = power
153
+ self.is_warn = False
154
+ super(PolynomialLR, self).__init__(optimizer, last_epoch)
155
+
156
+ def get_lr(self):
157
+ return [base_lr * ((1 - float(self.cur_iter) / self.max_iters) ** self.power)
158
+ for base_lr in self.base_lrs]
159
+
160
+ def step(self, epoch=None):
161
+ if epoch is not None and epoch != 0:
162
+ # update lr after each epoch if epoch is given
163
+ # after each epoch, set epoch += 1 and call this function
164
+ if not self.is_warn:
165
+ logger.log_warn('PolynomialLR is designed for updating learning rate after each iteration.\n'
166
+ 'However, it will be updated after each epoch now, please be careful.\n')
167
+ self.is_warn = True
168
+
169
+ self.last_epoch = epoch
170
+ assert self.last_epoch <= self.epochs
171
+ self.cur_iter = self.last_epoch * self.iters_per_epoch
172
+
173
+ elif epoch is None:
174
+ # update lr after each iteration if epoch is None
175
+ self.cur_iter += 1
176
+ self.last_epoch = math.floor(self.cur_iter / self.iters_per_epoch)
177
+
178
+ for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
179
+ param_group['lr'] = lr
harmonizer/src/train/torchtask/nn/module/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .third_party import SynchronizedBatchNorm2d, patch_replication_callback
2
+ from .gaussian_blur import GaussianBlurLayer
3
+ from .gaussian_noise import GaussianNoiseLayer
harmonizer/src/train/torchtask/nn/module/gaussian_blur.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import scipy.ndimage
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from torchtask.utils import logger
9
+
10
+
11
+ class GaussianBlurLayer(nn.Module):
12
+ """ Add Gaussian Blur to a 4D tensor
13
+
14
+ This layer takes a 4D tensor of {N, C, H, W} as input.
15
+ The Gaussian blur will be performed in given channel number (C) splitly.
16
+ """
17
+
18
+ def __init__(self, channels, kernel_size):
19
+ """
20
+ Arguments:
21
+ channels (int): Channel for input tensor
22
+ kernel_size (int): Size of the kernel used in blurring
23
+ """
24
+
25
+ super(GaussianBlurLayer, self).__init__()
26
+ self.channels = channels
27
+ self.kernel_size = kernel_size
28
+ assert self.kernel_size % 2 != 0
29
+
30
+ self.op = nn.Sequential(
31
+ nn.ReflectionPad2d(math.floor(self.kernel_size / 2)),
32
+ nn.Conv2d(channels, channels, self.kernel_size,
33
+ stride=1, padding=0, bias=None, groups=channels)
34
+ )
35
+
36
+ self._init_kernel()
37
+
38
+ def forward(self, x):
39
+ """
40
+ Arguments:
41
+ x (torch.Tensor): input 4D tensor
42
+
43
+ Returns:
44
+ torch.Tensor: Blurred version of the input
45
+ """
46
+
47
+ if not len(list(x.shape)) == 4:
48
+ logger.log_err('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
49
+ elif not x.shape[1] == self.channels:
50
+ logger.log_err('In \'GaussianBlurLayer\', the required channel ({0}) is'
51
+ 'not the same as input ({1})\n'.format(self.channels, x.shape[1]))
52
+
53
+ return self.op(x)
54
+
55
+ def _init_kernel(self):
56
+ sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
57
+
58
+ n = np.zeros((self.kernel_size, self.kernel_size))
59
+ i = math.floor(self.kernel_size / 2)
60
+ n[i, i] = 1
61
+ kernel = scipy.ndimage.gaussian_filter(n, sigma)
62
+
63
+ for name, param in self.named_parameters():
64
+ param.data.copy_(torch.from_numpy(kernel))
harmonizer/src/train/torchtask/nn/module/gaussian_noise.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class GaussianNoiseLayer(nn.Module):
8
+ """ Add Gaussian noise to a 4D tensor
9
+ """
10
+
11
+ def __init__(self, std):
12
+ super(GaussianNoiseLayer, self).__init__()
13
+ self.std = std
14
+ self.noise = torch.zeros(0)
15
+ self.enable = False if self.std is None else True
16
+
17
+ def forward(self, inp):
18
+ if not self.enable:
19
+ return inp
20
+
21
+ if self.noise.shape != inp.shape:
22
+ self.noise = torch.zeros(inp.shape).cuda()
23
+ self.noise.data.normal_(0, std=random.uniform(0, self.std))
24
+
25
+ imax = inp.max(dim=3, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
26
+ imin = inp.min(dim=3, keepdim=True)[0].min(dim=2, keepdim=True)[0].min(dim=1, keepdim=True)[0]
27
+
28
+ # normalize to [0, 1]
29
+ inp.sub_(imin).div_(imax - imin + 1e-9)
30
+ # add noise
31
+ inp.add_(self.noise)
32
+ # clip to [0, 1]
33
+ upper_bound = (inp > 1.0).float()
34
+ lower_bound = (inp < 0.0).float()
35
+ inp.mul_(1 - upper_bound).add_(upper_bound)
36
+ inp.mul_(1 - lower_bound)
37
+ # de-normalize
38
+ inp.mul_(imax - imin + 1e-9).add_(imin)
39
+
40
+ return inp
harmonizer/src/train/torchtask/nn/module/third_party/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sync_batchnorm import SynchronizedBatchNorm2d, patch_replication_callback
harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12
+ from .replicate import DataParallelWithCallback, patch_replication_callback
harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/batchnorm.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from torch.nn.modules.batchnorm import _BatchNorm
17
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18
+
19
+ from .comm import SyncMaster
20
+
21
+ __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22
+
23
+
24
+ def _sum_ft(tensor):
25
+ """sum over the first and last dimention"""
26
+ return tensor.sum(dim=0).sum(dim=-1)
27
+
28
+
29
+ def _unsqueeze_ft(tensor):
30
+ """add new dementions at the front and the tail"""
31
+ return tensor.unsqueeze(0).unsqueeze(-1)
32
+
33
+
34
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36
+
37
+
38
+ class _SynchronizedBatchNorm(_BatchNorm):
39
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41
+
42
+ self._sync_master = SyncMaster(self._data_parallel_master)
43
+
44
+ self._is_parallel = False
45
+ self._parallel_id = None
46
+ self._slave_pipe = None
47
+
48
+ def forward(self, input):
49
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50
+ if not (self._is_parallel and self.training):
51
+ return F.batch_norm(
52
+ input, self.running_mean, self.running_var, self.weight, self.bias,
53
+ self.training, self.momentum, self.eps)
54
+
55
+ # Resize the input to (B, C, -1).
56
+ input_shape = input.size()
57
+ input = input.view(input.size(0), self.num_features, -1)
58
+
59
+ # Compute the sum and square-sum.
60
+ sum_size = input.size(0) * input.size(2)
61
+ input_sum = _sum_ft(input)
62
+ input_ssum = _sum_ft(input ** 2)
63
+
64
+ # Reduce-and-broadcast the statistics.
65
+ if self._parallel_id == 0:
66
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
67
+ else:
68
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
69
+
70
+ # Compute the output.
71
+ if self.affine:
72
+ # MJY:: Fuse the multiplication for speed.
73
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
74
+ else:
75
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
76
+
77
+ # Reshape it.
78
+ return output.view(input_shape)
79
+
80
+ def __data_parallel_replicate__(self, ctx, copy_id):
81
+ self._is_parallel = True
82
+ self._parallel_id = copy_id
83
+
84
+ # parallel_id == 0 means master device.
85
+ if self._parallel_id == 0:
86
+ ctx.sync_master = self._sync_master
87
+ else:
88
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
89
+
90
+ def _data_parallel_master(self, intermediates):
91
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92
+
93
+ # Always using same "device order" makes the ReduceAdd operation faster.
94
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
95
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
96
+
97
+ to_reduce = [i[1][:2] for i in intermediates]
98
+ to_reduce = [j for i in to_reduce for j in i] # flatten
99
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
100
+
101
+ sum_size = sum([i[1].sum_size for i in intermediates])
102
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
103
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
104
+
105
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
106
+
107
+ outputs = []
108
+ for i, rec in enumerate(intermediates):
109
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
110
+
111
+ return outputs
112
+
113
+ def _compute_mean_std(self, sum_, ssum, size):
114
+ """Compute the mean and standard-deviation with sum and square-sum. This method
115
+ also maintains the moving average on the master device."""
116
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117
+ mean = sum_ / size
118
+ sumvar = ssum - sum_ * mean
119
+ unbias_var = sumvar / (size - 1)
120
+ bias_var = sumvar / size
121
+
122
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
123
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
124
+
125
+ return mean, bias_var.clamp(self.eps) ** -0.5
126
+
127
+
128
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
129
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130
+ mini-batch.
131
+ .. math::
132
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
133
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
134
+ standard-deviation are reduced across all devices during training.
135
+ For example, when one uses `nn.DataParallel` to wrap the network during
136
+ training, PyTorch's implementation normalize the tensor on each device using
137
+ the statistics only on that device, which accelerated the computation and
138
+ is also easy to implement, but the statistics might be inaccurate.
139
+ Instead, in this synchronized version, the statistics will be computed
140
+ over all training samples distributed on multiple devices.
141
+
142
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
143
+ as the built-in PyTorch implementation.
144
+ The mean and standard-deviation are calculated per-dimension over
145
+ the mini-batches and gamma and beta are learnable parameter vectors
146
+ of size C (where C is the input size).
147
+ During training, this layer keeps a running estimate of its computed mean
148
+ and variance. The running sum is kept with a default momentum of 0.1.
149
+ During evaluation, this running mean/variance is used for normalization.
150
+ Because the BatchNorm is done over the `C` dimension, computing statistics
151
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
152
+ Args:
153
+ num_features: num_features from an expected input of size
154
+ `batch_size x num_features [x width]`
155
+ eps: a value added to the denominator for numerical stability.
156
+ Default: 1e-5
157
+ momentum: the value used for the running_mean and running_var
158
+ computation. Default: 0.1
159
+ affine: a boolean value that when set to ``True``, gives the layer learnable
160
+ affine parameters. Default: ``True``
161
+ Shape:
162
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
163
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
164
+ Examples:
165
+ >>> # With Learnable Parameters
166
+ >>> m = SynchronizedBatchNorm1d(100)
167
+ >>> # Without Learnable Parameters
168
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
169
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
170
+ >>> output = m(input)
171
+ """
172
+
173
+ def _check_input_dim(self, input):
174
+ if input.dim() != 2 and input.dim() != 3:
175
+ raise ValueError('expected 2D or 3D input (got {}D input)'
176
+ .format(input.dim()))
177
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
178
+
179
+
180
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
181
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
182
+ of 3d inputs
183
+ .. math::
184
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
185
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
186
+ standard-deviation are reduced across all devices during training.
187
+ For example, when one uses `nn.DataParallel` to wrap the network during
188
+ training, PyTorch's implementation normalize the tensor on each device using
189
+ the statistics only on that device, which accelerated the computation and
190
+ is also easy to implement, but the statistics might be inaccurate.
191
+ Instead, in this synchronized version, the statistics will be computed
192
+ over all training samples distributed on multiple devices.
193
+
194
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
195
+ as the built-in PyTorch implementation.
196
+ The mean and standard-deviation are calculated per-dimension over
197
+ the mini-batches and gamma and beta are learnable parameter vectors
198
+ of size C (where C is the input size).
199
+ During training, this layer keeps a running estimate of its computed mean
200
+ and variance. The running sum is kept with a default momentum of 0.1.
201
+ During evaluation, this running mean/variance is used for normalization.
202
+ Because the BatchNorm is done over the `C` dimension, computing statistics
203
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
204
+ Args:
205
+ num_features: num_features from an expected input of
206
+ size batch_size x num_features x height x width
207
+ eps: a value added to the denominator for numerical stability.
208
+ Default: 1e-5
209
+ momentum: the value used for the running_mean and running_var
210
+ computation. Default: 0.1
211
+ affine: a boolean value that when set to ``True``, gives the layer learnable
212
+ affine parameters. Default: ``True``
213
+ Shape:
214
+ - Input: :math:`(N, C, H, W)`
215
+ - Output: :math:`(N, C, H, W)` (same shape as input)
216
+ Examples:
217
+ >>> # With Learnable Parameters
218
+ >>> m = SynchronizedBatchNorm2d(100)
219
+ >>> # Without Learnable Parameters
220
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
221
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
222
+ >>> output = m(input)
223
+ """
224
+
225
+ def _check_input_dim(self, input):
226
+ if input.dim() != 4:
227
+ raise ValueError('expected 4D input (got {}D input)'
228
+ .format(input.dim()))
229
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
230
+
231
+
232
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
233
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
234
+ of 4d inputs
235
+ .. math::
236
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
237
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
238
+ standard-deviation are reduced across all devices during training.
239
+ For example, when one uses `nn.DataParallel` to wrap the network during
240
+ training, PyTorch's implementation normalize the tensor on each device using
241
+ the statistics only on that device, which accelerated the computation and
242
+ is also easy to implement, but the statistics might be inaccurate.
243
+ Instead, in this synchronized version, the statistics will be computed
244
+ over all training samples distributed on multiple devices.
245
+
246
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
247
+ as the built-in PyTorch implementation.
248
+ The mean and standard-deviation are calculated per-dimension over
249
+ the mini-batches and gamma and beta are learnable parameter vectors
250
+ of size C (where C is the input size).
251
+ During training, this layer keeps a running estimate of its computed mean
252
+ and variance. The running sum is kept with a default momentum of 0.1.
253
+ During evaluation, this running mean/variance is used for normalization.
254
+ Because the BatchNorm is done over the `C` dimension, computing statistics
255
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
256
+ or Spatio-temporal BatchNorm
257
+ Args:
258
+ num_features: num_features from an expected input of
259
+ size batch_size x num_features x depth x height x width
260
+ eps: a value added to the denominator for numerical stability.
261
+ Default: 1e-5
262
+ momentum: the value used for the running_mean and running_var
263
+ computation. Default: 0.1
264
+ affine: a boolean value that when set to ``True``, gives the layer learnable
265
+ affine parameters. Default: ``True``
266
+ Shape:
267
+ - Input: :math:`(N, C, D, H, W)`
268
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
269
+ Examples:
270
+ >>> # With Learnable Parameters
271
+ >>> m = SynchronizedBatchNorm3d(100)
272
+ >>> # Without Learnable Parameters
273
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
274
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
275
+ >>> output = m(input)
276
+ """
277
+
278
+ def _check_input_dim(self, input):
279
+ if input.dim() != 5:
280
+ raise ValueError('expected 5D input (got {}D input)'
281
+ .format(input.dim()))
282
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/comm.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
59
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
60
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
61
+ and passed to a registered callback.
62
+ - After receiving the messages, the master device should gather the information and determine to message passed
63
+ back to each slave devices.
64
+ """
65
+
66
+ def __init__(self, master_callback):
67
+ """
68
+ Args:
69
+ master_callback: a callback to be invoked after having collected messages from slave devices.
70
+ """
71
+ self._master_callback = master_callback
72
+ self._queue = queue.Queue()
73
+ self._registry = collections.OrderedDict()
74
+ self._activated = False
75
+
76
+ def __getstate__(self):
77
+ return {'master_callback': self._master_callback}
78
+
79
+ def __setstate__(self, state):
80
+ self.__init__(state['master_callback'])
81
+
82
+ def register_slave(self, identifier):
83
+ """
84
+ Register an slave device.
85
+ Args:
86
+ identifier: an identifier, usually is the device id.
87
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
88
+ """
89
+ if self._activated:
90
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
91
+ self._activated = False
92
+ self._registry.clear()
93
+ future = FutureResult()
94
+ self._registry[identifier] = _MasterRegistry(future)
95
+ return SlavePipe(identifier, self._queue, future)
96
+
97
+ def run_master(self, master_msg):
98
+ """
99
+ Main entry for the master device in each forward pass.
100
+ The messages were first collected from each devices (including the master device), and then
101
+ an callback will be invoked to compute the message to be sent back to each devices
102
+ (including the master device).
103
+ Args:
104
+ master_msg: the message that the master want to send to itself. This will be placed as the first
105
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106
+ Returns: the message to be sent back to the master device.
107
+ """
108
+ self._activated = True
109
+
110
+ intermediates = [(0, master_msg)]
111
+ for i in range(self.nr_slaves):
112
+ intermediates.append(self._queue.get())
113
+
114
+ results = self._master_callback(intermediates)
115
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
116
+
117
+ for i, res in results:
118
+ if i == 0:
119
+ continue
120
+ self._registry[i].result.put(res)
121
+
122
+ for i in range(self.nr_slaves):
123
+ assert self._queue.get() is True
124
+
125
+ return results[0][1]
126
+
127
+ @property
128
+ def nr_slaves(self):
129
+ return len(self._registry)
harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/replicate.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
31
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
32
+ (shared among multiple copies of this module on different devices).
33
+ Through this context, different copies can share some information.
34
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
35
+ of any slave copies.
36
+ """
37
+ master_copy = modules[0]
38
+ nr_modules = len(list(master_copy.modules()))
39
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
40
+
41
+ for i, module in enumerate(modules):
42
+ for j, m in enumerate(module.modules()):
43
+ if hasattr(m, '__data_parallel_replicate__'):
44
+ m.__data_parallel_replicate__(ctxs[j], i)
45
+
46
+
47
+ class DataParallelWithCallback(DataParallel):
48
+ """
49
+ Data Parallel with a replication callback.
50
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
51
+ original `replicate` function.
52
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
53
+ Examples:
54
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
55
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
56
+ # sync_bn.__data_parallel_replicate__ will be invoked.
57
+ """
58
+
59
+ def replicate(self, module, device_ids):
60
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
61
+ execute_replication_callbacks(modules)
62
+ return modules
63
+
64
+
65
+ def patch_replication_callback(data_parallel):
66
+ """
67
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
68
+ Useful when you have customized `DataParallel` implementation.
69
+ Examples:
70
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
71
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
72
+ > patch_replication_callback(sync_bn)
73
+ # this is equivalent to
74
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
75
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
76
+ """
77
+
78
+ assert isinstance(data_parallel, DataParallel)
79
+
80
+ old_replicate = data_parallel.replicate
81
+
82
+ @functools.wraps(old_replicate)
83
+ def new_replicate(module, device_ids):
84
+ modules = old_replicate(module, device_ids)
85
+ execute_replication_callbacks(modules)
86
+ return modules
87
+
88
+ data_parallel.replicate = new_replicate
harmonizer/src/train/torchtask/nn/module/third_party/sync_batchnorm/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+
13
+ import numpy as np
14
+ from torch.autograd import Variable
15
+
16
+
17
+ def as_numpy(v):
18
+ if isinstance(v, Variable):
19
+ v = v.data
20
+ return v.cpu().numpy()
21
+
22
+
23
+ class TorchTestCase(unittest.TestCase):
24
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25
+ npa, npb = as_numpy(a), as_numpy(b)
26
+ self.assertTrue(
27
+ np.allclose(npa, npb, atol=atol),
28
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29
+ )
harmonizer/src/train/torchtask/nn/optimizer.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+
4
+ import torch
5
+ import torch.optim as optim
6
+ from torch.optim.optimizer import Optimizer
7
+
8
+ from torchtask.utils import cmd
9
+ from torchtask.nn.func import pytorch_support
10
+
11
+
12
+ """ This file wraps the optimizers used in the script.
13
+ """
14
+
15
+
16
+ VALID_OPTIMIZER = ['sgd', 'rmsprop', 'adam', 'wdadam']
17
+
18
+
19
+ def add_parser_arguments(parser):
20
+ """ Add the arguments related to the optimizer.
21
+
22
+ This 'add_parser_arguments' function will be called every time.
23
+ Please do not use the argument's name that are already defined in is function.
24
+ The default value '-1' means that the default value corresponding to
25
+ different LR schedulers will be used.
26
+ """
27
+
28
+ parser.add_argument('--lr', type=float, default=-1, metavar='',
29
+ help='optimizer - learning rate (required by [all])')
30
+
31
+ parser.add_argument('--dampening', type=float, default=-1, metavar='',
32
+ help='optimizer - dampening for momentum (required by [sgd])')
33
+ parser.add_argument('--nesterov', type=cmd.str2bool, default=False, metavar='',
34
+ help='optimizer - enables Nesterov momentum if True (required by [sgd])')
35
+ parser.add_argument('--weight-decay', type=float, default=-1, metavar='',
36
+ help='optimizer - weight decay (L2 penalty) (required by [sgd, rmsprop, adam, wdadam])')
37
+ parser.add_argument('--momentum', type=float, default=-1, metavar='',
38
+ help='optimizer - momentum factor (required by [sgd, rmsprop])')
39
+ parser.add_argument('--alpha', type=float, default=-1, metavar='',
40
+ help='smoothing constant (required by [rmsprop])')
41
+ parser.add_argument('--centered', type=cmd.str2bool, default=False, metavar='',
42
+ help='if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance ( required by [rmsprop])')
43
+ parser.add_argument('--eps', type=float, default=-1, metavar='',
44
+ help='optimizer - term added to the denominator to improve numerical stability (required by [rmsprop, adam, wdadam])')
45
+ parser.add_argument('--beta1', type=float, default=-1, metavar='',
46
+ help='optimizer - coefficients used for computing running averages of gradient and its square (required by [adam, wdadam])')
47
+ parser.add_argument('--beta2', type=float, default=-1, metavar='',
48
+ help='optimizer - coefficients used for computing running averages of gradient and its square (required by [adam, wdadam])')
49
+ parser.add_argument('--amsgrad', type=cmd.str2bool, default=False, metavar='',
50
+ help='optimizer - use the AMSGrad variant if True (required by [wdadam])')
51
+
52
+
53
+ # ---------------------------------------------------------------------
54
+ # Wrapper of Optimizer
55
+ # ---------------------------------------------------------------------
56
+
57
+ def sgd(args):
58
+ """ Wrapper of torch.optim.SGD (PyTorch >= 1.0.0).
59
+
60
+ Implements stochastic gradient descent (optionally with momentum).
61
+ """
62
+ args.lr = 0.01 if args.lr == -1 else args.lr
63
+ args.weight_decay = 0 if args.weight_decay == -1 else args.weight_decay
64
+ args.momentum = 0 if args.momentum == -1 else args.momentum
65
+ args.dampening = 0 if args.dampening == -1 else args.dampening
66
+ args.nesterov = False if args.nesterov == False else args.nesterov
67
+
68
+ def sgd_wrapper(param_groups):
69
+ pytorch_support(required_version='1.0.0', info_str='Optimizer - SGD')
70
+ return optim.SGD(
71
+ param_groups,
72
+ lr=args.lr, momentum=args.momentum, dampening=args.dampening,
73
+ weight_decay=args.weight_decay, nesterov=args.nesterov)
74
+
75
+ return sgd_wrapper
76
+
77
+
78
+ def rmsprop(args):
79
+ """ Wrapper of torch.optim.RMSprop (PyTorch >= 1.0.0).
80
+
81
+ Implements RMSprop algorithm.
82
+ Proposed by G. Hinton in his course.
83
+ The centered version first appears in Generating Sequences With Recurrent Neural Networks.
84
+ """
85
+
86
+ args.lr = 0.01 if args.lr == -1 else args.lr
87
+ args.alpha = 0.99 if args.alpha == -1 else args.alpha
88
+ args.eps = 1e-08 if args.eps == -1 else args.eps
89
+ args.weight_decay = 0 if args.weight_decay == -1 else args.weight_decay
90
+ args.momentum = 0 if args.momentum == -1 else args.momentum
91
+ args.centered = False if args.centered == False else args.centered
92
+
93
+ def rmsprop_wrapper(param_groups):
94
+ pytorch_support(required_version='1.0.0', info_str='Optimizer - RMSprop')
95
+ return optim.RMSprop(
96
+ param_groups,
97
+ lr=args.lr, alpha=args.alpha, eps=args.eps, weight_decay=args.weight_decay,
98
+ momentum=args.momentum, centered=args.centered)
99
+
100
+ return rmsprop_wrapper
101
+
102
+
103
+ def adam(args):
104
+ """ Wrapper of torch.optim.Adam (PyTorch >= 1.0.0).
105
+
106
+ Implements Adam algorithm.
107
+ It has been proposed in 'Adam: A Method for Stochastic Optimization'.
108
+ """
109
+ args.lr = 0.001 if args.lr == -1 else args.lr
110
+ args.beta1 = 0.9 if args.beta1 == -1 else args.beta1
111
+ args.beta2 = 0.999 if args.beta2 == -1 else args.beta2
112
+ args.eps = 1e-08 if args.eps == -1 else args.eps
113
+ args.weight_decay = 0.0 if args.weight_decay == -1 else args.weight_decay
114
+
115
+ def adam_wrapper(param_groups):
116
+ pytorch_support(required_version='1.0.0', info_str='Optimizer - Adam')
117
+ return optim.Adam(
118
+ param_groups,
119
+ lr=args.lr, betas=(args.beta1, args.beta2), eps=args.eps,
120
+ weight_decay=args.weight_decay)
121
+
122
+ return adam_wrapper
123
+
124
+
125
+ def wdadam(args):
126
+ """ Wrapper of torchtask.nn.optimizer.WDAdam (PyTorch >= 1.0.0).
127
+
128
+ Implements Adam algorithm with weight decay and AMSGrad.
129
+ """
130
+ args.lr = 0.001 if args.lr == -1 else args.lr
131
+ args.beta1 = 0.9 if args.beta1 == -1 else args.beta1
132
+ args.beta2 = 0.999 if args.beta2 == -1 else args.beta2
133
+ args.eps = 1e-08 if args.eps == -1 else args.eps
134
+ args.weight_decay = 0.0 if args.weight_decay == -1 else args.weight_decay
135
+ args.amsgrad = False if args.amsgrad == False else args.amsgrad
136
+
137
+ def wdadam_wrapper(param_groups):
138
+ pytorch_support(required_version='1.0.0', info_str='Optimizer - WDAdam')
139
+ return WDAdam(
140
+ param_groups,
141
+ lr=args.lr, betas=(args.beta1, args.beta2), eps=args.eps,
142
+ weight_decay=args.weight_decay, amsgrad=args.amsgrad)
143
+
144
+ return wdadam_wrapper
145
+
146
+
147
+ # ---------------------------------------------------------------------
148
+ # Implementation of Optimizer
149
+ # ---------------------------------------------------------------------
150
+
151
+ class WDAdam(Optimizer):
152
+ """ Implements Adam algorithm with weight decay and AMSGrad.
153
+
154
+ It has been proposed in `Adam: A Method for Stochastic Optimization`.
155
+
156
+ Arguments:
157
+ params (iterable): iterable of parameters to optimize or dicts defining
158
+ parameter groups
159
+ lr (float, optional): learning rate (default: 1e-3)
160
+ betas (Tuple[float, float], optional): coefficients used for computing
161
+ running averages of gradient and its square (default: (0.9, 0.999))
162
+ eps (float, optional): term added to the denominator to improve
163
+ numerical stability (default: 1e-8)
164
+ weight_decay (float, optional): weight decay using the method from
165
+ the paper `Fixing Weight Decay Regularization in Adam` (default: 0)
166
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
167
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
168
+ """
169
+
170
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):
171
+ if not 0.0 <= lr:
172
+ raise ValueError("Invalid learning rate: {0}".format(lr))
173
+ if not 0.0 <= eps:
174
+ raise ValueError("Invalid epsilon value: {0}".format(eps))
175
+ if not 0.0 <= betas[0] < 1.0:
176
+ raise ValueError("Invalid beta parameter at index 0: {0}".format(betas[0]))
177
+ if not 0.0 <= betas[1] < 1.0:
178
+ raise ValueError("Invalid beta parameter at index 1: {0}".format(betas[1]))
179
+
180
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay / lr, amsgrad=amsgrad)
181
+ super(WDAdam, self).__init__(params, defaults)
182
+
183
+ def __setstate__(self, state):
184
+ super(WDAdam, self).__setstate__(state)
185
+ for group in self.param_groups:
186
+ group.setdefault('amsgrad', False)
187
+
188
+ def step(self, closure=None):
189
+ """ Performs a single optimization step.
190
+
191
+ Arguments:
192
+ closure (callable, optional): A closure that reevaluates the model
193
+ and returns the loss.
194
+ """
195
+ loss = None
196
+ if closure is not None:
197
+ loss = closure()
198
+
199
+ for group in self.param_groups:
200
+ for p in group['params']:
201
+ if p.grad is None:
202
+ continue
203
+
204
+ grad = p.grad.data
205
+ if grad.is_sparse:
206
+ raise RuntimeError('Adam does not support sparse gradients')
207
+ amsgrad = group['amsgrad']
208
+
209
+ # State initialization
210
+ state = self.state[p]
211
+ if len(state) == 0:
212
+ state['step'] = 0
213
+ # Exponential moving average of gradient values
214
+ state['exp_avg'] = torch.zeros_like(p.data)
215
+ # Exponential moving average of squared gradient values
216
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
217
+ # Maintains max of all exp. moving avg. of sq. grad. values
218
+ if amsgrad:
219
+ state['max_exp_avg_sq'] = torch.zeros_like(p.data)
220
+
221
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
222
+ if amsgrad:
223
+ max_exp_avg_sq = state['max_exp_avg_sq']
224
+ beta1, beta2 = group['betas']
225
+ state['step'] += 1
226
+
227
+ # Decay the first and second moment running average coefficient
228
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
229
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
230
+ if amsgrad:
231
+ # Maintains the maximum of all 2nd moment running avg. till now
232
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
233
+ # Use the max. for normalizing running avg. of gradient
234
+ denom = max_exp_avg_sq.sqrt().add_(group['eps'])
235
+ else:
236
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
237
+
238
+ bias_correction1 = 1 - beta1 ** state['step']
239
+ bias_correction2 = 1 - beta2 ** state['step']
240
+ step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
241
+
242
+ if group['weight_decay'] != 0:
243
+ p.data.add_(-group['weight_decay'] * group['lr'], p.data)
244
+
245
+ p.data.addcdiv_(-step_size, exp_avg, denom)
246
+
247
+ return loss
harmonizer/src/train/torchtask/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ scipy
3
+ Pillow
4
+ pyyaml
5
+ opencv-python
harmonizer/src/train/torchtask/runner.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+
4
+
5
+ from torchtask.utils import cmd
6
+ from torchtask.nn import optimizer, lrer
7
+ from torchtask.nn.func import pytorch_support
8
+
9
+
10
+ def create_parser():
11
+ parser = argparse.ArgumentParser(description='TorchTask Script Parser')
12
+
13
+ optimizer.add_parser_arguments(parser)
14
+ lrer.add_parser_arguments(parser)
15
+
16
+ return parser
17
+
18
+
19
+ def run_script(config, proxy_file, proxy_class):
20
+ # TorchTask requires PyTorch >= 1.0.0
21
+ pytorch_support(required_version='1.0.0', info_str='TorchTask')
22
+
23
+ # help information
24
+ if len(sys.argv) > 1 and sys.argv[1] in ['help', '--help', 'h', '-h']:
25
+ config['h'] = True
26
+
27
+ # create parser and parse args from config
28
+ parser = create_parser()
29
+ proxy_file.add_parser_arguments(parser)
30
+ args = cmd.parse_args(parser, config)
31
+
32
+ task_proxy = proxy_class(args)
33
+ task_proxy.run()
harmonizer/src/train/torchtask/template/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import func as func_template
2
+ from . import data as data_template
3
+ from . import model as model_template
4
+ from . import criterion as criterion_template
5
+ from . import proxy as proxy_template
6
+ from . import trainer as trainer_template
7
+
8
+
9
+ __all__ = [
10
+ 'func_template',
11
+ 'data_template',
12
+ 'model_template',
13
+ 'criterion_template',
14
+ 'proxy_template',
15
+ 'trainer_template',
16
+ ]