SankarSrin vivym commited on
Commit
36239b8
0 Parent(s):

Duplicate from vivym/image-matting-app

Browse files

Co-authored-by: Ming Yang <vivym@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +194 -0
  3. README.md +14 -0
  4. app.py +172 -0
  5. configs/modnet-hrnet_w18.yml +5 -0
  6. configs/modnet-mobilenetv2.yml +47 -0
  7. configs/modnet-resnet50_vd.yml +5 -0
  8. configs/ppmatting-1024.yml +29 -0
  9. configs/ppmatting-2048.yml +54 -0
  10. configs/ppmatting-512.yml +44 -0
  11. configs/ppmatting-hrnet_w48-composition.yml +7 -0
  12. configs/ppmatting-hrnet_w48-distinctions.yml +55 -0
  13. models/modnet-hrnet_w18.pdparams +3 -0
  14. models/modnet-mobilenetv2.pdparams +3 -0
  15. models/modnet-resnet50_vd.pdparams +3 -0
  16. models/ppmatting-1024.pdparams +3 -0
  17. models/ppmatting-2048.pdparams +3 -0
  18. models/ppmatting-512.pdparams +3 -0
  19. ppmatting/__init__.py +1 -0
  20. ppmatting/core/__init__.py +4 -0
  21. ppmatting/core/predict.py +58 -0
  22. ppmatting/core/train.py +315 -0
  23. ppmatting/core/val.py +162 -0
  24. ppmatting/core/val_ml.py +162 -0
  25. ppmatting/datasets/__init__.py +17 -0
  26. ppmatting/datasets/composition_1k.py +31 -0
  27. ppmatting/datasets/distinctions_646.py +31 -0
  28. ppmatting/datasets/matting_dataset.py +251 -0
  29. ppmatting/metrics/__init__.py +3 -0
  30. ppmatting/metrics/metric.py +278 -0
  31. ppmatting/ml/__init__.py +1 -0
  32. ppmatting/ml/methods.py +97 -0
  33. ppmatting/models/__init__.py +7 -0
  34. ppmatting/models/backbone/__init__.py +5 -0
  35. ppmatting/models/backbone/gca_enc.py +395 -0
  36. ppmatting/models/backbone/hrnet.py +835 -0
  37. ppmatting/models/backbone/mobilenet_v2.py +242 -0
  38. ppmatting/models/backbone/resnet_vd.py +368 -0
  39. ppmatting/models/backbone/vgg.py +166 -0
  40. ppmatting/models/dim.py +208 -0
  41. ppmatting/models/gca.py +305 -0
  42. ppmatting/models/human_matting.py +454 -0
  43. ppmatting/models/layers/__init__.py +15 -0
  44. ppmatting/models/layers/gca_module.py +211 -0
  45. ppmatting/models/losses/__init__.py +1 -0
  46. ppmatting/models/losses/loss.py +163 -0
  47. ppmatting/models/modnet.py +494 -0
  48. ppmatting/models/ppmatting.py +338 -0
  49. ppmatting/transforms/__init__.py +1 -0
  50. ppmatting/transforms/transforms.py +791 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.pdparams filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+
