vivym commited on
Commit
4a582ec
1 Parent(s): d0444de
Files changed (45) hide show
  1. .gitignore +195 -0
  2. app.py +176 -0
  3. configs/modnet-hrnet_w18.yml +5 -0
  4. configs/modnet-mobilenetv2.yml +47 -0
  5. configs/modnet-resnet50_vd.yml +5 -0
  6. configs/ppmatting-1024.yml +29 -0
  7. configs/ppmatting-2048.yml +54 -0
  8. configs/ppmatting-512.yml +44 -0
  9. configs/ppmatting-hrnet_w48-composition.yml +7 -0
  10. configs/ppmatting-hrnet_w48-distinctions.yml +55 -0
  11. ppmatting/__init__.py +1 -0
  12. ppmatting/core/__init__.py +4 -0
  13. ppmatting/core/predict.py +58 -0
  14. ppmatting/core/train.py +315 -0
  15. ppmatting/core/val.py +162 -0
  16. ppmatting/core/val_ml.py +162 -0
  17. ppmatting/datasets/__init__.py +17 -0
  18. ppmatting/datasets/composition_1k.py +31 -0
  19. ppmatting/datasets/distinctions_646.py +31 -0
  20. ppmatting/datasets/matting_dataset.py +251 -0
  21. ppmatting/metrics/__init__.py +3 -0
  22. ppmatting/metrics/metric.py +278 -0
  23. ppmatting/ml/__init__.py +1 -0
  24. ppmatting/ml/methods.py +97 -0
  25. ppmatting/models/__init__.py +7 -0
  26. ppmatting/models/backbone/__init__.py +5 -0
  27. ppmatting/models/backbone/gca_enc.py +395 -0
  28. ppmatting/models/backbone/hrnet.py +835 -0
  29. ppmatting/models/backbone/mobilenet_v2.py +242 -0
  30. ppmatting/models/backbone/resnet_vd.py +368 -0
  31. ppmatting/models/backbone/vgg.py +166 -0
  32. ppmatting/models/dim.py +208 -0
  33. ppmatting/models/gca.py +305 -0
  34. ppmatting/models/human_matting.py +454 -0
  35. ppmatting/models/layers/__init__.py +15 -0
  36. ppmatting/models/layers/gca_module.py +211 -0
  37. ppmatting/models/losses/__init__.py +1 -0
  38. ppmatting/models/losses/loss.py +163 -0
  39. ppmatting/models/modnet.py +494 -0
  40. ppmatting/models/ppmatting.py +338 -0
  41. ppmatting/transforms/__init__.py +1 -0
  42. ppmatting/transforms/transforms.py +791 -0
  43. ppmatting/utils/__init__.py +2 -0
  44. ppmatting/utils/estimate_foreground_ml.py +236 -0
  45. ppmatting/utils/utils.py +71 -0