3
+ /images
4
+
5
+ # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python
6
+ # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python
7
+
8
+ ### Python ###
9
+ # Byte-compiled / optimized / DLL files
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+
14
+ # C extensions
15
+ *.so
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ share/python-wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+ MANIFEST
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .nox/
51
+ .coverage
52
+ .coverage.*
53
+ .cache
54
+ nosetests.xml
55
+ coverage.xml
56
+ *.cover
57
+ *.py,cover
58
+ .hypothesis/
59
+ .pytest_cache/
60
+ cover/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+ db.sqlite3-journal
71
+
72
+ # Flask stuff:
73
+ instance/
74
+ .webassets-cache
75
+
76
+ # Scrapy stuff:
77
+ .scrapy
78
+
79
+ # Sphinx documentation
80
+ docs/_build/
81
+
82
+ # PyBuilder
83
+ .pybuilder/
84
+ target/
85
+
86
+ # Jupyter Notebook
87
+ .ipynb_checkpoints
88
+
89
+ # IPython
90
+ profile_default/
91
+ ipython_config.py
92
+
93
+ # pyenv
94
+ # For a library or package, you might want to ignore these files since the code is
95
+ # intended to run in multiple environments; otherwise, check them in:
96
+ # .python-version
97
+
98
+ # pipenv
99
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
101
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
102
+ # install all needed dependencies.
103
+ #Pipfile.lock
104
+
105
+ # poetry
106
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
108
+ # commonly ignored for libraries.
109
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110
+ #poetry.lock
111
+
112
+ # pdm
113
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
114
+ #pdm.lock
115
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
116
+ # in version control.
117
+ # https://pdm.fming.dev/#use-with-ide
118
+ .pdm.toml
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ ### Python Patch ###
171
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
172
+ poetry.toml
173
+
174
+
175
+ ### VisualStudioCode ###
176
+ .vscode/*
177
+ !.vscode/settings.json
178
+ !.vscode/tasks.json
179
+ !.vscode/launch.json
180
+ !.vscode/extensions.json
181
+ !.vscode/*.code-snippets
182
+
183
+ # Local History for Visual Studio Code
184
+ .history/
185
+
186
+ # Built Visual Studio Code Extensions
187
+ *.vsix
188
+
189
+ ### VisualStudioCode Patch ###
190
+ # Ignore all local history of files
191
+ .history
192
+ .ionide
193
+
194
+ # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Image Matting App
3
+ emoji: 🐨
4
+ colorFrom: yellow
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.11.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: vivym/image-matting-app
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hashlib import sha1
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ from paddleseg.cvlibs import manager, Config
10
+ from paddleseg.utils import load_entire_model
11
+
12
+ manager.BACKBONES._components_dict.clear()
13
+ manager.TRANSFORMS._components_dict.clear()
14
+
15
+ import ppmatting as ppmatting
16
+ from ppmatting.core import predict
17
+ from ppmatting.utils import estimate_foreground_ml
18
+
19
+ model_names = [
20
+ "modnet-mobilenetv2",
21
+ "ppmatting-512",
22
+ "ppmatting-1024",
23
+ "ppmatting-2048",
24
+ "modnet-hrnet_w18",
25
+ "modnet-resnet50_vd",
26
+ ]
27
+ model_dict = {
28
+ name: None
29
+ for name in model_names
30
+ }
31
+
32
+ last_result = {
33
+ "cache_key": None,
34
+ "algorithm": None,
35
+ }
36
+
37
+
38
+ def image_matting(
39
+ image: np.ndarray,
40
+ result_type: str,
41
+ bg_color: str,
42
+ algorithm: str,
43
+ morph_op: str,
44
+ morph_op_factor: float,
45
+ ) -> np.ndarray:
46
+ image = np.ascontiguousarray(image)
47
+ cache_key = sha1(image).hexdigest()
48
+ if cache_key == last_result["cache_key"] and algorithm == last_result["algorithm"]:
49
+ alpha = last_result["alpha"]
50
+ else:
51
+ cfg = Config(f"configs/{algorithm}.yml")
52
+ if model_dict[algorithm] is not None:
53
+ model = model_dict[algorithm]
54
+ else:
55
+ model = cfg.model
56
+ load_entire_model(model, f"models/{algorithm}.pdparams")
57
+ model.eval()
58
+ model_dict[algorithm] = model
59
+
60
+ transforms = ppmatting.transforms.Compose(cfg.val_transforms)
61
+
62
+ alpha = predict(
63
+ model,
64
+ transforms=transforms,
65
+ image=image,
66
+ )
67
+ last_result["cache_key"] = cache_key
68
+ last_result["algorithm"] = algorithm
69
+ last_result["alpha"] = alpha
70
+
71
+ alpha = (alpha * 255).astype(np.uint8)
72
+ kernel = np.ones((5, 5), np.uint8)
73
+ if morph_op == "Dilate":
74
+ alpha = cv2.dilate(alpha, kernel, iterations=int(morph_op_factor))
75
+ else:
76
+ alpha = cv2.erode(alpha, kernel, iterations=int(morph_op_factor))
77
+ alpha = (alpha / 255).astype(np.float32)
78
+
79
+ image = (image / 255.0).astype("float32")
80
+ fg = estimate_foreground_ml(image, alpha)
81
+
82
+ if result_type == "Remove BG":
83
+ result = np.concatenate((fg, alpha[:, :, None]), axis=-1)
84
+ elif result_type == "Replace BG":
85
+ bg_r = int(bg_color[1:3], base=16)
86
+ bg_g = int(bg_color[3:5], base=16)
87
+ bg_b = int(bg_color[5:7], base=16)
88
+
89
+ bg = np.zeros_like(fg)
90
+ bg[:, :, 0] = bg_r / 255.
91
+ bg[:, :, 1] = bg_g / 255.
92
+ bg[:, :, 2] = bg_b / 255.
93
+
94
+ result = alpha[:, :, None] * fg + (1 - alpha[:, :, None]) * bg
95
+ result = np.clip(result, 0, 1)
96
+ else:
97
+ result = alpha
98
+
99
+ return result
100
+
101
+
102
+ def main():
103
+ with gr.Blocks() as app:
104
+ gr.Markdown("Image Matting Powered By AI")
105
+
106
+ with gr.Row(variant="panel"):
107
+ image_input = gr.Image()
108
+ image_output = gr.Image()
109
+
110
+ with gr.Row(variant="panel"):
111
+ result_type = gr.Radio(
112
+ label="Mode",
113
+ show_label=True,
114
+ choices=[
115
+ "Remove BG",
116
+ "Replace BG",
117
+ "Generate Mask",
118
+ ],
119
+ value="Remove BG",
120
+ )
121
+ bg_color = gr.ColorPicker(
122
+ label="BG Color",
123
+ show_label=True,
124
+ value="#000000",
125
+ )
126
+ algorithm = gr.Dropdown(
127
+ label="Algorithm",
128
+ show_label=True,
129
+ choices=model_names,
130
+ value="modnet-hrnet_w18"
131
+ )
132
+
133
+ with gr.Row(variant="panel"):
134
+ morph_op = gr.Radio(
135
+ label="Post-process",
136
+ show_label=True,
137
+ choices=[
138
+ "Dilate",
139
+ "Erode",
140
+ ],
141
+ value="Dilate",
142
+ )
143
+
144
+ morph_op_factor = gr.Slider(
145
+ label="Factor",
146
+ show_label=True,
147
+ minimum=0,
148
+ maximum=20,
149
+ value=0,
150
+ step=1,
151
+ )
152
+
153
+ run_button = gr.Button("Run")
154
+
155
+ run_button.click(
156
+ image_matting,
157
+ inputs=[
158
+ image_input,
159
+ result_type,
160
+ bg_color,
161
+ algorithm,
162
+ morph_op,
163
+ morph_op_factor,
164
+ ],
165
+ outputs=image_output,
166
+ )
167
+
168
+ app.launch()
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()
configs/modnet-hrnet_w18.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ _base_: modnet-mobilenetv2.yml
2
+ model:
3
+ backbone:
4
+ type: HRNet_W18
5
+ # pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
configs/modnet-mobilenetv2.yml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 16
2
+ iters: 100000
3
+
4
+ train_dataset:
5
+ type: MattingDataset
6
+ dataset_root: data/PPM-100
7
+ train_file: train.txt
8
+ transforms:
9
+ - type: LoadImages
10
+ - type: RandomCrop
11
+ crop_size: [512, 512]
12
+ - type: RandomDistort
13
+ - type: RandomBlur
14
+ - type: RandomHorizontalFlip
15
+ - type: Normalize
16
+ mode: train
17
+
18
+ val_dataset:
19
+ type: MattingDataset
20
+ dataset_root: data/PPM-100
21
+ val_file: val.txt
22
+ transforms:
23
+ - type: LoadImages
24
+ - type: ResizeByShort
25
+ short_size: 512
26
+ - type: ResizeToIntMult
27
+ mult_int: 32
28
+ - type: Normalize
29
+ mode: val
30
+ get_trimap: False
31
+
32
+ model:
33
+ type: MODNet
34
+ backbone:
35
+ type: MobileNetV2
36
+ # pretrained: https://paddleseg.bj.bcebos.com/matting/models/MobileNetV2_pretrained/model.pdparams
37
+ pretrained: Null
38
+
39
+ optimizer:
40
+ type: sgd
41
+ momentum: 0.9
42
+ weight_decay: 4.0e-5
43
+
44
+ lr_scheduler:
45
+ type: PiecewiseDecay
46
+ boundaries: [40000, 80000]
47
+ values: [0.02, 0.002, 0.0002]
configs/modnet-resnet50_vd.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ _base_: modnet-mobilenetv2.yml
2
+ model:
3
+ backbone:
4
+ type: ResNet50_vd
5
+ # pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
configs/ppmatting-1024.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_: 'ppmatting-hrnet_w18-human_512.yml'
2
+
3
+
4
+ train_dataset:
5
+ transforms:
6
+ - type: LoadImages
7
+ - type: LimitShort
8
+ max_short: 1024
9
+ - type: RandomCrop
10
+ crop_size: [1024, 1024]
11
+ - type: RandomDistort
12
+ - type: RandomBlur
13
+ prob: 0.1
14
+ - type: RandomNoise
15
+ prob: 0.5
16
+ - type: RandomReJpeg
17
+ prob: 0.2
18
+ - type: RandomHorizontalFlip
19
+ - type: Normalize
20
+
21
+ val_dataset:
22
+ transforms:
23
+ - type: LoadImages
24
+ - type: LimitShort
25
+ max_short: 1024
26
+ - type: ResizeToIntMult
27
+ mult_int: 32
28
+ - type: Normalize
29
+
configs/ppmatting-2048.yml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 4
2
+ iters: 50000
3
+
4
+ train_dataset:
5
+ type: MattingDataset
6
+ dataset_root: data/PPM-100
7
+ train_file: train.txt
8
+ transforms:
9
+ - type: LoadImages
10
+ - type: RandomResize
11
+ size: [2048, 2048]
12
+ scale: [0.3, 1.5]
13
+ - type: RandomCrop
14
+ crop_size: [2048, 2048]
15
+ - type: RandomDistort
16
+ - type: RandomBlur
17
+ prob: 0.1
18
+ - type: RandomHorizontalFlip
19
+ - type: Padding
20
+ target_size: [2048, 2048]
21
+ - type: Normalize
22
+ mode: train
23
+
24
+ val_dataset:
25
+ type: MattingDataset
26
+ dataset_root: data/PPM-100
27
+ val_file: val.txt
28
+ transforms:
29
+ - type: LoadImages
30
+ - type: ResizeByShort
31
+ short_size: 2048
32
+ - type: ResizeToIntMult
33
+ mult_int: 128
34
+ - type: Normalize
35
+ mode: val
36
+ get_trimap: False
37
+
38
+ model:
39
+ type: HumanMatting
40
+ backbone:
41
+ type: ResNet34_vd
42
+ # pretrained: https://paddleseg.bj.bcebos.com/matting/models/ResNet34_vd_pretrained/model.pdparams
43
+ pretrained: Null
44
+ if_refine: True
45
+
46
+ optimizer:
47
+ type: sgd
48
+ momentum: 0.9
49
+ weight_decay: 4.0e-5
50
+
51
+ lr_scheduler:
52
+ type: PiecewiseDecay
53
+ boundaries: [30000, 40000]
54
+ values: [0.001, 0.0001, 0.00001]
configs/ppmatting-512.yml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_: 'ppmatting-hrnet_w48-distinctions.yml'
2
+
3
+ batch_size: 4
4
+ iters: 200000
5
+
6
+ train_dataset:
7
+ type: MattingDataset
8
+ dataset_root: data/PPM-100
9
+ train_file: train.txt
10
+ transforms:
11
+ - type: LoadImages
12
+ - type: LimitShort
13
+ max_short: 512
14
+ - type: RandomCrop
15
+ crop_size: [512, 512]
16
+ - type: RandomDistort
17
+ - type: RandomBlur
18
+ prob: 0.1
19
+ - type: RandomNoise
20
+ prob: 0.5
21
+ - type: RandomReJpeg
22
+ prob: 0.2
23
+ - type: RandomHorizontalFlip
24
+ - type: Normalize
25
+ mode: train
26
+
27
+ val_dataset:
28
+ type: MattingDataset
29
+ dataset_root: data/PPM-100
30
+ val_file: val.txt
31
+ transforms:
32
+ - type: LoadImages
33
+ - type: LimitShort
34
+ max_short: 512
35
+ - type: ResizeToIntMult
36
+ mult_int: 32
37
+ - type: Normalize
38
+ mode: val
39
+ get_trimap: False
40
+
41
+ model:
42
+ backbone:
43
+ type: HRNet_W18
44
+ # pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
configs/ppmatting-hrnet_w48-composition.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ _base_: 'ppmatting-hrnet_w48-distinctions.yml'
2
+
3
+ train_dataset:
4
+ dataset_root: data/matting/Composition-1k
5
+
6
+ val_dataset:
7
+ dataset_root: data/matting/Composition-1k
configs/ppmatting-hrnet_w48-distinctions.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 4
2
+ iters: 300000
3
+
4
+ train_dataset:
5
+ type: MattingDataset
6
+ dataset_root: data/matting/Distinctions-646
7
+ train_file: train.txt
8
+ transforms:
9
+ - type: LoadImages
10
+ - type: Padding
11
+ target_size: [512, 512]
12
+ - type: RandomCrop
13
+ crop_size: [[512, 512],[640, 640], [800, 800]]
14
+ - type: Resize
15
+ target_size: [512, 512]
16
+ - type: RandomDistort
17
+ - type: RandomBlur
18
+ prob: 0.1
19
+ - type: RandomHorizontalFlip
20
+ - type: Normalize
21
+ mode: train
22
+ separator: '|'
23
+
24
+ val_dataset:
25
+ type: MattingDataset
26
+ dataset_root: data/matting/Distinctions-646
27
+ val_file: val.txt
28
+ transforms:
29
+ - type: LoadImages
30
+ - type: LimitShort
31
+ max_short: 1536
32
+ - type: ResizeToIntMult
33
+ mult_int: 32
34
+ - type: Normalize
35
+ mode: val
36
+ get_trimap: False
37
+ separator: '|'
38
+
39
+ model:
40
+ type: PPMatting
41
+ backbone:
42
+ type: HRNet_W48
43
+ # pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w48_ssld.tar.gz
44
+ pretrained: Null
45
+
46
+ optimizer:
47
+ type: sgd
48
+ momentum: 0.9
49
+ weight_decay: 4.0e-5
50
+
51
+ lr_scheduler:
52
+ type: PolynomialDecay
53
+ learning_rate: 0.01
54
+ end_lr: 0
55
+ power: 0.9
models/modnet-hrnet_w18.pdparams ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02863c8069c11367cdd7d25469ed66d133e9b835fee1f6adc76086eb33c83ac8
3
+ size 41174502
models/modnet-mobilenetv2.pdparams ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3dbec3ca48dae927354efabd5b9d35e9af9998caf91e544c606af8589ad0528a
3
+ size 26143420
models/modnet-resnet50_vd.pdparams ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77568bbe3120b2490b1167df7c137402b3d3513617f5e69306e7bafd3d9f525e
3
+ size 368802825
models/ppmatting-1024.pdparams ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8e8db1e8b20a62f24b5ffae4f7e8f4a89d0db11169647967fecf2c3d17c0f99
3
+ size 98439023
models/ppmatting-2048.pdparams ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b13d1a1284d61d087cb0dd5e1d02178754053f7a03fd456484c77719b2e3a97
3
+ size 255754333
models/ppmatting-512.pdparams ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0121d1494a4bad8a620e07a935b7bd97374f121ca0f48ba96b56df2972b0e054
3
+ size 98439023
ppmatting/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from . import ml, metrics, transforms, datasets, models
ppmatting/core/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ from .val import evaluate
2
+ from .val_ml import evaluate_ml
3
+ from .train import train
4
+ from .predict import predict
ppmatting/core/predict.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import paddle
5
+ import paddle.nn.functional as F
6
+
7
+
8
+ def reverse_transform(alpha, trans_info):
9
+ """recover pred to origin shape"""
10
+ for item in trans_info[::-1]:
11
+ if item[0] == "resize":
12
+ h, w = item[1][0], item[1][1]
13
+ alpha = F.interpolate(alpha, [h, w], mode="bilinear")
14
+ elif item[0] == "padding":
15
+ h, w = item[1][0], item[1][1]
16
+ alpha = alpha[:, :, 0:h, 0:w]
17
+ else:
18
+ raise Exception(f"Unexpected info '{item[0]}' in im_info")
19
+
20
+ return alpha
21
+
22
+
23
+ def preprocess(img, transforms, trimap=None):
24
+ data = {}
25
+ data["img"] = img
26
+ if trimap is not None:
27
+ data["trimap"] = trimap
28
+ data["gt_fields"] = ["trimap"]
29
+ data["trans_info"] = []
30
+ data = transforms(data)
31
+ data["img"] = paddle.to_tensor(data["img"])
32
+ data["img"] = data["img"].unsqueeze(0)
33
+ if trimap is not None:
34
+ data["trimap"] = paddle.to_tensor(data["trimap"])
35
+ data["trimap"] = data["trimap"].unsqueeze((0, 1))
36
+
37
+ return data
38
+
39
+
40
+ def predict(
41
+ model,
42
+ transforms,
43
+ image: np.ndarray,
44
+ trimap: Optional[np.ndarray] = None,
45
+ ):
46
+ with paddle.no_grad():
47
+ data = preprocess(img=image, transforms=transforms, trimap=None)
48
+
49
+ alpha = model(data)
50
+
51
+ alpha = reverse_transform(alpha, data["trans_info"])
52
+ alpha = alpha.numpy().squeeze()
53
+
54
+ if trimap is not None:
55
+ alpha[trimap == 0] = 0
56
+ alpha[trimap == 255] = 1.
57
+
58
+ return alpha
ppmatting/core/train.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+ from collections import deque, defaultdict
18
+ import pickle
19
+ import shutil
20
+
21
+ import numpy as np
22
+ import paddle
23
+ import paddle.nn.functional as F
24
+ from paddleseg.utils import TimeAverager, calculate_eta, resume, logger
25
+
26
+ from .val import evaluate
27
+
28
+
29
+ def visual_in_traning(log_writer, vis_dict, step):
30
+ """
31
+ Visual in vdl
32
+
33
+ Args:
34
+ log_writer (LogWriter): The log writer of vdl.
35
+ vis_dict (dict): Dict of tensor. The shape of thesor is (C, H, W)
36
+ """
37
+ for key, value in vis_dict.items():
38
+ value_shape = value.shape
39
+ if value_shape[0] not in [1, 3]:
40
+ value = value[0]
41
+ value = value.unsqueeze(0)
42
+ value = paddle.transpose(value, (1, 2, 0))
43
+ min_v = paddle.min(value)
44
+ max_v = paddle.max(value)
45
+ if (min_v > 0) and (max_v < 1):
46
+ value = value * 255
47
+ elif (min_v < 0 and min_v >= -1) and (max_v <= 1):
48
+ value = (1 + value) / 2 * 255
49
+ else:
50
+ value = (value - min_v) / (max_v - min_v) * 255
51
+
52
+ value = value.astype('uint8')
53
+ value = value.numpy()
54
+ log_writer.add_image(tag=key, img=value, step=step)
55
+
56
+
57
+ def save_best(best_model_dir, metrics_data, iter):
58
+ with open(os.path.join(best_model_dir, 'best_metrics.txt'), 'w') as f:
59
+ for key, value in metrics_data.items():
60
+ line = key + ' ' + str(value) + '\n'
61
+ f.write(line)
62
+ f.write('iter' + ' ' + str(iter) + '\n')
63
+
64
+
65
+ def get_best(best_file, metrics, resume_model=None):
66
+ '''Get best metrics and iter from file'''
67
+ best_metrics_data = {}
68
+ if os.path.exists(best_file) and (resume_model is not None):
69
+ values = []
70
+ with open(best_file, 'r') as f:
71
+ lines = f.readlines()
72
+ for line in lines:
73
+ line = line.strip()
74
+ key, value = line.split(' ')
75
+ best_metrics_data[key] = eval(value)
76
+ if key == 'iter':
77
+ best_iter = eval(value)
78
+ else:
79
+ for key in metrics:
80
+ best_metrics_data[key] = np.inf
81
+ best_iter = -1
82
+ return best_metrics_data, best_iter
83
+
84
+
85
+ def train(model,
86
+ train_dataset,
87
+ val_dataset=None,
88
+ optimizer=None,
89
+ save_dir='output',
90
+ iters=10000,
91
+ batch_size=2,
92
+ resume_model=None,
93
+ save_interval=1000,
94
+ log_iters=10,
95
+ log_image_iters=1000,
96
+ num_workers=0,
97
+ use_vdl=False,
98
+ losses=None,
99
+ keep_checkpoint_max=5,
100
+ eval_begin_iters=None,
101
+ metrics='sad'):
102
+ """
103
+ Launch training.
104
+ Args:
105
+ model(nn.Layer): A matting model.
106
+ train_dataset (paddle.io.Dataset): Used to read and process training datasets.
107
+ val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
108
+ optimizer (paddle.optimizer.Optimizer): The optimizer.
109
+ save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
110
+ iters (int, optional): How may iters to train the model. Defualt: 10000.
111
+ batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
112
+ resume_model (str, optional): The path of resume model.
113
+ save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
114
+ log_iters (int, optional): Display logging information at every log_iters. Default: 10.
115
+ log_image_iters (int, optional): Log image to vdl. Default: 1000.
116
+ num_workers (int, optional): Num workers for data loader. Default: 0.
117
+ use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
118
+ losses (dict, optional): A dict of loss, refer to the loss function of the model for details. Default: None.
119
+ keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
120
+ eval_begin_iters (int): The iters begin evaluation. It will evaluate at iters/2 if it is None. Defalust: None.
121
+ metrics(str|list, optional): The metrics to evaluate, it may be the combination of ("sad", "mse", "grad", "conn").
122
+ """
123
+ model.train()
124
+ nranks = paddle.distributed.ParallelEnv().nranks
125
+ local_rank = paddle.distributed.ParallelEnv().local_rank
126
+
127
+ start_iter = 0
128
+ if resume_model is not None:
129
+ start_iter = resume(model, optimizer, resume_model)
130
+
131
+ if not os.path.isdir(save_dir):
132
+ if os.path.exists(save_dir):
133
+ os.remove(save_dir)
134
+ os.makedirs(save_dir)
135
+
136
+ if nranks > 1:
137
+ # Initialize parallel environment if not done.
138
+ if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
139
+ ):
140
+ paddle.distributed.init_parallel_env()
141
+ ddp_model = paddle.DataParallel(model)
142
+ else:
143
+ ddp_model = paddle.DataParallel(model)
144
+
145
+ batch_sampler = paddle.io.DistributedBatchSampler(
146
+ train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
147
+
148
+ loader = paddle.io.DataLoader(
149
+ train_dataset,
150
+ batch_sampler=batch_sampler,
151
+ num_workers=num_workers,
152
+ return_list=True, )
153
+
154
+ if use_vdl:
155
+ from visualdl import LogWriter
156
+ log_writer = LogWriter(save_dir)
157
+
158
+ if isinstance(metrics, str):
159
+ metrics = [metrics]
160
+ elif not isinstance(metrics, list):
161
+ metrics = ['sad']
162
+ best_metrics_data, best_iter = get_best(
163
+ os.path.join(save_dir, 'best_model', 'best_metrics.txt'),
164
+ metrics,
165
+ resume_model=resume_model)
166
+ avg_loss = defaultdict(float)
167
+ iters_per_epoch = len(batch_sampler)
168
+ reader_cost_averager = TimeAverager()
169
+ batch_cost_averager = TimeAverager()
170
+ save_models = deque()
171
+ batch_start = time.time()
172
+
173
+ iter = start_iter
174
+ while iter < iters:
175
+ for data in loader:
176
+ iter += 1
177
+ if iter > iters:
178
+ break
179
+ reader_cost_averager.record(time.time() - batch_start)
180
+
181
+ logit_dict, loss_dict = ddp_model(data) if nranks > 1 else model(
182
+ data)
183
+
184
+ loss_dict['all'].backward()
185
+
186
+ optimizer.step()
187
+ lr = optimizer.get_lr()
188
+ if isinstance(optimizer._learning_rate,
189
+ paddle.optimizer.lr.LRScheduler):
190
+ optimizer._learning_rate.step()
191
+ model.clear_gradients()
192
+
193
+ for key, value in loss_dict.items():
194
+ avg_loss[key] += value.numpy()[0]
195
+ batch_cost_averager.record(
196
+ time.time() - batch_start, num_samples=batch_size)
197
+
198
+ if (iter) % log_iters == 0 and local_rank == 0:
199
+ for key, value in avg_loss.items():
200
+ avg_loss[key] = value / log_iters
201
+ remain_iters = iters - iter
202
+ avg_train_batch_cost = batch_cost_averager.get_average()
203
+ avg_train_reader_cost = reader_cost_averager.get_average()
204
+ eta = calculate_eta(remain_iters, avg_train_batch_cost)
205
+ # loss info
206
+ loss_str = ' ' * 26 + '\t[LOSSES]'
207
+ loss_str = loss_str
208
+ for key, value in avg_loss.items():
209
+ if key != 'all':
210
+ loss_str = loss_str + ' ' + key + '={:.4f}'.format(
211
+ value)
212
+ logger.info(
213
+ "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.5f}, ips={:.4f} samples/sec | ETA {}\n{}\n"
214
+ .format((iter - 1) // iters_per_epoch + 1, iter, iters,
215
+ avg_loss['all'], lr, avg_train_batch_cost,
216
+ avg_train_reader_cost,
217
+ batch_cost_averager.get_ips_average(
218
+ ), eta, loss_str))
219
+ if use_vdl:
220
+ for key, value in avg_loss.items():
221
+ log_tag = 'Train/' + key
222
+ log_writer.add_scalar(log_tag, value, iter)
223
+
224
+ log_writer.add_scalar('Train/lr', lr, iter)
225
+ log_writer.add_scalar('Train/batch_cost',
226
+ avg_train_batch_cost, iter)
227
+ log_writer.add_scalar('Train/reader_cost',
228
+ avg_train_reader_cost, iter)
229
+ if iter % log_image_iters == 0:
230
+ vis_dict = {}
231
+ # ground truth
232
+ vis_dict['ground truth/img'] = data['img'][0]
233
+ for key in data['gt_fields']:
234
+ key = key[0]
235
+ vis_dict['/'.join(['ground truth', key])] = data[
236
+ key][0]
237
+ # predict
238
+ for key, value in logit_dict.items():
239
+ vis_dict['/'.join(['predict', key])] = logit_dict[
240
+ key][0]
241
+ visual_in_traning(
242
+ log_writer=log_writer, vis_dict=vis_dict, step=iter)
243
+
244
+ for key in avg_loss.keys():
245
+ avg_loss[key] = 0.
246
+ reader_cost_averager.reset()
247
+ batch_cost_averager.reset()
248
+
249
+ # save model
250
+ if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
251
+ current_save_dir = os.path.join(save_dir,
252
+ "iter_{}".format(iter))
253
+ if not os.path.isdir(current_save_dir):
254
+ os.makedirs(current_save_dir)
255
+ paddle.save(model.state_dict(),
256
+ os.path.join(current_save_dir, 'model.pdparams'))
257
+ paddle.save(optimizer.state_dict(),
258
+ os.path.join(current_save_dir, 'model.pdopt'))
259
+ save_models.append(current_save_dir)
260
+ if len(save_models) > keep_checkpoint_max > 0:
261
+ model_to_remove = save_models.popleft()
262
+ shutil.rmtree(model_to_remove)
263
+
264
+ # eval model
265
+ if eval_begin_iters is None:
266
+ eval_begin_iters = iters // 2
267
+ if (iter % save_interval == 0 or iter == iters) and (
268
+ val_dataset is not None
269
+ ) and local_rank == 0 and iter >= eval_begin_iters:
270
+ num_workers = 1 if num_workers > 0 else 0
271
+ metrics_data = evaluate(
272
+ model,
273
+ val_dataset,
274
+ num_workers=1,
275
+ print_detail=True,
276
+ save_results=False,
277
+ metrics=metrics)
278
+ model.train()
279
+
280
+ # save best model and add evaluation results to vdl
281
+ if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
282
+ if val_dataset is not None and iter >= eval_begin_iters:
283
+ if metrics_data[metrics[0]] < best_metrics_data[metrics[0]]:
284
+ best_iter = iter
285
+ best_metrics_data = metrics_data.copy()
286
+ best_model_dir = os.path.join(save_dir, "best_model")
287
+ paddle.save(
288
+ model.state_dict(),
289
+ os.path.join(best_model_dir, 'model.pdparams'))
290
+ save_best(best_model_dir, best_metrics_data, iter)
291
+
292
+ show_list = []
293
+ for key, value in best_metrics_data.items():
294
+ show_list.append((key, value))
295
+ log_str = '[EVAL] The model with the best validation {} ({:.4f}) was saved at iter {}.'.format(
296
+ show_list[0][0], show_list[0][1], best_iter)
297
+ if len(show_list) > 1:
298
+ log_str += " While"
299
+ for i in range(1, len(show_list)):
300
+ log_str = log_str + ' {}: {:.4f},'.format(
301
+ show_list[i][0], show_list[i][1])
302
+ log_str = log_str[:-1]
303
+ logger.info(log_str)
304
+
305
+ if use_vdl:
306
+ for key, value in metrics_data.items():
307
+ log_writer.add_scalar('Evaluate/' + key, value,
308
+ iter)
309
+
310
+ batch_start = time.time()
311
+
312
+ # Sleep for half a second to let dataloader release resources.
313
+ time.sleep(0.5)
314
+ if use_vdl:
315
+ log_writer.close()
ppmatting/core/val.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import cv2
18
+ import numpy as np
19
+ import time
20
+ import paddle
21
+ import paddle.nn.functional as F
22
+ from paddleseg.utils import TimeAverager, calculate_eta, logger, progbar
23
+
24
+ from ppmatting.metrics import metrics_class_dict
25
+
26
+ np.set_printoptions(suppress=True)
27
+
28
+
29
+ def save_alpha_pred(alpha, path):
30
+ """
31
+ The value of alpha is range [0, 1], shape should be [h,w]
32
+ """
33
+ dirname = os.path.dirname(path)
34
+ if not os.path.exists(dirname):
35
+ os.makedirs(dirname)
36
+
37
+ alpha = (alpha).astype('uint8')
38
+ cv2.imwrite(path, alpha)
39
+
40
+
41
+ def reverse_transform(alpha, trans_info):
42
+ """recover pred to origin shape"""
43
+ for item in trans_info[::-1]:
44
+ if item[0][0] == 'resize':
45
+ h, w = item[1][0], item[1][1]
46
+ alpha = F.interpolate(alpha, [h, w], mode='bilinear')
47
+ elif item[0][0] == 'padding':
48
+ h, w = item[1][0], item[1][1]
49
+ alpha = alpha[:, :, 0:h, 0:w]
50
+ else:
51
+ raise Exception("Unexpected info '{}' in im_info".format(item[0]))
52
+ return alpha
53
+
54
+
55
+ def evaluate(model,
56
+ eval_dataset,
57
+ num_workers=0,
58
+ print_detail=True,
59
+ save_dir='output/results',
60
+ save_results=True,
61
+ metrics='sad'):
62
+ model.eval()
63
+ nranks = paddle.distributed.ParallelEnv().nranks
64
+ local_rank = paddle.distributed.ParallelEnv().local_rank
65
+ if nranks > 1:
66
+ # Initialize parallel environment if not done.
67
+ if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
68
+ ):
69
+ paddle.distributed.init_parallel_env()
70
+
71
+ loader = paddle.io.DataLoader(
72
+ eval_dataset,
73
+ batch_size=1,
74
+ drop_last=False,
75
+ num_workers=num_workers,
76
+ return_list=True, )
77
+
78
+ total_iters = len(loader)
79
+ # Get metric instances and data saving
80
+ metrics_ins = {}
81
+ metrics_data = {}
82
+ if isinstance(metrics, str):
83
+ metrics = [metrics]
84
+ elif not isinstance(metrics, list):
85
+ metrics = ['sad']
86
+ for key in metrics:
87
+ key = key.lower()
88
+ metrics_ins[key] = metrics_class_dict[key]()
89
+ metrics_data[key] = None
90
+
91
+ if print_detail:
92
+ logger.info("Start evaluating (total_samples: {}, total_iters: {})...".
93
+ format(len(eval_dataset), total_iters))
94
+ progbar_val = progbar.Progbar(
95
+ target=total_iters, verbose=1 if nranks < 2 else 2)
96
+ reader_cost_averager = TimeAverager()
97
+ batch_cost_averager = TimeAverager()
98
+ batch_start = time.time()
99
+
100
+ img_name = ''
101
+ i = 0
102
+ with paddle.no_grad():
103
+ for iter, data in enumerate(loader):
104
+ reader_cost_averager.record(time.time() - batch_start)
105
+ alpha_pred = model(data)
106
+
107
+ alpha_pred = reverse_transform(alpha_pred, data['trans_info'])
108
+ alpha_pred = alpha_pred.numpy()
109
+
110
+ alpha_gt = data['alpha'].numpy() * 255
111
+ trimap = data.get('ori_trimap')
112
+ if trimap is not None:
113
+ trimap = trimap.numpy().astype('uint8')
114
+ alpha_pred = np.round(alpha_pred * 255)
115
+ for key in metrics_ins.keys():
116
+ metrics_data[key] = metrics_ins[key].update(alpha_pred,
117
+ alpha_gt, trimap)
118
+
119
+ if save_results:
120
+ alpha_pred_one = alpha_pred[0].squeeze()
121
+ if trimap is not None:
122
+ trimap = trimap.squeeze().astype('uint8')
123
+ alpha_pred_one[trimap == 255] = 255
124
+ alpha_pred_one[trimap == 0] = 0
125
+
126
+ save_name = data['img_name'][0]
127
+ name, ext = os.path.splitext(save_name)
128
+ if save_name == img_name:
129
+ save_name = name + '_' + str(i) + ext
130
+ i += 1
131
+ else:
132
+ img_name = save_name
133
+ save_name = name + '_' + str(i) + ext
134
+ i = 1
135
+
136
+ save_alpha_pred(alpha_pred_one,
137
+ os.path.join(save_dir, save_name))
138
+
139
+ batch_cost_averager.record(
140
+ time.time() - batch_start, num_samples=len(alpha_gt))
141
+ batch_cost = batch_cost_averager.get_average()
142
+ reader_cost = reader_cost_averager.get_average()
143
+
144
+ if local_rank == 0 and print_detail:
145
+ show_list = [(k, v) for k, v in metrics_data.items()]
146
+ show_list = show_list + [('batch_cost', batch_cost),
147
+ ('reader cost', reader_cost)]
148
+ progbar_val.update(iter + 1, show_list)
149
+
150
+ reader_cost_averager.reset()
151
+ batch_cost_averager.reset()
152
+ batch_start = time.time()
153
+
154
+ for key in metrics_ins.keys():
155
+ metrics_data[key] = metrics_ins[key].evaluate()
156
+ log_str = '[EVAL] '
157
+ for key, value in metrics_data.items():
158
+ log_str = log_str + key + ': {:.4f}, '.format(value)
159
+ log_str = log_str[:-2]
160
+
161
+ logger.info(log_str)
162
+ return metrics_data
ppmatting/core/val_ml.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import cv2
18
+ import numpy as np
19
+ import time
20
+ import paddle
21
+ import paddle.nn.functional as F
22
+ from paddleseg.utils import TimeAverager, calculate_eta, logger, progbar
23
+
24
+ from ppmatting.metrics import metric
25
+ from pymatting.util.util import load_image, save_image, stack_images
26
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
27
+
28
+ np.set_printoptions(suppress=True)
29
+
30
+
31
+ def save_alpha_pred(alpha, path):
32
+ """
33
+ The value of alpha is range [0, 1], shape should be [h,w]
34
+ """
35
+ dirname = os.path.dirname(path)
36
+ if not os.path.exists(dirname):
37
+ os.makedirs(dirname)
38
+
39
+ alpha = (alpha).astype('uint8')
40
+ cv2.imwrite(path, alpha)
41
+
42
+
43
+ def reverse_transform(alpha, trans_info):
44
+ """recover pred to origin shape"""
45
+ for item in trans_info[::-1]:
46
+ if item[0][0] == 'resize':
47
+ h, w = item[1][0].numpy()[0], item[1][1].numpy()[0]
48
+ alpha = cv2.resize(alpha, dsize=(w, h))
49
+ elif item[0][0] == 'padding':
50
+ h, w = item[1][0].numpy()[0], item[1][1].numpy()[0]
51
+ alpha = alpha[0:h, 0:w]
52
+ else:
53
+ raise Exception("Unexpected info '{}' in im_info".format(item[0]))
54
+ return alpha
55
+
56
+
57
+ def evaluate_ml(model,
58
+ eval_dataset,
59
+ num_workers=0,
60
+ print_detail=True,
61
+ save_dir='output/results',
62
+ save_results=True):
63
+
64
+ loader = paddle.io.DataLoader(
65
+ eval_dataset,
66
+ batch_size=1,
67
+ drop_last=False,
68
+ num_workers=num_workers,
69
+ return_list=True, )
70
+
71
+ total_iters = len(loader)
72
+ mse_metric = metric.MSE()
73
+ sad_metric = metric.SAD()
74
+ grad_metric = metric.Grad()
75
+ conn_metric = metric.Conn()
76
+
77
+ if print_detail:
78
+ logger.info("Start evaluating (total_samples: {}, total_iters: {})...".
79
+ format(len(eval_dataset), total_iters))
80
+ progbar_val = progbar.Progbar(target=total_iters, verbose=1)
81
+ reader_cost_averager = TimeAverager()
82
+ batch_cost_averager = TimeAverager()
83
+ batch_start = time.time()
84
+
85
+ img_name = ''
86
+ i = 0
87
+ ignore_cnt = 0
88
+ for iter, data in enumerate(loader):
89
+
90
+ reader_cost_averager.record(time.time() - batch_start)
91
+
92
+ image_rgb_chw = data['img'].numpy()[0]
93
+ image_rgb_hwc = np.transpose(image_rgb_chw, (1, 2, 0))
94
+ trimap = data['trimap'].numpy().squeeze() / 255.0
95
+ image = image_rgb_hwc * 0.5 + 0.5 # reverse normalize (x/255 - mean) / std
96
+
97
+ is_fg = trimap >= 0.9
98
+ is_bg = trimap <= 0.1
99
+
100
+ if is_fg.sum() == 0 or is_bg.sum() == 0:
101
+ ignore_cnt += 1
102
+ logger.info(str(iter))
103
+ continue
104
+
105
+ alpha_pred = model(image, trimap)
106
+
107
+ alpha_pred = reverse_transform(alpha_pred, data['trans_info'])
108
+
109
+ alpha_gt = data['alpha'].numpy().squeeze() * 255
110
+
111
+ trimap = data['ori_trimap'].numpy().squeeze()
112
+
113
+ alpha_pred = np.round(alpha_pred * 255)
114
+ mse = mse_metric.update(alpha_pred, alpha_gt, trimap)
115
+ sad = sad_metric.update(alpha_pred, alpha_gt, trimap)
116
+ grad = grad_metric.update(alpha_pred, alpha_gt, trimap)
117
+ conn = conn_metric.update(alpha_pred, alpha_gt, trimap)
118
+
119
+ if sad > 1000:
120
+ print(data['img_name'][0])
121
+
122
+ if save_results:
123
+ alpha_pred_one = alpha_pred
124
+ alpha_pred_one[trimap == 255] = 255
125
+ alpha_pred_one[trimap == 0] = 0
126
+
127
+ save_name = data['img_name'][0]
128
+ name, ext = os.path.splitext(save_name)
129
+ if save_name == img_name:
130
+ save_name = name + '_' + str(i) + ext
131
+ i += 1
132
+ else:
133
+ img_name = save_name
134
+ save_name = name + '_' + str(0) + ext
135
+ i = 1
136
+ save_alpha_pred(alpha_pred_one, os.path.join(save_dir, save_name))
137
+
138
+ batch_cost_averager.record(
139
+ time.time() - batch_start, num_samples=len(alpha_gt))
140
+ batch_cost = batch_cost_averager.get_average()
141
+ reader_cost = reader_cost_averager.get_average()
142
+
143
+ if print_detail:
144
+ progbar_val.update(iter + 1,
145
+ [('SAD', sad), ('MSE', mse), ('Grad', grad),
146
+ ('Conn', conn), ('batch_cost', batch_cost),
147
+ ('reader cost', reader_cost)])
148
+
149
+ reader_cost_averager.reset()
150
+ batch_cost_averager.reset()
151
+ batch_start = time.time()
152
+
153
+ mse = mse_metric.evaluate()
154
+ sad = sad_metric.evaluate()
155
+ grad = grad_metric.evaluate()
156
+ conn = conn_metric.evaluate()
157
+
158
+ logger.info('[EVAL] SAD: {:.4f}, MSE: {:.4f}, Grad: {:.4f}, Conn: {:.4f}'.
159
+ format(sad, mse, grad, conn))
160
+ logger.info('{}'.format(ignore_cnt))
161
+
162
+ return sad, mse, grad, conn
ppmatting/datasets/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .matting_dataset import MattingDataset
16
+ from .composition_1k import Composition1K
17
+ from .distinctions_646 import Distinctions646
ppmatting/datasets/composition_1k.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import math
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import random
21
+ import paddle
22
+ from paddleseg.cvlibs import manager
23
+
24
+ import ppmatting.transforms as T
25
+ from ppmatting.datasets.matting_dataset import MattingDataset
26
+
27
+
28
+ @manager.DATASETS.add_component
29
+ class Composition1K(MattingDataset):
30
+ def __init__(self, **kwargs):
31
+ super().__init__(**kwargs)
ppmatting/datasets/distinctions_646.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import math
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import random
21
+ import paddle
22
+ from paddleseg.cvlibs import manager
23
+
24
+ import ppmatting.transforms as T
25
+ from ppmatting.datasets.matting_dataset import MattingDataset
26
+
27
+
28
+ @manager.DATASETS.add_component
29
+ class Distinctions646(MattingDataset):
30
+ def __init__(self, **kwargs):
31
+ super().__init__(**kwargs)
ppmatting/datasets/matting_dataset.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import math
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import random
21
+ import paddle
22
+ from paddleseg.cvlibs import manager
23
+
24
+ import ppmatting.transforms as T
25
+
26
+
27
+ @manager.DATASETS.add_component
28
+ class MattingDataset(paddle.io.Dataset):
29
+ """
30
+ Pass in a dataset that conforms to the format.
31
+ matting_dataset/
32
+ |--bg/
33
+ |
34
+ |--train/
35
+ | |--fg/
36
+ | |--alpha/
37
+ |
38
+ |--val/
39
+ | |--fg/
40
+ | |--alpha/
41
+ | |--trimap/ (if existing)
42
+ |
43
+ |--train.txt
44
+ |
45
+ |--val.txt
46
+ See README.md for more information of dataset.
47
+
48
+ Args:
49
+ dataset_root(str): The root path of dataset.
50
+ transforms(list): Transforms for image.
51
+ mode (str, optional): which part of dataset to use. it is one of ('train', 'val', 'trainval'). Default: 'train'.
52
+ train_file (str|list, optional): File list is used to train. It should be `foreground_image.png background_image.png`
53
+ or `foreground_image.png`. It shold be provided if mode equal to 'train'. Default: None.
54
+ val_file (str|list, optional): File list is used to evaluation. It should be `foreground_image.png background_image.png`
55
+ or `foreground_image.png` or ``foreground_image.png background_image.png trimap_image.png`.
56
+ It shold be provided if mode equal to 'val'. Default: None.
57
+ get_trimap (bool, optional): Whether to get triamp. Default: True.
58
+ separator (str, optional): The separator of train_file or val_file. If file name contains ' ', '|' may be perfect. Default: ' '.
59
+ key_del (tuple|list, optional): The key which is not need will be delete to accellect data reader. Default: None.
60
+ if_rssn (bool, optional): Whether to use RSSN while Compositing image. Including denoise and blur. Default: False.
61
+ """
62
+
63
+ def __init__(self,
64
+ dataset_root,
65
+ transforms,
66
+ mode='train',
67
+ train_file=None,
68
+ val_file=None,
69
+ get_trimap=True,
70
+ separator=' ',
71
+ key_del=None,
72
+ if_rssn=False):
73
+ super().__init__()
74
+ self.dataset_root = dataset_root
75
+ self.transforms = T.Compose(transforms)
76
+ self.mode = mode
77
+ self.get_trimap = get_trimap
78
+ self.separator = separator
79
+ self.key_del = key_del
80
+ self.if_rssn = if_rssn
81
+
82
+ # check file
83
+ if mode == 'train' or mode == 'trainval':
84
+ if train_file is None:
85
+ raise ValueError(
86
+ "When `mode` is 'train' or 'trainval', `train_file must be provided!"
87
+ )
88
+ if isinstance(train_file, str):
89
+ train_file = [train_file]
90
+ file_list = train_file
91
+
92
+ if mode == 'val' or mode == 'trainval':
93
+ if val_file is None:
94
+ raise ValueError(
95
+ "When `mode` is 'val' or 'trainval', `val_file must be provided!"
96
+ )
97
+ if isinstance(val_file, str):
98
+ val_file = [val_file]
99
+ file_list = val_file
100
+
101
+ if mode == 'trainval':
102
+ file_list = train_file + val_file
103
+
104
+ # read file
105
+ self.fg_bg_list = []
106
+ for file in file_list:
107
+ file = os.path.join(dataset_root, file)
108
+ with open(file, 'r') as f:
109
+ lines = f.readlines()
110
+ for line in lines:
111
+ line = line.strip()
112
+ self.fg_bg_list.append(line)
113
+ if mode != 'val':
114
+ random.shuffle(self.fg_bg_list)
115
+
116
+ def __getitem__(self, idx):
117
+ data = {}
118
+ fg_bg_file = self.fg_bg_list[idx]
119
+ fg_bg_file = fg_bg_file.split(self.separator)
120
+ data['img_name'] = fg_bg_file[0] # using in save prediction results
121
+ fg_file = os.path.join(self.dataset_root, fg_bg_file[0])
122
+ alpha_file = fg_file.replace('/fg', '/alpha')
123
+ fg = cv2.imread(fg_file)
124
+ alpha = cv2.imread(alpha_file, 0)
125
+ data['alpha'] = alpha
126
+ data['gt_fields'] = []
127
+
128
+ # line is: fg [bg] [trimap]
129
+ if len(fg_bg_file) >= 2:
130
+ bg_file = os.path.join(self.dataset_root, fg_bg_file[1])
131
+ bg = cv2.imread(bg_file)
132
+ data['img'], data['fg'], data['bg'] = self.composite(fg, alpha, bg)
133
+ if self.mode in ['train', 'trainval']:
134
+ data['gt_fields'].append('fg')
135
+ data['gt_fields'].append('bg')
136
+ data['gt_fields'].append('alpha')
137
+ if len(fg_bg_file) == 3 and self.get_trimap:
138
+ if self.mode == 'val':
139
+ trimap_path = os.path.join(self.dataset_root, fg_bg_file[2])
140
+ if os.path.exists(trimap_path):
141
+ data['trimap'] = trimap_path
142
+ data['gt_fields'].append('trimap')
143
+ data['ori_trimap'] = cv2.imread(trimap_path, 0)
144
+ else:
145
+ raise FileNotFoundError(
146
+ 'trimap is not Found: {}'.format(fg_bg_file[2]))
147
+ else:
148
+ data['img'] = fg
149
+ if self.mode in ['train', 'trainval']:
150
+ data['fg'] = fg.copy()
151
+ data['bg'] = fg.copy()
152
+ data['gt_fields'].append('fg')
153
+ data['gt_fields'].append('bg')
154
+ data['gt_fields'].append('alpha')
155
+
156
+ data['trans_info'] = [] # Record shape change information
157
+
158
+ # Generate trimap from alpha if no trimap file provided
159
+ if self.get_trimap:
160
+ if 'trimap' not in data:
161
+ data['trimap'] = self.gen_trimap(
162
+ data['alpha'], mode=self.mode).astype('float32')
163
+ data['gt_fields'].append('trimap')
164
+ if self.mode == 'val':
165
+ data['ori_trimap'] = data['trimap'].copy()
166
+
167
+ # Delete key which is not need
168
+ if self.key_del is not None:
169
+ for key in self.key_del:
170
+ if key in data.keys():
171
+ data.pop(key)
172
+ if key in data['gt_fields']:
173
+ data['gt_fields'].remove(key)
174
+ data = self.transforms(data)
175
+
176
+ # When evaluation, gt should not be transforms.
177
+ if self.mode == 'val':
178
+ data['gt_fields'].append('alpha')
179
+
180
+ data['img'] = data['img'].astype('float32')
181
+ for key in data.get('gt_fields', []):
182
+ data[key] = data[key].astype('float32')
183
+
184
+ if 'trimap' in data:
185
+ data['trimap'] = data['trimap'][np.newaxis, :, :]
186
+ if 'ori_trimap' in data:
187
+ data['ori_trimap'] = data['ori_trimap'][np.newaxis, :, :]
188
+
189
+ data['alpha'] = data['alpha'][np.newaxis, :, :] / 255.
190
+
191
+ return data
192
+
193
+ def __len__(self):
194
+ return len(self.fg_bg_list)
195
+
196
+ def composite(self, fg, alpha, ori_bg):
197
+ if self.if_rssn:
198
+ if np.random.rand() < 0.5:
199
+ fg = cv2.fastNlMeansDenoisingColored(fg, None, 3, 3, 7, 21)
200
+ ori_bg = cv2.fastNlMeansDenoisingColored(ori_bg, None, 3, 3, 7,
201
+ 21)
202
+ if np.random.rand() < 0.5:
203
+ radius = np.random.choice([19, 29, 39, 49, 59])
204
+ ori_bg = cv2.GaussianBlur(ori_bg, (radius, radius), 0, 0)
205
+ fg_h, fg_w = fg.shape[:2]
206
+ ori_bg_h, ori_bg_w = ori_bg.shape[:2]
207
+
208
+ wratio = fg_w / ori_bg_w
209
+ hratio = fg_h / ori_bg_h
210
+ ratio = wratio if wratio > hratio else hratio
211
+
212
+ # Resize ori_bg if it is smaller than fg.
213
+ if ratio > 1:
214
+ resize_h = math.ceil(ori_bg_h * ratio)
215
+ resize_w = math.ceil(ori_bg_w * ratio)
216
+ bg = cv2.resize(
217
+ ori_bg, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR)
218
+ else:
219
+ bg = ori_bg
220
+
221
+ bg = bg[0:fg_h, 0:fg_w, :]
222
+ alpha = alpha / 255
223
+ alpha = np.expand_dims(alpha, axis=2)
224
+ image = alpha * fg + (1 - alpha) * bg
225
+ image = image.astype(np.uint8)
226
+ return image, fg, bg
227
+
228
+ @staticmethod
229
+ def gen_trimap(alpha, mode='train', eval_kernel=7):
230
+ if mode == 'train':
231
+ k_size = random.choice(range(2, 5))
232
+ iterations = np.random.randint(5, 15)
233
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
234
+ (k_size, k_size))
235
+ dilated = cv2.dilate(alpha, kernel, iterations=iterations)
236
+ eroded = cv2.erode(alpha, kernel, iterations=iterations)
237
+ trimap = np.zeros(alpha.shape)
238
+ trimap.fill(128)
239
+ trimap[eroded > 254.5] = 255
240
+ trimap[dilated < 0.5] = 0
241
+ else:
242
+ k_size = eval_kernel
243
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
244
+ (k_size, k_size))
245
+ dilated = cv2.dilate(alpha, kernel)
246
+ trimap = np.zeros(alpha.shape)
247
+ trimap.fill(128)
248
+ trimap[alpha >= 250] = 255
249
+ trimap[dilated <= 5] = 0
250
+
251
+ return trimap
ppmatting/metrics/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ from .metric import MSE, SAD, Grad, Conn
2
+
3
+ metrics_class_dict = {'sad': SAD, 'mse': MSE, 'grad': Grad, 'conn': Conn}
ppmatting/metrics/metric.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Grad and Conn is refer to https://github.com/yucornetto/MGMatting/blob/main/code-base/utils/evaluate.py
16
+ # Output of `Grad` is sightly different from the MATLAB version provided by Adobe (less than 0.1%)
17
+ # Output of `Conn` is smaller than the MATLAB version (~5%, maybe MATLAB has a different algorithm)
18
+ # So do not report results calculated by these functions in your paper.
19
+ # Evaluate your inference with the MATLAB file `DIM_evaluation_code/evaluate.m`.
20
+
21
+ import cv2
22
+ import numpy as np
23
+ from scipy.ndimage import convolve
24
+ from scipy.special import gamma
25
+ from skimage.measure import label
26
+
27
+
28
+ class MSE:
29
+ """
30
+ Only calculate the unknown region if trimap provided.
31
+ """
32
+
33
+ def __init__(self):
34
+ self.mse_diffs = 0
35
+ self.count = 0
36
+
37
+ def update(self, pred, gt, trimap=None):
38
+ """
39
+ update metric.
40
+ Args:
41
+ pred (np.ndarray): The value range is [0., 255.].
42
+ gt (np.ndarray): The value range is [0, 255].
43
+ trimap (np.ndarray, optional) The value is in {0, 128, 255}. Default: None.
44
+ """
45
+ if trimap is None:
46
+ trimap = np.ones_like(gt) * 128
47
+ if not (pred.shape == gt.shape == trimap.shape):
48
+ raise ValueError(
49
+ 'The shape of `pred`, `gt` and `trimap` should be equal. '
50
+ 'but they are {}, {} and {}'.format(pred.shape, gt.shape,
51
+ trimap.shape))
52
+ pred[trimap == 0] = 0
53
+ pred[trimap == 255] = 255
54
+
55
+ mask = trimap == 128
56
+ pixels = float(mask.sum())
57
+ pred = pred / 255.
58
+ gt = gt / 255.
59
+ diff = (pred - gt) * mask
60
+ mse_diff = (diff**2).sum() / pixels if pixels > 0 else 0
61
+
62
+ self.mse_diffs += mse_diff
63
+ self.count += 1
64
+
65
+ return mse_diff
66
+
67
+ def evaluate(self):
68
+ mse = self.mse_diffs / self.count if self.count > 0 else 0
69
+ return mse
70
+
71
+
72
+ class SAD:
73
+ """
74
+ Only calculate the unknown region if trimap provided.
75
+ """
76
+
77
+ def __init__(self):
78
+ self.sad_diffs = 0
79
+ self.count = 0
80
+
81
+ def update(self, pred, gt, trimap=None):
82
+ """
83
+ update metric.
84
+ Args:
85
+ pred (np.ndarray): The value range is [0., 255.].
86
+ gt (np.ndarray): The value range is [0., 255.].
87
+ trimap (np.ndarray, optional)L The value is in {0, 128, 255}. Default: None.
88
+ """
89
+ if trimap is None:
90
+ trimap = np.ones_like(gt) * 128
91
+ if not (pred.shape == gt.shape == trimap.shape):
92
+ raise ValueError(
93
+ 'The shape of `pred`, `gt` and `trimap` should be equal. '
94
+ 'but they are {}, {} and {}'.format(pred.shape, gt.shape,
95
+ trimap.shape))
96
+ pred[trimap == 0] = 0
97
+ pred[trimap == 255] = 255
98
+
99
+ mask = trimap == 128
100
+ pred = pred / 255.
101
+ gt = gt / 255.
102
+ diff = (pred - gt) * mask
103
+ sad_diff = (np.abs(diff)).sum()
104
+
105
+ sad_diff /= 1000
106
+ self.sad_diffs += sad_diff
107
+ self.count += 1
108
+
109
+ return sad_diff
110
+
111
+ def evaluate(self):
112
+ sad = self.sad_diffs / self.count if self.count > 0 else 0
113
+ return sad
114
+
115
+
116
+ class Grad:
117
+ """
118
+ Only calculate the unknown region if trimap provided.
119
+ Refer to: https://github.com/open-mlab/mmediting/blob/master/mmedit/core/evaluation/metrics.py
120
+ """
121
+
122
+ def __init__(self):
123
+ self.grad_diffs = 0
124
+ self.count = 0
125
+
126
+ def gaussian(self, x, sigma):
127
+ return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi))
128
+
129
+ def dgaussian(self, x, sigma):
130
+ return -x * self.gaussian(x, sigma) / sigma**2
131
+
132
+ def gauss_filter(self, sigma, epsilon=1e-2):
133
+ half_size = np.ceil(
134
+ sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon)))
135
+ size = int(2 * half_size + 1)
136
+
137
+ # create filter in x axis
138
+ filter_x = np.zeros((size, size))
139
+ for i in range(size):
140
+ for j in range(size):
141
+ filter_x[i, j] = self.gaussian(
142
+ i - half_size, sigma) * self.dgaussian(j - half_size, sigma)
143
+
144
+ # normalize filter
145
+ norm = np.sqrt((filter_x**2).sum())
146
+ filter_x = filter_x / norm
147
+ filter_y = np.transpose(filter_x)
148
+
149
+ return filter_x, filter_y
150
+
151
+ def gauss_gradient(self, img, sigma):
152
+ filter_x, filter_y = self.gauss_filter(sigma)
153
+ img_filtered_x = cv2.filter2D(
154
+ img, -1, filter_x, borderType=cv2.BORDER_REPLICATE)
155
+ img_filtered_y = cv2.filter2D(
156
+ img, -1, filter_y, borderType=cv2.BORDER_REPLICATE)
157
+ return np.sqrt(img_filtered_x**2 + img_filtered_y**2)
158
+
159
+ def update(self, pred, gt, trimap=None, sigma=1.4):
160
+ """
161
+ update metric.
162
+ Args:
163
+ pred (np.ndarray): The value range is [0., 1.].
164
+ gt (np.ndarray): The value range is [0, 255].
165
+ trimap (np.ndarray, optional)L The value is in {0, 128, 255}. Default: None.
166
+ sigma (float, optional): Standard deviation of the gaussian kernel. Default: 1.4.
167
+ """
168
+ if trimap is None:
169
+ trimap = np.ones_like(gt) * 128
170
+ if not (pred.shape == gt.shape == trimap.shape):
171
+ raise ValueError(
172
+ 'The shape of `pred`, `gt` and `trimap` should be equal. '
173
+ 'but they are {}, {} and {}'.format(pred.shape, gt.shape,
174
+ trimap.shape))
175
+ pred[trimap == 0] = 0
176
+ pred[trimap == 255] = 255
177
+
178
+ gt = gt.squeeze()
179
+ pred = pred.squeeze()
180
+ gt = gt.astype(np.float64)
181
+ pred = pred.astype(np.float64)
182
+ gt_normed = np.zeros_like(gt)
183
+ pred_normed = np.zeros_like(pred)
184
+ cv2.normalize(gt, gt_normed, 1., 0., cv2.NORM_MINMAX)
185
+ cv2.normalize(pred, pred_normed, 1., 0., cv2.NORM_MINMAX)
186
+
187
+ gt_grad = self.gauss_gradient(gt_normed, sigma).astype(np.float32)
188
+ pred_grad = self.gauss_gradient(pred_normed, sigma).astype(np.float32)
189
+
190
+ grad_diff = ((gt_grad - pred_grad)**2 * (trimap == 128)).sum()
191
+
192
+ grad_diff /= 1000
193
+ self.grad_diffs += grad_diff
194
+ self.count += 1
195
+
196
+ return grad_diff
197
+
198
+ def evaluate(self):
199
+ grad = self.grad_diffs / self.count if self.count > 0 else 0
200
+ return grad
201
+
202
+
203
+ class Conn:
204
+ """
205
+ Only calculate the unknown region if trimap provided.
206
+ Refer to: Refer to: https://github.com/open-mlab/mmediting/blob/master/mmedit/core/evaluation/metrics.py
207
+ """
208
+
209
+ def __init__(self):
210
+ self.conn_diffs = 0
211
+ self.count = 0
212
+
213
+ def update(self, pred, gt, trimap=None, step=0.1):
214
+ """
215
+ update metric.
216
+ Args:
217
+ pred (np.ndarray): The value range is [0., 1.].
218
+ gt (np.ndarray): The value range is [0, 255].
219
+ trimap (np.ndarray, optional)L The value is in {0, 128, 255}. Default: None.
220
+ step (float, optional): Step of threshold when computing intersection between
221
+ `gt` and `pred`. Default: 0.1.
222
+ """
223
+ if trimap is None:
224
+ trimap = np.ones_like(gt) * 128
225
+ if not (pred.shape == gt.shape == trimap.shape):
226
+ raise ValueError(
227
+ 'The shape of `pred`, `gt` and `trimap` should be equal. '
228
+ 'but they are {}, {} and {}'.format(pred.shape, gt.shape,
229
+ trimap.shape))
230
+ pred[trimap == 0] = 0
231
+ pred[trimap == 255] = 255
232
+
233
+ gt = gt.squeeze()
234
+ pred = pred.squeeze()
235
+ gt = gt.astype(np.float32) / 255
236
+ pred = pred.astype(np.float32) / 255
237
+
238
+ thresh_steps = np.arange(0, 1 + step, step)
239
+ round_down_map = -np.ones_like(gt)
240
+ for i in range(1, len(thresh_steps)):
241
+ gt_thresh = gt >= thresh_steps[i]
242
+ pred_thresh = pred >= thresh_steps[i]
243
+ intersection = (gt_thresh & pred_thresh).astype(np.uint8)
244
+
245
+ # connected components
246
+ _, output, stats, _ = cv2.connectedComponentsWithStats(
247
+ intersection, connectivity=4)
248
+ # start from 1 in dim 0 to exclude background
249
+ size = stats[1:, -1]
250
+
251
+ # largest connected component of the intersection
252
+ omega = np.zeros_like(gt)
253
+ if len(size) != 0:
254
+ max_id = np.argmax(size)
255
+ # plus one to include background
256
+ omega[output == max_id + 1] = 1
257
+
258
+ mask = (round_down_map == -1) & (omega == 0)
259
+ round_down_map[mask] = thresh_steps[i - 1]
260
+ round_down_map[round_down_map == -1] = 1
261
+
262
+ gt_diff = gt - round_down_map
263
+ pred_diff = pred - round_down_map
264
+ # only calculate difference larger than or equal to 0.15
265
+ gt_phi = 1 - gt_diff * (gt_diff >= 0.15)
266
+ pred_phi = 1 - pred_diff * (pred_diff >= 0.15)
267
+
268
+ conn_diff = np.sum(np.abs(gt_phi - pred_phi) * (trimap == 128))
269
+
270
+ conn_diff /= 1000
271
+ self.conn_diffs += conn_diff
272
+ self.count += 1
273
+
274
+ return conn_diff
275
+
276
+ def evaluate(self):
277
+ conn = self.conn_diffs / self.count if self.count > 0 else 0
278
+ return conn
ppmatting/ml/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .methods import CloseFormMatting, KNNMatting, LearningBasedMatting, FastMatting, RandomWalksMatting
ppmatting/ml/methods.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pymatting
16
+ from paddleseg.cvlibs import manager
17
+
18
+
19
+ class BaseMLMatting(object):
20
+ def __init__(self, alpha_estimator, **kargs):
21
+ self.alpha_estimator = alpha_estimator
22
+ self.kargs = kargs
23
+
24
+ def __call__(self, image, trimap):
25
+ image = self.__to_float64(image)
26
+ trimap = self.__to_float64(trimap)
27
+ alpha_matte = self.alpha_estimator(image, trimap, **self.kargs)
28
+ return alpha_matte
29
+
30
+ def __to_float64(self, x):
31
+ x_dtype = x.dtype
32
+ assert x_dtype in ["float32", "float64"]
33
+ x = x.astype("float64")
34
+ return x
35
+
36
+
37
+ @manager.MODELS.add_component
38
+ class CloseFormMatting(BaseMLMatting):
39
+ def __init__(self, **kargs):
40
+ cf_alpha_estimator = pymatting.estimate_alpha_cf
41
+ super().__init__(cf_alpha_estimator, **kargs)
42
+
43
+
44
+ @manager.MODELS.add_component
45
+ class KNNMatting(BaseMLMatting):
46
+ def __init__(self, **kargs):
47
+ knn_alpha_estimator = pymatting.estimate_alpha_knn
48
+ super().__init__(knn_alpha_estimator, **kargs)
49
+
50
+
51
+ @manager.MODELS.add_component
52
+ class LearningBasedMatting(BaseMLMatting):
53
+ def __init__(self, **kargs):
54
+ lbdm_alpha_estimator = pymatting.estimate_alpha_lbdm
55
+ super().__init__(lbdm_alpha_estimator, **kargs)
56
+
57
+
58
+ @manager.MODELS.add_component
59
+ class FastMatting(BaseMLMatting):
60
+ def __init__(self, **kargs):
61
+ lkm_alpha_estimator = pymatting.estimate_alpha_lkm
62
+ super().__init__(lkm_alpha_estimator, **kargs)
63
+
64
+
65
+ @manager.MODELS.add_component
66
+ class RandomWalksMatting(BaseMLMatting):
67
+ def __init__(self, **kargs):
68
+ rw_alpha_estimator = pymatting.estimate_alpha_rw
69
+ super().__init__(rw_alpha_estimator, **kargs)
70
+
71
+
72
+ if __name__ == "__main__":
73
+ from pymatting.util.util import load_image, save_image, stack_images
74
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
75
+ import cv2
76
+
77
+ root = "/mnt/liuyi22/PaddlePaddle/PaddleSeg/Matting/data/examples/"
78
+ image_path = root + "lemur.png"
79
+ trimap_path = root + "lemur_trimap.png"
80
+ cutout_path = root + "lemur_cutout.png"
81
+ image = cv2.cvtColor(
82
+ cv2.imread(image_path).astype("float64"), cv2.COLOR_BGR2RGB) / 255.0
83
+
84
+ cv2.imwrite("image.png", (image * 255).astype('uint8'))
85
+ trimap = load_image(trimap_path, "GRAY")
86
+ print(image.shape, trimap.shape)
87
+ print(image.dtype, trimap.dtype)
88
+ cf = CloseFormMatting()
89
+ alpha = cf(image, trimap)
90
+
91
+ # alpha = pymatting.estimate_alpha_lkm(image, trimap)
92
+
93
+ foreground = estimate_foreground_ml(image, alpha)
94
+
95
+ cutout = stack_images(foreground, alpha)
96
+
97
+ save_image(cutout_path, cutout)
ppmatting/models/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ from .backbone import *
2
+ from .losses import *
3
+ from .modnet import MODNet
4
+ from .human_matting import HumanMatting
5
+ from .dim import DIM
6
+ from .ppmatting import PPMatting
7
+ from .gca import GCABaseline, GCA
ppmatting/models/backbone/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ from .mobilenet_v2 import *
2
+ from .hrnet import *
3
+ from .resnet_vd import *
4
+ from .vgg import *
5
+ from .gca_enc import *
ppmatting/models/backbone/gca_enc.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # The gca code was heavily based on https://github.com/Yaoyi-Li/GCA-Matting
16
+ # and https://github.com/open-mmlab/mmediting
17
+
18
+ import paddle
19
+ import paddle.nn as nn
20
+ import paddle.nn.functional as F
21
+ from paddleseg.cvlibs import manager, param_init
22
+ from paddleseg.utils import utils
23
+
24
+ from ppmatting.models.layers import GuidedCxtAtten
25
+
26
+
27
+ class ResNet_D(nn.Layer):
28
+ def __init__(self,
29
+ input_channels,
30
+ layers,
31
+ late_downsample=False,
32
+ pretrained=None):
33
+
34
+ super().__init__()
35
+
36
+ self.pretrained = pretrained
37
+
38
+ self._norm_layer = nn.BatchNorm
39
+ self.inplanes = 64
40
+ self.late_downsample = late_downsample
41
+ self.midplanes = 64 if late_downsample else 32
42
+ self.start_stride = [1, 2, 1, 2] if late_downsample else [2, 1, 2, 1]
43
+ self.conv1 = nn.utils.spectral_norm(
44
+ nn.Conv2D(
45
+ input_channels,
46
+ 32,
47
+ kernel_size=3,
48
+ stride=self.start_stride[0],
49
+ padding=1,
50
+ bias_attr=False))
51
+ self.conv2 = nn.utils.spectral_norm(
52
+ nn.Conv2D(
53
+ 32,
54
+ self.midplanes,
55
+ kernel_size=3,
56
+ stride=self.start_stride[1],
57
+ padding=1,
58
+ bias_attr=False))
59
+ self.conv3 = nn.utils.spectral_norm(
60
+ nn.Conv2D(
61
+ self.midplanes,
62
+ self.inplanes,
63
+ kernel_size=3,
64
+ stride=self.start_stride[2],
65
+ padding=1,
66
+ bias_attr=False))
67
+ self.bn1 = self._norm_layer(32)
68
+ self.bn2 = self._norm_layer(self.midplanes)
69
+ self.bn3 = self._norm_layer(self.inplanes)
70
+ self.activation = nn.ReLU()
71
+ self.layer1 = self._make_layer(
72
+ BasicBlock, 64, layers[0], stride=self.start_stride[3])
73
+ self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2)
74
+ self.layer3 = self._make_layer(BasicBlock, 256, layers[2], stride=2)
75
+ self.layer_bottleneck = self._make_layer(
76
+ BasicBlock, 512, layers[3], stride=2)
77
+
78
+ self.init_weight()
79
+
80
+ def _make_layer(self, block, planes, block_num, stride=1):
81
+ if block_num == 0:
82
+ return nn.Sequential(nn.Identity())
83
+ norm_layer = self._norm_layer
84
+ downsample = None
85
+ if stride != 1:
86
+ downsample = nn.Sequential(
87
+ nn.AvgPool2D(2, stride),
88
+ nn.utils.spectral_norm(
89
+ conv1x1(self.inplanes, planes * block.expansion)),
90
+ norm_layer(planes * block.expansion), )
91
+ elif self.inplanes != planes * block.expansion:
92
+ downsample = nn.Sequential(
93
+ nn.utils.spectral_norm(
94
+ conv1x1(self.inplanes, planes * block.expansion, stride)),
95
+ norm_layer(planes * block.expansion), )
96
+
97
+ layers = [block(self.inplanes, planes, stride, downsample, norm_layer)]
98
+ self.inplanes = planes * block.expansion
99
+ for _ in range(1, block_num):
100
+ layers.append(block(self.inplanes, planes, norm_layer=norm_layer))
101
+
102
+ return nn.Sequential(*layers)
103
+
104
+ def forward(self, x):
105
+ x = self.conv1(x)
106
+ x = self.bn1(x)
107
+ x = self.activation(x)
108
+ x = self.conv2(x)
109
+ x = self.bn2(x)
110
+ x1 = self.activation(x) # N x 32 x 256 x 256
111
+ x = self.conv3(x1)
112
+ x = self.bn3(x)
113
+ x2 = self.activation(x) # N x 64 x 128 x 128
114
+
115
+ x3 = self.layer1(x2) # N x 64 x 128 x 128
116
+ x4 = self.layer2(x3) # N x 128 x 64 x 64
117
+ x5 = self.layer3(x4) # N x 256 x 32 x 32
118
+ x = self.layer_bottleneck(x5) # N x 512 x 16 x 16
119
+
120
+ return x, (x1, x2, x3, x4, x5)
121
+
122
+ def init_weight(self):
123
+
124
+ for layer in self.sublayers():
125
+ if isinstance(layer, nn.Conv2D):
126
+
127
+ if hasattr(layer, "weight_orig"):
128
+ param = layer.weight_orig
129
+ else:
130
+ param = layer.weight
131
+ param_init.xavier_uniform(param)
132
+
133
+ elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
134
+ param_init.constant_init(layer.weight, value=1.0)
135
+ param_init.constant_init(layer.bias, value=0.0)
136
+
137
+ elif isinstance(layer, BasicBlock):
138
+ param_init.constant_init(layer.bn2.weight, value=0.0)
139
+
140
+ if self.pretrained is not None:
141
+ utils.load_pretrained_model(self, self.pretrained)
142
+
143
+
144
+ @manager.MODELS.add_component
145
+ class ResShortCut_D(ResNet_D):
146
+ def __init__(self,
147
+ input_channels,
148
+ layers,
149
+ late_downsample=False,
150
+ pretrained=None):
151
+ super().__init__(
152
+ input_channels,
153
+ layers,
154
+ late_downsample=late_downsample,
155
+ pretrained=pretrained)
156
+
157
+ self.shortcut_inplane = [input_channels, self.midplanes, 64, 128, 256]
158
+ self.shortcut_plane = [32, self.midplanes, 64, 128, 256]
159
+
160
+ self.shortcut = nn.LayerList()
161
+ for stage, inplane in enumerate(self.shortcut_inplane):
162
+ self.shortcut.append(
163
+ self._make_shortcut(inplane, self.shortcut_plane[stage]))
164
+
165
+ def _make_shortcut(self, inplane, planes):
166
+ return nn.Sequential(
167
+ nn.utils.spectral_norm(
168
+ nn.Conv2D(
169
+ inplane, planes, kernel_size=3, padding=1,
170
+ bias_attr=False)),
171
+ nn.ReLU(),
172
+ self._norm_layer(planes),
173
+ nn.utils.spectral_norm(
174
+ nn.Conv2D(
175
+ planes, planes, kernel_size=3, padding=1, bias_attr=False)),
176
+ nn.ReLU(),
177
+ self._norm_layer(planes))
178
+
179
+ def forward(self, x):
180
+
181
+ out = self.conv1(x)
182
+ out = self.bn1(out)
183
+ out = self.activation(out)
184
+ out = self.conv2(out)
185
+ out = self.bn2(out)
186
+ x1 = self.activation(out) # N x 32 x 256 x 256
187
+ out = self.conv3(x1)
188
+ out = self.bn3(out)
189
+ out = self.activation(out)
190
+
191
+ x2 = self.layer1(out) # N x 64 x 128 x 128
192
+ x3 = self.layer2(x2) # N x 128 x 64 x 64
193
+ x4 = self.layer3(x3) # N x 256 x 32 x 32
194
+ out = self.layer_bottleneck(x4) # N x 512 x 16 x 16
195
+
196
+ fea1 = self.shortcut[0](x) # input image and trimap
197
+ fea2 = self.shortcut[1](x1)
198
+ fea3 = self.shortcut[2](x2)
199
+ fea4 = self.shortcut[3](x3)
200
+ fea5 = self.shortcut[4](x4)
201
+
202
+ return out, {
203
+ 'shortcut': (fea1, fea2, fea3, fea4, fea5),
204
+ 'image': x[:, :3, ...]
205
+ }
206
+
207
+
208
+ @manager.MODELS.add_component
209
+ class ResGuidedCxtAtten(ResNet_D):
210
+ def __init__(self,
211
+ input_channels,
212
+ layers,
213
+ late_downsample=False,
214
+ pretrained=None):
215
+ super().__init__(
216
+ input_channels,
217
+ layers,
218
+ late_downsample=late_downsample,
219
+ pretrained=pretrained)
220
+ self.input_channels = input_channels
221
+ self.shortcut_inplane = [input_channels, self.midplanes, 64, 128, 256]
222
+ self.shortcut_plane = [32, self.midplanes, 64, 128, 256]
223
+
224
+ self.shortcut = nn.LayerList()
225
+ for stage, inplane in enumerate(self.shortcut_inplane):
226
+ self.shortcut.append(
227
+ self._make_shortcut(inplane, self.shortcut_plane[stage]))
228
+
229
+ self.guidance_head = nn.Sequential(
230
+ nn.Pad2D(
231
+ 1, mode="reflect"),
232
+ nn.utils.spectral_norm(
233
+ nn.Conv2D(
234
+ 3, 16, kernel_size=3, padding=0, stride=2,
235
+ bias_attr=False)),
236
+ nn.ReLU(),
237
+ self._norm_layer(16),
238
+ nn.Pad2D(
239
+ 1, mode="reflect"),
240
+ nn.utils.spectral_norm(
241
+ nn.Conv2D(
242
+ 16, 32, kernel_size=3, padding=0, stride=2,
243
+ bias_attr=False)),
244
+ nn.ReLU(),
245
+ self._norm_layer(32),
246
+ nn.Pad2D(
247
+ 1, mode="reflect"),
248
+ nn.utils.spectral_norm(
249
+ nn.Conv2D(
250
+ 32,
251
+ 128,
252
+ kernel_size=3,
253
+ padding=0,
254
+ stride=2,
255
+ bias_attr=False)),
256
+ nn.ReLU(),
257
+ self._norm_layer(128))
258
+
259
+ self.gca = GuidedCxtAtten(128, 128)
260
+
261
+ self.init_weight()
262
+
263
+ def init_weight(self):
264
+
265
+ for layer in self.sublayers():
266
+ if isinstance(layer, nn.Conv2D):
267
+ initializer = nn.initializer.XavierUniform()
268
+ if hasattr(layer, "weight_orig"):
269
+ param = layer.weight_orig
270
+ else:
271
+ param = layer.weight
272
+ initializer(param, param.block)
273
+
274
+ elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
275
+ param_init.constant_init(layer.weight, value=1.0)
276
+ param_init.constant_init(layer.bias, value=0.0)
277
+
278
+ elif isinstance(layer, BasicBlock):
279
+ param_init.constant_init(layer.bn2.weight, value=0.0)
280
+
281
+ if self.pretrained is not None:
282
+ utils.load_pretrained_model(self, self.pretrained)
283
+
284
+ def _make_shortcut(self, inplane, planes):
285
+ return nn.Sequential(
286
+ nn.utils.spectral_norm(
287
+ nn.Conv2D(
288
+ inplane, planes, kernel_size=3, padding=1,
289
+ bias_attr=False)),
290
+ nn.ReLU(),
291
+ self._norm_layer(planes),
292
+ nn.utils.spectral_norm(
293
+ nn.Conv2D(
294
+ planes, planes, kernel_size=3, padding=1, bias_attr=False)),
295
+ nn.ReLU(),
296
+ self._norm_layer(planes))
297
+
298
+ def forward(self, x):
299
+
300
+ out = self.conv1(x)
301
+ out = self.bn1(out)
302
+ out = self.activation(out)
303
+ out = self.conv2(out)
304
+ out = self.bn2(out)
305
+ x1 = self.activation(out) # N x 32 x 256 x 256
306
+ out = self.conv3(x1)
307
+ out = self.bn3(out)
308
+ out = self.activation(out)
309
+
310
+ im_fea = self.guidance_head(
311
+ x[:, :3, ...]) # downsample origin image and extract features
312
+ if self.input_channels == 6:
313
+ unknown = F.interpolate(
314
+ x[:, 4:5, ...], scale_factor=1 / 8, mode='nearest')
315
+ else:
316
+ unknown = x[:, 3:, ...].equal(paddle.to_tensor([1.]))
317
+ unknown = paddle.cast(unknown, dtype='float32')
318
+ unknown = F.interpolate(unknown, scale_factor=1 / 8, mode='nearest')
319
+
320
+ x2 = self.layer1(out) # N x 64 x 128 x 128
321
+ x3 = self.layer2(x2) # N x 128 x 64 x 64
322
+ x3 = self.gca(im_fea, x3, unknown) # contextual attention
323
+ x4 = self.layer3(x3) # N x 256 x 32 x 32
324
+ out = self.layer_bottleneck(x4) # N x 512 x 16 x 16
325
+
326
+ fea1 = self.shortcut[0](x) # input image and trimap
327
+ fea2 = self.shortcut[1](x1)
328
+ fea3 = self.shortcut[2](x2)
329
+ fea4 = self.shortcut[3](x3)
330
+ fea5 = self.shortcut[4](x4)
331
+
332
+ return out, {
333
+ 'shortcut': (fea1, fea2, fea3, fea4, fea5),
334
+ 'image_fea': im_fea,
335
+ 'unknown': unknown,
336
+ }
337
+
338
+
339
+ class BasicBlock(nn.Layer):
340
+ expansion = 1
341
+
342
+ def __init__(self,
343
+ inplanes,
344
+ planes,
345
+ stride=1,
346
+ downsample=None,
347
+ norm_layer=None):
348
+ super().__init__()
349
+ if norm_layer is None:
350
+ norm_layer = nn.BatchNorm
351
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
352
+ self.conv1 = nn.utils.spectral_norm(conv3x3(inplanes, planes, stride))
353
+ self.bn1 = norm_layer(planes)
354
+ self.activation = nn.ReLU()
355
+ self.conv2 = nn.utils.spectral_norm(conv3x3(planes, planes))
356
+ self.bn2 = norm_layer(planes)
357
+ self.downsample = downsample
358
+ self.stride = stride
359
+
360
+ def forward(self, x):
361
+ identity = x
362
+
363
+ out = self.conv1(x)
364
+ out = self.bn1(out)
365
+ out = self.activation(out)
366
+
367
+ out = self.conv2(out)
368
+ out = self.bn2(out)
369
+
370
+ if self.downsample is not None:
371
+ identity = self.downsample(x)
372
+
373
+ out += identity
374
+ out = self.activation(out)
375
+
376
+ return out
377
+
378
+
379
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
380
+ """3x3 convolution with padding"""
381
+ return nn.Conv2D(
382
+ in_planes,
383
+ out_planes,
384
+ kernel_size=3,
385
+ stride=stride,
386
+ padding=dilation,
387
+ groups=groups,
388
+ bias_attr=False,
389
+ dilation=dilation)
390
+
391
+
392
+ def conv1x1(in_planes, out_planes, stride=1):
393
+ """1x1 convolution"""
394
+ return nn.Conv2D(
395
+ in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False)
ppmatting/models/backbone/hrnet.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import paddle
18
+ import paddle.nn as nn
19
+ import paddle.nn.functional as F
20
+
21
+ from paddleseg.cvlibs import manager, param_init
22
+ from paddleseg.models import layers
23
+ from paddleseg.utils import utils
24
+
25
+ __all__ = [
26
+ "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
27
+ "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", "HRNet_W60", "HRNet_W64"
28
+ ]
29
+
30
+
31
+ class HRNet(nn.Layer):
32
+ """
33
+ The HRNet implementation based on PaddlePaddle.
34
+
35
+ The original article refers to
36
+ Jingdong Wang, et, al. "HRNet:Deep High-Resolution Representation Learning for Visual Recognition"
37
+ (https://arxiv.org/pdf/1908.07919.pdf).
38
+
39
+ Args:
40
+ pretrained (str, optional): The path of pretrained model.
41
+ stage1_num_modules (int, optional): Number of modules for stage1. Default 1.
42
+ stage1_num_blocks (list, optional): Number of blocks per module for stage1. Default (4).
43
+ stage1_num_channels (list, optional): Number of channels per branch for stage1. Default (64).
44
+ stage2_num_modules (int, optional): Number of modules for stage2. Default 1.
45
+ stage2_num_blocks (list, optional): Number of blocks per module for stage2. Default (4, 4).
46
+ stage2_num_channels (list, optional): Number of channels per branch for stage2. Default (18, 36).
47
+ stage3_num_modules (int, optional): Number of modules for stage3. Default 4.
48
+ stage3_num_blocks (list, optional): Number of blocks per module for stage3. Default (4, 4, 4).
49
+ stage3_num_channels (list, optional): Number of channels per branch for stage3. Default [18, 36, 72).
50
+ stage4_num_modules (int, optional): Number of modules for stage4. Default 3.
51
+ stage4_num_blocks (list, optional): Number of blocks per module for stage4. Default (4, 4, 4, 4).
52
+ stage4_num_channels (list, optional): Number of channels per branch for stage4. Default (18, 36, 72. 144).
53
+ has_se (bool, optional): Whether to use Squeeze-and-Excitation module. Default False.
54
+ align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
55
+ e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
56
+ """
57
+
58
+ def __init__(self,
59
+ input_channels=3,
60
+ pretrained=None,
61
+ stage1_num_modules=1,
62
+ stage1_num_blocks=(4, ),
63
+ stage1_num_channels=(64, ),
64
+ stage2_num_modules=1,
65
+ stage2_num_blocks=(4, 4),
66
+ stage2_num_channels=(18, 36),
67
+ stage3_num_modules=4,
68
+ stage3_num_blocks=(4, 4, 4),
69
+ stage3_num_channels=(18, 36, 72),
70
+ stage4_num_modules=3,
71
+ stage4_num_blocks=(4, 4, 4, 4),
72
+ stage4_num_channels=(18, 36, 72, 144),
73
+ has_se=False,
74
+ align_corners=False,
75
+ padding_same=True):
76
+ super(HRNet, self).__init__()
77
+ self.pretrained = pretrained
78
+ self.stage1_num_modules = stage1_num_modules
79
+ self.stage1_num_blocks = stage1_num_blocks
80
+ self.stage1_num_channels = stage1_num_channels
81
+ self.stage2_num_modules = stage2_num_modules
82
+ self.stage2_num_blocks = stage2_num_blocks
83
+ self.stage2_num_channels = stage2_num_channels
84
+ self.stage3_num_modules = stage3_num_modules
85
+ self.stage3_num_blocks = stage3_num_blocks
86
+ self.stage3_num_channels = stage3_num_channels
87
+ self.stage4_num_modules = stage4_num_modules
88
+ self.stage4_num_blocks = stage4_num_blocks
89
+ self.stage4_num_channels = stage4_num_channels
90
+ self.has_se = has_se
91
+ self.align_corners = align_corners
92
+
93
+ self.feat_channels = [i for i in stage4_num_channels]
94
+ self.feat_channels = [64] + self.feat_channels
95
+
96
+ self.conv_layer1_1 = layers.ConvBNReLU(
97
+ in_channels=input_channels,
98
+ out_channels=64,
99
+ kernel_size=3,
100
+ stride=2,
101
+ padding=1 if not padding_same else 'same',
102
+ bias_attr=False)
103
+
104
+ self.conv_layer1_2 = layers.ConvBNReLU(
105
+ in_channels=64,
106
+ out_channels=64,
107
+ kernel_size=3,
108
+ stride=2,
109
+ padding=1 if not padding_same else 'same',
110
+ bias_attr=False)
111
+
112
+ self.la1 = Layer1(
113
+ num_channels=64,
114
+ num_blocks=self.stage1_num_blocks[0],
115
+ num_filters=self.stage1_num_channels[0],
116
+ has_se=has_se,
117
+ name="layer2",
118
+ padding_same=padding_same)
119
+
120
+ self.tr1 = TransitionLayer(
121
+ in_channels=[self.stage1_num_channels[0] * 4],
122
+ out_channels=self.stage2_num_channels,
123
+ name="tr1",
124
+ padding_same=padding_same)
125
+
126
+ self.st2 = Stage(
127
+ num_channels=self.stage2_num_channels,
128
+ num_modules=self.stage2_num_modules,
129
+ num_blocks=self.stage2_num_blocks,
130
+ num_filters=self.stage2_num_channels,
131
+ has_se=self.has_se,
132
+ name="st2",
133
+ align_corners=align_corners,
134
+ padding_same=padding_same)
135
+
136
+ self.tr2 = TransitionLayer(
137
+ in_channels=self.stage2_num_channels,
138
+ out_channels=self.stage3_num_channels,
139
+ name="tr2",
140
+ padding_same=padding_same)
141
+ self.st3 = Stage(
142
+ num_channels=self.stage3_num_channels,
143
+ num_modules=self.stage3_num_modules,
144
+ num_blocks=self.stage3_num_blocks,
145
+ num_filters=self.stage3_num_channels,
146
+ has_se=self.has_se,
147
+ name="st3",
148
+ align_corners=align_corners,
149
+ padding_same=padding_same)
150
+
151
+ self.tr3 = TransitionLayer(
152
+ in_channels=self.stage3_num_channels,
153
+ out_channels=self.stage4_num_channels,
154
+ name="tr3",
155
+ padding_same=padding_same)
156
+ self.st4 = Stage(
157
+ num_channels=self.stage4_num_channels,
158
+ num_modules=self.stage4_num_modules,
159
+ num_blocks=self.stage4_num_blocks,
160
+ num_filters=self.stage4_num_channels,
161
+ has_se=self.has_se,
162
+ name="st4",
163
+ align_corners=align_corners,
164
+ padding_same=padding_same)
165
+
166
+ self.init_weight()
167
+
168
+ def forward(self, x):
169
+ feat_list = []
170
+ conv1 = self.conv_layer1_1(x)
171
+ feat_list.append(conv1)
172
+ conv2 = self.conv_layer1_2(conv1)
173
+
174
+ la1 = self.la1(conv2)
175
+
176
+ tr1 = self.tr1([la1])
177
+ st2 = self.st2(tr1)
178
+
179
+ tr2 = self.tr2(st2)
180
+ st3 = self.st3(tr2)
181
+
182
+ tr3 = self.tr3(st3)
183
+ st4 = self.st4(tr3)
184
+
185
+ feat_list = feat_list + st4
186
+
187
+ return feat_list
188
+
189
+ def init_weight(self):
190
+ for layer in self.sublayers():
191
+ if isinstance(layer, nn.Conv2D):
192
+ param_init.normal_init(layer.weight, std=0.001)
193
+ elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
194
+ param_init.constant_init(layer.weight, value=1.0)
195
+ param_init.constant_init(layer.bias, value=0.0)
196
+ if self.pretrained is not None:
197
+ utils.load_pretrained_model(self, self.pretrained)
198
+
199
+
200
+ class Layer1(nn.Layer):
201
+ def __init__(self,
202
+ num_channels,
203
+ num_filters,
204
+ num_blocks,
205
+ has_se=False,
206
+ name=None,
207
+ padding_same=True):
208
+ super(Layer1, self).__init__()
209
+
210
+ self.bottleneck_block_list = []
211
+
212
+ for i in range(num_blocks):
213
+ bottleneck_block = self.add_sublayer(
214
+ "bb_{}_{}".format(name, i + 1),
215
+ BottleneckBlock(
216
+ num_channels=num_channels if i == 0 else num_filters * 4,
217
+ num_filters=num_filters,
218
+ has_se=has_se,
219
+ stride=1,
220
+ downsample=True if i == 0 else False,
221
+ name=name + '_' + str(i + 1),
222
+ padding_same=padding_same))
223
+ self.bottleneck_block_list.append(bottleneck_block)
224
+
225
+ def forward(self, x):
226
+ conv = x
227
+ for block_func in self.bottleneck_block_list:
228
+ conv = block_func(conv)
229
+ return conv
230
+
231
+
232
+ class TransitionLayer(nn.Layer):
233
+ def __init__(self, in_channels, out_channels, name=None, padding_same=True):
234
+ super(TransitionLayer, self).__init__()
235
+
236
+ num_in = len(in_channels)
237
+ num_out = len(out_channels)
238
+ self.conv_bn_func_list = []
239
+ for i in range(num_out):
240
+ residual = None
241
+ if i < num_in:
242
+ if in_channels[i] != out_channels[i]:
243
+ residual = self.add_sublayer(
244
+ "transition_{}_layer_{}".format(name, i + 1),
245
+ layers.ConvBNReLU(
246
+ in_channels=in_channels[i],
247
+ out_channels=out_channels[i],
248
+ kernel_size=3,
249
+ padding=1 if not padding_same else 'same',
250
+ bias_attr=False))
251
+ else:
252
+ residual = self.add_sublayer(
253
+ "transition_{}_layer_{}".format(name, i + 1),
254
+ layers.ConvBNReLU(
255
+ in_channels=in_channels[-1],
256
+ out_channels=out_channels[i],
257
+ kernel_size=3,
258
+ stride=2,
259
+ padding=1 if not padding_same else 'same',
260
+ bias_attr=False))
261
+ self.conv_bn_func_list.append(residual)
262
+
263
+ def forward(self, x):
264
+ outs = []
265
+ for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
266
+ if conv_bn_func is None:
267
+ outs.append(x[idx])
268
+ else:
269
+ if idx < len(x):
270
+ outs.append(conv_bn_func(x[idx]))
271
+ else:
272
+ outs.append(conv_bn_func(x[-1]))
273
+ return outs
274
+
275
+
276
+ class Branches(nn.Layer):
277
+ def __init__(self,
278
+ num_blocks,
279
+ in_channels,
280
+ out_channels,
281
+ has_se=False,
282
+ name=None,
283
+ padding_same=True):
284
+ super(Branches, self).__init__()
285
+
286
+ self.basic_block_list = []
287
+
288
+ for i in range(len(out_channels)):
289
+ self.basic_block_list.append([])
290
+ for j in range(num_blocks[i]):
291
+ in_ch = in_channels[i] if j == 0 else out_channels[i]
292
+ basic_block_func = self.add_sublayer(
293
+ "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
294
+ BasicBlock(
295
+ num_channels=in_ch,
296
+ num_filters=out_channels[i],
297
+ has_se=has_se,
298
+ name=name + '_branch_layer_' + str(i + 1) + '_' +
299
+ str(j + 1),
300
+ padding_same=padding_same))
301
+ self.basic_block_list[i].append(basic_block_func)
302
+
303
+ def forward(self, x):
304
+ outs = []
305
+ for idx, input in enumerate(x):
306
+ conv = input
307
+ for basic_block_func in self.basic_block_list[idx]:
308
+ conv = basic_block_func(conv)
309
+ outs.append(conv)
310
+ return outs
311
+
312
+
313
+ class BottleneckBlock(nn.Layer):
314
+ def __init__(self,
315
+ num_channels,
316
+ num_filters,
317
+ has_se,
318
+ stride=1,
319
+ downsample=False,
320
+ name=None,
321
+ padding_same=True):
322
+ super(BottleneckBlock, self).__init__()
323
+
324
+ self.has_se = has_se
325
+ self.downsample = downsample
326
+
327
+ self.conv1 = layers.ConvBNReLU(
328
+ in_channels=num_channels,
329
+ out_channels=num_filters,
330
+ kernel_size=1,
331
+ bias_attr=False)
332
+
333
+ self.conv2 = layers.ConvBNReLU(
334
+ in_channels=num_filters,
335
+ out_channels=num_filters,
336
+ kernel_size=3,
337
+ stride=stride,
338
+ padding=1 if not padding_same else 'same',
339
+ bias_attr=False)
340
+
341
+ self.conv3 = layers.ConvBN(
342
+ in_channels=num_filters,
343
+ out_channels=num_filters * 4,
344
+ kernel_size=1,
345
+ bias_attr=False)
346
+
347
+ if self.downsample:
348
+ self.conv_down = layers.ConvBN(
349
+ in_channels=num_channels,
350
+ out_channels=num_filters * 4,
351
+ kernel_size=1,
352
+ bias_attr=False)
353
+
354
+ if self.has_se:
355
+ self.se = SELayer(
356
+ num_channels=num_filters * 4,
357
+ num_filters=num_filters * 4,
358
+ reduction_ratio=16,
359
+ name=name + '_fc')
360
+
361
+ self.add = layers.Add()
362
+ self.relu = layers.Activation("relu")
363
+
364
+ def forward(self, x):
365
+ residual = x
366
+ conv1 = self.conv1(x)
367
+ conv2 = self.conv2(conv1)
368
+ conv3 = self.conv3(conv2)
369
+
370
+ if self.downsample:
371
+ residual = self.conv_down(x)
372
+
373
+ if self.has_se:
374
+ conv3 = self.se(conv3)
375
+
376
+ y = self.add(conv3, residual)
377
+ y = self.relu(y)
378
+ return y
379
+
380
+
381
+ class BasicBlock(nn.Layer):
382
+ def __init__(self,
383
+ num_channels,
384
+ num_filters,
385
+ stride=1,
386
+ has_se=False,
387
+ downsample=False,
388
+ name=None,
389
+ padding_same=True):
390
+ super(BasicBlock, self).__init__()
391
+
392
+ self.has_se = has_se
393
+ self.downsample = downsample
394
+
395
+ self.conv1 = layers.ConvBNReLU(
396
+ in_channels=num_channels,
397
+ out_channels=num_filters,
398
+ kernel_size=3,
399
+ stride=stride,
400
+ padding=1 if not padding_same else 'same',
401
+ bias_attr=False)
402
+ self.conv2 = layers.ConvBN(
403
+ in_channels=num_filters,
404
+ out_channels=num_filters,
405
+ kernel_size=3,
406
+ padding=1 if not padding_same else 'same',
407
+ bias_attr=False)
408
+
409
+ if self.downsample:
410
+ self.conv_down = layers.ConvBNReLU(
411
+ in_channels=num_channels,
412
+ out_channels=num_filters,
413
+ kernel_size=1,
414
+ bias_attr=False)
415
+
416
+ if self.has_se:
417
+ self.se = SELayer(
418
+ num_channels=num_filters,
419
+ num_filters=num_filters,
420
+ reduction_ratio=16,
421
+ name=name + '_fc')
422
+
423
+ self.add = layers.Add()
424
+ self.relu = layers.Activation("relu")
425
+
426
+ def forward(self, x):
427
+ residual = x
428
+ conv1 = self.conv1(x)
429
+ conv2 = self.conv2(conv1)
430
+
431
+ if self.downsample:
432
+ residual = self.conv_down(x)
433
+
434
+ if self.has_se:
435
+ conv2 = self.se(conv2)
436
+
437
+ y = self.add(conv2, residual)
438
+ y = self.relu(y)
439
+ return y
440
+
441
+
442
+ class SELayer(nn.Layer):
443
+ def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
444
+ super(SELayer, self).__init__()
445
+
446
+ self.pool2d_gap = nn.AdaptiveAvgPool2D(1)
447
+
448
+ self._num_channels = num_channels
449
+
450
+ med_ch = int(num_channels / reduction_ratio)
451
+ stdv = 1.0 / math.sqrt(num_channels * 1.0)
452
+ self.squeeze = nn.Linear(
453
+ num_channels,
454
+ med_ch,
455
+ weight_attr=paddle.ParamAttr(
456
+ initializer=nn.initializer.Uniform(-stdv, stdv)))
457
+
458
+ stdv = 1.0 / math.sqrt(med_ch * 1.0)
459
+ self.excitation = nn.Linear(
460
+ med_ch,
461
+ num_filters,
462
+ weight_attr=paddle.ParamAttr(
463
+ initializer=nn.initializer.Uniform(-stdv, stdv)))
464
+
465
+ def forward(self, x):
466
+ pool = self.pool2d_gap(x)
467
+ pool = paddle.reshape(pool, shape=[-1, self._num_channels])
468
+ squeeze = self.squeeze(pool)
469
+ squeeze = F.relu(squeeze)
470
+ excitation = self.excitation(squeeze)
471
+ excitation = F.sigmoid(excitation)
472
+ excitation = paddle.reshape(
473
+ excitation, shape=[-1, self._num_channels, 1, 1])
474
+ out = x * excitation
475
+ return out
476
+
477
+
478
+ class Stage(nn.Layer):
479
+ def __init__(self,
480
+ num_channels,
481
+ num_modules,
482
+ num_blocks,
483
+ num_filters,
484
+ has_se=False,
485
+ multi_scale_output=True,
486
+ name=None,
487
+ align_corners=False,
488
+ padding_same=True):
489
+ super(Stage, self).__init__()
490
+
491
+ self._num_modules = num_modules
492
+
493
+ self.stage_func_list = []
494
+ for i in range(num_modules):
495
+ if i == num_modules - 1 and not multi_scale_output:
496
+ stage_func = self.add_sublayer(
497
+ "stage_{}_{}".format(name, i + 1),
498
+ HighResolutionModule(
499
+ num_channels=num_channels,
500
+ num_blocks=num_blocks,
501
+ num_filters=num_filters,
502
+ has_se=has_se,
503
+ multi_scale_output=False,
504
+ name=name + '_' + str(i + 1),
505
+ align_corners=align_corners,
506
+ padding_same=padding_same))
507
+ else:
508
+ stage_func = self.add_sublayer(
509
+ "stage_{}_{}".format(name, i + 1),
510
+ HighResolutionModule(
511
+ num_channels=num_channels,
512
+ num_blocks=num_blocks,
513
+ num_filters=num_filters,
514
+ has_se=has_se,
515
+ name=name + '_' + str(i + 1),
516
+ align_corners=align_corners,
517
+ padding_same=padding_same))
518
+
519
+ self.stage_func_list.append(stage_func)
520
+
521
+ def forward(self, x):
522
+ out = x
523
+ for idx in range(self._num_modules):
524
+ out = self.stage_func_list[idx](out)
525
+ return out
526
+
527
+
528
+ class HighResolutionModule(nn.Layer):
529
+ def __init__(self,
530
+ num_channels,
531
+ num_blocks,
532
+ num_filters,
533
+ has_se=False,
534
+ multi_scale_output=True,
535
+ name=None,
536
+ align_corners=False,
537
+ padding_same=True):
538
+ super(HighResolutionModule, self).__init__()
539
+
540
+ self.branches_func = Branches(
541
+ num_blocks=num_blocks,
542
+ in_channels=num_channels,
543
+ out_channels=num_filters,
544
+ has_se=has_se,
545
+ name=name,
546
+ padding_same=padding_same)
547
+
548
+ self.fuse_func = FuseLayers(
549
+ in_channels=num_filters,
550
+ out_channels=num_filters,
551
+ multi_scale_output=multi_scale_output,
552
+ name=name,
553
+ align_corners=align_corners,
554
+ padding_same=padding_same)
555
+
556
+ def forward(self, x):
557
+ out = self.branches_func(x)
558
+ out = self.fuse_func(out)
559
+ return out
560
+
561
+
562
+ class FuseLayers(nn.Layer):
563
+ def __init__(self,
564
+ in_channels,
565
+ out_channels,
566
+ multi_scale_output=True,
567
+ name=None,
568
+ align_corners=False,
569
+ padding_same=True):
570
+ super(FuseLayers, self).__init__()
571
+
572
+ self._actual_ch = len(in_channels) if multi_scale_output else 1
573
+ self._in_channels = in_channels
574
+ self.align_corners = align_corners
575
+
576
+ self.residual_func_list = []
577
+ for i in range(self._actual_ch):
578
+ for j in range(len(in_channels)):
579
+ if j > i:
580
+ residual_func = self.add_sublayer(
581
+ "residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
582
+ layers.ConvBN(
583
+ in_channels=in_channels[j],
584
+ out_channels=out_channels[i],
585
+ kernel_size=1,
586
+ bias_attr=False))
587
+ self.residual_func_list.append(residual_func)
588
+ elif j < i:
589
+ pre_num_filters = in_channels[j]
590
+ for k in range(i - j):
591
+ if k == i - j - 1:
592
+ residual_func = self.add_sublayer(
593
+ "residual_{}_layer_{}_{}_{}".format(
594
+ name, i + 1, j + 1, k + 1),
595
+ layers.ConvBN(
596
+ in_channels=pre_num_filters,
597
+ out_channels=out_channels[i],
598
+ kernel_size=3,
599
+ stride=2,
600
+ padding=1 if not padding_same else 'same',
601
+ bias_attr=False))
602
+ pre_num_filters = out_channels[i]
603
+ else:
604
+ residual_func = self.add_sublayer(
605
+ "residual_{}_layer_{}_{}_{}".format(
606
+ name, i + 1, j + 1, k + 1),
607
+ layers.ConvBNReLU(
608
+ in_channels=pre_num_filters,
609
+ out_channels=out_channels[j],
610
+ kernel_size=3,
611
+ stride=2,
612
+ padding=1 if not padding_same else 'same',
613
+ bias_attr=False))
614
+ pre_num_filters = out_channels[j]
615
+ self.residual_func_list.append(residual_func)
616
+
617
+ def forward(self, x):
618
+ outs = []
619
+ residual_func_idx = 0
620
+ for i in range(self._actual_ch):
621
+ residual = x[i]
622
+ residual_shape = paddle.shape(residual)[-2:]
623
+ for j in range(len(self._in_channels)):
624
+ if j > i:
625
+ y = self.residual_func_list[residual_func_idx](x[j])
626
+ residual_func_idx += 1
627
+
628
+ y = F.interpolate(
629
+ y,
630
+ residual_shape,
631
+ mode='bilinear',
632
+ align_corners=self.align_corners)
633
+ residual = residual + y
634
+ elif j < i:
635
+ y = x[j]
636
+ for k in range(i - j):
637
+ y = self.residual_func_list[residual_func_idx](y)
638
+ residual_func_idx += 1
639
+
640
+ residual = residual + y
641
+
642
+ residual = F.relu(residual)
643
+ outs.append(residual)
644
+
645
+ return outs
646
+
647
+
648
+ @manager.BACKBONES.add_component
649
+ def HRNet_W18_Small_V1(**kwargs):
650
+ model = HRNet(
651
+ stage1_num_modules=1,
652
+ stage1_num_blocks=[1],
653
+ stage1_num_channels=[32],
654
+ stage2_num_modules=1,
655
+ stage2_num_blocks=[2, 2],
656
+ stage2_num_channels=[16, 32],
657
+ stage3_num_modules=1,
658
+ stage3_num_blocks=[2, 2, 2],
659
+ stage3_num_channels=[16, 32, 64],
660
+ stage4_num_modules=1,
661
+ stage4_num_blocks=[2, 2, 2, 2],
662
+ stage4_num_channels=[16, 32, 64, 128],
663
+ **kwargs)
664
+ return model
665
+
666
+
667
+ @manager.BACKBONES.add_component
668
+ def HRNet_W18_Small_V2(**kwargs):
669
+ model = HRNet(
670
+ stage1_num_modules=1,
671
+ stage1_num_blocks=[2],
672
+ stage1_num_channels=[64],
673
+ stage2_num_modules=1,
674
+ stage2_num_blocks=[2, 2],
675
+ stage2_num_channels=[18, 36],
676
+ stage3_num_modules=3,
677
+ stage3_num_blocks=[2, 2, 2],
678
+ stage3_num_channels=[18, 36, 72],
679
+ stage4_num_modules=2,
680
+ stage4_num_blocks=[2, 2, 2, 2],
681
+ stage4_num_channels=[18, 36, 72, 144],
682
+ **kwargs)
683
+ return model
684
+
685
+
686
+ @manager.BACKBONES.add_component
687
+ def HRNet_W18(**kwargs):
688
+ model = HRNet(
689
+ stage1_num_modules=1,
690
+ stage1_num_blocks=[4],
691
+ stage1_num_channels=[64],
692
+ stage2_num_modules=1,
693
+ stage2_num_blocks=[4, 4],
694
+ stage2_num_channels=[18, 36],
695
+ stage3_num_modules=4,
696
+ stage3_num_blocks=[4, 4, 4],
697
+ stage3_num_channels=[18, 36, 72],
698
+ stage4_num_modules=3,
699
+ stage4_num_blocks=[4, 4, 4, 4],
700
+ stage4_num_channels=[18, 36, 72, 144],
701
+ **kwargs)
702
+ return model
703
+
704
+
705
+ @manager.BACKBONES.add_component
706
+ def HRNet_W30(**kwargs):
707
+ model = HRNet(
708
+ stage1_num_modules=1,
709
+ stage1_num_blocks=[4],
710
+ stage1_num_channels=[64],
711
+ stage2_num_modules=1,
712
+ stage2_num_blocks=[4, 4],
713
+ stage2_num_channels=[30, 60],
714
+ stage3_num_modules=4,
715
+ stage3_num_blocks=[4, 4, 4],
716
+ stage3_num_channels=[30, 60, 120],
717
+ stage4_num_modules=3,
718
+ stage4_num_blocks=[4, 4, 4, 4],
719
+ stage4_num_channels=[30, 60, 120, 240],
720
+ **kwargs)
721
+ return model
722
+
723
+
724
+ @manager.BACKBONES.add_component
725
+ def HRNet_W32(**kwargs):
726
+ model = HRNet(
727
+ stage1_num_modules=1,
728
+ stage1_num_blocks=[4],
729
+ stage1_num_channels=[64],
730
+ stage2_num_modules=1,
731
+ stage2_num_blocks=[4, 4],
732
+ stage2_num_channels=[32, 64],
733
+ stage3_num_modules=4,
734
+ stage3_num_blocks=[4, 4, 4],
735
+ stage3_num_channels=[32, 64, 128],
736
+ stage4_num_modules=3,
737
+ stage4_num_blocks=[4, 4, 4, 4],
738
+ stage4_num_channels=[32, 64, 128, 256],
739
+ **kwargs)
740
+ return model
741
+
742
+
743
+ @manager.BACKBONES.add_component
744
+ def HRNet_W40(**kwargs):
745
+ model = HRNet(
746
+ stage1_num_modules=1,
747
+ stage1_num_blocks=[4],
748
+ stage1_num_channels=[64],
749
+ stage2_num_modules=1,
750
+ stage2_num_blocks=[4, 4],
751
+ stage2_num_channels=[40, 80],
752
+ stage3_num_modules=4,
753
+ stage3_num_blocks=[4, 4, 4],
754
+ stage3_num_channels=[40, 80, 160],
755
+ stage4_num_modules=3,
756
+ stage4_num_blocks=[4, 4, 4, 4],
757
+ stage4_num_channels=[40, 80, 160, 320],
758
+ **kwargs)
759
+ return model
760
+
761
+
762
+ @manager.BACKBONES.add_component
763
+ def HRNet_W44(**kwargs):
764
+ model = HRNet(
765
+ stage1_num_modules=1,
766
+ stage1_num_blocks=[4],
767
+ stage1_num_channels=[64],
768
+ stage2_num_modules=1,
769
+ stage2_num_blocks=[4, 4],
770
+ stage2_num_channels=[44, 88],
771
+ stage3_num_modules=4,
772
+ stage3_num_blocks=[4, 4, 4],
773
+ stage3_num_channels=[44, 88, 176],
774
+ stage4_num_modules=3,
775
+ stage4_num_blocks=[4, 4, 4, 4],
776
+ stage4_num_channels=[44, 88, 176, 352],
777
+ **kwargs)
778
+ return model
779
+
780
+
781
+ @manager.BACKBONES.add_component
782
+ def HRNet_W48(**kwargs):
783
+ model = HRNet(
784
+ stage1_num_modules=1,
785
+ stage1_num_blocks=[4],
786
+ stage1_num_channels=[64],
787
+ stage2_num_modules=1,
788
+ stage2_num_blocks=[4, 4],
789
+ stage2_num_channels=[48, 96],
790
+ stage3_num_modules=4,
791
+ stage3_num_blocks=[4, 4, 4],
792
+ stage3_num_channels=[48, 96, 192],
793
+ stage4_num_modules=3,
794
+ stage4_num_blocks=[4, 4, 4, 4],
795
+ stage4_num_channels=[48, 96, 192, 384],
796
+ **kwargs)
797
+ return model
798
+
799
+
800
+ @manager.BACKBONES.add_component
801
+ def HRNet_W60(**kwargs):
802
+ model = HRNet(
803
+ stage1_num_modules=1,
804
+ stage1_num_blocks=[4],
805
+ stage1_num_channels=[64],
806
+ stage2_num_modules=1,
807
+ stage2_num_blocks=[4, 4],
808
+ stage2_num_channels=[60, 120],
809
+ stage3_num_modules=4,
810
+ stage3_num_blocks=[4, 4, 4],
811
+ stage3_num_channels=[60, 120, 240],
812
+ stage4_num_modules=3,
813
+ stage4_num_blocks=[4, 4, 4, 4],
814
+ stage4_num_channels=[60, 120, 240, 480],
815
+ **kwargs)
816
+ return model
817
+
818
+
819
+ @manager.BACKBONES.add_component
820
+ def HRNet_W64(**kwargs):
821
+ model = HRNet(
822
+ stage1_num_modules=1,
823
+ stage1_num_blocks=[4],
824
+ stage1_num_channels=[64],
825
+ stage2_num_modules=1,
826
+ stage2_num_blocks=[4, 4],
827
+ stage2_num_channels=[64, 128],
828
+ stage3_num_modules=4,
829
+ stage3_num_blocks=[4, 4, 4],
830
+ stage3_num_channels=[64, 128, 256],
831
+ stage4_num_modules=3,
832
+ stage4_num_blocks=[4, 4, 4, 4],
833
+ stage4_num_channels=[64, 128, 256, 512],
834
+ **kwargs)
835
+ return model
ppmatting/models/backbone/mobilenet_v2.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import numpy as np
18
+ import paddle
19
+ from paddle import ParamAttr
20
+ import paddle.nn as nn
21
+ import paddle.nn.functional as F
22
+ from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
23
+ from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
24
+
25
+ from paddleseg import utils
26
+ from paddleseg.cvlibs import manager
27
+
28
+ MODEL_URLS = {
29
+ "MobileNetV2_x0_25":
30
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x0_25_pretrained.pdparams",
31
+ "MobileNetV2_x0_5":
32
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x0_5_pretrained.pdparams",
33
+ "MobileNetV2_x0_75":
34
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x0_75_pretrained.pdparams",
35
+ "MobileNetV2":
36
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_pretrained.pdparams",
37
+ "MobileNetV2_x1_5":
38
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x1_5_pretrained.pdparams",
39
+ "MobileNetV2_x2_0":
40
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_x2_0_pretrained.pdparams"
41
+ }
42
+
43
+ __all__ = ["MobileNetV2"]
44
+
45
+
46
+ class ConvBNLayer(nn.Layer):
47
+ def __init__(self,
48
+ num_channels,
49
+ filter_size,
50
+ num_filters,
51
+ stride,
52
+ padding,
53
+ channels=None,
54
+ num_groups=1,
55
+ name=None,
56
+ use_cudnn=True):
57
+ super(ConvBNLayer, self).__init__()
58
+
59
+ self._conv = Conv2D(
60
+ in_channels=num_channels,
61
+ out_channels=num_filters,
62
+ kernel_size=filter_size,
63
+ stride=stride,
64
+ padding=padding,
65
+ groups=num_groups,
66
+ weight_attr=ParamAttr(name=name + "_weights"),
67
+ bias_attr=False)
68
+
69
+ self._batch_norm = BatchNorm(
70
+ num_filters,
71
+ param_attr=ParamAttr(name=name + "_bn_scale"),
72
+ bias_attr=ParamAttr(name=name + "_bn_offset"),
73
+ moving_mean_name=name + "_bn_mean",
74
+ moving_variance_name=name + "_bn_variance")
75
+
76
+ def forward(self, inputs, if_act=True):
77
+ y = self._conv(inputs)
78
+ y = self._batch_norm(y)
79
+ if if_act:
80
+ y = F.relu6(y)
81
+ return y
82
+
83
+
84
+ class InvertedResidualUnit(nn.Layer):
85
+ def __init__(self, num_channels, num_in_filter, num_filters, stride,
86
+ filter_size, padding, expansion_factor, name):
87
+ super(InvertedResidualUnit, self).__init__()
88
+ num_expfilter = int(round(num_in_filter * expansion_factor))
89
+ self._expand_conv = ConvBNLayer(
90
+ num_channels=num_channels,
91
+ num_filters=num_expfilter,
92
+ filter_size=1,
93
+ stride=1,
94
+ padding=0,
95
+ num_groups=1,
96
+ name=name + "_expand")
97
+
98
+ self._bottleneck_conv = ConvBNLayer(
99
+ num_channels=num_expfilter,
100
+ num_filters=num_expfilter,
101
+ filter_size=filter_size,
102
+ stride=stride,
103
+ padding=padding,
104
+ num_groups=num_expfilter,
105
+ use_cudnn=False,
106
+ name=name + "_dwise")
107
+
108
+ self._linear_conv = ConvBNLayer(
109
+ num_channels=num_expfilter,
110
+ num_filters=num_filters,
111
+ filter_size=1,
112
+ stride=1,
113
+ padding=0,
114
+ num_groups=1,
115
+ name=name + "_linear")
116
+
117
+ def forward(self, inputs, ifshortcut):
118
+ y = self._expand_conv(inputs, if_act=True)
119
+ y = self._bottleneck_conv(y, if_act=True)
120
+ y = self._linear_conv(y, if_act=False)
121
+ if ifshortcut:
122
+ y = paddle.add(inputs, y)
123
+ return y
124
+
125
+
126
+ class InvresiBlocks(nn.Layer):
127
+ def __init__(self, in_c, t, c, n, s, name):
128
+ super(InvresiBlocks, self).__init__()
129
+
130
+ self._first_block = InvertedResidualUnit(
131
+ num_channels=in_c,
132
+ num_in_filter=in_c,
133
+ num_filters=c,
134
+ stride=s,
135
+ filter_size=3,
136
+ padding=1,
137
+ expansion_factor=t,
138
+ name=name + "_1")
139
+
140
+ self._block_list = []
141
+ for i in range(1, n):
142
+ block = self.add_sublayer(
143
+ name + "_" + str(i + 1),
144
+ sublayer=InvertedResidualUnit(
145
+ num_channels=c,
146
+ num_in_filter=c,
147
+ num_filters=c,
148
+ stride=1,
149
+ filter_size=3,
150
+ padding=1,
151
+ expansion_factor=t,
152
+ name=name + "_" + str(i + 1)))
153
+ self._block_list.append(block)
154
+
155
+ def forward(self, inputs):
156
+ y = self._first_block(inputs, ifshortcut=False)
157
+ for block in self._block_list:
158
+ y = block(y, ifshortcut=True)
159
+ return y
160
+
161
+
162
+ @manager.BACKBONES.add_component
163
+ class MobileNet(nn.Layer):
164
+ def __init__(self,
165
+ input_channels=3,
166
+ scale=1.0,
167
+ pretrained=None,
168
+ prefix_name=""):
169
+ super(MobileNet, self).__init__()
170
+ self.scale = scale
171
+
172
+ bottleneck_params_list = [
173
+ (1, 16, 1, 1),
174
+ (6, 24, 2, 2),
175
+ (6, 32, 3, 2),
176
+ (6, 64, 4, 2),
177
+ (6, 96, 3, 1),
178
+ (6, 160, 3, 2),
179
+ (6, 320, 1, 1),
180
+ ]
181
+
182
+ self.conv1 = ConvBNLayer(
183
+ num_channels=input_channels,
184
+ num_filters=int(32 * scale),
185
+ filter_size=3,
186
+ stride=2,
187
+ padding=1,
188
+ name=prefix_name + "conv1_1")
189
+
190
+ self.block_list = []
191
+ i = 1
192
+ in_c = int(32 * scale)
193
+ for layer_setting in bottleneck_params_list:
194
+ t, c, n, s = layer_setting
195
+ i += 1
196
+ block = self.add_sublayer(
197
+ prefix_name + "conv" + str(i),
198
+ sublayer=InvresiBlocks(
199
+ in_c=in_c,
200
+ t=t,
201
+ c=int(c * scale),
202
+ n=n,
203
+ s=s,
204
+ name=prefix_name + "conv" + str(i)))
205
+ self.block_list.append(block)
206
+ in_c = int(c * scale)
207
+
208
+ self.out_c = int(1280 * scale) if scale > 1.0 else 1280
209
+ self.conv9 = ConvBNLayer(
210
+ num_channels=in_c,
211
+ num_filters=self.out_c,
212
+ filter_size=1,
213
+ stride=1,
214
+ padding=0,
215
+ name=prefix_name + "conv9")
216
+
217
+ self.feat_channels = [int(i * scale) for i in [16, 24, 32, 96, 1280]]
218
+ self.pretrained = pretrained
219
+ self.init_weight()
220
+
221
+ def forward(self, inputs):
222
+ feat_list = []
223
+ y = self.conv1(inputs, if_act=True)
224
+
225
+ block_index = 0
226
+ for block in self.block_list:
227
+ y = block(y)
228
+ if block_index in [0, 1, 2, 4]:
229
+ feat_list.append(y)
230
+ block_index += 1
231
+ y = self.conv9(y, if_act=True)
232
+ feat_list.append(y)
233
+ return feat_list
234
+
235
+ def init_weight(self):
236
+ utils.load_pretrained_model(self, self.pretrained)
237
+
238
+
239
+ @manager.BACKBONES.add_component
240
+ def MobileNetV2(**kwargs):
241
+ model = MobileNet(scale=1.0, **kwargs)
242
+ return model
ppmatting/models/backbone/resnet_vd.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import paddle
16
+ import paddle.nn as nn
17
+ import paddle.nn.functional as F
18
+
19
+ from paddleseg.cvlibs import manager
20
+ from paddleseg.models import layers
21
+ from paddleseg.utils import utils
22
+
23
+ __all__ = [
24
+ "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd"
25
+ ]
26
+
27
+
28
+ class ConvBNLayer(nn.Layer):
29
+ def __init__(
30
+ self,
31
+ in_channels,
32
+ out_channels,
33
+ kernel_size,
34
+ stride=1,
35
+ dilation=1,
36
+ groups=1,
37
+ is_vd_mode=False,
38
+ act=None, ):
39
+ super(ConvBNLayer, self).__init__()
40
+
41
+ self.is_vd_mode = is_vd_mode
42
+ self._pool2d_avg = nn.AvgPool2D(
43
+ kernel_size=2, stride=2, padding=0, ceil_mode=True)
44
+ self._conv = nn.Conv2D(
45
+ in_channels=in_channels,
46
+ out_channels=out_channels,
47
+ kernel_size=kernel_size,
48
+ stride=stride,
49
+ padding=(kernel_size - 1) // 2 if dilation == 1 else 0,
50
+ dilation=dilation,
51
+ groups=groups,
52
+ bias_attr=False)
53
+
54
+ self._batch_norm = layers.SyncBatchNorm(out_channels)
55
+ self._act_op = layers.Activation(act=act)
56
+
57
+ def forward(self, inputs):
58
+ if self.is_vd_mode:
59
+ inputs = self._pool2d_avg(inputs)
60
+ y = self._conv(inputs)
61
+ y = self._batch_norm(y)
62
+ y = self._act_op(y)
63
+
64
+ return y
65
+
66
+
67
+ class BottleneckBlock(nn.Layer):
68
+ def __init__(self,
69
+ in_channels,
70
+ out_channels,
71
+ stride,
72
+ shortcut=True,
73
+ if_first=False,
74
+ dilation=1):
75
+ super(BottleneckBlock, self).__init__()
76
+
77
+ self.conv0 = ConvBNLayer(
78
+ in_channels=in_channels,
79
+ out_channels=out_channels,
80
+ kernel_size=1,
81
+ act='relu')
82
+
83
+ self.dilation = dilation
84
+
85
+ self.conv1 = ConvBNLayer(
86
+ in_channels=out_channels,
87
+ out_channels=out_channels,
88
+ kernel_size=3,
89
+ stride=stride,
90
+ act='relu',
91
+ dilation=dilation)
92
+ self.conv2 = ConvBNLayer(
93
+ in_channels=out_channels,
94
+ out_channels=out_channels * 4,
95
+ kernel_size=1,
96
+ act=None)
97
+
98
+ if not shortcut:
99
+ self.short = ConvBNLayer(
100
+ in_channels=in_channels,
101
+ out_channels=out_channels * 4,
102
+ kernel_size=1,
103
+ stride=1,
104
+ is_vd_mode=False if if_first or stride == 1 else True)
105
+
106
+ self.shortcut = shortcut
107
+
108
+ def forward(self, inputs):
109
+ y = self.conv0(inputs)
110
+
111
+ ####################################################################
112
+ # If given dilation rate > 1, using corresponding padding.
113
+ # The performance drops down without the follow padding.
114
+ if self.dilation > 1:
115
+ padding = self.dilation
116
+ y = F.pad(y, [padding, padding, padding, padding])
117
+ #####################################################################
118
+
119
+ conv1 = self.conv1(y)
120
+ conv2 = self.conv2(conv1)
121
+
122
+ if self.shortcut:
123
+ short = inputs
124
+ else:
125
+ short = self.short(inputs)
126
+
127
+ y = paddle.add(x=short, y=conv2)
128
+ y = F.relu(y)
129
+ return y
130
+
131
+
132
+ class BasicBlock(nn.Layer):
133
+ def __init__(self,
134
+ in_channels,
135
+ out_channels,
136
+ stride,
137
+ shortcut=True,
138
+ if_first=False):
139
+ super(BasicBlock, self).__init__()
140
+ self.stride = stride
141
+ self.conv0 = ConvBNLayer(
142
+ in_channels=in_channels,
143
+ out_channels=out_channels,
144
+ kernel_size=3,
145
+ stride=stride,
146
+ act='relu')
147
+ self.conv1 = ConvBNLayer(
148
+ in_channels=out_channels,
149
+ out_channels=out_channels,
150
+ kernel_size=3,
151
+ act=None)
152
+
153
+ if not shortcut:
154
+ self.short = ConvBNLayer(
155
+ in_channels=in_channels,
156
+ out_channels=out_channels,
157
+ kernel_size=1,
158
+ stride=1,
159
+ is_vd_mode=False if if_first or stride == 1 else True)
160
+
161
+ self.shortcut = shortcut
162
+
163
+ def forward(self, inputs):
164
+ y = self.conv0(inputs)
165
+ conv1 = self.conv1(y)
166
+
167
+ if self.shortcut:
168
+ short = inputs
169
+ else:
170
+ short = self.short(inputs)
171
+ y = paddle.add(x=short, y=conv1)
172
+ y = F.relu(y)
173
+
174
+ return y
175
+
176
+
177
+ class ResNet_vd(nn.Layer):
178
+ """
179
+ The ResNet_vd implementation based on PaddlePaddle.
180
+
181
+ The original article refers to Jingdong
182
+ Tong He, et, al. "Bag of Tricks for Image Classification with Convolutional Neural Networks"
183
+ (https://arxiv.org/pdf/1812.01187.pdf).
184
+
185
+ Args:
186
+ layers (int, optional): The layers of ResNet_vd. The supported layers are (18, 34, 50, 101, 152, 200). Default: 50.
187
+ output_stride (int, optional): The stride of output features compared to input images. It is 8 or 16. Default: 8.
188
+ multi_grid (tuple|list, optional): The grid of stage4. Defult: (1, 1, 1).
189
+ pretrained (str, optional): The path of pretrained model.
190
+
191
+ """
192
+
193
+ def __init__(self,
194
+ input_channels=3,
195
+ layers=50,
196
+ output_stride=32,
197
+ multi_grid=(1, 1, 1),
198
+ pretrained=None):
199
+ super(ResNet_vd, self).__init__()
200
+
201
+ self.conv1_logit = None # for gscnn shape stream
202
+ self.layers = layers
203
+ supported_layers = [18, 34, 50, 101, 152, 200]
204
+ assert layers in supported_layers, \
205
+ "supported layers are {} but input layer is {}".format(
206
+ supported_layers, layers)
207
+
208
+ if layers == 18:
209
+ depth = [2, 2, 2, 2]
210
+ elif layers == 34 or layers == 50:
211
+ depth = [3, 4, 6, 3]
212
+ elif layers == 101:
213
+ depth = [3, 4, 23, 3]
214
+ elif layers == 152:
215
+ depth = [3, 8, 36, 3]
216
+ elif layers == 200:
217
+ depth = [3, 12, 48, 3]
218
+ num_channels = [64, 256, 512,
219
+ 1024] if layers >= 50 else [64, 64, 128, 256]
220
+ num_filters = [64, 128, 256, 512]
221
+
222
+ # for channels of four returned stages
223
+ self.feat_channels = [c * 4 for c in num_filters
224
+ ] if layers >= 50 else num_filters
225
+ self.feat_channels = [64] + self.feat_channels
226
+
227
+ dilation_dict = None
228
+ if output_stride == 8:
229
+ dilation_dict = {2: 2, 3: 4}
230
+ elif output_stride == 16:
231
+ dilation_dict = {3: 2}
232
+
233
+ self.conv1_1 = ConvBNLayer(
234
+ in_channels=input_channels,
235
+ out_channels=32,
236
+ kernel_size=3,
237
+ stride=2,
238
+ act='relu')
239
+ self.conv1_2 = ConvBNLayer(
240
+ in_channels=32,
241
+ out_channels=32,
242
+ kernel_size=3,
243
+ stride=1,
244
+ act='relu')
245
+ self.conv1_3 = ConvBNLayer(
246
+ in_channels=32,
247
+ out_channels=64,
248
+ kernel_size=3,
249
+ stride=1,
250
+ act='relu')
251
+ self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
252
+
253
+ # self.block_list = []
254
+ self.stage_list = []
255
+ if layers >= 50:
256
+ for block in range(len(depth)):
257
+ shortcut = False
258
+ block_list = []
259
+ for i in range(depth[block]):
260
+ if layers in [101, 152] and block == 2:
261
+ if i == 0:
262
+ conv_name = "res" + str(block + 2) + "a"
263
+ else:
264
+ conv_name = "res" + str(block + 2) + "b" + str(i)
265
+ else:
266
+ conv_name = "res" + str(block + 2) + chr(97 + i)
267
+
268
+ ###############################################################################
269
+ # Add dilation rate for some segmentation tasks, if dilation_dict is not None.
270
+ dilation_rate = dilation_dict[
271
+ block] if dilation_dict and block in dilation_dict else 1
272
+
273
+ # Actually block here is 'stage', and i is 'block' in 'stage'
274
+ # At the stage 4, expand the the dilation_rate if given multi_grid
275
+ if block == 3:
276
+ dilation_rate = dilation_rate * multi_grid[i]
277
+ ###############################################################################
278
+
279
+ bottleneck_block = self.add_sublayer(
280
+ 'bb_%d_%d' % (block, i),
281
+ BottleneckBlock(
282
+ in_channels=num_channels[block]
283
+ if i == 0 else num_filters[block] * 4,
284
+ out_channels=num_filters[block],
285
+ stride=2 if i == 0 and block != 0 and
286
+ dilation_rate == 1 else 1,
287
+ shortcut=shortcut,
288
+ if_first=block == i == 0,
289
+ dilation=dilation_rate))
290
+
291
+ block_list.append(bottleneck_block)
292
+ shortcut = True
293
+ self.stage_list.append(block_list)
294
+ else:
295
+ for block in range(len(depth)):
296
+ shortcut = False
297
+ block_list = []
298
+ for i in range(depth[block]):
299
+ conv_name = "res" + str(block + 2) + chr(97 + i)
300
+ basic_block = self.add_sublayer(
301
+ 'bb_%d_%d' % (block, i),
302
+ BasicBlock(
303
+ in_channels=num_channels[block]
304
+ if i == 0 else num_filters[block],
305
+ out_channels=num_filters[block],
306
+ stride=2 if i == 0 and block != 0 else 1,
307
+ shortcut=shortcut,
308
+ if_first=block == i == 0))
309
+ block_list.append(basic_block)
310
+ shortcut = True
311
+ self.stage_list.append(block_list)
312
+
313
+ self.pretrained = pretrained
314
+ self.init_weight()
315
+
316
+ def forward(self, inputs):
317
+ feat_list = []
318
+ y = self.conv1_1(inputs)
319
+ y = self.conv1_2(y)
320
+ y = self.conv1_3(y)
321
+ feat_list.append(y)
322
+
323
+ y = self.pool2d_max(y)
324
+
325
+ # A feature list saves the output feature map of each stage.
326
+ for stage in self.stage_list:
327
+ for block in stage:
328
+ y = block(y)
329
+ feat_list.append(y)
330
+
331
+ return feat_list
332
+
333
+ def init_weight(self):
334
+ utils.load_pretrained_model(self, self.pretrained)
335
+
336
+
337
+ @manager.BACKBONES.add_component
338
+ def ResNet18_vd(**args):
339
+ model = ResNet_vd(layers=18, **args)
340
+ return model
341
+
342
+
343
+ @manager.BACKBONES.add_component
344
+ def ResNet34_vd(**args):
345
+ model = ResNet_vd(layers=34, **args)
346
+ return model
347
+
348
+
349
+ @manager.BACKBONES.add_component
350
+ def ResNet50_vd(**args):
351
+ model = ResNet_vd(layers=50, **args)
352
+ return model
353
+
354
+
355
+ @manager.BACKBONES.add_component
356
+ def ResNet101_vd(**args):
357
+ model = ResNet_vd(layers=101, **args)
358
+ return model
359
+
360
+
361
+ def ResNet152_vd(**args):
362
+ model = ResNet_vd(layers=152, **args)
363
+ return model
364
+
365
+
366
+ def ResNet200_vd(**args):
367
+ model = ResNet_vd(layers=200, **args)
368
+ return model
ppmatting/models/backbone/vgg.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import paddle
16
+ from paddle import ParamAttr
17
+ import paddle.nn as nn
18
+ import paddle.nn.functional as F
19
+ from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
20
+ from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
21
+
22
+ from paddleseg.cvlibs import manager
23
+ from paddleseg.utils import utils
24
+
25
+
26
+ class ConvBlock(nn.Layer):
27
+ def __init__(self, input_channels, output_channels, groups, name=None):
28
+ super(ConvBlock, self).__init__()
29
+
30
+ self.groups = groups
31
+ self._conv_1 = Conv2D(
32
+ in_channels=input_channels,
33
+ out_channels=output_channels,
34
+ kernel_size=3,
35
+ stride=1,
36
+ padding=1,
37
+ weight_attr=ParamAttr(name=name + "1_weights"),
38
+ bias_attr=False)
39
+ if groups == 2 or groups == 3 or groups == 4:
40
+ self._conv_2 = Conv2D(
41
+ in_channels=output_channels,
42
+ out_channels=output_channels,
43
+ kernel_size=3,
44
+ stride=1,
45
+ padding=1,
46
+ weight_attr=ParamAttr(name=name + "2_weights"),
47
+ bias_attr=False)
48
+ if groups == 3 or groups == 4:
49
+ self._conv_3 = Conv2D(
50
+ in_channels=output_channels,
51
+ out_channels=output_channels,
52
+ kernel_size=3,
53
+ stride=1,
54
+ padding=1,
55
+ weight_attr=ParamAttr(name=name + "3_weights"),
56
+ bias_attr=False)
57
+ if groups == 4:
58
+ self._conv_4 = Conv2D(
59
+ in_channels=output_channels,
60
+ out_channels=output_channels,
61
+ kernel_size=3,
62
+ stride=1,
63
+ padding=1,
64
+ weight_attr=ParamAttr(name=name + "4_weights"),
65
+ bias_attr=False)
66
+
67
+ self._pool = MaxPool2D(
68
+ kernel_size=2, stride=2, padding=0, return_mask=True)
69
+
70
+ def forward(self, inputs):
71
+ x = self._conv_1(inputs)
72
+ x = F.relu(x)
73
+ if self.groups == 2 or self.groups == 3 or self.groups == 4:
74
+ x = self._conv_2(x)
75
+ x = F.relu(x)
76
+ if self.groups == 3 or self.groups == 4:
77
+ x = self._conv_3(x)
78
+ x = F.relu(x)
79
+ if self.groups == 4:
80
+ x = self._conv_4(x)
81
+ x = F.relu(x)
82
+ skip = x
83
+ x, max_indices = self._pool(x)
84
+ return x, max_indices, skip
85
+
86
+
87
+ class VGGNet(nn.Layer):
88
+ def __init__(self, input_channels=3, layers=11, pretrained=None):
89
+ super(VGGNet, self).__init__()
90
+ self.pretrained = pretrained
91
+
92
+ self.layers = layers
93
+ self.vgg_configure = {
94
+ 11: [1, 1, 2, 2, 2],
95
+ 13: [2, 2, 2, 2, 2],
96
+ 16: [2, 2, 3, 3, 3],
97
+ 19: [2, 2, 4, 4, 4]
98
+ }
99
+ assert self.layers in self.vgg_configure.keys(), \
100
+ "supported layers are {} but input layer is {}".format(
101
+ self.vgg_configure.keys(), layers)
102
+ self.groups = self.vgg_configure[self.layers]
103
+
104
+ # matting的第一层卷积输入为4通道,初始化是直接初始化为0
105
+ self._conv_block_1 = ConvBlock(
106
+ input_channels, 64, self.groups[0], name="conv1_")
107
+ self._conv_block_2 = ConvBlock(64, 128, self.groups[1], name="conv2_")
108
+ self._conv_block_3 = ConvBlock(128, 256, self.groups[2], name="conv3_")
109
+ self._conv_block_4 = ConvBlock(256, 512, self.groups[3], name="conv4_")
110
+ self._conv_block_5 = ConvBlock(512, 512, self.groups[4], name="conv5_")
111
+
112
+ # 这一层的初始化需要利用vgg fc6的参数转换后进行初始化,可以暂时不考虑初始化
113
+ self._conv_6 = Conv2D(
114
+ 512, 512, kernel_size=3, padding=1, bias_attr=False)
115
+
116
+ self.init_weight()
117
+
118
+ def forward(self, inputs):
119
+ fea_list = []
120
+ ids_list = []
121
+ x, ids, skip = self._conv_block_1(inputs)
122
+ fea_list.append(skip)
123
+ ids_list.append(ids)
124
+ x, ids, skip = self._conv_block_2(x)
125
+ fea_list.append(skip)
126
+ ids_list.append(ids)
127
+ x, ids, skip = self._conv_block_3(x)
128
+ fea_list.append(skip)
129
+ ids_list.append(ids)
130
+ x, ids, skip = self._conv_block_4(x)
131
+ fea_list.append(skip)
132
+ ids_list.append(ids)
133
+ x, ids, skip = self._conv_block_5(x)
134
+ fea_list.append(skip)
135
+ ids_list.append(ids)
136
+ x = F.relu(self._conv_6(x))
137
+ fea_list.append(x)
138
+ return fea_list
139
+
140
+ def init_weight(self):
141
+ if self.pretrained is not None:
142
+ utils.load_pretrained_model(self, self.pretrained)
143
+
144
+
145
+ @manager.BACKBONES.add_component
146
+ def VGG11(**args):
147
+ model = VGGNet(layers=11, **args)
148
+ return model
149
+
150
+
151
+ @manager.BACKBONES.add_component
152
+ def VGG13(**args):
153
+ model = VGGNet(layers=13, **args)
154
+ return model
155
+
156
+
157
+ @manager.BACKBONES.add_component
158
+ def VGG16(**args):
159
+ model = VGGNet(layers=16, **args)
160
+ return model
161
+
162
+
163
+ @manager.BACKBONES.add_component
164
+ def VGG19(**args):
165
+ model = VGGNet(layers=19, **args)
166
+ return model
ppmatting/models/dim.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import defaultdict
16
+ import paddle
17
+ import paddle.nn as nn
18
+ import paddle.nn.functional as F
19
+ from paddleseg.models import layers
20
+ from paddleseg import utils
21
+ from paddleseg.cvlibs import manager
22
+
23
+ from ppmatting.models.losses import MRSD
24
+
25
+
26
+ @manager.MODELS.add_component
27
+ class DIM(nn.Layer):
28
+ """
29
+ The DIM implementation based on PaddlePaddle.
30
+
31
+ The original article refers to
32
+ Ning Xu, et, al. "Deep Image Matting"
33
+ (https://arxiv.org/pdf/1908.07919.pdf).
34
+
35
+ Args:
36
+ backbone: backbone model.
37
+ stage (int, optional): The stage of model. Defautl: 3.
38
+ decoder_input_channels(int, optional): The channel of decoder input. Default: 512.
39
+ pretrained(str, optional): The path of pretrianed model. Defautl: None.
40
+
41
+ """
42
+
43
+ def __init__(self,
44
+ backbone,
45
+ stage=3,
46
+ decoder_input_channels=512,
47
+ pretrained=None):
48
+ super().__init__()
49
+ self.backbone = backbone
50
+ self.pretrained = pretrained
51
+ self.stage = stage
52
+ self.loss_func_dict = None
53
+
54
+ decoder_output_channels = [64, 128, 256, 512]
55
+ self.decoder = Decoder(
56
+ input_channels=decoder_input_channels,
57
+ output_channels=decoder_output_channels)
58
+ if self.stage == 2:
59
+ for param in self.backbone.parameters():
60
+ param.stop_gradient = True
61
+ for param in self.decoder.parameters():
62
+ param.stop_gradient = True
63
+ if self.stage >= 2:
64
+ self.refine = Refine()
65
+ self.init_weight()
66
+
67
+ def forward(self, inputs):
68
+ input_shape = paddle.shape(inputs['img'])[-2:]
69
+ x = paddle.concat([inputs['img'], inputs['trimap'] / 255], axis=1)
70
+ fea_list = self.backbone(x)
71
+
72
+ # decoder stage
73
+ up_shape = []
74
+ for i in range(5):
75
+ up_shape.append(paddle.shape(fea_list[i])[-2:])
76
+ alpha_raw = self.decoder(fea_list, up_shape)
77
+ alpha_raw = F.interpolate(
78
+ alpha_raw, input_shape, mode='bilinear', align_corners=False)
79
+ logit_dict = {'alpha_raw': alpha_raw}
80
+ if self.stage < 2:
81
+ return logit_dict
82
+
83
+ if self.stage >= 2:
84
+ # refine stage
85
+ refine_input = paddle.concat([inputs['img'], alpha_raw], axis=1)
86
+ alpha_refine = self.refine(refine_input)
87
+
88
+ # finally alpha
89
+ alpha_pred = alpha_refine + alpha_raw
90
+ alpha_pred = F.interpolate(
91
+ alpha_pred, input_shape, mode='bilinear', align_corners=False)
92
+ if not self.training:
93
+ alpha_pred = paddle.clip(alpha_pred, min=0, max=1)
94
+ logit_dict['alpha_pred'] = alpha_pred
95
+ if self.training:
96
+ loss_dict = self.loss(logit_dict, inputs)
97
+ return logit_dict, loss_dict
98
+ else:
99
+ return alpha_pred
100
+
101
+ def loss(self, logit_dict, label_dict, loss_func_dict=None):
102
+ if loss_func_dict is None:
103
+ if self.loss_func_dict is None:
104
+ self.loss_func_dict = defaultdict(list)
105
+ self.loss_func_dict['alpha_raw'].append(MRSD())
106
+ self.loss_func_dict['comp'].append(MRSD())
107
+ self.loss_func_dict['alpha_pred'].append(MRSD())
108
+ else:
109
+ self.loss_func_dict = loss_func_dict
110
+
111
+ loss = {}
112
+ mask = label_dict['trimap'] == 128
113
+ loss['all'] = 0
114
+
115
+ if self.stage != 2:
116
+ loss['alpha_raw'] = self.loss_func_dict['alpha_raw'][0](
117
+ logit_dict['alpha_raw'], label_dict['alpha'], mask)
118
+ loss['alpha_raw'] = 0.5 * loss['alpha_raw']
119
+ loss['all'] = loss['all'] + loss['alpha_raw']
120
+
121
+ if self.stage == 1 or self.stage == 3:
122
+ comp_pred = logit_dict['alpha_raw'] * label_dict['fg'] + \
123
+ (1 - logit_dict['alpha_raw']) * label_dict['bg']
124
+ loss['comp'] = self.loss_func_dict['comp'][0](
125
+ comp_pred, label_dict['img'], mask)
126
+ loss['comp'] = 0.5 * loss['comp']
127
+ loss['all'] = loss['all'] + loss['comp']
128
+
129
+ if self.stage == 2 or self.stage == 3:
130
+ loss['alpha_pred'] = self.loss_func_dict['alpha_pred'][0](
131
+ logit_dict['alpha_pred'], label_dict['alpha'], mask)
132
+ loss['all'] = loss['all'] + loss['alpha_pred']
133
+
134
+ return loss
135
+
136
+ def init_weight(self):
137
+ if self.pretrained is not None:
138
+ utils.load_entire_model(self, self.pretrained)
139
+
140
+
141
+ # bilinear interpolate skip connect
142
+ class Up(nn.Layer):
143
+ def __init__(self, input_channels, output_channels):
144
+ super().__init__()
145
+ self.conv = layers.ConvBNReLU(
146
+ input_channels,
147
+ output_channels,
148
+ kernel_size=5,
149
+ padding=2,
150
+ bias_attr=False)
151
+
152
+ def forward(self, x, skip, output_shape):
153
+ x = F.interpolate(
154
+ x, size=output_shape, mode='bilinear', align_corners=False)
155
+ x = x + skip
156
+ x = self.conv(x)
157
+ x = F.relu(x)
158
+
159
+ return x
160
+
161
+
162
+ class Decoder(nn.Layer):
163
+ def __init__(self, input_channels, output_channels=(64, 128, 256, 512)):
164
+ super().__init__()
165
+ self.deconv6 = nn.Conv2D(
166
+ input_channels, input_channels, kernel_size=1, bias_attr=False)
167
+ self.deconv5 = Up(input_channels, output_channels[-1])
168
+ self.deconv4 = Up(output_channels[-1], output_channels[-2])
169
+ self.deconv3 = Up(output_channels[-2], output_channels[-3])
170
+ self.deconv2 = Up(output_channels[-3], output_channels[-4])
171
+ self.deconv1 = Up(output_channels[-4], 64)
172
+
173
+ self.alpha_conv = nn.Conv2D(
174
+ 64, 1, kernel_size=5, padding=2, bias_attr=False)
175
+
176
+ def forward(self, fea_list, shape_list):
177
+ x = fea_list[-1]
178
+ x = self.deconv6(x)
179
+ x = self.deconv5(x, fea_list[4], shape_list[4])
180
+ x = self.deconv4(x, fea_list[3], shape_list[3])
181
+ x = self.deconv3(x, fea_list[2], shape_list[2])
182
+ x = self.deconv2(x, fea_list[1], shape_list[1])
183
+ x = self.deconv1(x, fea_list[0], shape_list[0])
184
+ alpha = self.alpha_conv(x)
185
+ alpha = F.sigmoid(alpha)
186
+
187
+ return alpha
188
+
189
+
190
+ class Refine(nn.Layer):
191
+ def __init__(self):
192
+ super().__init__()
193
+ self.conv1 = layers.ConvBNReLU(
194
+ 4, 64, kernel_size=3, padding=1, bias_attr=False)
195
+ self.conv2 = layers.ConvBNReLU(
196
+ 64, 64, kernel_size=3, padding=1, bias_attr=False)
197
+ self.conv3 = layers.ConvBNReLU(
198
+ 64, 64, kernel_size=3, padding=1, bias_attr=False)
199
+ self.alpha_pred = layers.ConvBNReLU(
200
+ 64, 1, kernel_size=3, padding=1, bias_attr=False)
201
+
202
+ def forward(self, x):
203
+ x = self.conv1(x)
204
+ x = self.conv2(x)
205
+ x = self.conv3(x)
206
+ alpha = self.alpha_pred(x)
207
+
208
+ return alpha
ppmatting/models/gca.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # The gca code was heavily based on https://github.com/Yaoyi-Li/GCA-Matting
16
+ # and https://github.com/open-mmlab/mmediting
17
+
18
+ import paddle
19
+ import paddle.nn as nn
20
+ import paddle.nn.functional as F
21
+ from paddleseg.models import layers
22
+ from paddleseg import utils
23
+ from paddleseg.cvlibs import manager, param_init
24
+
25
+ from ppmatting.models.layers import GuidedCxtAtten
26
+
27
+
28
+ @manager.MODELS.add_component
29
+ class GCABaseline(nn.Layer):
30
+ def __init__(self, backbone, pretrained=None):
31
+ super().__init__()
32
+ self.encoder = backbone
33
+ self.decoder = ResShortCut_D_Dec([2, 3, 3, 2])
34
+
35
+ def forward(self, inputs):
36
+
37
+ x = paddle.concat([inputs['img'], inputs['trimap'] / 255], axis=1)
38
+ embedding, mid_fea = self.encoder(x)
39
+ alpha_pred = self.decoder(embedding, mid_fea)
40
+
41
+ if self.training:
42
+ logit_dict = {'alpha_pred': alpha_pred, }
43
+ loss_dict = {}
44
+ alpha_gt = inputs['alpha']
45
+ loss_dict["alpha"] = F.l1_loss(alpha_pred, alpha_gt)
46
+ loss_dict["all"] = loss_dict["alpha"]
47
+ return logit_dict, loss_dict
48
+
49
+ return alpha_pred
50
+
51
+
52
+ @manager.MODELS.add_component
53
+ class GCA(GCABaseline):
54
+ def __init__(self, backbone, pretrained=None):
55
+ super().__init__(backbone, pretrained)
56
+ self.decoder = ResGuidedCxtAtten_Dec([2, 3, 3, 2])
57
+
58
+
59
+ def conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1):
60
+ """5x5 convolution with padding"""
61
+ return nn.Conv2D(
62
+ in_planes,
63
+ out_planes,
64
+ kernel_size=5,
65
+ stride=stride,
66
+ padding=2,
67
+ groups=groups,
68
+ bias_attr=False,
69
+ dilation=dilation)
70
+
71
+
72
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
73
+ """3x3 convolution with padding"""
74
+ return nn.Conv2D(
75
+ in_planes,
76
+ out_planes,
77
+ kernel_size=3,
78
+ stride=stride,
79
+ padding=dilation,
80
+ groups=groups,
81
+ bias_attr=False,
82
+ dilation=dilation)
83
+
84
+
85
+ def conv1x1(in_planes, out_planes, stride=1):
86
+ """1x1 convolution"""
87
+ return nn.Conv2D(
88
+ in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False)
89
+
90
+
91
+ class BasicBlock(nn.Layer):
92
+ expansion = 1
93
+
94
+ def __init__(self,
95
+ inplanes,
96
+ planes,
97
+ stride=1,
98
+ upsample=None,
99
+ norm_layer=None,
100
+ large_kernel=False):
101
+ super().__init__()
102
+ if norm_layer is None:
103
+ norm_layer = nn.BatchNorm
104
+ self.stride = stride
105
+ conv = conv5x5 if large_kernel else conv3x3
106
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
107
+ if self.stride > 1:
108
+ self.conv1 = nn.utils.spectral_norm(
109
+ nn.Conv2DTranspose(
110
+ inplanes,
111
+ inplanes,
112
+ kernel_size=4,
113
+ stride=2,
114
+ padding=1,
115
+ bias_attr=False))
116
+ else:
117
+ self.conv1 = nn.utils.spectral_norm(conv(inplanes, inplanes))
118
+ self.bn1 = norm_layer(inplanes)
119
+ self.activation = nn.LeakyReLU(0.2)
120
+ self.conv2 = nn.utils.spectral_norm(conv(inplanes, planes))
121
+ self.bn2 = norm_layer(planes)
122
+ self.upsample = upsample
123
+
124
+ def forward(self, x):
125
+ identity = x
126
+
127
+ out = self.conv1(x)
128
+ out = self.bn1(out)
129
+ out = self.activation(out)
130
+
131
+ out = self.conv2(out)
132
+ out = self.bn2(out)
133
+
134
+ if self.upsample is not None:
135
+ identity = self.upsample(x)
136
+
137
+ out += identity
138
+ out = self.activation(out)
139
+
140
+ return out
141
+
142
+
143
+ class ResNet_D_Dec(nn.Layer):
144
+ def __init__(self,
145
+ layers=[3, 4, 4, 2],
146
+ norm_layer=None,
147
+ large_kernel=False,
148
+ late_downsample=False):
149
+ super().__init__()
150
+
151
+ if norm_layer is None:
152
+ norm_layer = nn.BatchNorm
153
+ self._norm_layer = norm_layer
154
+ self.large_kernel = large_kernel
155
+ self.kernel_size = 5 if self.large_kernel else 3
156
+
157
+ self.inplanes = 512 if layers[0] > 0 else 256
158
+ self.late_downsample = late_downsample
159
+ self.midplanes = 64 if late_downsample else 32
160
+
161
+ self.conv1 = nn.utils.spectral_norm(
162
+ nn.Conv2DTranspose(
163
+ self.midplanes,
164
+ 32,
165
+ kernel_size=4,
166
+ stride=2,
167
+ padding=1,
168
+ bias_attr=False))
169
+ self.bn1 = norm_layer(32)
170
+ self.leaky_relu = nn.LeakyReLU(0.2)
171
+ self.conv2 = nn.Conv2D(
172
+ 32,
173
+ 1,
174
+ kernel_size=self.kernel_size,
175
+ stride=1,
176
+ padding=self.kernel_size // 2)
177
+ self.upsample = nn.UpsamplingNearest2D(scale_factor=2)
178
+ self.tanh = nn.Tanh()
179
+ self.layer1 = self._make_layer(BasicBlock, 256, layers[0], stride=2)
180
+ self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2)
181
+ self.layer3 = self._make_layer(BasicBlock, 64, layers[2], stride=2)
182
+ self.layer4 = self._make_layer(
183
+ BasicBlock, self.midplanes, layers[3], stride=2)
184
+
185
+ self.init_weight()
186
+
187
+ def _make_layer(self, block, planes, blocks, stride=1):
188
+ if blocks == 0:
189
+ return nn.Sequential(nn.Identity())
190
+ norm_layer = self._norm_layer
191
+ upsample = None
192
+ if stride != 1:
193
+ upsample = nn.Sequential(
194
+ nn.UpsamplingNearest2D(scale_factor=2),
195
+ nn.utils.spectral_norm(
196
+ conv1x1(self.inplanes, planes * block.expansion)),
197
+ norm_layer(planes * block.expansion), )
198
+ elif self.inplanes != planes * block.expansion:
199
+ upsample = nn.Sequential(
200
+ nn.utils.spectral_norm(
201
+ conv1x1(self.inplanes, planes * block.expansion)),
202
+ norm_layer(planes * block.expansion), )
203
+
204
+ layers = [
205
+ block(self.inplanes, planes, stride, upsample, norm_layer,
206
+ self.large_kernel)
207
+ ]
208
+ self.inplanes = planes * block.expansion
209
+ for _ in range(1, blocks):
210
+ layers.append(
211
+ block(
212
+ self.inplanes,
213
+ planes,
214
+ norm_layer=norm_layer,
215
+ large_kernel=self.large_kernel))
216
+
217
+ return nn.Sequential(*layers)
218
+
219
+ def forward(self, x, mid_fea):
220
+ x = self.layer1(x) # N x 256 x 32 x 32
221
+ print(x.shape)
222
+ x = self.layer2(x) # N x 128 x 64 x 64
223
+ print(x.shape)
224
+ x = self.layer3(x) # N x 64 x 128 x 128
225
+ print(x.shape)
226
+ x = self.layer4(x) # N x 32 x 256 x 256
227
+ print(x.shape)
228
+ x = self.conv1(x)
229
+ x = self.bn1(x)
230
+ x = self.leaky_relu(x)
231
+ x = self.conv2(x)
232
+
233
+ alpha = (self.tanh(x) + 1.0) / 2.0
234
+
235
+ return alpha
236
+
237
+ def init_weight(self):
238
+ for layer in self.sublayers():
239
+ if isinstance(layer, nn.Conv2D):
240
+
241
+ if hasattr(layer, "weight_orig"):
242
+ param = layer.weight_orig
243
+ else:
244
+ param = layer.weight
245
+ param_init.xavier_uniform(param)
246
+
247
+ elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
248
+ param_init.constant_init(layer.weight, value=1.0)
249
+ param_init.constant_init(layer.bias, value=0.0)
250
+
251
+ elif isinstance(layer, BasicBlock):
252
+ param_init.constant_init(layer.bn2.weight, value=0.0)
253
+
254
+
255
+ class ResShortCut_D_Dec(ResNet_D_Dec):
256
+ def __init__(self,
257
+ layers,
258
+ norm_layer=None,
259
+ large_kernel=False,
260
+ late_downsample=False):
261
+ super().__init__(
262
+ layers, norm_layer, large_kernel, late_downsample=late_downsample)
263
+
264
+ def forward(self, x, mid_fea):
265
+ fea1, fea2, fea3, fea4, fea5 = mid_fea['shortcut']
266
+ x = self.layer1(x) + fea5
267
+ x = self.layer2(x) + fea4
268
+ x = self.layer3(x) + fea3
269
+ x = self.layer4(x) + fea2
270
+ x = self.conv1(x)
271
+ x = self.bn1(x)
272
+ x = self.leaky_relu(x) + fea1
273
+ x = self.conv2(x)
274
+
275
+ alpha = (self.tanh(x) + 1.0) / 2.0
276
+
277
+ return alpha
278
+
279
+
280
+ class ResGuidedCxtAtten_Dec(ResNet_D_Dec):
281
+ def __init__(self,
282
+ layers,
283
+ norm_layer=None,
284
+ large_kernel=False,
285
+ late_downsample=False):
286
+ super().__init__(
287
+ layers, norm_layer, large_kernel, late_downsample=late_downsample)
288
+ self.gca = GuidedCxtAtten(128, 128)
289
+
290
+ def forward(self, x, mid_fea):
291
+ fea1, fea2, fea3, fea4, fea5 = mid_fea['shortcut']
292
+ im = mid_fea['image_fea']
293
+ x = self.layer1(x) + fea5 # N x 256 x 32 x 32
294
+ x = self.layer2(x) + fea4 # N x 128 x 64 x 64
295
+ x = self.gca(im, x, mid_fea['unknown']) # contextual attention
296
+ x = self.layer3(x) + fea3 # N x 64 x 128 x 128
297
+ x = self.layer4(x) + fea2 # N x 32 x 256 x 256
298
+ x = self.conv1(x)
299
+ x = self.bn1(x)
300
+ x = self.leaky_relu(x) + fea1
301
+ x = self.conv2(x)
302
+
303
+ alpha = (self.tanh(x) + 1.0) / 2.0
304
+
305
+ return alpha
ppmatting/models/human_matting.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import defaultdict
16
+ import time
17
+
18
+ import paddle
19
+ import paddle.nn as nn
20
+ import paddle.nn.functional as F
21
+ import paddleseg
22
+ from paddleseg.models import layers
23
+ from paddleseg import utils
24
+ from paddleseg.cvlibs import manager
25
+
26
+ from ppmatting.models.losses import MRSD
27
+
28
+
29
+ def conv_up_psp(in_channels, out_channels, up_sample):
30
+ return nn.Sequential(
31
+ layers.ConvBNReLU(
32
+ in_channels, out_channels, 3, padding=1),
33
+ nn.Upsample(
34
+ scale_factor=up_sample, mode='bilinear', align_corners=False))
35
+
36
+
37
+ @manager.MODELS.add_component
38
+ class HumanMatting(nn.Layer):
39
+ """A model for """
40
+
41
+ def __init__(self,
42
+ backbone,
43
+ pretrained=None,
44
+ backbone_scale=0.25,
45
+ refine_kernel_size=3,
46
+ if_refine=True):
47
+ super().__init__()
48
+ if if_refine:
49
+ if backbone_scale > 0.5:
50
+ raise ValueError(
51
+ 'Backbone_scale should not be greater than 1/2, but it is {}'
52
+ .format(backbone_scale))
53
+ else:
54
+ backbone_scale = 1
55
+
56
+ self.backbone = backbone
57
+ self.backbone_scale = backbone_scale
58
+ self.pretrained = pretrained
59
+ self.if_refine = if_refine
60
+ if if_refine:
61
+ self.refiner = Refiner(kernel_size=refine_kernel_size)
62
+ self.loss_func_dict = None
63
+
64
+ self.backbone_channels = backbone.feat_channels
65
+ ######################
66
+ ### Decoder part - Glance
67
+ ######################
68
+ self.psp_module = layers.PPModule(
69
+ self.backbone_channels[-1],
70
+ 512,
71
+ bin_sizes=(1, 3, 5),
72
+ dim_reduction=False,
73
+ align_corners=False)
74
+ self.psp4 = conv_up_psp(512, 256, 2)
75
+ self.psp3 = conv_up_psp(512, 128, 4)
76
+ self.psp2 = conv_up_psp(512, 64, 8)
77
+ self.psp1 = conv_up_psp(512, 64, 16)
78
+ # stage 5g
79
+ self.decoder5_g = nn.Sequential(
80
+ layers.ConvBNReLU(
81
+ 512 + self.backbone_channels[-1], 512, 3, padding=1),
82
+ layers.ConvBNReLU(
83
+ 512, 512, 3, padding=2, dilation=2),
84
+ layers.ConvBNReLU(
85
+ 512, 256, 3, padding=2, dilation=2),
86
+ nn.Upsample(
87
+ scale_factor=2, mode='bilinear', align_corners=False))
88
+ # stage 4g
89
+ self.decoder4_g = nn.Sequential(
90
+ layers.ConvBNReLU(
91
+ 512, 256, 3, padding=1),
92
+ layers.ConvBNReLU(
93
+ 256, 256, 3, padding=1),
94
+ layers.ConvBNReLU(
95
+ 256, 128, 3, padding=1),
96
+ nn.Upsample(
97
+ scale_factor=2, mode='bilinear', align_corners=False))
98
+ # stage 3g
99
+ self.decoder3_g = nn.Sequential(
100
+ layers.ConvBNReLU(
101
+ 256, 128, 3, padding=1),
102
+ layers.ConvBNReLU(
103
+ 128, 128, 3, padding=1),
104
+ layers.ConvBNReLU(
105
+ 128, 64, 3, padding=1),
106
+ nn.Upsample(
107
+ scale_factor=2, mode='bilinear', align_corners=False))
108
+ # stage 2g
109
+ self.decoder2_g = nn.Sequential(
110
+ layers.ConvBNReLU(
111
+ 128, 128, 3, padding=1),
112
+ layers.ConvBNReLU(
113
+ 128, 128, 3, padding=1),
114
+ layers.ConvBNReLU(
115
+ 128, 64, 3, padding=1),
116
+ nn.Upsample(
117
+ scale_factor=2, mode='bilinear', align_corners=False))
118
+ # stage 1g
119
+ self.decoder1_g = nn.Sequential(
120
+ layers.ConvBNReLU(
121
+ 128, 64, 3, padding=1),
122
+ layers.ConvBNReLU(
123
+ 64, 64, 3, padding=1),
124
+ layers.ConvBNReLU(
125
+ 64, 64, 3, padding=1),
126
+ nn.Upsample(
127
+ scale_factor=2, mode='bilinear', align_corners=False))
128
+ # stage 0g
129
+ self.decoder0_g = nn.Sequential(
130
+ layers.ConvBNReLU(
131
+ 64, 64, 3, padding=1),
132
+ layers.ConvBNReLU(
133
+ 64, 64, 3, padding=1),
134
+ nn.Conv2D(
135
+ 64, 3, 3, padding=1))
136
+
137
+ ##########################
138
+ ### Decoder part - FOCUS
139
+ ##########################
140
+ self.bridge_block = nn.Sequential(
141
+ layers.ConvBNReLU(
142
+ self.backbone_channels[-1], 512, 3, dilation=2, padding=2),
143
+ layers.ConvBNReLU(
144
+ 512, 512, 3, dilation=2, padding=2),
145
+ layers.ConvBNReLU(
146
+ 512, 512, 3, dilation=2, padding=2))
147
+ # stage 5f
148
+ self.decoder5_f = nn.Sequential(
149
+ layers.ConvBNReLU(
150
+ 512 + self.backbone_channels[-1], 512, 3, padding=1),
151
+ layers.ConvBNReLU(
152
+ 512, 512, 3, padding=2, dilation=2),
153
+ layers.ConvBNReLU(
154
+ 512, 256, 3, padding=2, dilation=2),
155
+ nn.Upsample(
156
+ scale_factor=2, mode='bilinear', align_corners=False))
157
+ # stage 4f
158
+ self.decoder4_f = nn.Sequential(
159
+ layers.ConvBNReLU(
160
+ 256 + self.backbone_channels[-2], 256, 3, padding=1),
161
+ layers.ConvBNReLU(
162
+ 256, 256, 3, padding=1),
163
+ layers.ConvBNReLU(
164
+ 256, 128, 3, padding=1),
165
+ nn.Upsample(
166
+ scale_factor=2, mode='bilinear', align_corners=False))
167
+ # stage 3f
168
+ self.decoder3_f = nn.Sequential(
169
+ layers.ConvBNReLU(
170
+ 128 + self.backbone_channels[-3], 128, 3, padding=1),
171
+ layers.ConvBNReLU(
172
+ 128, 128, 3, padding=1),
173
+ layers.ConvBNReLU(
174
+ 128, 64, 3, padding=1),
175
+ nn.Upsample(
176
+ scale_factor=2, mode='bilinear', align_corners=False))
177
+ # stage 2f
178
+ self.decoder2_f = nn.Sequential(
179
+ layers.ConvBNReLU(
180
+ 64 + self.backbone_channels[-4], 128, 3, padding=1),
181
+ layers.ConvBNReLU(
182
+ 128, 128, 3, padding=1),
183
+ layers.ConvBNReLU(
184
+ 128, 64, 3, padding=1),
185
+ nn.Upsample(
186
+ scale_factor=2, mode='bilinear', align_corners=False))
187
+ # stage 1f
188
+ self.decoder1_f = nn.Sequential(
189
+ layers.ConvBNReLU(
190
+ 64 + self.backbone_channels[-5], 64, 3, padding=1),
191
+ layers.ConvBNReLU(
192
+ 64, 64, 3, padding=1),
193
+ layers.ConvBNReLU(
194
+ 64, 64, 3, padding=1),
195
+ nn.Upsample(
196
+ scale_factor=2, mode='bilinear', align_corners=False))
197
+ # stage 0f
198
+ self.decoder0_f = nn.Sequential(
199
+ layers.ConvBNReLU(
200
+ 64, 64, 3, padding=1),
201
+ layers.ConvBNReLU(
202
+ 64, 64, 3, padding=1),
203
+ nn.Conv2D(
204
+ 64, 1 + 1 + 32, 3, padding=1))
205
+ self.init_weight()
206
+
207
+ def forward(self, data):
208
+ src = data['img']
209
+ src_h, src_w = paddle.shape(src)[2:]
210
+ if self.if_refine:
211
+ # It is not need when exporting.
212
+ if isinstance(src_h, paddle.Tensor):
213
+ if (src_h % 4 != 0) or (src_w % 4) != 0:
214
+ raise ValueError(
215
+ 'The input image must have width and height that are divisible by 4'
216
+ )
217
+
218
+ # Downsample src for backbone
219
+ src_sm = F.interpolate(
220
+ src,
221
+ scale_factor=self.backbone_scale,
222
+ mode='bilinear',
223
+ align_corners=False)
224
+
225
+ # Base
226
+ fea_list = self.backbone(src_sm)
227
+ ##########################
228
+ ### Decoder part - GLANCE
229
+ ##########################
230
+ #psp: N, 512, H/32, W/32
231
+ psp = self.psp_module(fea_list[-1])
232
+ #d6_g: N, 512, H/16, W/16
233
+ d5_g = self.decoder5_g(paddle.concat((psp, fea_list[-1]), 1))
234
+ #d5_g: N, 512, H/8, W/8
235
+ d4_g = self.decoder4_g(paddle.concat((self.psp4(psp), d5_g), 1))
236
+ #d4_g: N, 256, H/4, W/4
237
+ d3_g = self.decoder3_g(paddle.concat((self.psp3(psp), d4_g), 1))
238
+ #d4_g: N, 128, H/2, W/2
239
+ d2_g = self.decoder2_g(paddle.concat((self.psp2(psp), d3_g), 1))
240
+ #d2_g: N, 64, H, W
241
+ d1_g = self.decoder1_g(paddle.concat((self.psp1(psp), d2_g), 1))
242
+ #d0_g: N, 3, H, W
243
+ d0_g = self.decoder0_g(d1_g)
244
+ # The 1st channel is foreground. The 2nd is transition region. The 3rd is background.
245
+ # glance_sigmoid = F.sigmoid(d0_g)
246
+ glance_sigmoid = F.softmax(d0_g, axis=1)
247
+
248
+ ##########################
249
+ ### Decoder part - FOCUS
250
+ ##########################
251
+ bb = self.bridge_block(fea_list[-1])
252
+ #bg: N, 512, H/32, W/32
253
+ d5_f = self.decoder5_f(paddle.concat((bb, fea_list[-1]), 1))
254
+ #d5_f: N, 256, H/16, W/16
255
+ d4_f = self.decoder4_f(paddle.concat((d5_f, fea_list[-2]), 1))
256
+ #d4_f: N, 128, H/8, W/8
257
+ d3_f = self.decoder3_f(paddle.concat((d4_f, fea_list[-3]), 1))
258
+ #d3_f: N, 64, H/4, W/4
259
+ d2_f = self.decoder2_f(paddle.concat((d3_f, fea_list[-4]), 1))
260
+ #d2_f: N, 64, H/2, W/2
261
+ d1_f = self.decoder1_f(paddle.concat((d2_f, fea_list[-5]), 1))
262
+ #d1_f: N, 64, H, W
263
+ d0_f = self.decoder0_f(d1_f)
264
+ #d0_f: N, 1, H, W
265
+ focus_sigmoid = F.sigmoid(d0_f[:, 0:1, :, :])
266
+ pha_sm = self.fusion(glance_sigmoid, focus_sigmoid)
267
+ err_sm = d0_f[:, 1:2, :, :]
268
+ err_sm = paddle.clip(err_sm, 0., 1.)
269
+ hid_sm = F.relu(d0_f[:, 2:, :, :])
270
+
271
+ # Refiner
272
+ if self.if_refine:
273
+ pha = self.refiner(
274
+ src=src, pha=pha_sm, err=err_sm, hid=hid_sm, tri=glance_sigmoid)
275
+ # Clamp outputs
276
+ pha = paddle.clip(pha, 0., 1.)
277
+
278
+ if self.training:
279
+ logit_dict = {
280
+ 'glance': glance_sigmoid,
281
+ 'focus': focus_sigmoid,
282
+ 'fusion': pha_sm,
283
+ 'error': err_sm
284
+ }
285
+ if self.if_refine:
286
+ logit_dict['refine'] = pha
287
+ loss_dict = self.loss(logit_dict, data)
288
+ return logit_dict, loss_dict
289
+ else:
290
+ return pha if self.if_refine else pha_sm
291
+
292
+ def loss(self, logit_dict, label_dict, loss_func_dict=None):
293
+ if loss_func_dict is None:
294
+ if self.loss_func_dict is None:
295
+ self.loss_func_dict = defaultdict(list)
296
+ self.loss_func_dict['glance'].append(nn.NLLLoss())
297
+ self.loss_func_dict['focus'].append(MRSD())
298
+ self.loss_func_dict['cm'].append(MRSD())
299
+ self.loss_func_dict['err'].append(paddleseg.models.MSELoss())
300
+ self.loss_func_dict['refine'].append(paddleseg.models.L1Loss())
301
+ else:
302
+ self.loss_func_dict = loss_func_dict
303
+
304
+ loss = {}
305
+
306
+ # glance loss computation
307
+ # get glance label
308
+ glance_label = F.interpolate(
309
+ label_dict['trimap'],
310
+ logit_dict['glance'].shape[2:],
311
+ mode='nearest',
312
+ align_corners=False)
313
+ glance_label_trans = (glance_label == 128).astype('int64')
314
+ glance_label_bg = (glance_label == 0).astype('int64')
315
+ glance_label = glance_label_trans + glance_label_bg * 2
316
+ loss_glance = self.loss_func_dict['glance'][0](
317
+ paddle.log(logit_dict['glance'] + 1e-6), glance_label.squeeze(1))
318
+ loss['glance'] = loss_glance
319
+
320
+ # focus loss computation
321
+ focus_label = F.interpolate(
322
+ label_dict['alpha'],
323
+ logit_dict['focus'].shape[2:],
324
+ mode='bilinear',
325
+ align_corners=False)
326
+ loss_focus = self.loss_func_dict['focus'][0](
327
+ logit_dict['focus'], focus_label, glance_label_trans)
328
+ loss['focus'] = loss_focus
329
+
330
+ # collaborative matting loss
331
+ loss_cm_func = self.loss_func_dict['cm']
332
+ # fusion_sigmoid loss
333
+ loss_cm = loss_cm_func[0](logit_dict['fusion'], focus_label)
334
+ loss['cm'] = loss_cm
335
+
336
+ # error loss
337
+ err = F.interpolate(
338
+ logit_dict['error'],
339
+ label_dict['alpha'].shape[2:],
340
+ mode='bilinear',
341
+ align_corners=False)
342
+ err_label = (F.interpolate(
343
+ logit_dict['fusion'],
344
+ label_dict['alpha'].shape[2:],
345
+ mode='bilinear',
346
+ align_corners=False) - label_dict['alpha']).abs()
347
+ loss_err = self.loss_func_dict['err'][0](err, err_label)
348
+ loss['err'] = loss_err
349
+
350
+ loss_all = 0.25 * loss_glance + 0.25 * loss_focus + 0.25 * loss_cm + loss_err
351
+
352
+ # refine loss
353
+ if self.if_refine:
354
+ loss_refine = self.loss_func_dict['refine'][0](logit_dict['refine'],
355
+ label_dict['alpha'])
356
+ loss['refine'] = loss_refine
357
+ loss_all = loss_all + loss_refine
358
+
359
+ loss['all'] = loss_all
360
+ return loss
361
+
362
+ def fusion(self, glance_sigmoid, focus_sigmoid):
363
+ # glance_sigmoid [N, 3, H, W].
364
+ # In index, 0 is foreground, 1 is transition, 2 is backbone.
365
+ # After fusion, the foreground is 1, the background is 0, and the transion is between (0, 1).
366
+ index = paddle.argmax(glance_sigmoid, axis=1, keepdim=True)
367
+ transition_mask = (index == 1).astype('float32')
368
+ fg = (index == 0).astype('float32')
369
+ fusion_sigmoid = focus_sigmoid * transition_mask + fg
370
+ return fusion_sigmoid
371
+
372
+ def init_weight(self):
373
+ if self.pretrained is not None:
374
+ utils.load_entire_model(self, self.pretrained)
375
+
376
+
377
+ class Refiner(nn.Layer):
378
+ '''
379
+ Refiner refines the coarse output to full resolution.
380
+
381
+ Args:
382
+ kernel_size: The convolution kernel_size. Options: [1, 3]. Default: 3.
383
+ '''
384
+
385
+ def __init__(self, kernel_size=3):
386
+ super().__init__()
387
+ if kernel_size not in [1, 3]:
388
+ raise ValueError("kernel_size must be in [1, 3]")
389
+
390
+ self.kernel_size = kernel_size
391
+
392
+ channels = [32, 24, 16, 12, 1]
393
+ self.conv1 = layers.ConvBNReLU(
394
+ channels[0] + 4 + 3,
395
+ channels[1],
396
+ kernel_size,
397
+ padding=0,
398
+ bias_attr=False)
399
+ self.conv2 = layers.ConvBNReLU(
400
+ channels[1], channels[2], kernel_size, padding=0, bias_attr=False)
401
+ self.conv3 = layers.ConvBNReLU(
402
+ channels[2] + 3,
403
+ channels[3],
404
+ kernel_size,
405
+ padding=0,
406
+ bias_attr=False)
407
+ self.conv4 = nn.Conv2D(
408
+ channels[3], channels[4], kernel_size, padding=0, bias_attr=True)
409
+
410
+ def forward(self, src, pha, err, hid, tri):
411
+ '''
412
+ Args:
413
+ src: (B, 3, H, W) full resolution source image.
414
+ pha: (B, 1, Hc, Wc) coarse alpha prediction.
415
+ err: (B, 1, Hc, Hc) coarse error prediction.
416
+ hid: (B, 32, Hc, Hc) coarse hidden encoding.
417
+ tri: (B, 1, Hc, Hc) trimap prediction.
418
+ '''
419
+ h_full, w_full = paddle.shape(src)[2:]
420
+ h_half, w_half = h_full // 2, w_full // 2
421
+ h_quat, w_quat = h_full // 4, w_full // 4
422
+
423
+ x = paddle.concat([hid, pha, tri], axis=1)
424
+ x = F.interpolate(
425
+ x,
426
+ paddle.concat((h_half, w_half)),
427
+ mode='bilinear',
428
+ align_corners=False)
429
+ y = F.interpolate(
430
+ src,
431
+ paddle.concat((h_half, w_half)),
432
+ mode='bilinear',
433
+ align_corners=False)
434
+
435
+ if self.kernel_size == 3:
436
+ x = F.pad(x, [3, 3, 3, 3])
437
+ y = F.pad(y, [3, 3, 3, 3])
438
+
439
+ x = self.conv1(paddle.concat([x, y], axis=1))
440
+ x = self.conv2(x)
441
+
442
+ if self.kernel_size == 3:
443
+ x = F.interpolate(x, paddle.concat((h_full + 4, w_full + 4)))
444
+ y = F.pad(src, [2, 2, 2, 2])
445
+ else:
446
+ x = F.interpolate(
447
+ x, paddle.concat((h_full, w_full)), mode='nearest')
448
+ y = src
449
+
450
+ x = self.conv3(paddle.concat([x, y], axis=1))
451
+ x = self.conv4(x)
452
+
453
+ pha = x
454
+ return pha
ppmatting/models/layers/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .gca_module import GuidedCxtAtten
ppmatting/models/layers/gca_module.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # The gca code was heavily based on https://github.com/Yaoyi-Li/GCA-Matting
16
+ # and https://github.com/open-mmlab/mmediting
17
+
18
+ import paddle
19
+ import paddle.nn as nn
20
+ import paddle.nn.functional as F
21
+
22
+ from paddleseg.cvlibs import param_init
23
+
24
+
25
+ class GuidedCxtAtten(nn.Layer):
26
+ def __init__(self,
27
+ out_channels,
28
+ guidance_channels,
29
+ kernel_size=3,
30
+ stride=1,
31
+ rate=2):
32
+ super().__init__()
33
+
34
+ self.kernel_size = kernel_size
35
+ self.rate = rate
36
+ self.stride = stride
37
+ self.guidance_conv = nn.Conv2D(
38
+ in_channels=guidance_channels,
39
+ out_channels=guidance_channels // 2,
40
+ kernel_size=1)
41
+
42
+ self.out_conv = nn.Sequential(
43
+ nn.Conv2D(
44
+ in_channels=out_channels,
45
+ out_channels=out_channels,
46
+ kernel_size=1,
47
+ bias_attr=False),
48
+ nn.BatchNorm(out_channels))
49
+
50
+ self.init_weight()
51
+
52
+ def init_weight(self):
53
+ param_init.xavier_uniform(self.guidance_conv.weight)
54
+ param_init.constant_init(self.guidance_conv.bias, value=0.0)
55
+ param_init.xavier_uniform(self.out_conv[0].weight)
56
+ param_init.constant_init(self.out_conv[1].weight, value=1e-3)
57
+ param_init.constant_init(self.out_conv[1].bias, value=0.0)
58
+
59
+ def forward(self, img_feat, alpha_feat, unknown=None, softmax_scale=1.):
60
+
61
+ img_feat = self.guidance_conv(img_feat)
62
+ img_feat = F.interpolate(
63
+ img_feat, scale_factor=1 / self.rate, mode='nearest')
64
+
65
+ # process unknown mask
66
+ unknown, softmax_scale = self.process_unknown_mask(unknown, img_feat,
67
+ softmax_scale)
68
+
69
+ img_ps, alpha_ps, unknown_ps = self.extract_feature_maps_patches(
70
+ img_feat, alpha_feat, unknown)
71
+
72
+ self_mask = self.get_self_correlation_mask(img_feat)
73
+
74
+ # split tensors by batch dimension; tuple is returned
75
+ img_groups = paddle.split(img_feat, 1, axis=0)
76
+ img_ps_groups = paddle.split(img_ps, 1, axis=0)
77
+ alpha_ps_groups = paddle.split(alpha_ps, 1, axis=0)
78
+ unknown_ps_groups = paddle.split(unknown_ps, 1, axis=0)
79
+ scale_groups = paddle.split(softmax_scale, 1, axis=0)
80
+ groups = (img_groups, img_ps_groups, alpha_ps_groups, unknown_ps_groups,
81
+ scale_groups)
82
+
83
+ y = []
84
+
85
+ for img_i, img_ps_i, alpha_ps_i, unknown_ps_i, scale_i in zip(*groups):
86
+ # conv for compare
87
+ similarity_map = self.compute_similarity_map(img_i, img_ps_i)
88
+
89
+ gca_score = self.compute_guided_attention_score(
90
+ similarity_map, unknown_ps_i, scale_i, self_mask)
91
+
92
+ yi = self.propagate_alpha_feature(gca_score, alpha_ps_i)
93
+
94
+ y.append(yi)
95
+
96
+ y = paddle.concat(y, axis=0) # back to the mini-batch
97
+ y = paddle.reshape(y, alpha_feat.shape)
98
+
99
+ y = self.out_conv(y) + alpha_feat
100
+
101
+ return y
102
+
103
+ def extract_feature_maps_patches(self, img_feat, alpha_feat, unknown):
104
+
105
+ # extract image feature patches with shape:
106
+ # (N, img_h*img_w, img_c, img_ks, img_ks)
107
+ img_ks = self.kernel_size
108
+ img_ps = self.extract_patches(img_feat, img_ks, self.stride)
109
+
110
+ # extract alpha feature patches with shape:
111
+ # (N, img_h*img_w, alpha_c, alpha_ks, alpha_ks)
112
+ alpha_ps = self.extract_patches(alpha_feat, self.rate * 2, self.rate)
113
+
114
+ # extract unknown mask patches with shape: (N, img_h*img_w, 1, 1)
115
+ unknown_ps = self.extract_patches(unknown, img_ks, self.stride)
116
+ unknown_ps = unknown_ps.squeeze(axis=2) # squeeze channel dimension
117
+ unknown_ps = unknown_ps.mean(axis=[2, 3], keepdim=True)
118
+
119
+ return img_ps, alpha_ps, unknown_ps
120
+
121
+ def extract_patches(self, x, kernel_size, stride):
122
+ n, c, _, _ = x.shape
123
+ x = self.pad(x, kernel_size, stride)
124
+ x = F.unfold(x, [kernel_size, kernel_size], strides=[stride, stride])
125
+ x = paddle.transpose(x, (0, 2, 1))
126
+ x = paddle.reshape(x, (n, -1, c, kernel_size, kernel_size))
127
+
128
+ return x
129
+
130
+ def pad(self, x, kernel_size, stride):
131
+ left = (kernel_size - stride + 1) // 2
132
+ right = (kernel_size - stride) // 2
133
+ pad = (left, right, left, right)
134
+ return F.pad(x, pad, mode='reflect')
135
+
136
+ def compute_guided_attention_score(self, similarity_map, unknown_ps, scale,
137
+ self_mask):
138
+ # scale the correlation with predicted scale factor for known and
139
+ # unknown area
140
+ unknown_scale, known_scale = scale[0]
141
+ out = similarity_map * (
142
+ unknown_scale * paddle.greater_than(unknown_ps,
143
+ paddle.to_tensor([0.])) +
144
+ known_scale * paddle.less_equal(unknown_ps, paddle.to_tensor([0.])))
145
+ # mask itself, self-mask only applied to unknown area
146
+ out = out + self_mask * unknown_ps
147
+ gca_score = F.softmax(out, axis=1)
148
+
149
+ return gca_score
150
+
151
+ def propagate_alpha_feature(self, gca_score, alpha_ps):
152
+
153
+ alpha_ps = alpha_ps[0] # squeeze dim 0
154
+ if self.rate == 1:
155
+ gca_score = self.pad(gca_score, kernel_size=2, stride=1)
156
+ alpha_ps = paddle.transpose(alpha_ps, (1, 0, 2, 3))
157
+ out = F.conv2d(gca_score, alpha_ps) / 4.
158
+ else:
159
+ out = F.conv2d_transpose(
160
+ gca_score, alpha_ps, stride=self.rate, padding=1) / 4.
161
+
162
+ return out
163
+
164
+ def compute_similarity_map(self, img_feat, img_ps):
165
+ img_ps = img_ps[0] # squeeze dim 0
166
+ # convolve the feature to get correlation (similarity) map
167
+ img_ps_normed = img_ps / paddle.clip(self.l2_norm(img_ps), 1e-4)
168
+ img_feat = F.pad(img_feat, (1, 1, 1, 1), mode='reflect')
169
+ similarity_map = F.conv2d(img_feat, img_ps_normed)
170
+
171
+ return similarity_map
172
+
173
+ def get_self_correlation_mask(self, img_feat):
174
+ _, _, h, w = img_feat.shape
175
+ self_mask = F.one_hot(
176
+ paddle.reshape(paddle.arange(h * w), (h, w)),
177
+ num_classes=int(h * w))
178
+
179
+ self_mask = paddle.transpose(self_mask, (2, 0, 1))
180
+ self_mask = paddle.reshape(self_mask, (1, h * w, h, w))
181
+
182
+ return self_mask * (-1e4)
183
+
184
+ def process_unknown_mask(self, unknown, img_feat, softmax_scale):
185
+
186
+ n, _, h, w = img_feat.shape
187
+
188
+ if unknown is not None:
189
+ unknown = unknown.clone()
190
+ unknown = F.interpolate(
191
+ unknown, scale_factor=1 / self.rate, mode='nearest')
192
+ unknown_mean = unknown.mean(axis=[2, 3])
193
+ known_mean = 1 - unknown_mean
194
+ unknown_scale = paddle.clip(
195
+ paddle.sqrt(unknown_mean / known_mean), 0.1, 10)
196
+ known_scale = paddle.clip(
197
+ paddle.sqrt(known_mean / unknown_mean), 0.1, 10)
198
+ softmax_scale = paddle.concat([unknown_scale, known_scale], axis=1)
199
+ else:
200
+ unknown = paddle.ones([n, 1, h, w])
201
+ softmax_scale = paddle.reshape(
202
+ paddle.to_tensor([softmax_scale, softmax_scale]), (1, 2))
203
+ softmax_scale = paddle.expand(softmax_scale, (n, 2))
204
+
205
+ return unknown, softmax_scale
206
+
207
+ @staticmethod
208
+ def l2_norm(x):
209
+ x = x**2
210
+ x = x.sum(axis=[1, 2, 3], keepdim=True)
211
+ return paddle.sqrt(x)
ppmatting/models/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .loss import *
ppmatting/models/losses/loss.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import paddle
16
+ import paddle.nn as nn
17
+ import paddle.nn.functional as F
18
+
19
+ from paddleseg.cvlibs import manager
20
+ import cv2
21
+
22
+
23
+ @manager.LOSSES.add_component
24
+ class MRSD(nn.Layer):
25
+ def __init__(self, eps=1e-6):
26
+ super().__init__()
27
+ self.eps = eps
28
+
29
+ def forward(self, logit, label, mask=None):
30
+ """
31
+ Forward computation.
32
+
33
+ Args:
34
+ logit (Tensor): Logit tensor, the data type is float32, float64.
35
+ label (Tensor): Label tensor, the data type is float32, float64. The shape should equal to logit.
36
+ mask (Tensor, optional): The mask where the loss valid. Default: None.
37
+ """
38
+ if len(label.shape) == 3:
39
+ label = label.unsqueeze(1)
40
+ sd = paddle.square(logit - label)
41
+ loss = paddle.sqrt(sd + self.eps)
42
+ if mask is not None:
43
+ mask = mask.astype('float32')
44
+ if len(mask.shape) == 3:
45
+ mask = mask.unsqueeze(1)
46
+ loss = loss * mask
47
+ loss = loss.sum() / (mask.sum() + self.eps)
48
+ mask.stop_gradient = True
49
+ else:
50
+ loss = loss.mean()
51
+
52
+ return loss
53
+
54
+
55
+ @manager.LOSSES.add_component
56
+ class GradientLoss(nn.Layer):
57
+ def __init__(self, eps=1e-6):
58
+ super().__init__()
59
+ self.kernel_x, self.kernel_y = self.sobel_kernel()
60
+ self.eps = eps
61
+
62
+ def forward(self, logit, label, mask=None):
63
+ if len(label.shape) == 3:
64
+ label = label.unsqueeze(1)
65
+ if mask is not None:
66
+ if len(mask.shape) == 3:
67
+ mask = mask.unsqueeze(1)
68
+ logit = logit * mask
69
+ label = label * mask
70
+ loss = paddle.sum(
71
+ F.l1_loss(self.sobel(logit), self.sobel(label), 'none')) / (
72
+ mask.sum() + self.eps)
73
+ else:
74
+ loss = F.l1_loss(self.sobel(logit), self.sobel(label), 'mean')
75
+
76
+ return loss
77
+
78
+ def sobel(self, input):
79
+ """Using Sobel to compute gradient. Return the magnitude."""
80
+ if not len(input.shape) == 4:
81
+ raise ValueError("Invalid input shape, we expect NCHW, but it is ",
82
+ input.shape)
83
+
84
+ n, c, h, w = input.shape
85
+
86
+ input_pad = paddle.reshape(input, (n * c, 1, h, w))
87
+ input_pad = F.pad(input_pad, pad=[1, 1, 1, 1], mode='replicate')
88
+
89
+ grad_x = F.conv2d(input_pad, self.kernel_x, padding=0)
90
+ grad_y = F.conv2d(input_pad, self.kernel_y, padding=0)
91
+
92
+ mag = paddle.sqrt(grad_x * grad_x + grad_y * grad_y + self.eps)
93
+ mag = paddle.reshape(mag, (n, c, h, w))
94
+
95
+ return mag
96
+
97
+ def sobel_kernel(self):
98
+ kernel_x = paddle.to_tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0],
99
+ [-1.0, 0.0, 1.0]]).astype('float32')
100
+ kernel_x = kernel_x / kernel_x.abs().sum()
101
+ kernel_y = kernel_x.transpose([1, 0])
102
+ kernel_x = kernel_x.unsqueeze(0).unsqueeze(0)
103
+ kernel_y = kernel_y.unsqueeze(0).unsqueeze(0)
104
+ kernel_x.stop_gradient = True
105
+ kernel_y.stop_gradient = True
106
+ return kernel_x, kernel_y
107
+
108
+
109
+ @manager.LOSSES.add_component
110
+ class LaplacianLoss(nn.Layer):
111
+ """
112
+ Laplacian loss is refer to
113
+ https://github.com/JizhiziLi/AIM/blob/master/core/evaluate.py#L83
114
+ """
115
+
116
+ def __init__(self):
117
+ super().__init__()
118
+ self.gauss_kernel = self.build_gauss_kernel(
119
+ size=5, sigma=1.0, n_channels=1)
120
+
121
+ def forward(self, logit, label, mask=None):
122
+ if len(label.shape) == 3:
123
+ label = label.unsqueeze(1)
124
+ if mask is not None:
125
+ if len(mask.shape) == 3:
126
+ mask = mask.unsqueeze(1)
127
+ logit = logit * mask
128
+ label = label * mask
129
+ pyr_label = self.laplacian_pyramid(label, self.gauss_kernel, 5)
130
+ pyr_logit = self.laplacian_pyramid(logit, self.gauss_kernel, 5)
131
+ loss = sum(F.l1_loss(a, b) for a, b in zip(pyr_label, pyr_logit))
132
+
133
+ return loss
134
+
135
+ def build_gauss_kernel(self, size=5, sigma=1.0, n_channels=1):
136
+ if size % 2 != 1:
137
+ raise ValueError("kernel size must be uneven")
138
+ grid = np.float32(np.mgrid[0:size, 0:size].T)
139
+ gaussian = lambda x: np.exp((x - size // 2)**2 / (-2 * sigma**2))**2
140
+ kernel = np.sum(gaussian(grid), axis=2)
141
+ kernel /= np.sum(kernel)
142
+ kernel = np.tile(kernel, (n_channels, 1, 1))
143
+ kernel = paddle.to_tensor(kernel[:, None, :, :])
144
+ kernel.stop_gradient = True
145
+ return kernel
146
+
147
+ def conv_gauss(self, input, kernel):
148
+ n_channels, _, kh, kw = kernel.shape
149
+ x = F.pad(input, (kh // 2, kw // 2, kh // 2, kh // 2), mode='replicate')
150
+ x = F.conv2d(x, kernel, groups=n_channels)
151
+
152
+ return x
153
+
154
+ def laplacian_pyramid(self, input, kernel, max_levels=5):
155
+ current = input
156
+ pyr = []
157
+ for level in range(max_levels):
158
+ filtered = self.conv_gauss(current, kernel)
159
+ diff = current - filtered
160
+ pyr.append(diff)
161
+ current = F.avg_pool2d(filtered, 2)
162
+ pyr.append(current)
163
+ return pyr
ppmatting/models/modnet.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # you may not use this file except in compliance with the License.
2
+ # You may obtain a copy of the License at
3
+ #
4
+ # http://www.apache.org/licenses/LICENSE-2.0
5
+ #
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from collections import defaultdict
13
+
14
+ import paddle
15
+ import paddle.nn as nn
16
+ import paddle.nn.functional as F
17
+ import numpy as np
18
+ import scipy
19
+ import paddleseg
20
+ from paddleseg.models import layers, losses
21
+ from paddleseg import utils
22
+ from paddleseg.cvlibs import manager, param_init
23
+
24
+
25
+ @manager.MODELS.add_component
26
+ class MODNet(nn.Layer):
27
+ """
28
+ The MODNet implementation based on PaddlePaddle.
29
+
30
+ The original article refers to
31
+ Zhanghan Ke, et, al. "Is a Green Screen Really Necessary for Real-Time Portrait Matting?"
32
+ (https://arxiv.org/pdf/2011.11961.pdf).
33
+
34
+ Args:
35
+ backbone: backbone model.
36
+ hr(int, optional): The channels of high resolutions branch. Defautl: None.
37
+ pretrained(str, optional): The path of pretrianed model. Defautl: None.
38
+
39
+ """
40
+
41
+ def __init__(self, backbone, hr_channels=32, pretrained=None):
42
+ super().__init__()
43
+ self.backbone = backbone
44
+ self.pretrained = pretrained
45
+ self.head = MODNetHead(
46
+ hr_channels=hr_channels, backbone_channels=backbone.feat_channels)
47
+ self.init_weight()
48
+ self.blurer = GaussianBlurLayer(1, 3)
49
+ self.loss_func_dict = None
50
+
51
+ def forward(self, inputs):
52
+ """
53
+ If training, return a dict.
54
+ If evaluation, return the final alpha prediction.
55
+ """
56
+ x = inputs['img']
57
+ feat_list = self.backbone(x)
58
+ y = self.head(inputs=inputs, feat_list=feat_list)
59
+ if self.training:
60
+ loss = self.loss(y, inputs)
61
+ return y, loss
62
+ else:
63
+ return y
64
+
65
+ def loss(self, logit_dict, label_dict, loss_func_dict=None):
66
+ if loss_func_dict is None:
67
+ if self.loss_func_dict is None:
68
+ self.loss_func_dict = defaultdict(list)
69
+ self.loss_func_dict['semantic'].append(paddleseg.models.MSELoss(
70
+ ))
71
+ self.loss_func_dict['detail'].append(paddleseg.models.L1Loss())
72
+ self.loss_func_dict['fusion'].append(paddleseg.models.L1Loss())
73
+ self.loss_func_dict['fusion'].append(paddleseg.models.L1Loss())
74
+ else:
75
+ self.loss_func_dict = loss_func_dict
76
+
77
+ loss = {}
78
+ # semantic loss
79
+ semantic_gt = F.interpolate(
80
+ label_dict['alpha'],
81
+ scale_factor=1 / 16,
82
+ mode='bilinear',
83
+ align_corners=False)
84
+ semantic_gt = self.blurer(semantic_gt)
85
+ # semantic_gt.stop_gradient=True
86
+ loss['semantic'] = self.loss_func_dict['semantic'][0](
87
+ logit_dict['semantic'], semantic_gt)
88
+
89
+ # detail loss
90
+ trimap = label_dict['trimap']
91
+ mask = (trimap == 128).astype('float32')
92
+ logit_detail = logit_dict['detail'] * mask
93
+ label_detail = label_dict['alpha'] * mask
94
+ loss_detail = self.loss_func_dict['detail'][0](logit_detail,
95
+ label_detail)
96
+ loss_detail = loss_detail / (mask.mean() + 1e-6)
97
+ loss['detail'] = 10 * loss_detail
98
+
99
+ # fusion loss
100
+ matte = logit_dict['matte']
101
+ alpha = label_dict['alpha']
102
+ transition_mask = label_dict['trimap'] == 128
103
+ matte_boundary = paddle.where(transition_mask, matte, alpha)
104
+ # l1 loss
105
+ loss_fusion_l1 = self.loss_func_dict['fusion'][0](
106
+ matte, alpha) + 4 * self.loss_func_dict['fusion'][0](matte_boundary,
107
+ alpha)
108
+ # composition loss
109
+ loss_fusion_comp = self.loss_func_dict['fusion'][1](
110
+ matte * label_dict['img'], alpha *
111
+ label_dict['img']) + 4 * self.loss_func_dict['fusion'][1](
112
+ matte_boundary * label_dict['img'], alpha * label_dict['img'])
113
+ # consisten loss with semantic
114
+ transition_mask = F.interpolate(
115
+ label_dict['trimap'],
116
+ scale_factor=1 / 16,
117
+ mode='nearest',
118
+ align_corners=False)
119
+ transition_mask = transition_mask == 128
120
+ matte_con_sem = F.interpolate(
121
+ matte, scale_factor=1 / 16, mode='bilinear', align_corners=False)
122
+ matte_con_sem = self.blurer(matte_con_sem)
123
+ logit_semantic = logit_dict['semantic'].clone()
124
+ logit_semantic.stop_gradient = True
125
+ matte_con_sem = paddle.where(transition_mask, logit_semantic,
126
+ matte_con_sem)
127
+ if False:
128
+ import cv2
129
+ matte_con_sem_num = matte_con_sem.numpy()
130
+ matte_con_sem_num = matte_con_sem_num[0].squeeze()
131
+ matte_con_sem_num = (matte_con_sem_num * 255).astype('uint8')
132
+ semantic = logit_dict['semantic'].numpy()
133
+ semantic = semantic[0].squeeze()
134
+ semantic = (semantic * 255).astype('uint8')
135
+ transition_mask = transition_mask.astype('uint8')
136
+ transition_mask = transition_mask.numpy()
137
+ transition_mask = (transition_mask[0].squeeze()) * 255
138
+ cv2.imwrite('matte_con.png', matte_con_sem_num)
139
+ cv2.imwrite('semantic.png', semantic)
140
+ cv2.imwrite('transition.png', transition_mask)
141
+ mse_loss = paddleseg.models.MSELoss()
142
+ loss_fusion_con_sem = mse_loss(matte_con_sem, logit_dict['semantic'])
143
+ loss_fusion = loss_fusion_l1 + loss_fusion_comp + loss_fusion_con_sem
144
+ loss['fusion'] = loss_fusion
145
+ loss['fusion_l1'] = loss_fusion_l1
146
+ loss['fusion_comp'] = loss_fusion_comp
147
+ loss['fusion_con_sem'] = loss_fusion_con_sem
148
+
149
+ loss['all'] = loss['semantic'] + loss['detail'] + loss['fusion']
150
+
151
+ return loss
152
+
153
+ def init_weight(self):
154
+ if self.pretrained is not None:
155
+ utils.load_entire_model(self, self.pretrained)
156
+
157
+
158
+ class MODNetHead(nn.Layer):
159
+ def __init__(self, hr_channels, backbone_channels):
160
+ super().__init__()
161
+
162
+ self.lr_branch = LRBranch(backbone_channels)
163
+ self.hr_branch = HRBranch(hr_channels, backbone_channels)
164
+ self.f_branch = FusionBranch(hr_channels, backbone_channels)
165
+ self.init_weight()
166
+
167
+ def forward(self, inputs, feat_list):
168
+ pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(feat_list)
169
+ pred_detail, hr2x = self.hr_branch(inputs['img'], enc2x, enc4x, lr8x)
170
+ pred_matte = self.f_branch(inputs['img'], lr8x, hr2x)
171
+
172
+ if self.training:
173
+ logit_dict = {
174
+ 'semantic': pred_semantic,
175
+ 'detail': pred_detail,
176
+ 'matte': pred_matte
177
+ }
178
+ return logit_dict
179
+ else:
180
+ return pred_matte
181
+
182
+ def init_weight(self):
183
+ for layer in self.sublayers():
184
+ if isinstance(layer, nn.Conv2D):
185
+ param_init.kaiming_uniform(layer.weight)
186
+
187
+
188
+ class FusionBranch(nn.Layer):
189
+ def __init__(self, hr_channels, enc_channels):
190
+ super().__init__()
191
+ self.conv_lr4x = Conv2dIBNormRelu(
192
+ enc_channels[2], hr_channels, 5, stride=1, padding=2)
193
+
194
+ self.conv_f2x = Conv2dIBNormRelu(
195
+ 2 * hr_channels, hr_channels, 3, stride=1, padding=1)
196
+ self.conv_f = nn.Sequential(
197
+ Conv2dIBNormRelu(
198
+ hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
199
+ Conv2dIBNormRelu(
200
+ int(hr_channels / 2),
201
+ 1,
202
+ 1,
203
+ stride=1,
204
+ padding=0,
205
+ with_ibn=False,
206
+ with_relu=False))
207
+
208
+ def forward(self, img, lr8x, hr2x):
209
+ lr4x = F.interpolate(
210
+ lr8x, scale_factor=2, mode='bilinear', align_corners=False)
211
+ lr4x = self.conv_lr4x(lr4x)
212
+ lr2x = F.interpolate(
213
+ lr4x, scale_factor=2, mode='bilinear', align_corners=False)
214
+
215
+ f2x = self.conv_f2x(paddle.concat((lr2x, hr2x), axis=1))
216
+ f = F.interpolate(
217
+ f2x, scale_factor=2, mode='bilinear', align_corners=False)
218
+ f = self.conv_f(paddle.concat((f, img), axis=1))
219
+ pred_matte = F.sigmoid(f)
220
+
221
+ return pred_matte
222
+
223
+
224
+ class HRBranch(nn.Layer):
225
+ """
226
+ High Resolution Branch of MODNet
227
+ """
228
+
229
+ def __init__(self, hr_channels, enc_channels):
230
+ super().__init__()
231
+
232
+ self.tohr_enc2x = Conv2dIBNormRelu(
233
+ enc_channels[0], hr_channels, 1, stride=1, padding=0)
234
+ self.conv_enc2x = Conv2dIBNormRelu(
235
+ hr_channels + 3, hr_channels, 3, stride=2, padding=1)
236
+
237
+ self.tohr_enc4x = Conv2dIBNormRelu(
238
+ enc_channels[1], hr_channels, 1, stride=1, padding=0)
239
+ self.conv_enc4x = Conv2dIBNormRelu(
240
+ 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
241
+
242
+ self.conv_hr4x = nn.Sequential(
243
+ Conv2dIBNormRelu(
244
+ 2 * hr_channels + enc_channels[2] + 3,
245
+ 2 * hr_channels,
246
+ 3,
247
+ stride=1,
248
+ padding=1),
249
+ Conv2dIBNormRelu(
250
+ 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
251
+ Conv2dIBNormRelu(
252
+ 2 * hr_channels, hr_channels, 3, stride=1, padding=1))
253
+
254
+ self.conv_hr2x = nn.Sequential(
255
+ Conv2dIBNormRelu(
256
+ 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
257
+ Conv2dIBNormRelu(
258
+ 2 * hr_channels, hr_channels, 3, stride=1, padding=1),
259
+ Conv2dIBNormRelu(
260
+ hr_channels, hr_channels, 3, stride=1, padding=1),
261
+ Conv2dIBNormRelu(
262
+ hr_channels, hr_channels, 3, stride=1, padding=1))
263
+
264
+ self.conv_hr = nn.Sequential(
265
+ Conv2dIBNormRelu(
266
+ hr_channels + 3, hr_channels, 3, stride=1, padding=1),
267
+ Conv2dIBNormRelu(
268
+ hr_channels,
269
+ 1,
270
+ 1,
271
+ stride=1,
272
+ padding=0,
273
+ with_ibn=False,
274
+ with_relu=False))
275
+
276
+ def forward(self, img, enc2x, enc4x, lr8x):
277
+ img2x = F.interpolate(
278
+ img, scale_factor=1 / 2, mode='bilinear', align_corners=False)
279
+ img4x = F.interpolate(
280
+ img, scale_factor=1 / 4, mode='bilinear', align_corners=False)
281
+
282
+ enc2x = self.tohr_enc2x(enc2x)
283
+ hr4x = self.conv_enc2x(paddle.concat((img2x, enc2x), axis=1))
284
+
285
+ enc4x = self.tohr_enc4x(enc4x)
286
+ hr4x = self.conv_enc4x(paddle.concat((hr4x, enc4x), axis=1))
287
+
288
+ lr4x = F.interpolate(
289
+ lr8x, scale_factor=2, mode='bilinear', align_corners=False)
290
+ hr4x = self.conv_hr4x(paddle.concat((hr4x, lr4x, img4x), axis=1))
291
+
292
+ hr2x = F.interpolate(
293
+ hr4x, scale_factor=2, mode='bilinear', align_corners=False)
294
+ hr2x = self.conv_hr2x(paddle.concat((hr2x, enc2x), axis=1))
295
+
296
+ pred_detail = None
297
+ if self.training:
298
+ hr = F.interpolate(
299
+ hr2x, scale_factor=2, mode='bilinear', align_corners=False)
300
+ hr = self.conv_hr(paddle.concat((hr, img), axis=1))
301
+ pred_detail = F.sigmoid(hr)
302
+
303
+ return pred_detail, hr2x
304
+
305
+
306
+ class LRBranch(nn.Layer):
307
+ def __init__(self, backbone_channels):
308
+ super().__init__()
309
+ self.se_block = SEBlock(backbone_channels[4], reduction=4)
310
+ self.conv_lr16x = Conv2dIBNormRelu(
311
+ backbone_channels[4], backbone_channels[3], 5, stride=1, padding=2)
312
+ self.conv_lr8x = Conv2dIBNormRelu(
313
+ backbone_channels[3], backbone_channels[2], 5, stride=1, padding=2)
314
+ self.conv_lr = Conv2dIBNormRelu(
315
+ backbone_channels[2],
316
+ 1,
317
+ 3,
318
+ stride=2,
319
+ padding=1,
320
+ with_ibn=False,
321
+ with_relu=False)
322
+
323
+ def forward(self, feat_list):
324
+ enc2x, enc4x, enc32x = feat_list[0], feat_list[1], feat_list[4]
325
+
326
+ enc32x = self.se_block(enc32x)
327
+ lr16x = F.interpolate(
328
+ enc32x, scale_factor=2, mode='bilinear', align_corners=False)
329
+ lr16x = self.conv_lr16x(lr16x)
330
+ lr8x = F.interpolate(
331
+ lr16x, scale_factor=2, mode='bilinear', align_corners=False)
332
+ lr8x = self.conv_lr8x(lr8x)
333
+
334
+ pred_semantic = None
335
+ if self.training:
336
+ lr = self.conv_lr(lr8x)
337
+ pred_semantic = F.sigmoid(lr)
338
+
339
+ return pred_semantic, lr8x, [enc2x, enc4x]
340
+
341
+
342
+ class IBNorm(nn.Layer):
343
+ """
344
+ Combine Instance Norm and Batch Norm into One Layer
345
+ """
346
+
347
+ def __init__(self, in_channels):
348
+ super().__init__()
349
+ self.bnorm_channels = in_channels // 2
350
+ self.inorm_channels = in_channels - self.bnorm_channels
351
+
352
+ self.bnorm = nn.BatchNorm2D(self.bnorm_channels)
353
+ self.inorm = nn.InstanceNorm2D(self.inorm_channels)
354
+
355
+ def forward(self, x):
356
+ bn_x = self.bnorm(x[:, :self.bnorm_channels, :, :])
357
+ in_x = self.inorm(x[:, self.bnorm_channels:, :, :])
358
+
359
+ return paddle.concat((bn_x, in_x), 1)
360
+
361
+
362
+ class Conv2dIBNormRelu(nn.Layer):
363
+ """
364
+ Convolution + IBNorm + Relu
365
+ """
366
+
367
+ def __init__(self,
368
+ in_channels,
369
+ out_channels,
370
+ kernel_size,
371
+ stride=1,
372
+ padding=0,
373
+ dilation=1,
374
+ groups=1,
375
+ bias_attr=None,
376
+ with_ibn=True,
377
+ with_relu=True):
378
+
379
+ super().__init__()
380
+
381
+ layers = [
382
+ nn.Conv2D(
383
+ in_channels,
384
+ out_channels,
385
+ kernel_size,
386
+ stride=stride,
387
+ padding=padding,
388
+ dilation=dilation,
389
+ groups=groups,
390
+ bias_attr=bias_attr)
391
+ ]
392
+
393
+ if with_ibn:
394
+ layers.append(IBNorm(out_channels))
395
+
396
+ if with_relu:
397
+ layers.append(nn.ReLU())
398
+
399
+ self.layers = nn.Sequential(*layers)
400
+
401
+ def forward(self, x):
402
+ return self.layers(x)
403
+
404
+
405
+ class SEBlock(nn.Layer):
406
+ """
407
+ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
408
+ """
409
+
410
+ def __init__(self, num_channels, reduction=1):
411
+ super().__init__()
412
+ self.pool = nn.AdaptiveAvgPool2D(1)
413
+ self.conv = nn.Sequential(
414
+ nn.Conv2D(
415
+ num_channels,
416
+ int(num_channels // reduction),
417
+ 1,
418
+ bias_attr=False),
419
+ nn.ReLU(),
420
+ nn.Conv2D(
421
+ int(num_channels // reduction),
422
+ num_channels,
423
+ 1,
424
+ bias_attr=False),
425
+ nn.Sigmoid())
426
+
427
+ def forward(self, x):
428
+ w = self.pool(x)
429
+ w = self.conv(w)
430
+ return w * x
431
+
432
+
433
+ class GaussianBlurLayer(nn.Layer):
434
+ """ Add Gaussian Blur to a 4D tensors
435
+ This layer takes a 4D tensor of {N, C, H, W} as input.
436
+ The Gaussian blur will be performed in given channel number (C) splitly.
437
+ """
438
+
439
+ def __init__(self, channels, kernel_size):
440
+ """
441
+ Args:
442
+ channels (int): Channel for input tensor
443
+ kernel_size (int): Size of the kernel used in blurring
444
+ """
445
+
446
+ super(GaussianBlurLayer, self).__init__()
447
+ self.channels = channels
448
+ self.kernel_size = kernel_size
449
+ assert self.kernel_size % 2 != 0
450
+
451
+ self.op = nn.Sequential(
452
+ nn.Pad2D(
453
+ int(self.kernel_size / 2), mode='reflect'),
454
+ nn.Conv2D(
455
+ channels,
456
+ channels,
457
+ self.kernel_size,
458
+ stride=1,
459
+ padding=0,
460
+ bias_attr=False,
461
+ groups=channels))
462
+
463
+ self._init_kernel()
464
+ self.op[1].weight.stop_gradient = True
465
+
466
+ def forward(self, x):
467
+ """
468
+ Args:
469
+ x (paddle.Tensor): input 4D tensor
470
+ Returns:
471
+ paddle.Tensor: Blurred version of the input
472
+ """
473
+
474
+ if not len(list(x.shape)) == 4:
475
+ print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
476
+ exit()
477
+ elif not x.shape[1] == self.channels:
478
+ print('In \'GaussianBlurLayer\', the required channel ({0}) is'
479
+ 'not the same as input ({1})\n'.format(self.channels, x.shape[
480
+ 1]))
481
+ exit()
482
+
483
+ return self.op(x)
484
+
485
+ def _init_kernel(self):
486
+ sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
487
+
488
+ n = np.zeros((self.kernel_size, self.kernel_size))
489
+ i = int(self.kernel_size / 2)
490
+ n[i, i] = 1
491
+ kernel = scipy.ndimage.gaussian_filter(n, sigma)
492
+ kernel = kernel.astype('float32')
493
+ kernel = kernel[np.newaxis, np.newaxis, :, :]
494
+ paddle.assign(kernel, self.op[1].weight)
ppmatting/models/ppmatting.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import defaultdict
16
+ import time
17
+
18
+ import paddle
19
+ import paddle.nn as nn
20
+ import paddle.nn.functional as F
21
+ import paddleseg
22
+ from paddleseg.models import layers
23
+ from paddleseg import utils
24
+ from paddleseg.cvlibs import manager
25
+
26
+ from ppmatting.models.losses import MRSD, GradientLoss
27
+ from ppmatting.models.backbone import resnet_vd
28
+
29
+
30
+ @manager.MODELS.add_component
31
+ class PPMatting(nn.Layer):
32
+ """
33
+ The PPMattinh implementation based on PaddlePaddle.
34
+
35
+ The original article refers to
36
+ Guowei Chen, et, al. "PP-Matting: High-Accuracy Natural Image Matting"
37
+ (https://arxiv.org/pdf/2204.09433.pdf).
38
+
39
+ Args:
40
+ backbone: backbone model.
41
+ pretrained(str, optional): The path of pretrianed model. Defautl: None.
42
+
43
+ """
44
+
45
+ def __init__(self, backbone, pretrained=None):
46
+ super().__init__()
47
+ self.backbone = backbone
48
+ self.pretrained = pretrained
49
+ self.loss_func_dict = self.get_loss_func_dict()
50
+
51
+ self.backbone_channels = backbone.feat_channels
52
+
53
+ self.scb = SCB(self.backbone_channels[-1])
54
+
55
+ self.hrdb = HRDB(
56
+ self.backbone_channels[0] + self.backbone_channels[1],
57
+ scb_channels=self.scb.out_channels,
58
+ gf_index=[0, 2, 4])
59
+
60
+ self.init_weight()
61
+
62
+ def forward(self, inputs):
63
+ x = inputs['img']
64
+ input_shape = paddle.shape(x)
65
+ fea_list = self.backbone(x)
66
+
67
+ scb_logits = self.scb(fea_list[-1])
68
+ semantic_map = F.softmax(scb_logits[-1], axis=1)
69
+
70
+ fea0 = F.interpolate(
71
+ fea_list[0], input_shape[2:], mode='bilinear', align_corners=False)
72
+ fea1 = F.interpolate(
73
+ fea_list[1], input_shape[2:], mode='bilinear', align_corners=False)
74
+ hrdb_input = paddle.concat([fea0, fea1], 1)
75
+ hrdb_logit = self.hrdb(hrdb_input, scb_logits)
76
+ detail_map = F.sigmoid(hrdb_logit)
77
+ fusion = self.fusion(semantic_map, detail_map)
78
+
79
+ if self.training:
80
+ logit_dict = {
81
+ 'semantic': semantic_map,
82
+ 'detail': detail_map,
83
+ 'fusion': fusion
84
+ }
85
+ loss_dict = self.loss(logit_dict, inputs)
86
+ return logit_dict, loss_dict
87
+ else:
88
+ return fusion
89
+
90
+ def get_loss_func_dict(self):
91
+ loss_func_dict = defaultdict(list)
92
+ loss_func_dict['semantic'].append(nn.NLLLoss())
93
+ loss_func_dict['detail'].append(MRSD())
94
+ loss_func_dict['detail'].append(GradientLoss())
95
+ loss_func_dict['fusion'].append(MRSD())
96
+ loss_func_dict['fusion'].append(MRSD())
97
+ loss_func_dict['fusion'].append(GradientLoss())
98
+ return loss_func_dict
99
+
100
+ def loss(self, logit_dict, label_dict):
101
+ loss = {}
102
+
103
+ # semantic loss computation
104
+ # get semantic label
105
+ semantic_label = label_dict['trimap']
106
+ semantic_label_trans = (semantic_label == 128).astype('int64')
107
+ semantic_label_bg = (semantic_label == 0).astype('int64')
108
+ semantic_label = semantic_label_trans + semantic_label_bg * 2
109
+ loss_semantic = self.loss_func_dict['semantic'][0](
110
+ paddle.log(logit_dict['semantic'] + 1e-6),
111
+ semantic_label.squeeze(1))
112
+ loss['semantic'] = loss_semantic
113
+
114
+ # detail loss computation
115
+ transparent = label_dict['trimap'] == 128
116
+ detail_alpha_loss = self.loss_func_dict['detail'][0](
117
+ logit_dict['detail'], label_dict['alpha'], transparent)
118
+ # gradient loss
119
+ detail_gradient_loss = self.loss_func_dict['detail'][1](
120
+ logit_dict['detail'], label_dict['alpha'], transparent)
121
+ loss_detail = detail_alpha_loss + detail_gradient_loss
122
+ loss['detail'] = loss_detail
123
+ loss['detail_alpha'] = detail_alpha_loss
124
+ loss['detail_gradient'] = detail_gradient_loss
125
+
126
+ # fusion loss
127
+ loss_fusion_func = self.loss_func_dict['fusion']
128
+ # fusion_sigmoid loss
129
+ fusion_alpha_loss = loss_fusion_func[0](logit_dict['fusion'],
130
+ label_dict['alpha'])
131
+ # composion loss
132
+ comp_pred = logit_dict['fusion'] * label_dict['fg'] + (
133
+ 1 - logit_dict['fusion']) * label_dict['bg']
134
+ comp_gt = label_dict['alpha'] * label_dict['fg'] + (
135
+ 1 - label_dict['alpha']) * label_dict['bg']
136
+ fusion_composition_loss = loss_fusion_func[1](comp_pred, comp_gt)
137
+ # grandient loss
138
+ fusion_grad_loss = loss_fusion_func[2](logit_dict['fusion'],
139
+ label_dict['alpha'])
140
+ # fusion loss
141
+ loss_fusion = fusion_alpha_loss + fusion_composition_loss + fusion_grad_loss
142
+ loss['fusion'] = loss_fusion
143
+ loss['fusion_alpha'] = fusion_alpha_loss
144
+ loss['fusion_composition'] = fusion_composition_loss
145
+ loss['fusion_gradient'] = fusion_grad_loss
146
+
147
+ loss[
148
+ 'all'] = 0.25 * loss_semantic + 0.25 * loss_detail + 0.25 * loss_fusion
149
+
150
+ return loss
151
+
152
+ def fusion(self, semantic_map, detail_map):
153
+ # semantic_map [N, 3, H, W]
154
+ # In index, 0 is foreground, 1 is transition, 2 is backbone
155
+ # After fusion, the foreground is 1, the background is 0, and the transion is between [0, 1]
156
+ index = paddle.argmax(semantic_map, axis=1, keepdim=True)
157
+ transition_mask = (index == 1).astype('float32')
158
+ fg = (index == 0).astype('float32')
159
+ alpha = detail_map * transition_mask + fg
160
+ return alpha
161
+
162
+ def init_weight(self):
163
+ if self.pretrained is not None:
164
+ utils.load_entire_model(self, self.pretrained)
165
+
166
+
167
+ class SCB(nn.Layer):
168
+ def __init__(self, in_channels):
169
+ super().__init__()
170
+ self.in_channels = [512 + in_channels, 512, 256, 128, 128, 64]
171
+ self.mid_channels = [512, 256, 128, 128, 64, 64]
172
+ self.out_channels = [256, 128, 64, 64, 64, 3]
173
+
174
+ self.psp_module = layers.PPModule(
175
+ in_channels,
176
+ 512,
177
+ bin_sizes=(1, 3, 5),
178
+ dim_reduction=False,
179
+ align_corners=False)
180
+
181
+ psp_upsamples = [2, 4, 8, 16]
182
+ self.psps = nn.LayerList([
183
+ self.conv_up_psp(512, self.out_channels[i], psp_upsamples[i])
184
+ for i in range(4)
185
+ ])
186
+
187
+ scb_list = [
188
+ self._make_stage(
189
+ self.in_channels[i],
190
+ self.mid_channels[i],
191
+ self.out_channels[i],
192
+ padding=int(i == 0) + 1,
193
+ dilation=int(i == 0) + 1)
194
+ for i in range(len(self.in_channels) - 1)
195
+ ]
196
+ scb_list += [
197
+ nn.Sequential(
198
+ layers.ConvBNReLU(
199
+ self.in_channels[-1], self.mid_channels[-1], 3, padding=1),
200
+ layers.ConvBNReLU(
201
+ self.mid_channels[-1], self.mid_channels[-1], 3, padding=1),
202
+ nn.Conv2D(
203
+ self.mid_channels[-1], self.out_channels[-1], 3, padding=1))
204
+ ]
205
+ self.scb_stages = nn.LayerList(scb_list)
206
+
207
+ def forward(self, x):
208
+ psp_x = self.psp_module(x)
209
+ psps = [psp(psp_x) for psp in self.psps]
210
+
211
+ scb_logits = []
212
+ for i, scb_stage in enumerate(self.scb_stages):
213
+ if i == 0:
214
+ x = scb_stage(paddle.concat((psp_x, x), 1))
215
+ elif i <= len(psps):
216
+ x = scb_stage(paddle.concat((psps[i - 1], x), 1))
217
+ else:
218
+ x = scb_stage(x)
219
+ scb_logits.append(x)
220
+ return scb_logits
221
+
222
+ def conv_up_psp(self, in_channels, out_channels, up_sample):
223
+ return nn.Sequential(
224
+ layers.ConvBNReLU(
225
+ in_channels, out_channels, 3, padding=1),
226
+ nn.Upsample(
227
+ scale_factor=up_sample, mode='bilinear', align_corners=False))
228
+
229
+ def _make_stage(self,
230
+ in_channels,
231
+ mid_channels,
232
+ out_channels,
233
+ padding=1,
234
+ dilation=1):
235
+ layer_list = [
236
+ layers.ConvBNReLU(
237
+ in_channels, mid_channels, 3, padding=1), layers.ConvBNReLU(
238
+ mid_channels,
239
+ mid_channels,
240
+ 3,
241
+ padding=padding,
242
+ dilation=dilation), layers.ConvBNReLU(
243
+ mid_channels,
244
+ out_channels,
245
+ 3,
246
+ padding=padding,
247
+ dilation=dilation), nn.Upsample(
248
+ scale_factor=2,
249
+ mode='bilinear',
250
+ align_corners=False)
251
+ ]
252
+ return nn.Sequential(*layer_list)
253
+
254
+
255
+ class HRDB(nn.Layer):
256
+ """
257
+ The High-Resolution Detail Branch
258
+
259
+ Args:
260
+ in_channels(int): The number of input channels.
261
+ scb_channels(list|tuple): The channels of scb logits
262
+ gf_index(list|tuple, optional): Which logit is selected as guidance flow from scb logits. Default: (0, 2, 4)
263
+ """
264
+
265
+ def __init__(self, in_channels, scb_channels, gf_index=(0, 2, 4)):
266
+ super().__init__()
267
+ self.gf_index = gf_index
268
+ self.gf_list = nn.LayerList(
269
+ [nn.Conv2D(scb_channels[i], 1, 1) for i in gf_index])
270
+
271
+ channels = [64, 32, 16, 8]
272
+ self.res_list = [
273
+ resnet_vd.BasicBlock(
274
+ in_channels, channels[0], stride=1, shortcut=False)
275
+ ]
276
+ self.res_list += [
277
+ resnet_vd.BasicBlock(
278
+ i, i, stride=1) for i in channels[1:-1]
279
+ ]
280
+ self.res_list = nn.LayerList(self.res_list)
281
+
282
+ self.convs = nn.LayerList([
283
+ nn.Conv2D(
284
+ channels[i], channels[i + 1], kernel_size=1)
285
+ for i in range(len(channels) - 1)
286
+ ])
287
+ self.gates = nn.LayerList(
288
+ [GatedSpatailConv2d(i, i) for i in channels[1:]])
289
+
290
+ self.detail_conv = nn.Conv2D(channels[-1], 1, 1, bias_attr=False)
291
+
292
+ def forward(self, x, scb_logits):
293
+ for i in range(len(self.res_list)):
294
+ x = self.res_list[i](x)
295
+ x = self.convs[i](x)
296
+ gf = self.gf_list[i](scb_logits[self.gf_index[i]])
297
+ gf = F.interpolate(
298
+ gf, paddle.shape(x)[-2:], mode='bilinear', align_corners=False)
299
+ x = self.gates[i](x, gf)
300
+ return self.detail_conv(x)
301
+
302
+
303
+ class GatedSpatailConv2d(nn.Layer):
304
+ def __init__(self,
305
+ in_channels,
306
+ out_channels,
307
+ kernel_size=1,
308
+ stride=1,
309
+ padding=0,
310
+ dilation=1,
311
+ groups=1,
312
+ bias_attr=False):
313
+ super().__init__()
314
+ self._gate_conv = nn.Sequential(
315
+ layers.SyncBatchNorm(in_channels + 1),
316
+ nn.Conv2D(
317
+ in_channels + 1, in_channels + 1, kernel_size=1),
318
+ nn.ReLU(),
319
+ nn.Conv2D(
320
+ in_channels + 1, 1, kernel_size=1),
321
+ layers.SyncBatchNorm(1),
322
+ nn.Sigmoid())
323
+ self.conv = nn.Conv2D(
324
+ in_channels,
325
+ out_channels,
326
+ kernel_size=kernel_size,
327
+ stride=stride,
328
+ padding=padding,
329
+ dilation=dilation,
330
+ groups=groups,
331
+ bias_attr=bias_attr)
332
+
333
+ def forward(self, input_features, gating_features):
334
+ cat = paddle.concat([input_features, gating_features], axis=1)
335
+ alphas = self._gate_conv(cat)
336
+ x = input_features * (alphas + 1)
337
+ x = self.conv(x)
338
+ return x
ppmatting/transforms/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .transforms import *
ppmatting/transforms/transforms.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import random
17
+ import string
18
+
19
+ import cv2
20
+ import numpy as np
21
+ from paddleseg.transforms import functional
22
+ from paddleseg.cvlibs import manager
23
+ from paddleseg.utils import seg_env
24
+ from PIL import Image
25
+
26
+
27
+ @manager.TRANSFORMS.add_component
28
+ class Compose:
29
+ """
30
+ Do transformation on input data with corresponding pre-processing and augmentation operations.
31
+ The shape of input data to all operations is [height, width, channels].
32
+ """
33
+
34
+ def __init__(self, transforms, to_rgb=True):
35
+ if not isinstance(transforms, list):
36
+ raise TypeError('The transforms must be a list!')
37
+ self.transforms = transforms
38
+ self.to_rgb = to_rgb
39
+
40
+ def __call__(self, data):
41
+ """
42
+ Args:
43
+ data (dict): The data to transform.
44
+
45
+ Returns:
46
+ dict: Data after transformation
47
+ """
48
+ if 'trans_info' not in data:
49
+ data['trans_info'] = []
50
+ for op in self.transforms:
51
+ data = op(data)
52
+ if data is None:
53
+ return None
54
+
55
+ data['img'] = np.transpose(data['img'], (2, 0, 1))
56
+ for key in data.get('gt_fields', []):
57
+ if len(data[key].shape) == 2:
58
+ continue
59
+ data[key] = np.transpose(data[key], (2, 0, 1))
60
+
61
+ return data
62
+
63
+
64
+ @manager.TRANSFORMS.add_component
65
+ class LoadImages:
66
+ def __init__(self, to_rgb=False):
67
+ self.to_rgb = to_rgb
68
+
69
+ def __call__(self, data):
70
+ if isinstance(data['img'], str):
71
+ data['img'] = cv2.imread(data['img'])
72
+ for key in data.get('gt_fields', []):
73
+ if isinstance(data[key], str):
74
+ data[key] = cv2.imread(data[key], cv2.IMREAD_UNCHANGED)
75
+ # if alpha and trimap has 3 channels, extract one.
76
+ if key in ['alpha', 'trimap']:
77
+ if len(data[key].shape) > 2:
78
+ data[key] = data[key][:, :, 0]
79
+
80
+ if self.to_rgb:
81
+ data['img'] = cv2.cvtColor(data['img'], cv2.COLOR_BGR2RGB)
82
+ for key in data.get('gt_fields', []):
83
+ if len(data[key].shape) == 2:
84
+ continue
85
+ data[key] = cv2.cvtColor(data[key], cv2.COLOR_BGR2RGB)
86
+
87
+ return data
88
+
89
+
90
+ @manager.TRANSFORMS.add_component
91
+ class Resize:
92
+ def __init__(self, target_size=(512, 512), random_interp=False):
93
+ if isinstance(target_size, list) or isinstance(target_size, tuple):
94
+ if len(target_size) != 2:
95
+ raise ValueError(
96
+ '`target_size` should include 2 elements, but it is {}'.
97
+ format(target_size))
98
+ else:
99
+ raise TypeError(
100
+ "Type of `target_size` is invalid. It should be list or tuple, but it is {}"
101
+ .format(type(target_size)))
102
+
103
+ self.target_size = target_size
104
+ self.random_interp = random_interp
105
+ self.interps = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC]
106
+
107
+ def __call__(self, data):
108
+ if self.random_interp:
109
+ interp = np.random.choice(self.interps)
110
+ else:
111
+ interp = cv2.INTER_LINEAR
112
+ data['trans_info'].append(('resize', data['img'].shape[0:2]))
113
+ data['img'] = functional.resize(data['img'], self.target_size, interp)
114
+ for key in data.get('gt_fields', []):
115
+ if key == 'trimap':
116
+ data[key] = functional.resize(data[key], self.target_size,
117
+ cv2.INTER_NEAREST)
118
+ else:
119
+ data[key] = functional.resize(data[key], self.target_size,
120
+ interp)
121
+ return data
122
+
123
+
124
+ @manager.TRANSFORMS.add_component
125
+ class RandomResize:
126
+ """
127
+ Resize image to a size determinned by `scale` and `size`.
128
+
129
+ Args:
130
+ size(tuple|list): The reference size to resize. A tuple or list with length 2.
131
+ scale(tupel|list, optional): A range of scale base on `size`. A tuple or list with length 2. Default: None.
132
+ """
133
+
134
+ def __init__(self, size=None, scale=None):
135
+ if isinstance(size, list) or isinstance(size, tuple):
136
+ if len(size) != 2:
137
+ raise ValueError(
138
+ '`size` should include 2 elements, but it is {}'.format(
139
+ size))
140
+ elif size is not None:
141
+ raise TypeError(
142
+ "Type of `size` is invalid. It should be list or tuple, but it is {}"
143
+ .format(type(size)))
144
+
145
+ if scale is not None:
146
+ if isinstance(scale, list) or isinstance(scale, tuple):
147
+ if len(scale) != 2:
148
+ raise ValueError(
149
+ '`scale` should include 2 elements, but it is {}'.
150
+ format(scale))
151
+ else:
152
+ raise TypeError(
153
+ "Type of `scale` is invalid. It should be list or tuple, but it is {}"
154
+ .format(type(scale)))
155
+ self.size = size
156
+ self.scale = scale
157
+
158
+ def __call__(self, data):
159
+ h, w = data['img'].shape[:2]
160
+ if self.scale is not None:
161
+ scale = np.random.uniform(self.scale[0], self.scale[1])
162
+ else:
163
+ scale = 1.
164
+ if self.size is not None:
165
+ scale_factor = max(self.size[0] / w, self.size[1] / h)
166
+ else:
167
+ scale_factor = 1
168
+ scale = scale * scale_factor
169
+
170
+ w = int(round(w * scale))
171
+ h = int(round(h * scale))
172
+ data['img'] = functional.resize(data['img'], (w, h))
173
+ for key in data.get('gt_fields', []):
174
+ if key == 'trimap':
175
+ data[key] = functional.resize(data[key], (w, h),
176
+ cv2.INTER_NEAREST)
177
+ else:
178
+ data[key] = functional.resize(data[key], (w, h))
179
+ return data
180
+
181
+
182
+ @manager.TRANSFORMS.add_component
183
+ class ResizeByLong:
184
+ """
185
+ Resize the long side of an image to given size, and then scale the other side proportionally.
186
+
187
+ Args:
188
+ long_size (int): The target size of long side.
189
+ """
190
+
191
+ def __init__(self, long_size):
192
+ self.long_size = long_size
193
+
194
+ def __call__(self, data):
195
+ data['trans_info'].append(('resize', data['img'].shape[0:2]))
196
+ data['img'] = functional.resize_long(data['img'], self.long_size)
197
+ for key in data.get('gt_fields', []):
198
+ if key == 'trimap':
199
+ data[key] = functional.resize_long(data[key], self.long_size,
200
+ cv2.INTER_NEAREST)
201
+ else:
202
+ data[key] = functional.resize_long(data[key], self.long_size)
203
+ return data
204
+
205
+
206
+ @manager.TRANSFORMS.add_component
207
+ class ResizeByShort:
208
+ """
209
+ Resize the short side of an image to given size, and then scale the other side proportionally.
210
+
211
+ Args:
212
+ short_size (int): The target size of short side.
213
+ """
214
+
215
+ def __init__(self, short_size):
216
+ self.short_size = short_size
217
+
218
+ def __call__(self, data):
219
+ data['trans_info'].append(('resize', data['img'].shape[0:2]))
220
+ data['img'] = functional.resize_short(data['img'], self.short_size)
221
+ for key in data.get('gt_fields', []):
222
+ if key == 'trimap':
223
+ data[key] = functional.resize_short(data[key], self.short_size,
224
+ cv2.INTER_NEAREST)
225
+ else:
226
+ data[key] = functional.resize_short(data[key], self.short_size)
227
+ return data
228
+
229
+
230
+ @manager.TRANSFORMS.add_component
231
+ class ResizeToIntMult:
232
+ """
233
+ Resize to some int muitple, d.g. 32.
234
+ """
235
+
236
+ def __init__(self, mult_int=32):
237
+ self.mult_int = mult_int
238
+
239
+ def __call__(self, data):
240
+ data['trans_info'].append(('resize', data['img'].shape[0:2]))
241
+
242
+ h, w = data['img'].shape[0:2]
243
+ rw = w - w % self.mult_int
244
+ rh = h - h % self.mult_int
245
+ data['img'] = functional.resize(data['img'], (rw, rh))
246
+ for key in data.get('gt_fields', []):
247
+ if key == 'trimap':
248
+ data[key] = functional.resize(data[key], (rw, rh),
249
+ cv2.INTER_NEAREST)
250
+ else:
251
+ data[key] = functional.resize(data[key], (rw, rh))
252
+
253
+ return data
254
+
255
+
256
+ @manager.TRANSFORMS.add_component
257
+ class Normalize:
258
+ """
259
+ Normalize an image.
260
+
261
+ Args:
262
+ mean (list, optional): The mean value of a data set. Default: [0.5, 0.5, 0.5].
263
+ std (list, optional): The standard deviation of a data set. Default: [0.5, 0.5, 0.5].
264
+
265
+ Raises:
266
+ ValueError: When mean/std is not list or any value in std is 0.
267
+ """
268
+
269
+ def __init__(self, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
270
+ self.mean = mean
271
+ self.std = std
272
+ if not (isinstance(self.mean,
273
+ (list, tuple)) and isinstance(self.std,
274
+ (list, tuple))):
275
+ raise ValueError(
276
+ "{}: input type is invalid. It should be list or tuple".format(
277
+ self))
278
+ from functools import reduce
279
+ if reduce(lambda x, y: x * y, self.std) == 0:
280
+ raise ValueError('{}: std is invalid!'.format(self))
281
+
282
+ def __call__(self, data):
283
+ mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
284
+ std = np.array(self.std)[np.newaxis, np.newaxis, :]
285
+ data['img'] = functional.normalize(data['img'], mean, std)
286
+ if 'fg' in data.get('gt_fields', []):
287
+ data['fg'] = functional.normalize(data['fg'], mean, std)
288
+ if 'bg' in data.get('gt_fields', []):
289
+ data['bg'] = functional.normalize(data['bg'], mean, std)
290
+
291
+ return data
292
+
293
+
294
+ @manager.TRANSFORMS.add_component
295
+ class RandomCropByAlpha:
296
+ """
297
+ Randomly crop while centered on uncertain area by a certain probability.
298
+
299
+ Args:
300
+ crop_size (tuple|list): The size you want to crop from image.
301
+ p (float): The probability centered on uncertain area.
302
+
303
+ """
304
+
305
+ def __init__(self, crop_size=((320, 320), (480, 480), (640, 640)),
306
+ prob=0.5):
307
+ self.crop_size = crop_size
308
+ self.prob = prob
309
+
310
+ def __call__(self, data):
311
+ idex = np.random.randint(low=0, high=len(self.crop_size))
312
+ crop_w, crop_h = self.crop_size[idex]
313
+
314
+ img_h = data['img'].shape[0]
315
+ img_w = data['img'].shape[1]
316
+ if np.random.rand() < self.prob:
317
+ crop_center = np.where((data['alpha'] > 0) & (data['alpha'] < 255))
318
+ center_h_array, center_w_array = crop_center
319
+ if len(center_h_array) == 0:
320
+ return data
321
+ rand_ind = np.random.randint(len(center_h_array))
322
+ center_h = center_h_array[rand_ind]
323
+ center_w = center_w_array[rand_ind]
324
+ delta_h = crop_h // 2
325
+ delta_w = crop_w // 2
326
+ start_h = max(0, center_h - delta_h)
327
+ start_w = max(0, center_w - delta_w)
328
+ else:
329
+ start_h = 0
330
+ start_w = 0
331
+ if img_h > crop_h:
332
+ start_h = np.random.randint(img_h - crop_h + 1)
333
+ if img_w > crop_w:
334
+ start_w = np.random.randint(img_w - crop_w + 1)
335
+
336
+ end_h = min(img_h, start_h + crop_h)
337
+ end_w = min(img_w, start_w + crop_w)
338
+
339
+ data['img'] = data['img'][start_h:end_h, start_w:end_w]
340
+ for key in data.get('gt_fields', []):
341
+ data[key] = data[key][start_h:end_h, start_w:end_w]
342
+
343
+ return data
344
+
345
+
346
+ @manager.TRANSFORMS.add_component
347
+ class RandomCrop:
348
+ """
349
+ Randomly crop
350
+
351
+ Args:
352
+ crop_size (tuple|list): The size you want to crop from image.
353
+ """
354
+
355
+ def __init__(self, crop_size=((320, 320), (480, 480), (640, 640))):
356
+ if not isinstance(crop_size[0], (list, tuple)):
357
+ crop_size = [crop_size]
358
+ self.crop_size = crop_size
359
+
360
+ def __call__(self, data):
361
+ idex = np.random.randint(low=0, high=len(self.crop_size))
362
+ crop_w, crop_h = self.crop_size[idex]
363
+ img_h, img_w = data['img'].shape[0:2]
364
+
365
+ start_h = 0
366
+ start_w = 0
367
+ if img_h > crop_h:
368
+ start_h = np.random.randint(img_h - crop_h + 1)
369
+ if img_w > crop_w:
370
+ start_w = np.random.randint(img_w - crop_w + 1)
371
+
372
+ end_h = min(img_h, start_h + crop_h)
373
+ end_w = min(img_w, start_w + crop_w)
374
+
375
+ data['img'] = data['img'][start_h:end_h, start_w:end_w]
376
+ for key in data.get('gt_fields', []):
377
+ data[key] = data[key][start_h:end_h, start_w:end_w]
378
+
379
+ return data
380
+
381
+
382
+ @manager.TRANSFORMS.add_component
383
+ class LimitLong:
384
+ """
385
+ Limit the long edge of image.
386
+
387
+ If the long edge is larger than max_long, resize the long edge
388
+ to max_long, while scale the short edge proportionally.
389
+
390
+ If the long edge is smaller than min_long, resize the long edge
391
+ to min_long, while scale the short edge proportionally.
392
+
393
+ Args:
394
+ max_long (int, optional): If the long edge of image is larger than max_long,
395
+ it will be resize to max_long. Default: None.
396
+ min_long (int, optional): If the long edge of image is smaller than min_long,
397
+ it will be resize to min_long. Default: None.
398
+ """
399
+
400
+ def __init__(self, max_long=None, min_long=None):
401
+ if max_long is not None:
402
+ if not isinstance(max_long, int):
403
+ raise TypeError(
404
+ "Type of `max_long` is invalid. It should be int, but it is {}"
405
+ .format(type(max_long)))
406
+ if min_long is not None:
407
+ if not isinstance(min_long, int):
408
+ raise TypeError(
409
+ "Type of `min_long` is invalid. It should be int, but it is {}"
410
+ .format(type(min_long)))
411
+ if (max_long is not None) and (min_long is not None):
412
+ if min_long > max_long:
413
+ raise ValueError(
414
+ '`max_long should not smaller than min_long, but they are {} and {}'
415
+ .format(max_long, min_long))
416
+ self.max_long = max_long
417
+ self.min_long = min_long
418
+
419
+ def __call__(self, data):
420
+ h, w = data['img'].shape[:2]
421
+ long_edge = max(h, w)
422
+ target = long_edge
423
+ if (self.max_long is not None) and (long_edge > self.max_long):
424
+ target = self.max_long
425
+ elif (self.min_long is not None) and (long_edge < self.min_long):
426
+ target = self.min_long
427
+
428
+ data['trans_info'].append(('resize', data['img'].shape[0:2]))
429
+ if target != long_edge:
430
+ data['img'] = functional.resize_long(data['img'], target)
431
+ for key in data.get('gt_fields', []):
432
+ if key == 'trimap':
433
+ data[key] = functional.resize_long(data[key], target,
434
+ cv2.INTER_NEAREST)
435
+ else:
436
+ data[key] = functional.resize_long(data[key], target)
437
+
438
+ return data
439
+
440
+
441
+ @manager.TRANSFORMS.add_component
442
+ class LimitShort:
443
+ """
444
+ Limit the short edge of image.
445
+
446
+ If the short edge is larger than max_short, resize the short edge
447
+ to max_short, while scale the long edge proportionally.
448
+
449
+ If the short edge is smaller than min_short, resize the short edge
450
+ to min_short, while scale the long edge proportionally.
451
+
452
+ Args:
453
+ max_short (int, optional): If the short edge of image is larger than max_short,
454
+ it will be resize to max_short. Default: None.
455
+ min_short (int, optional): If the short edge of image is smaller than min_short,
456
+ it will be resize to min_short. Default: None.
457
+ """
458
+
459
+ def __init__(self, max_short=None, min_short=None):
460
+ if max_short is not None:
461
+ if not isinstance(max_short, int):
462
+ raise TypeError(
463
+ "Type of `max_short` is invalid. It should be int, but it is {}"
464
+ .format(type(max_short)))
465
+ if min_short is not None:
466
+ if not isinstance(min_short, int):
467
+ raise TypeError(
468
+ "Type of `min_short` is invalid. It should be int, but it is {}"
469
+ .format(type(min_short)))
470
+ if (max_short is not None) and (min_short is not None):
471
+ if min_short > max_short:
472
+ raise ValueError(
473
+ '`max_short should not smaller than min_short, but they are {} and {}'
474
+ .format(max_short, min_short))
475
+ self.max_short = max_short
476
+ self.min_short = min_short
477
+
478
+ def __call__(self, data):
479
+ h, w = data['img'].shape[:2]
480
+ short_edge = min(h, w)
481
+ target = short_edge
482
+ if (self.max_short is not None) and (short_edge > self.max_short):
483
+ target = self.max_short
484
+ elif (self.min_short is not None) and (short_edge < self.min_short):
485
+ target = self.min_short
486
+
487
+ data['trans_info'].append(('resize', data['img'].shape[0:2]))
488
+ if target != short_edge:
489
+ data['img'] = functional.resize_short(data['img'], target)
490
+ for key in data.get('gt_fields', []):
491
+ if key == 'trimap':
492
+ data[key] = functional.resize_short(data[key], target,
493
+ cv2.INTER_NEAREST)
494
+ else:
495
+ data[key] = functional.resize_short(data[key], target)
496
+
497
+ return data
498
+
499
+
500
+ @manager.TRANSFORMS.add_component
501
+ class RandomHorizontalFlip:
502
+ """
503
+ Flip an image horizontally with a certain probability.
504
+
505
+ Args:
506
+ prob (float, optional): A probability of horizontally flipping. Default: 0.5.
507
+ """
508
+
509
+ def __init__(self, prob=0.5):
510
+ self.prob = prob
511
+
512
+ def __call__(self, data):
513
+ if random.random() < self.prob:
514
+ data['img'] = functional.horizontal_flip(data['img'])
515
+ for key in data.get('gt_fields', []):
516
+ data[key] = functional.horizontal_flip(data[key])
517
+
518
+ return data
519
+
520
+
521
+ @manager.TRANSFORMS.add_component
522
+ class RandomBlur:
523
+ """
524
+ Blurring an image by a Gaussian function with a certain probability.
525
+
526
+ Args:
527
+ prob (float, optional): A probability of blurring an image. Default: 0.1.
528
+ """
529
+
530
+ def __init__(self, prob=0.1):
531
+ self.prob = prob
532
+
533
+ def __call__(self, data):
534
+ if self.prob <= 0:
535
+ n = 0
536
+ elif self.prob >= 1:
537
+ n = 1
538
+ else:
539
+ n = int(1.0 / self.prob)
540
+ if n > 0:
541
+ if np.random.randint(0, n) == 0:
542
+ radius = np.random.randint(3, 10)
543
+ if radius % 2 != 1:
544
+ radius = radius + 1
545
+ if radius > 9:
546
+ radius = 9
547
+ data['img'] = cv2.GaussianBlur(data['img'], (radius, radius), 0,
548
+ 0)
549
+ for key in data.get('gt_fields', []):
550
+ if key == 'trimap':
551
+ continue
552
+ data[key] = cv2.GaussianBlur(data[key], (radius, radius), 0,
553
+ 0)
554
+ return data
555
+
556
+
557
+ @manager.TRANSFORMS.add_component
558
+ class RandomDistort:
559
+ """
560
+ Distort an image with random configurations.
561
+
562
+ Args:
563
+ brightness_range (float, optional): A range of brightness. Default: 0.5.
564
+ brightness_prob (float, optional): A probability of adjusting brightness. Default: 0.5.
565
+ contrast_range (float, optional): A range of contrast. Default: 0.5.
566
+ contrast_prob (float, optional): A probability of adjusting contrast. Default: 0.5.
567
+ saturation_range (float, optional): A range of saturation. Default: 0.5.
568
+ saturation_prob (float, optional): A probability of adjusting saturation. Default: 0.5.
569
+ hue_range (int, optional): A range of hue. Default: 18.
570
+ hue_prob (float, optional): A probability of adjusting hue. Default: 0.5.
571
+ """
572
+
573
+ def __init__(self,
574
+ brightness_range=0.5,
575
+ brightness_prob=0.5,
576
+ contrast_range=0.5,
577
+ contrast_prob=0.5,
578
+ saturation_range=0.5,
579
+ saturation_prob=0.5,
580
+ hue_range=18,
581
+ hue_prob=0.5):
582
+ self.brightness_range = brightness_range
583
+ self.brightness_prob = brightness_prob
584
+ self.contrast_range = contrast_range
585
+ self.contrast_prob = contrast_prob
586
+ self.saturation_range = saturation_range
587
+ self.saturation_prob = saturation_prob
588
+ self.hue_range = hue_range
589
+ self.hue_prob = hue_prob
590
+
591
+ def __call__(self, data):
592
+ brightness_lower = 1 - self.brightness_range
593
+ brightness_upper = 1 + self.brightness_range
594
+ contrast_lower = 1 - self.contrast_range
595
+ contrast_upper = 1 + self.contrast_range
596
+ saturation_lower = 1 - self.saturation_range
597
+ saturation_upper = 1 + self.saturation_range
598
+ hue_lower = -self.hue_range
599
+ hue_upper = self.hue_range
600
+ ops = [
601
+ functional.brightness, functional.contrast, functional.saturation,
602
+ functional.hue
603
+ ]
604
+ random.shuffle(ops)
605
+ params_dict = {
606
+ 'brightness': {
607
+ 'brightness_lower': brightness_lower,
608
+ 'brightness_upper': brightness_upper
609
+ },
610
+ 'contrast': {
611
+ 'contrast_lower': contrast_lower,
612
+ 'contrast_upper': contrast_upper
613
+ },
614
+ 'saturation': {
615
+ 'saturation_lower': saturation_lower,
616
+ 'saturation_upper': saturation_upper
617
+ },
618
+ 'hue': {
619
+ 'hue_lower': hue_lower,
620
+ 'hue_upper': hue_upper
621
+ }
622
+ }
623
+ prob_dict = {
624
+ 'brightness': self.brightness_prob,
625
+ 'contrast': self.contrast_prob,
626
+ 'saturation': self.saturation_prob,
627
+ 'hue': self.hue_prob
628
+ }
629
+
630
+ im = data['img'].astype('uint8')
631
+ im = Image.fromarray(im)
632
+ for id in range(len(ops)):
633
+ params = params_dict[ops[id].__name__]
634
+ params['im'] = im
635
+ prob = prob_dict[ops[id].__name__]
636
+ if np.random.uniform(0, 1) < prob:
637
+ im = ops[id](**params)
638
+ data['img'] = np.asarray(im)
639
+
640
+ for key in data.get('gt_fields', []):
641
+ if key in ['alpha', 'trimap']:
642
+ continue
643
+ else:
644
+ im = data[key].astype('uint8')
645
+ im = Image.fromarray(im)
646
+ for id in range(len(ops)):
647
+ params = params_dict[ops[id].__name__]
648
+ params['im'] = im
649
+ prob = prob_dict[ops[id].__name__]
650
+ if np.random.uniform(0, 1) < prob:
651
+ im = ops[id](**params)
652
+ data[key] = np.asarray(im)
653
+ return data
654
+
655
+
656
+ @manager.TRANSFORMS.add_component
657
+ class Padding:
658
+ """
659
+ Add bottom-right padding to a raw image or annotation image.
660
+
661
+ Args:
662
+ target_size (list|tuple): The target size after padding.
663
+ im_padding_value (list, optional): The padding value of raw image.
664
+ Default: [127.5, 127.5, 127.5].
665
+ label_padding_value (int, optional): The padding value of annotation image. Default: 255.
666
+
667
+ Raises:
668
+ TypeError: When target_size is neither list nor tuple.
669
+ ValueError: When the length of target_size is not 2.
670
+ """
671
+
672
+ def __init__(self, target_size, im_padding_value=(127.5, 127.5, 127.5)):
673
+ if isinstance(target_size, list) or isinstance(target_size, tuple):
674
+ if len(target_size) != 2:
675
+ raise ValueError(
676
+ '`target_size` should include 2 elements, but it is {}'.
677
+ format(target_size))
678
+ else:
679
+ raise TypeError(
680
+ "Type of target_size is invalid. It should be list or tuple, now is {}"
681
+ .format(type(target_size)))
682
+
683
+ self.target_size = target_size
684
+ self.im_padding_value = im_padding_value
685
+
686
+ def __call__(self, data):
687
+ im_height, im_width = data['img'].shape[0], data['img'].shape[1]
688
+ target_height = self.target_size[1]
689
+ target_width = self.target_size[0]
690
+ pad_height = max(0, target_height - im_height)
691
+ pad_width = max(0, target_width - im_width)
692
+ data['trans_info'].append(('padding', data['img'].shape[0:2]))
693
+ if (pad_height == 0) and (pad_width == 0):
694
+ return data
695
+ else:
696
+ data['img'] = cv2.copyMakeBorder(
697
+ data['img'],
698
+ 0,
699
+ pad_height,
700
+ 0,
701
+ pad_width,
702
+ cv2.BORDER_CONSTANT,
703
+ value=self.im_padding_value)
704
+ for key in data.get('gt_fields', []):
705
+ if key in ['trimap', 'alpha']:
706
+ value = 0
707
+ else:
708
+ value = self.im_padding_value
709
+ data[key] = cv2.copyMakeBorder(
710
+ data[key],
711
+ 0,
712
+ pad_height,
713
+ 0,
714
+ pad_width,
715
+ cv2.BORDER_CONSTANT,
716
+ value=value)
717
+ return data
718
+
719
+
720
+ @manager.TRANSFORMS.add_component
721
+ class RandomSharpen:
722
+ def __init__(self, prob=0.1):
723
+ if prob < 0:
724
+ self.prob = 0
725
+ elif prob > 1:
726
+ self.prob = 1
727
+ else:
728
+ self.prob = prob
729
+
730
+ def __call__(self, data):
731
+ if np.random.rand() > self.prob:
732
+ return data
733
+
734
+ radius = np.random.choice([0, 3, 5, 7, 9])
735
+ w = np.random.uniform(0.1, 0.5)
736
+ blur_img = cv2.GaussianBlur(data['img'], (radius, radius), 5)
737
+ data['img'] = cv2.addWeighted(data['img'], 1 + w, blur_img, -w, 0)
738
+ for key in data.get('gt_fields', []):
739
+ if key == 'trimap' or key == 'alpha':
740
+ continue
741
+ blur_img = cv2.GaussianBlur(data[key], (0, 0), 5)
742
+ data[key] = cv2.addWeighted(data[key], 1.5, blur_img, -0.5, 0)
743
+
744
+ return data
745
+
746
+
747
+ @manager.TRANSFORMS.add_component
748
+ class RandomNoise:
749
+ def __init__(self, prob=0.1):
750
+ if prob < 0:
751
+ self.prob = 0
752
+ elif prob > 1:
753
+ self.prob = 1
754
+ else:
755
+ self.prob = prob
756
+
757
+ def __call__(self, data):
758
+ if np.random.rand() > self.prob:
759
+ return data
760
+ mean = np.random.uniform(0, 0.04)
761
+ var = np.random.uniform(0, 0.001)
762
+ noise = np.random.normal(mean, var**0.5, data['img'].shape) * 255
763
+ data['img'] = data['img'] + noise
764
+ data['img'] = np.clip(data['img'], 0, 255)
765
+
766
+ return data
767
+
768
+
769
+ @manager.TRANSFORMS.add_component
770
+ class RandomReJpeg:
771
+ def __init__(self, prob=0.1):
772
+ if prob < 0:
773
+ self.prob = 0
774
+ elif prob > 1:
775
+ self.prob = 1
776
+ else:
777
+ self.prob = prob
778
+
779
+ def __call__(self, data):
780
+ if np.random.rand() > self.prob:
781
+ return data
782
+ q = np.random.randint(70, 95)
783
+ img = data['img'].astype('uint8')
784
+
785
+ # Ensure no conflicts between processes
786
+ tmp_name = str(os.getpid()) + '.jpg'
787
+ tmp_name = os.path.join(seg_env.TMP_HOME, tmp_name)
788
+ cv2.imwrite(tmp_name, img, [int(cv2.IMWRITE_JPEG_QUALITY), q])
789
+ data['img'] = cv2.imread(tmp_name)
790
+
791
+ return data