.gitignore ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+
3
+ /models
4
+ /images
5
+
6
+ # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python
7
+ # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python
8
+
9
+ ### Python ###
10
+ # Byte-compiled / optimized / DLL files
11
+ __pycache__/
12
+ *.py[cod]
13
+ *$py.class
14
+
15
+ # C extensions
16
+ *.so
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ share/python-wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+ MANIFEST
37
+
38
+ # PyInstaller
39
+ # Usually these files are written by a python script from a template
40
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
+ *.manifest
42
+ *.spec
43
+
44
+ # Installer logs
45
+ pip-log.txt
46
+ pip-delete-this-directory.txt
47
+
48
+ # Unit test / coverage reports
49
+ htmlcov/
50
+ .tox/
51
+ .nox/
52
+ .coverage
53
+ .coverage.*
54
+ .cache
55
+ nosetests.xml
56
+ coverage.xml
57
+ *.cover
58
+ *.py,cover
59
+ .hypothesis/
60
+ .pytest_cache/
61
+ cover/
62
+
63
+ # Translations
64
+ *.mo
65
+ *.pot
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+ db.sqlite3-journal
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ .pybuilder/
85
+ target/
86
+
87
+ # Jupyter Notebook
88
+ .ipynb_checkpoints
89
+
90
+ # IPython
91
+ profile_default/
92
+ ipython_config.py
93
+
94
+ # pyenv
95
+ # For a library or package, you might want to ignore these files since the code is
96
+ # intended to run in multiple environments; otherwise, check them in:
97
+ # .python-version
98
+
99
+ # pipenv
100
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
102
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
103
+ # install all needed dependencies.
104
+ #Pipfile.lock
105
+
106
+ # poetry
107
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111
+ #poetry.lock
112
+
113
+ # pdm
114
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115
+ #pdm.lock
116
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117
+ # in version control.
118
+ # https://pdm.fming.dev/#use-with-ide
119
+ .pdm.toml
120
+
121
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122
+ __pypackages__/
123
+
124
+ # Celery stuff
125
+ celerybeat-schedule
126
+ celerybeat.pid
127
+
128
+ # SageMath parsed files
129
+ *.sage.py
130
+
131
+ # Environments
132
+ .env
133
+ .venv
134
+ env/
135
+ venv/
136
+ ENV/
137
+ env.bak/
138
+ venv.bak/
139
+
140
+ # Spyder project settings
141
+ .spyderproject
142
+ .spyproject
143
+
144
+ # Rope project settings
145
+ .ropeproject
146
+
147
+ # mkdocs documentation
148
+ /site
149
+
150
+ # mypy
151
+ .mypy_cache/
152
+ .dmypy.json
153
+ dmypy.json
154
+
155
+ # Pyre type checker
156
+ .pyre/
157
+
158
+ # pytype static type analyzer
159
+ .pytype/
160
+
161
+ # Cython debug symbols
162
+ cython_debug/
163
+
164
+ # PyCharm
165
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
168
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169
+ #.idea/
170
+
171
+ ### Python Patch ###
172
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
173
+ poetry.toml
174
+
175
+
176
+ ### VisualStudioCode ###
177
+ .vscode/*
178
+ !.vscode/settings.json
179
+ !.vscode/tasks.json
180
+ !.vscode/launch.json
181
+ !.vscode/extensions.json
182
+ !.vscode/*.code-snippets
183
+
184
+ # Local History for Visual Studio Code
185
+ .history/
186
+
187
+ # Built Visual Studio Code Extensions
188
+ *.vsix
189
+
190
+ ### VisualStudioCode Patch ###
191
+ # Ignore all local history of files
192
+ .history
193
+ .ionide
194
+
195
+ # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hashlib import sha1
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ from paddleseg.cvlibs import manager, Config
10
+ from paddleseg.utils import load_entire_model
11
+
12
+ manager.BACKBONES._components_dict.clear()
13
+ manager.TRANSFORMS._components_dict.clear()
14
+
15
+ import ppmatting as ppmatting
16
+ from ppmatting.core import predict
17
+ from ppmatting.utils import estimate_foreground_ml
18
+
19
+ model_names = [
20
+ "modnet-mobilenetv2",
21
+ "ppmatting-512",
22
+ "ppmatting-1024",
23
+ "ppmatting-2048",
24
+ "modnet-hrnet_w18",
25
+ "modnet-resnet50_vd",
26
+ ]
27
+ model_dict = {
28
+ name: None
29
+ for name in model_names
30
+ }
31
+
32
+ last_result = {
33
+ "cache_key": None,
34
+ "algorithm": None,
35
+ }
36
+
37
+
38
+ def image_matting(
39
+ image: np.ndarray,
40
+ result_type: str,
41
+ bg_color: str,
42
+ algorithm: str,
43
+ morph_op: str,
44
+ morph_op_factor: float,
45
+ ) -> np.ndarray:
46
+ image = np.ascontiguousarray(image)
47
+ cache_key = sha1(image).hexdigest()
48
+ if cache_key == last_result["cache_key"] and algorithm == last_result["algorithm"]:
49
+ alpha = last_result["alpha"]
50
+ else:
51
+ cfg = Config(f"configs/{algorithm}.yml")
52
+ if model_dict[algorithm] is not None:
53
+ model = model_dict[algorithm]
54
+ else:
55
+ model = cfg.model
56
+ load_entire_model(model, f"models/{algorithm}.pdparams")
57
+ model.eval()
58
+ model_dict[algorithm] = model
59
+
60
+ transforms = ppmatting.transforms.Compose(cfg.val_transforms)
61
+
62
+ alpha = predict(
63
+ model,
64
+ transforms=transforms,
65
+ image=image,
66
+ )
67
+ last_result["cache_key"] = cache_key
68
+ last_result["algorithm"] = algorithm
69
+ last_result["alpha"] = alpha
70
+
71
+ alpha = (alpha * 255).astype(np.uint8)
72
+ kernel = np.ones((5, 5), np.uint8)
73
+ if morph_op == "dilate":
74
+ alpha = cv2.dilate(alpha, kernel, iterations=int(morph_op_factor))
75
+ else:
76
+ alpha = cv2.erode(alpha, kernel, iterations=int(morph_op_factor))
77
+ alpha = (alpha / 255).astype(np.float32)
78
+
79
+ image = (image / 255.0).astype("float32")
80
+ fg = estimate_foreground_ml(image, alpha)
81
+
82
+ if result_type == "Remove BG":
83
+ result = np.concatenate((fg, alpha[:, :, None]), axis=-1)
84
+ elif result_type == "Replace BG":
85
+ bg_r = int(bg_color[1:3], base=16)
86
+ bg_g = int(bg_color[3:5], base=16)
87
+ bg_b = int(bg_color[5:7], base=16)
88
+
89
+ bg = np.zeros_like(fg)
90
+ bg[:, :, 0] = bg_r / 255.
91
+ bg[:, :, 1] = bg_g / 255.
92
+ bg[:, :, 2] = bg_b / 255.
93
+
94
+ result = alpha[:, :, None] * fg + (1 - alpha[:, :, None]) * bg
95
+ result = np.clip(result, 0, 1)
96
+ else:
97
+ result = alpha
98
+
99
+ return result
100
+
101
+
102
+ def main():
103
+ images_path = Path("images")
104
+ if not images_path.exists():
105
+ images_path.mkdir()
106
+
107
+ with gr.Blocks() as app:
108
+ gr.Markdown("Image Matting Powered By AI")
109
+
110
+ with gr.Row(variant="panel"):
111
+ image_input = gr.Image()
112
+ image_output = gr.Image()
113
+
114
+ with gr.Row(variant="panel"):
115
+ result_type = gr.Radio(
116
+ label="Mode",
117
+ show_label=True,
118
+ choices=[
119
+ "Remove BG",
120
+ "Replace BG",
121
+ "Generate Mask",
122
+ ],
123
+ value="Remove BG",
124
+ )
125
+ bg_color = gr.ColorPicker(
126
+ label="BG Color",
127
+ show_label=True,
128
+ value="#000000",
129
+ )
130
+ algorithm = gr.Dropdown(
131
+ label="Algorithm",
132
+ show_label=True,
133
+ choices=model_names,
134
+ value="modnet-hrnet_w18"
135
+ )
136
+
137
+ with gr.Row(variant="panel"):
138
+ morph_op = gr.Radio(
139
+ label="Post-process",
140
+ show_label=True,
141
+ choices=[
142
+ "Dilate",
143
+ "Erode",
144
+ ],
145
+ value="Dilate",
146
+ )
147
+
148
+ morph_op_factor = gr.Slider(
149
+ label="Factor",
150
+ show_label=True,
151
+ minimum=0,
152
+ maximum=20,
153
+ value=0,
154
+ step=1,
155
+ )
156
+
157
+ run_button = gr.Button("Run")
158
+
159
+ run_button.click(
160
+ image_matting,
161
+ inputs=[
162
+ image_input,
163
+ result_type,
164
+ bg_color,
165
+ algorithm,
166
+ morph_op,
167
+ morph_op_factor,
168
+ ],
169
+ outputs=image_output,
170
+ )
171
+
172
+ app.launch(share=True)
173
+
174
+
175
+ if __name__ == "__main__":
176
+ main()
configs/modnet-hrnet_w18.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _base_: modnet-mobilenetv2.yml
2
+ model:
3
+ backbone:
4
+ type: HRNet_W18
5
+ # pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
configs/modnet-mobilenetv2.yml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 16
2
+ iters: 100000
3
+
4
+ train_dataset:
5
+ type: MattingDataset
6
+ dataset_root: data/PPM-100
7
+ train_file: train.txt
8
+ transforms:
9
+ - type: LoadImages
10
+ - type: RandomCrop
11
+ crop_size: [512, 512]
12
+ - type: RandomDistort
13
+ - type: RandomBlur
14
+ - type: RandomHorizontalFlip
15
+ - type: Normalize
16
+ mode: train
17
+
18
+ val_dataset:
19
+ type: MattingDataset
20
+ dataset_root: data/PPM-100
21
+ val_file: val.txt
22
+ transforms:
23
+ - type: LoadImages
24
+ - type: ResizeByShort
25
+ short_size: 512
26
+ - type: ResizeToIntMult
27
+ mult_int: 32
28
+ - type: Normalize
29
+ mode: val
30
+ get_trimap: False
31
+
32
+ model:
33
+ type: MODNet
34
+ backbone:
35
+ type: MobileNetV2
36
+ # pretrained: https://paddleseg.bj.bcebos.com/matting/models/MobileNetV2_pretrained/model.pdparams
37
+ pretrained: Null
38
+
39
+ optimizer:
40
+ type: sgd
41
+ momentum: 0.9
42
+ weight_decay: 4.0e-5
43
+
44
+ lr_scheduler:
45
+ type: PiecewiseDecay
46
+ boundaries: [40000, 80000]
47
+ values: [0.02, 0.002, 0.0002]
configs/modnet-resnet50_vd.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _base_: modnet-mobilenetv2.yml
2
+ model:
3
+ backbone:
4
+ type: ResNet50_vd
5
+ # pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
configs/ppmatting-1024.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_: 'ppmatting-hrnet_w18-human_512.yml'
2
+
3
+
4
+ train_dataset:
5
+ transforms:
6
+ - type: LoadImages
7
+ - type: LimitShort
8
+ max_short: 1024
9
+ - type: RandomCrop
10
+ crop_size: [1024, 1024]
11
+ - type: RandomDistort
12
+ - type: RandomBlur
13
+ prob: 0.1
14
+ - type: RandomNoise
15
+ prob: 0.5
16
+ - type: RandomReJpeg
17
+ prob: 0.2
18
+ - type: RandomHorizontalFlip
19
+ - type: Normalize
20
+
21
+ val_dataset:
22
+ transforms:
23
+ - type: LoadImages
24
+ - type: LimitShort
25
+ max_short: 1024
26
+ - type: ResizeToIntMult
27
+ mult_int: 32
28
+ - type: Normalize
29
+
configs/ppmatting-2048.yml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 4
2
+ iters: 50000
3
+
4
+ train_dataset:
5
+ type: MattingDataset
6
+ dataset_root: data/PPM-100
7
+ train_file: train.txt
8
+ transforms:
9
+ - type: LoadImages
10
+ - type: RandomResize
11
+ size: [2048, 2048]
12
+ scale: [0.3, 1.5]
13
+ - type: RandomCrop
14
+ crop_size: [2048, 2048]
15
+ - type: RandomDistort
16
+ - type: RandomBlur
17
+ prob: 0.1
18
+ - type: RandomHorizontalFlip
19
+ - type: Padding
20
+ target_size: [2048, 2048]
21
+ - type: Normalize
22
+ mode: train
23
+
24
+ val_dataset:
25
+ type: MattingDataset
26
+ dataset_root: data/PPM-100
27
+ val_file: val.txt
28
+ transforms:
29
+ - type: LoadImages
30
+ - type: ResizeByShort
31
+ short_size: 2048
32
+ - type: ResizeToIntMult
33
+ mult_int: 128
34
+ - type: Normalize
35
+ mode: val
36
+ get_trimap: False
37
+
38
+ model:
39
+ type: HumanMatting
40
+ backbone:
41
+ type: ResNet34_vd
42
+ # pretrained: https://paddleseg.bj.bcebos.com/matting/models/ResNet34_vd_pretrained/model.pdparams
43
+ pretrained: Null
44
+ if_refine: True
45
+
46
+ optimizer:
47
+ type: sgd
48
+ momentum: 0.9
49
+ weight_decay: 4.0e-5
50
+
51
+ lr_scheduler:
52
+ type: PiecewiseDecay
53
+ boundaries: [30000, 40000]
54
+ values: [0.001, 0.0001, 0.00001]
configs/ppmatting-512.yml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_: 'ppmatting-hrnet_w48-distinctions.yml'
2
+
3
+ batch_size: 4
4
+ iters: 200000
5
+
6
+ train_dataset:
7
+ type: MattingDataset
8
+ dataset_root: data/PPM-100
9
+ train_file: train.txt
10
+ transforms:
11
+ - type: LoadImages
12
+ - type: LimitShort
13
+ max_short: 512
14
+ - type: RandomCrop
15
+ crop_size: [512, 512]
16
+ - type: RandomDistort
17
+ - type: RandomBlur
18
+ prob: 0.1
19
+ - type: RandomNoise
20
+ prob: 0.5
21
+ - type: RandomReJpeg
22
+ prob: 0.2
23
+ - type: RandomHorizontalFlip
24
+ - type: Normalize
25
+ mode: train
26
+
27
+ val_dataset:
28
+ type: MattingDataset
29
+ dataset_root: data/PPM-100
30
+ val_file: val.txt
31
+ transforms:
32
+ - type: LoadImages
33
+ - type: LimitShort
34
+ max_short: 512
35
+ - type: ResizeToIntMult
36
+ mult_int: 32
37
+ - type: Normalize
38
+ mode: val
39
+ get_trimap: False
40
+
41
+ model:
42
+ backbone:
43
+ type: HRNet_W18
44
+ # pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
configs/ppmatting-hrnet_w48-composition.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _base_: 'ppmatting-hrnet_w48-distinctions.yml'
2
+
3
+ train_dataset:
4
+ dataset_root: data/matting/Composition-1k
5
+
6
+ val_dataset:
7
+ dataset_root: data/matting/Composition-1k
configs/ppmatting-hrnet_w48-distinctions.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 4
2
+ iters: 300000
3
+
4
+ train_dataset:
5
+ type: MattingDataset
6
+ dataset_root: data/matting/Distinctions-646
7
+ train_file: train.txt
8
+ transforms:
9
+ - type: LoadImages
10
+ - type: Padding
11
+ target_size: [512, 512]
12
+ - type: RandomCrop
13
+ crop_size: [[512, 512],[640, 640], [800, 800]]
14
+ - type: Resize
15
+ target_size: [512, 512]
16
+ - type: RandomDistort
17
+ - type: RandomBlur
18
+ prob: 0.1
19
+ - type: RandomHorizontalFlip
20
+ - type: Normalize
21
+ mode: train
22
+ separator: '|'
23
+
24
+ val_dataset:
25
+ type: MattingDataset
26
+ dataset_root: data/matting/Distinctions-646
27
+ val_file: val.txt
28
+ transforms:
29
+ - type: LoadImages
30
+ - type: LimitShort
31
+ max_short: 1536
32
+ - type: ResizeToIntMult
33
+ mult_int: 32
34
+ - type: Normalize
35
+ mode: val
36
+ get_trimap: False
37
+ separator: '|'
38
+
39
+ model:
40
+ type: PPMatting
41
+ backbone:
42
+ type: HRNet_W48
43
+ # pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w48_ssld.tar.gz
44
+ pretrained: Null
45
+
46
+ optimizer:
47
+ type: sgd
48
+ momentum: 0.9
49
+ weight_decay: 4.0e-5
50
+
51
+ lr_scheduler:
52
+ type: PolynomialDecay
53
+ learning_rate: 0.01
54
+ end_lr: 0
55
+ power: 0.9
ppmatting/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import ml, metrics, transforms, datasets, models
ppmatting/core/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .val import evaluate
2
+ from .val_ml import evaluate_ml
3
+ from .train import train
4
+ from .predict import predict
ppmatting/core/predict.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import paddle
5
+ import paddle.nn.functional as F
6
+
7
+
8
+ def reverse_transform(alpha, trans_info):
9
+ """recover pred to origin shape"""
10
+ for item in trans_info[::-1]:
11
+ if item[0] == "resize":
12
+ h, w = item[1][0], item[1][1]
13
+ alpha = F.interpolate(alpha, [h, w], mode="bilinear")
14
+ elif item[0] == "padding":
15
+ h, w = item[1][0], item[1][1]
16
+ alpha = alpha[:, :, 0:h, 0:w]
17
+ else:
18
+ raise Exception(f"Unexpected info '{item[0]}' in im_info")
19
+
20
+ return alpha
21
+
22
+
23
+ def preprocess(img, transforms, trimap=None):
24
+ data = {}
25
+ data["img"] = img
26
+ if trimap is not None:
27
+ data["trimap"] = trimap
28
+ data["gt_fields"] = ["trimap"]
29
+ data["trans_info"] = []
30
+ data = transforms(data)
31
+ data["img"] = paddle.to_tensor(data["img"])
32
+ data["img"] = data["img"].unsqueeze(0)
33
+ if trimap is not None:
34
+ data["trimap"] = paddle.to_tensor(data["trimap"])
35
+ data["trimap"] = data["trimap"].unsqueeze((0, 1))
36
+
37
+ return data
38
+
39
+
40
+ def predict(
41
+ model,
42
+ transforms,
43
+ image: np.ndarray,
44
+ trimap: Optional[np.ndarray] = None,
45
+ ):
46
+ with paddle.no_grad():
47
+ data = preprocess(img=image, transforms=transforms, trimap=None)
48
+
49
+ alpha = model(data)
50
+
51
+ alpha = reverse_transform(alpha, data["trans_info"])
52
+ alpha = alpha.numpy().squeeze()
53
+
54
+ if trimap is not None:
55
+ alpha[trimap == 0] = 0
56
+ alpha[trimap == 255] = 1.
57
+
58
+ return alpha
ppmatting/core/train.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+ from collections import deque, defaultdict
18
+ import pickle
19
+ import shutil
20
+
21
+ import numpy as np
22
+ import paddle
23
+ import paddle.nn.functional as F
24
+ from paddleseg.utils import TimeAverager, calculate_eta, resume, logger
25
+
26
+ from .val import evaluate
27
+
28
+
29
+ def visual_in_traning(log_writer, vis_dict, step):
30
+ """
31
+ Visual in vdl
32
+
33
+ Args:
34
+ log_writer (LogWriter): The log writer of vdl.
35
+ vis_dict (dict): Dict of tensor. The shape of thesor is (C, H, W)
36
+ """
37
+ for key, value in vis_dict.items():
38
+ value_shape = value.shape
39
+ if value_shape[0] not in [1, 3]:
40
+ value = value[0]
41
+ value = value.unsqueeze(0)
42
+ value = paddle.transpose(value, (1, 2, 0))
43
+ min_v = paddle.min(value)
44
+ max_v = paddle.max(value)
45
+ if (min_v > 0) and (max_v < 1):
46
+ value = value * 255
47
+ elif (min_v < 0 and min_v >= -1) and (max_v <= 1):
48
+ value = (1 + value) / 2 * 255
49
+ else:
50
+ value = (value - min_v) / (max_v - min_v) * 255
51
+
52
+ value = value.astype('uint8')
53
+ value = value.numpy()
54
+ log_writer.add_image(tag=key, img=value, step=step)
55
+
56
+
57
+ def save_best(best_model_dir, metrics_data, iter):
58
+ with open(os.path.join(best_model_dir, 'best_metrics.txt'), 'w') as f:
59
+ for key, value in metrics_data.items():
60
+ line = key + ' ' + str(value) + '\n'
61
+ f.write(line)
62
+ f.write('iter' + ' ' + str(iter) + '\n')
63
+
64
+
65
+ def get_best(best_file, metrics, resume_model=None):
66
+ '''Get best metrics and iter from file'''
67
+ best_metrics_data = {}
68
+ if os.path.exists(best_file) and (resume_model is not None):
69
+ values = []
70
+ with open(best_file, 'r') as f:
71
+ lines = f.readlines()
72
+ for line in lines:
73
+ line = line.strip()
74
+ key, value = line.split(' ')
75
+ best_metrics_data[key] = eval(value)
76
+ if key == 'iter':
77
+ best_iter = eval(value)
78
+ else:
79
+ for key in metrics:
80
+ best_metrics_data[key] = np.inf
81
+ best_iter = -1
82
+ return best_metrics_data, best_iter
83
+
84
+
85
+ def train(model,
86
+ train_dataset,
87
+ val_dataset=None,
88
+ optimizer=None,
89
+ save_dir='output',
90
+ iters=10000,
91
+ batch_size=2,
92
+ resume_model=None,
93
+ save_interval=1000,
94
+ log_iters=10,
95
+ log_image_iters=1000,
96
+ num_workers=0,
97
+ use_vdl=False,
98
+ losses=None,
99
+ keep_checkpoint_max=5,
100
+ eval_begin_iters=None,
101
+ metrics='sad'):
102
+ """
103
+ Launch training.
104
+ Args:
105
+ model(nn.Layer): A matting model.
106
+ train_dataset (paddle.io.Dataset): Used to read and process training datasets.
107
+ val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
108
+ optimizer (paddle.optimizer.Optimizer): The optimizer.
109
+ save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
110
+ iters (int, optional): How may iters to train the model. Defualt: 10000.
111
+ batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
112
+ resume_model (str, optional): The path of resume model.
113
+ save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
114
+ log_iters (int, optional): Display logging information at every log_iters. Default: 10.
115
+ log_image_iters (int, optional): Log image to vdl. Default: 1000.
116
+ num_workers (int, optional): Num workers for data loader. Default: 0.
117
+ use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
118
+ losses (dict, optional): A dict of loss, refer to the loss function of the model for details. Default: None.
119
+ keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
120
+ eval_begin_iters (int): The iters begin evaluation. It will evaluate at iters/2 if it is None. Defalust: None.
121
+ metrics(str|list, optional): The metrics to evaluate, it may be the combination of ("sad", "mse", "grad", "conn").
122
+ """
123
+ model.train()
124
+ nranks = paddle.distributed.ParallelEnv().nranks
125
+ local_rank = paddle.distributed.ParallelEnv().local_rank
126
+
127
+ start_iter = 0
128
+ if resume_model is not None:
129
+ start_iter = resume(model, optimizer, resume_model)
130
+
131
+ if not os.path.isdir(save_dir):
132
+ if os.path.exists(save_dir):
133
+ os.remove(save_dir)
134
+ os.makedirs(save_dir)
135
+
136
+ if nranks > 1:
137
+ # Initialize parallel environment if not done.
138
+ if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
139
+ ):
140
+ paddle.distributed.init_parallel_env()
141
+ ddp_model = paddle.DataParallel(model)
142
+ else:
143
+ ddp_model = paddle.DataParallel(model)
144
+
145
+ batch_sampler = paddle.io.DistributedBatchSampler(
146
+ train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
147
+
148
+ loader = paddle.io.DataLoader(
149
+ train_dataset,
150
+ batch_sampler=batch_sampler,
151
+ num_workers=num_workers,
152
+ return_list=True, )
153
+
154
+ if use_vdl:
155
+ from visualdl import LogWriter
156
+ log_writer = LogWriter(save_dir)
157
+
158
+ if isinstance(metrics, str):
159
+ metrics = [metrics]
160
+ elif not isinstance(metrics, list):
161
+ metrics = ['sad']
162
+ best_metrics_data, best_iter = get_best(
163
+ os.path.join(save_dir, 'best_model', 'best_metrics.txt'),
164
+ metrics,
165
+ resume_model=resume_model)
166
+ avg_loss = defaultdict(float)
167
+ iters_per_epoch = len(batch_sampler)
168
+ reader_cost_averager = TimeAverager()
169
+ batch_cost_averager = TimeAverager()
170
+ save_models = deque()
171
+ batch_start = time.time()
172
+
173
+ iter = start_iter
174
+ while iter < iters:
175
+ for data in loader:
176
+ iter += 1
177
+ if iter > iters:
178
+ break
179
+ reader_cost_averager.record(time.time() - batch_start)
180
+
181
+ logit_dict, loss_dict = ddp_model(data) if nranks > 1 else model(
182
+ data)
183
+
184
+ loss_dict['all'].backward()
185
+
186
+ optimizer.step()
187
+ lr = optimizer.get_lr()
188
+ if isinstance(optimizer._learning_rate,
189
+ paddle.optimizer.lr.LRScheduler):
190
+ optimizer._learning_rate.step()
191
+ model.clear_gradients()
192
+
193
+ for key, value in loss_dict.items():
194
+ avg_loss[key] += value.numpy()[0]
195
+ batch_cost_averager.record(
196
+ time.time() - batch_start, num_samples=batch_size)
197
+
198
+ if (iter) % log_iters == 0 and local_rank == 0:
199
+ for key, value in avg_loss.items():
200
+ avg_loss[key] = value / log_iters
201
+ remain_iters = iters - iter
202
+ avg_train_batch_cost = batch_cost_averager.get_average()
203
+ avg_train_reader_cost = reader_cost_averager.get_average()
204
+ eta = calculate_eta(remain_iters, avg_train_batch_cost)
205
+ # loss info
206
+ loss_str = ' ' * 26 + '\t[LOSSES]'
207
+ loss_str = loss_str
208
+ for key, value in avg_loss.items():
209
+ if key != 'all':
210
+ loss_str = loss_str + ' ' + key + '={:.4f}'.format(
211
+ value)
212
+ logger.info(
213
+ "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.5f}, ips={:.4f} samples/sec | ETA {}\n{}\n"
214
+ .format((iter - 1) // iters_per_epoch + 1, iter, iters,
215
+ avg_loss['all'], lr, avg_train_batch_cost,
216
+ avg_train_reader_cost,
217
+ batch_cost_averager.get_ips_average(
218
+ ), eta, loss_str))
219
+ if use_vdl:
220
+ for key, value in avg_loss.items():
221
+ log_tag = 'Train/' + key
222
+ log_writer.add_scalar(log_tag, value, iter)
223
+
224
+ log_writer.add_scalar('Train/lr', lr, iter)
225
+ log_writer.add_scalar('Train/batch_cost',
226
+ avg_train_batch_cost, iter)
227
+ log_writer.add_scalar('Train/reader_cost',
228
+ avg_train_reader_cost, iter)
229
+ if iter % log_image_iters == 0:
230
+ vis_dict = {}
231
+ # ground truth
232
+ vis_dict['ground truth/img'] = data['img'][0]
233
+ for key in data['gt_fields']:
234
+ key = key[0]
235
+ vis_dict['/'.join(['ground truth', key])] = data[
236
+ key][0]
237
+ # predict
238
+ for key, value in logit_dict.items():
239
+ vis_dict['/'.join(['predict', key])] = logit_dict[
240
+ key][0]
241
+ visual_in_traning(
242
+ log_writer=log_writer, vis_dict=vis_dict, step=iter)
243
+
244
+ for key in avg_loss.keys():
245
+ avg_loss[key] = 0.
246
+ reader_cost_averager.reset()
247
+ batch_cost_averager.reset()
248
+
249
+ # save model
250
+ if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
251
+ current_save_dir = os.path.join(save_dir,
252
+ "iter_{}".format(iter))
253
+ if not os.path.isdir(current_save_dir):
254
+ os.makedirs(current_save_dir)
255
+ paddle.save(model.state_dict(),
256
+ os.path.join(current_save_dir, 'model.pdparams'))
257
+ paddle.save(optimizer.state_dict(),
258
+ os.path.join(current_save_dir, 'model.pdopt'))
259
+ save_models.append(current_save_dir)
260
+ if len(save_models) > keep_checkpoint_max > 0:
261
+ model_to_remove = save_models.popleft()
262
+ shutil.rmtree(model_to_remove)
263
+
264
+ # eval model
265
+ if eval_begin_iters is None:
266
+ eval_begin_iters = iters // 2
267
+ if (iter % save_interval == 0 or iter == iters) and (
268
+ val_dataset is not None
269
+ ) and local_rank == 0 and iter >= eval_begin_iters:
270
+ num_workers = 1 if num_workers > 0 else 0
271
+ metrics_data = evaluate(
272
+ model,
273
+ val_dataset,
274
+ num_workers=1,
275
+ print_detail=True,
276
+ save_results=False,
277
+ metrics=metrics)
278
+ model.train()
279
+
280
+ # save best model and add evaluation results to vdl
281
+ if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
282
+ if val_dataset is not None and iter >= eval_begin_iters:
283
+ if metrics_data[metrics[0]] < best_metrics_data[metrics[0]]:
284
+ best_iter = iter
285
+ best_metrics_data = metrics_data.copy()
286
+ best_model_dir = os.path.join(save_dir, "best_model")
287
+ paddle.save(
288
+ model.state_dict(),
289
+ os.path.join(best_model_dir, 'model.pdparams'))
290
+ save_best(best_model_dir, best_metrics_data, iter)
291
+
292
+ show_list = []
293
+ for key, value in best_metrics_data.items():
294
+ show_list.append((key, value))
295
+ log_str = '[EVAL] The model with the best validation {} ({:.4f}) was saved at iter {}.'.format(
296
+ show_list[0][0], show_list[0][1], best_iter)
297
+ if len(show_list) > 1:
298
+ log_str += " While"
299
+ for i in range(1, len(show_list)):
300
+ log_str = log_str + ' {}: {:.4f},'.format(
301
+ show_list[i][0], show_list[i][1])
302
+ log_str = log_str[:-1]
303
+ logger.info(log_str)
304
+
305
+ if use_vdl:
306
+ for key, value in metrics_data.items():
307
+ log_writer.add_scalar('Evaluate/' + key, value,
308
+ iter)
309
+
310
+ batch_start = time.time()
311
+
312
+ # Sleep for half a second to let dataloader release resources.
313
+ time.sleep(0.5)
314
+ if use_vdl:
315
+ log_writer.close()
ppmatting/core/val.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import cv2
18
+ import numpy as np
19
+ import time
20
+ import paddle
21
+ import paddle.nn.functional as F
22
+ from paddleseg.utils import TimeAverager, calculate_eta, logger, progbar
23
+
24
+ from ppmatting.metrics import metrics_class_dict
25
+
26
+ np.set_printoptions(suppress=True)
27
+
28
+
29
+ def save_alpha_pred(alpha, path):
30
+ """
31
+ The value of alpha is range [0, 1], shape should be [h,w]
32
+ """
33
+ dirname = os.path.dirname(path)
34
+ if not os.path.exists(dirname):
35
+ os.makedirs(dirname)
36
+
37
+ alpha = (alpha).astype('uint8')
38
+ cv2.imwrite(path, alpha)
39
+
40
+
41
+ def reverse_transform(alpha, trans_info):
42
+ """recover pred to origin shape"""
43
+ for item in trans_info[::-1]:
44
+ if item[0][0] == 'resize':
45
+ h, w = item[1][0], item[1][1]
46
+ alpha = F.interpolate(alpha, [h, w], mode='bilinear')
47
+ elif item[0][0] == 'padding':
48
+ h, w = item[1][0], item[1][1]
49
+ alpha = alpha[:, :, 0:h, 0:w]
50
+ else:
51
+ raise Exception("Unexpected info '{}' in im_info".format(item[0]))
52
+ return alpha
53
+
54
+
55
+ def evaluate(model,
56
+ eval_dataset,
57
+ num_workers=0,
58
+ print_detail=True,
59
+ save_dir='output/results',
60
+ save_results=True,
61
+ metrics='sad'):
62
+ model.eval()
63
+ nranks = paddle.distributed.ParallelEnv().nranks
64
+ local_rank = paddle.distributed.ParallelEnv().local_rank
65
+ if nranks > 1:
66
+ # Initialize parallel environment if not done.
67
+ if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
68
+ ):
69
+ paddle.distributed.init_parallel_env()
70
+
71
+ loader = paddle.io.DataLoader(
72
+ eval_dataset,
73
+ batch_size=1,
74
+ drop_last=False,
75
+ num_workers=num_workers,
76
+ return_list=True, )
77
+
78
+ total_iters = len(loader)
79
+ # Get metric instances and data saving
80
+ metrics_ins = {}
81
+ metrics_data = {}
82
+ if isinstance(metrics, str):
83
+ metrics = [metrics]
84
+ elif not isinstance(metrics, list):
85
+ metrics = ['sad']
86
+ for key in metrics:
87
+ key = key.lower()
88
+ metrics_ins[key] = metrics_class_dict[key]()
89
+ metrics_data[key] = None
90
+
91
+ if print_detail:
92
+ logger.info("Start evaluating (total_samples: {}, total_iters: {})...".
93
+ format(len(eval_dataset), total_iters))
94
+ progbar_val = progbar.Progbar(
95
+ target=total_iters, verbose=1 if nranks < 2 else 2)
96
+ reader_cost_averager = TimeAverager()
97
+ batch_cost_averager = TimeAverager()
98
+ batch_start = time.time()
99
+
100
+ img_name = ''
101
+ i = 0
102
+ with paddle.no_grad():
103
+ for iter, data in enumerate(loader):
104
+ reader_cost_averager.record(time.time() - batch_start)
105
+ alpha_pred = model(data)
106
+
107
+ alpha_pred = reverse_transform(alpha_pred, data['trans_info'])
108
+ alpha_pred = alpha_pred.numpy()
109
+
110
+ alpha_gt = data['alpha'].numpy() * 255
111
+ trimap = data.get('ori_trimap')
112
+ if trimap is not None:
113
+ trimap = trimap.numpy().astype('uint8')
114
+ alpha_pred = np.round(alpha_pred * 255)
115
+ for key in metrics_ins.keys():
116
+ metrics_data[key] = metrics_ins[key].update(alpha_pred,
117
+ alpha_gt, trimap)
118
+
119
+ if save_results:
120
+ alpha_pred_one = alpha_pred[0].squeeze()
121
+ if trimap is not None:
122
+ trimap = trimap.squeeze().astype('uint8')
123
+ alpha_pred_one[trimap == 255] = 255
124
+ alpha_pred_one[trimap == 0] = 0
125
+
126
+ save_name = data['img_name'][0]
127
+ name, ext = os.path.splitext(save_name)
128
+ if save_name == img_name:
129
+ save_name = name + '_' + str(i) + ext
130
+ i += 1
131
+ else:
132
+ img_name = save_name
133
+ save_name = name + '_' + str(i) + ext
134
+ i = 1
135
+
136
+ save_alpha_pred(alpha_pred_one,
137
+ os.path.join(save_dir, save_name))
138
+
139
+ batch_cost_averager.record(
140
+ time.time() - batch_start, num_samples=len(alpha_gt))
141
+ batch_cost = batch_cost_averager.get_average()
142
+ reader_cost = reader_cost_averager.get_average()
143
+
144
+ if local_rank == 0 and print_detail:
145
+ show_list = [(k, v) for k, v in metrics_data.items()]
146
+ show_list = show_list + [('batch_cost', batch_cost),
147
+ ('reader cost', reader_cost)]
148
+ progbar_val.update(iter + 1, show_list)
149
+
150
+ reader_cost_averager.reset()
151
+ batch_cost_averager.reset()
152
+ batch_start = time.time()
153
+
154
+ for key in metrics_ins.keys():
155
+ metrics_data[key] = metrics_ins[key].evaluate()
156
+ log_str = '[EVAL] '
157
+ for key, value in metrics_data.items():
158
+ log_str = log_str + key + ': {:.4f}, '.format(value)
159
+ log_str = log_str[:-2]
160
+
161
+ logger.info(log_str)
162
+ return metrics_data
ppmatting/core/val_ml.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import cv2
18
+ import numpy as np
19
+ import time
20
+ import paddle
21
+ import paddle.nn.functional as F
22
+ from paddleseg.utils import TimeAverager, calculate_eta, logger, progbar
23
+
24
+ from ppmatting.metrics import metric
25
+ from pymatting.util.util import load_image, save_image, stack_images
26
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
27
+
28
+ np.set_printoptions(suppress=True)
29
+
30
+
31
+ def save_alpha_pred(alpha, path):
32
+ """
33
+ The value of alpha is range [0, 1], shape should be [h,w]
34
+ """
35
+ dirname = os.path.dirname(path)
36
+ if not os.path.exists(dirname):
37
+ os.makedirs(dirname)
38
+
39
+ alpha = (alpha).astype('uint8')
40
+ cv2.imwrite(path, alpha)
41
+
42
+
43
+ def reverse_transform(alpha, trans_info):
44
+ """recover pred to origin shape"""
45
+ for item in trans_info[::-1]:
46
+ if item[0][0] == 'resize':
47
+ h, w = item[1][0].numpy()[0], item[1][1].numpy()[0]
48
+ alpha = cv2.resize(alpha, dsize=(w, h))
49
+ elif item[0][0] == 'padding':
50
+ h, w = item[1][0].numpy()[0], item[1][1].numpy()[0]
51
+ alpha = alpha[0:h, 0:w]
52
+ else:
53
+ raise Exception("Unexpected info '{}' in im_info".format(item[0]))
54
+ return alpha
55
+
56
+
57
+ def evaluate_ml(model,
58
+ eval_dataset,
59
+ num_workers=0,
60
+ print_detail=True,
61
+ save_dir='output/results',
62
+ save_results=True):
63
+
64
+ loader = paddle.io.DataLoader(
65
+ eval_dataset,
66
+ batch_size=1,
67
+ drop_last=False,
68
+ num_workers=num_workers,
69
+ return_list=True, )
70
+
71
+ total_iters = len(loader)
72
+ mse_metric = metric.MSE()
73
+ sad_metric = metric.SAD()
74
+ grad_metric = metric.Grad()
75
+ conn_metric = metric.Conn()
76
+
77
+ if print_detail:
78
+ logger.info("Start evaluating (total_samples: {}, total_iters: {})...".
79
+ format(len(eval_dataset), total_iters))
80
+ progbar_val = progbar.Progbar(target=total_iters, verbose=1)
81
+ reader_cost_averager = TimeAverager()
82
+ batch_cost_averager = TimeAverager()
83
+ batch_start = time.time()
84
+
85
+ img_name = ''
86
+ i = 0
87
+ ignore_cnt = 0
88
+ for iter, data in enumerate(loader):
89
+
90
+ reader_cost_averager.record(time.time() - batch_start)
91
+
92
+ image_rgb_chw = data['img'].numpy()[0]
93
+ image_rgb_hwc = np.transpose(image_rgb_chw, (1, 2, 0))
94
+ trimap = data['trimap'].numpy().squeeze() / 255.0
95
+ image = image_rgb_hwc * 0.5 + 0.5 # reverse normalize (x/255 - mean) / std
96
+
97
+ is_fg = trimap >= 0.9
98
+ is_bg = trimap <= 0.1
99
+
100
+ if is_fg.sum() == 0 or is_bg.sum() == 0:
101
+ ignore_cnt += 1
102
+ logger.info(str(iter))
103
+ continue
104
+
105
+ alpha_pred = model(image, trimap)
106
+
107
+ alpha_pred = reverse_transform(alpha_pred, data['trans_info'])
108
+
109
+ alpha_gt = data['alpha'].numpy().squeeze() * 255
110
+
111
+ trimap = data['ori_trimap'].numpy().squeeze()
112
+
113
+ alpha_pred = np.round(alpha_pred * 255)
114
+ mse = mse_metric.update(alpha_pred, alpha_gt, trimap)
115
+ sad = sad_metric.update(alpha_pred, alpha_gt, trimap)
116
+ grad = grad_metric.update(alpha_pred, alpha_gt, trimap)
117
+ conn = conn_metric.update(alpha_pred, alpha_gt, trimap)
118
+
119
+ if sad > 1000:
120
+ print(data['img_name'][0])
121
+
122
+ if save_results:
123
+ alpha_pred_one = alpha_pred
124
+ alpha_pred_one[trimap == 255] = 255
125
+ alpha_pred_one[trimap == 0] = 0
126
+
127
+ save_name = data['img_name'][0]
128
+ name, ext = os.path.splitext(save_name)
129
+ if save_name == img_name:
130
+ save_name = name + '_' + str(i) + ext
131
+ i += 1
132
+ else:
133
+ img_name = save_name
134
+ save_name = name + '_' + str(0) + ext
135
+ i = 1
136
+ save_alpha_pred(alpha_pred_one, os.path.join(save_dir, save_name))
137
+
138
+ batch_cost_averager.record(
139
+ time.time() - batch_start, num_samples=len(alpha_gt))
140
+ batch_cost = batch_cost_averager.get_average()
141
+ reader_cost = reader_cost_averager.get_average()
142
+
143
+ if print_detail:
144
+ progbar_val.update(iter + 1,
145
+ [('SAD', sad), ('MSE', mse), ('Grad', grad),
146
+ ('Conn', conn), ('batch_cost', batch_cost),
147
+ ('reader cost', reader_cost)])
148
+
149
+ reader_cost_averager.reset()
150
+ batch_cost_averager.reset()
151
+ batch_start = time.time()
152
+
153
+ mse = mse_metric.evaluate()
154
+ sad = sad_metric.evaluate()
155
+ grad = grad_metric.evaluate()
156
+ conn = conn_metric.evaluate()
157
+
158
+ logger.info('[EVAL] SAD: {:.4f}, MSE: {:.4f}, Grad: {:.4f}, Conn: {:.4f}'.
159
+ format(sad, mse, grad, conn))
160
+ logger.info('{}'.format(ignore_cnt))
161
+
162
+ return sad, mse, grad, conn
ppmatting/datasets/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .matting_dataset import MattingDataset
16
+ from .composition_1k import Composition1K
17
+ from .distinctions_646 import Distinctions646
ppmatting/datasets/composition_1k.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import math
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import random
21
+ import paddle
22
+ from paddleseg.cvlibs import manager
23
+
24
+ import ppmatting.transforms as T
25
+ from ppmatting.datasets.matting_dataset import MattingDataset
26
+