jgurzoni commited on
Commit
d7713d2
1 Parent(s): a381978

creating gradio app

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +168 -0
  2. README.md +3 -4
  3. app.py +40 -0
  4. big-lama/config.yaml +157 -0
  5. big-lama/models/best.ckpt +3 -0
  6. image_swapper.py +36 -0
  7. inpainter.py +110 -0
  8. maskformer.py +168 -0
  9. models/ade20k/__init__.py +1 -0
  10. models/ade20k/base.py +627 -0
  11. models/ade20k/color150.mat +0 -0
  12. models/ade20k/mobilenet.py +154 -0
  13. models/ade20k/object150_info.csv +151 -0
  14. models/ade20k/resnet.py +181 -0
  15. models/ade20k/segm_lib/nn/__init__.py +2 -0
  16. models/ade20k/segm_lib/nn/modules/__init__.py +12 -0
  17. models/ade20k/segm_lib/nn/modules/batchnorm.py +329 -0
  18. models/ade20k/segm_lib/nn/modules/comm.py +131 -0
  19. models/ade20k/segm_lib/nn/modules/replicate.py +94 -0
  20. models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py +56 -0
  21. models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py +111 -0
  22. models/ade20k/segm_lib/nn/modules/unittest.py +29 -0
  23. models/ade20k/segm_lib/nn/parallel/__init__.py +1 -0
  24. models/ade20k/segm_lib/nn/parallel/data_parallel.py +112 -0
  25. models/ade20k/segm_lib/utils/__init__.py +1 -0
  26. models/ade20k/segm_lib/utils/data/__init__.py +3 -0
  27. models/ade20k/segm_lib/utils/data/dataloader.py +425 -0
  28. models/ade20k/segm_lib/utils/data/dataset.py +118 -0
  29. models/ade20k/segm_lib/utils/data/distributed.py +58 -0
  30. models/ade20k/segm_lib/utils/data/sampler.py +131 -0
  31. models/ade20k/segm_lib/utils/th.py +41 -0
  32. models/ade20k/utils.py +40 -0
  33. requirements.txt +71 -0
  34. saicinpainting/__init__.py +0 -0
  35. saicinpainting/evaluation/__init__.py +33 -0
  36. saicinpainting/evaluation/data.py +168 -0
  37. saicinpainting/evaluation/evaluator.py +220 -0
  38. saicinpainting/evaluation/losses/__init__.py +0 -0
  39. saicinpainting/evaluation/losses/base_loss.py +528 -0
  40. saicinpainting/evaluation/losses/fid/__init__.py +0 -0
  41. saicinpainting/evaluation/losses/fid/fid_score.py +328 -0
  42. saicinpainting/evaluation/losses/fid/inception.py +323 -0
  43. saicinpainting/evaluation/losses/lpips.py +891 -0
  44. saicinpainting/evaluation/losses/ssim.py +74 -0
  45. saicinpainting/evaluation/masks/README.md +27 -0
  46. saicinpainting/evaluation/masks/__init__.py +0 -0
  47. saicinpainting/evaluation/masks/countless/.gitignore +1 -0
  48. saicinpainting/evaluation/masks/countless/README.md +25 -0
  49. saicinpainting/evaluation/masks/countless/__init__.py +0 -0
  50. saicinpainting/evaluation/masks/countless/countless2d.py +529 -0
.gitignore ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset dirs
2
+ ocus_images/
3
+ testimage/
4
+ results/
5
+
6
+ # Vstudio files
7
+ .vscode/
8
+
9
+ # Byte-compiled / optimized / DLL files
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+
14
+ # C extensions
15
+ *.so
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ share/python-wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+ MANIFEST
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .nox/
51
+ .coverage
52
+ .coverage.*
53
+ .cache
54
+ nosetests.xml
55
+ coverage.xml
56
+ *.cover
57
+ *.py,cover
58
+ .hypothesis/
59
+ .pytest_cache/
60
+ cover/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+ db.sqlite3-journal
71
+
72
+ # Flask stuff:
73
+ instance/
74
+ .webassets-cache
75
+
76
+ # Scrapy stuff:
77
+ .scrapy
78
+
79
+ # Sphinx documentation
80
+ docs/_build/
81
+
82
+ # PyBuilder
83
+ .pybuilder/
84
+ target/
85
+
86
+ # Jupyter Notebook
87
+ .ipynb_checkpoints
88
+
89
+ # IPython
90
+ profile_default/
91
+ ipython_config.py
92
+
93
+ # pyenv
94
+ # For a library or package, you might want to ignore these files since the code is
95
+ # intended to run in multiple environments; otherwise, check them in:
96
+ # .python-version
97
+
98
+ # pipenv
99
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
101
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
102
+ # install all needed dependencies.
103
+ #Pipfile.lock
104
+
105
+ # poetry
106
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
108
+ # commonly ignored for libraries.
109
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110
+ #poetry.lock
111
+
112
+ # pdm
113
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
114
+ #pdm.lock
115
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
116
+ # in version control.
117
+ # https://pdm.fming.dev/#use-with-ide
118
+ .pdm.toml
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
  title: Image Background Swapper
3
- emoji: 🐨
4
- colorFrom: purple
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.41.0
8
  app_file: app.py
9
  pinned: false
10
- license: cc
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Image Background Swapper
3
+ emoji:
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.41.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from maskformer import Mask2FormerSegmenter
3
+ from inpainter import Inpainter
4
+
5
+ import os
6
+ from tqdm import tqdm
7
+ from PIL import Image
8
+
9
+ def main(image1, image2, refine='False'):
10
+
11
+ segmenter = Mask2FormerSegmenter()
12
+ segmenter.load_models(checkpoint_name = "facebook/mask2former-swin-large-ade-semantic")
13
+ inpainter = Inpainter({'scale_factor': None, 'pad_out_to_modulo': 8, 'predict': {'out_key': 'inpainted'}})
14
+ inpainter.load_model_from_checkpoint('big-lama', 'best.ckpt')
15
+
16
+ fg_img1, mask_img1 = segmenter.retrieve_fg_image_and_mask(image1, verbose=False)
17
+ new_bg_img1 = inpainter.inpaint_img(image1, mask_img1, refine=refine)
18
+ fg_img2, mask_img2 = segmenter.retrieve_fg_image_and_mask(image2, verbose=False)
19
+ new_bg_img2 = inpainter.inpaint_img(image2, mask_img2, refine=refine)
20
+
21
+ image_a = Image.alpha_composite(new_bg_img2.convert('RGBA'), fg_img1)
22
+ image_b = Image.alpha_composite(new_bg_img1.convert('RGBA'), fg_img2)
23
+
24
+ return image_a, image_b
25
+
26
+
27
+ def process_image(image1, image2, refine=False):
28
+ img1 = Image.fromarray(image1.astype('uint8'), 'RGB')
29
+ img2 = Image.fromarray(image2.astype('uint8'), 'RGB')
30
+ return main(img1, img2, refine)
31
+
32
+
33
+ iface = gr.Interface(
34
+ fn=process_image,
35
+ inputs=["image", "image", gr.inputs.Checkbox(label="Use Refiner on background")],
36
+ outputs=["image", "image"],
37
+ title="Background Swapper App",
38
+ description="Upload two images to see their backgrounds swapped.",
39
+ )
40
+ iface.launch()
big-lama/config.yaml ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_title: b18_ffc075_batch8x15
2
+ training_model:
3
+ kind: default
4
+ visualize_each_iters: 1000
5
+ concat_mask: true
6
+ store_discr_outputs_for_vis: true
7
+ losses:
8
+ l1:
9
+ weight_missing: 0
10
+ weight_known: 10
11
+ perceptual:
12
+ weight: 0
13
+ adversarial:
14
+ kind: r1
15
+ weight: 10
16
+ gp_coef: 0.001
17
+ mask_as_fake_target: true
18
+ allow_scale_mask: true
19
+ feature_matching:
20
+ weight: 100
21
+ resnet_pl:
22
+ weight: 30
23
+ weights_path: ${env:TORCH_HOME}
24
+
25
+ optimizers:
26
+ generator:
27
+ kind: adam
28
+ lr: 0.001
29
+ discriminator:
30
+ kind: adam
31
+ lr: 0.0001
32
+ visualizer:
33
+ key_order:
34
+ - image
35
+ - predicted_image
36
+ - discr_output_fake
37
+ - discr_output_real
38
+ - inpainted
39
+ rescale_keys:
40
+ - discr_output_fake
41
+ - discr_output_real
42
+ kind: directory
43
+ outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples
44
+ location:
45
+ data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large
46
+ out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments
47
+ tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs
48
+ data:
49
+ batch_size: 15
50
+ val_batch_size: 2
51
+ num_workers: 3
52
+ train:
53
+ indir: ${location.data_root_dir}/train
54
+ out_size: 256
55
+ mask_gen_kwargs:
56
+ irregular_proba: 1
57
+ irregular_kwargs:
58
+ max_angle: 4
59
+ max_len: 200
60
+ max_width: 100
61
+ max_times: 5
62
+ min_times: 1
63
+ box_proba: 1
64
+ box_kwargs:
65
+ margin: 10
66
+ bbox_min_size: 30
67
+ bbox_max_size: 150
68
+ max_times: 3
69
+ min_times: 1
70
+ segm_proba: 0
71
+ segm_kwargs:
72
+ confidence_threshold: 0.5
73
+ max_object_area: 0.5
74
+ min_mask_area: 0.07
75
+ downsample_levels: 6
76
+ num_variants_per_mask: 1
77
+ rigidness_mode: 1
78
+ max_foreground_coverage: 0.3
79
+ max_foreground_intersection: 0.7
80
+ max_mask_intersection: 0.1
81
+ max_hidden_area: 0.1
82
+ max_scale_change: 0.25
83
+ horizontal_flip: true
84
+ max_vertical_shift: 0.2
85
+ position_shuffle: true
86
+ transform_variant: distortions
87
+ dataloader_kwargs:
88
+ batch_size: ${data.batch_size}
89
+ shuffle: true
90
+ num_workers: ${data.num_workers}
91
+ val:
92
+ indir: ${location.data_root_dir}/val
93
+ img_suffix: .png
94
+ dataloader_kwargs:
95
+ batch_size: ${data.val_batch_size}
96
+ shuffle: false
97
+ num_workers: ${data.num_workers}
98
+ visual_test:
99
+ indir: ${location.data_root_dir}/korean_test
100
+ img_suffix: _input.png
101
+ pad_out_to_modulo: 32
102
+ dataloader_kwargs:
103
+ batch_size: 1
104
+ shuffle: false
105
+ num_workers: ${data.num_workers}
106
+ generator:
107
+ kind: ffc_resnet
108
+ input_nc: 4
109
+ output_nc: 3
110
+ ngf: 64
111
+ n_downsampling: 3
112
+ n_blocks: 18
113
+ add_out_act: sigmoid
114
+ init_conv_kwargs:
115
+ ratio_gin: 0
116
+ ratio_gout: 0
117
+ enable_lfu: false
118
+ downsample_conv_kwargs:
119
+ ratio_gin: ${generator.init_conv_kwargs.ratio_gout}
120
+ ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin}
121
+ enable_lfu: false
122
+ resnet_conv_kwargs:
123
+ ratio_gin: 0.75
124
+ ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin}
125
+ enable_lfu: false
126
+ discriminator:
127
+ kind: pix2pixhd_nlayer
128
+ input_nc: 3
129
+ ndf: 64
130
+ n_layers: 4
131
+ evaluator:
132
+ kind: default
133
+ inpainted_key: inpainted
134
+ integral_kind: ssim_fid100_f1
135
+ trainer:
136
+ kwargs:
137
+ gpus: -1
138
+ accelerator: ddp
139
+ max_epochs: 200
140
+ gradient_clip_val: 1
141
+ log_gpu_memory: None
142
+ limit_train_batches: 25000
143
+ val_check_interval: ${trainer.kwargs.limit_train_batches}
144
+ log_every_n_steps: 1000
145
+ precision: 32
146
+ terminate_on_nan: false
147
+ check_val_every_n_epoch: 1
148
+ num_sanity_val_steps: 8
149
+ limit_val_batches: 1000
150
+ replace_sampler_ddp: false
151
+ checkpoint_kwargs:
152
+ verbose: true
153
+ save_top_k: 5
154
+ save_last: true
155
+ period: 1
156
+ monitor: val_ssim_fid100_f1_total_mean
157
+ mode: max
big-lama/models/best.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fccb7adffd53ec0974ee5503c3731c2c2f1e7e07856fd9228cdcc0b46fd5d423
3
+ size 410046389
image_swapper.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from maskformer import Mask2FormerSegmenter
2
+ from inpainter import Inpainter
3
+
4
+ import os
5
+ from tqdm import tqdm
6
+ from PIL import Image
7
+
8
+ if __name__ == "__main__":
9
+ dirpath = "ocus_images/final_images"
10
+ segmenter = Mask2FormerSegmenter()
11
+ segmenter.load_models(checkpoint_name = "facebook/mask2former-swin-large-ade-semantic")
12
+ inpainter = Inpainter({'scale_factor': None, 'pad_out_to_modulo': 8, 'predict': {'out_key': 'inpainted'}})
13
+ inpainter.load_model_from_checkpoint('big-lama', 'best.ckpt')
14
+
15
+ # List image files in the input directory
16
+ image_files = [file for file in os.listdir(dirpath) if file.lower().endswith(('.jpg', '.jpeg', '.png'))]
17
+
18
+ #for file in tqdm(image_files, desc="Processing images"):
19
+ for i in tqdm(range(1, len(image_files), 2), desc="Processing image pairs"):
20
+ filepath1 = os.path.join(dirpath, image_files[i-1])
21
+ filepath2 = os.path.join(dirpath, image_files[i])
22
+ image1 = Image.open(filepath1).convert('RGB')
23
+ image2 = Image.open(filepath2).convert('RGB')
24
+
25
+ fg_img1, mask_img1 = segmenter.retrieve_fg_image_and_mask(image1, verbose=False)
26
+ new_bg_img1 = inpainter.inpaint_img(image1, mask_img1, refine=False)
27
+ fg_img2, mask_img2 = segmenter.retrieve_fg_image_and_mask(image2, verbose=False)
28
+ new_bg_img2 = inpainter.inpaint_img(image2, mask_img2, refine=False)
29
+
30
+ image_a = Image.alpha_composite(new_bg_img2.convert('RGBA'), fg_img1)
31
+ image_b = Image.alpha_composite(new_bg_img1.convert('RGBA'), fg_img2)
32
+ image_a.save(f"results/joint/{os.path.basename(filepath1).split('.')[0]}_swapped.png")
33
+ image_b.save(f"results/joint/{os.path.basename(filepath2).split('.')[0]}_swapped.png")
34
+
35
+
36
+
inpainter.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import tqdm
6
+ import yaml
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+ from torch.utils.data._utils.collate import default_collate
10
+ from saicinpainting.training.trainers import load_checkpoint
11
+ from saicinpainting.evaluation.utils import move_to_device, load_image, prepare_image, pad_img_to_modulo, scale_image
12
+ from saicinpainting.evaluation.refinement import refine_predict
13
+
14
+ refiner_config = {
15
+ 'gpu_ids': '0,',
16
+ 'modulo': 8,
17
+ 'n_iters': 15,
18
+ 'lr': 0.002,
19
+ 'min_side': 512,
20
+ 'max_scales': 3,
21
+ 'px_budget': 1800000
22
+ }
23
+
24
+ class Inpainter():
25
+ def __init__(self, config):
26
+ self.model = None
27
+ self.config = config
28
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
29
+ self.scale_factor = config['scale_factor']
30
+ self.pad_out_to_modulo = config['pad_out_to_modulo']
31
+ self.predict_config = config['predict']
32
+ self.predict_config['model_path'] = 'big-lama'
33
+ self.predict_config['model_checkpoint'] = 'best.ckpt'
34
+ self.refiner_config = refiner_config
35
+
36
+ def load_model_from_checkpoint(self, model_path, checkpoint):
37
+ train_config_path = os.path.join(model_path, 'config.yaml')
38
+ with open(train_config_path, 'r') as f:
39
+ train_config = OmegaConf.create(yaml.safe_load(f))
40
+
41
+ train_config.training_model.predict_only = True
42
+ train_config.visualizer.kind = 'noop'
43
+
44
+ checkpoint_path = os.path.join(model_path,
45
+ 'models',
46
+ checkpoint)
47
+ self.model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
48
+
49
+
50
+ def load_batch_data(self, img_, mask_):
51
+ """Loads the image and mask from the given filenames.
52
+ """
53
+ image = prepare_image(img_, mode='RGB')
54
+ mask = prepare_image(mask_, mode='L')
55
+
56
+ result = dict(image=image, mask=mask[None, ...])
57
+
58
+ if self.scale_factor is not None:
59
+ result['image'] = scale_image(result['image'], self.scale_factor)
60
+ result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
61
+
62
+ if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
63
+ result['unpad_to_size'] = result['image'].shape[1:]
64
+ result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
65
+ result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
66
+
67
+ return result
68
+
69
+ def inpaint_img(self, original_img, mask_img, refine=False) -> Image:
70
+ """ Inpaints the image region defined by the given mask.
71
+ White pixels are to be masked and black pixels kept.
72
+ args:
73
+ refine: if True, uses the refinement model to enhance the inpainting result, at the cost of speed.
74
+
75
+ returns: the inpainted image
76
+ """
77
+ # in case we are given filenames instead of images
78
+ if isinstance(original_img, str):
79
+ original_img = load_image(original_img, mode='RGB')
80
+ mask_img = load_image(mask_img, mode='L')
81
+
82
+ self.model.eval()
83
+ if not refine:
84
+ self.model.to(self.device)
85
+ # load the image and mask
86
+ batch = default_collate([self.load_batch_data(original_img, mask_img)])
87
+
88
+ if refine:
89
+ assert 'unpad_to_size' in batch, "Unpadded size is required for the refinement"
90
+ # image unpadding is taken care of in the refiner, so that output image
91
+ # is same size as the input image
92
+ cur_res = refine_predict(batch, self.model, **self.refiner_config)
93
+ cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
94
+ else:
95
+ with torch.no_grad():
96
+ batch = move_to_device(batch, self.device)
97
+ batch['mask'] = (batch['mask'] > 0) * 1
98
+ batch = self.model(batch)
99
+ cur_res = batch[self.predict_config['out_key']][0].permute(1, 2, 0).detach().cpu().numpy()
100
+ unpad_to_size = batch.get('unpad_to_size', None)
101
+ if unpad_to_size is not None:
102
+ orig_height, orig_width = unpad_to_size
103
+ cur_res = cur_res[:orig_height, :orig_width]
104
+
105
+ cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
106
+ rslt_image = Image.fromarray(cur_res, 'RGB')
107
+ #cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
108
+
109
+ return rslt_image
110
+
maskformer.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from PIL import Image, ImageDraw
6
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
7
+ import torch
8
+ import cv2
9
+
10
+ def dilate_image_mask(image_mask: Image, dilate_siz=50):
11
+ # Convert the PIL image to a NumPy array
12
+ image_np = np.array(image_mask)
13
+ kernel = np.ones((dilate_siz, dilate_siz),np.uint8)
14
+ dilated_image_np = cv2.dilate(image_np, kernel, iterations = 1)
15
+ # Convert the expanded NumPy array back to PIL format
16
+ dilated_image = Image.fromarray(dilated_image_np)
17
+
18
+ return dilated_image
19
+
20
+ def get_foreground_image(image: Image, mask_array: np.ndarray):
21
+ """Returns a PIL RGBA image with the mask applied to the original image."""
22
+
23
+ # resize the overlay mask to the original image size
24
+ resized_mask = Image.fromarray(mask_array.astype(np.uint8)).resize(image.size)
25
+ resized_mask = np.array(resized_mask)
26
+
27
+ image_array = np.array(image)
28
+ # Apply binary mask element-wise using NumPy for each color channel
29
+ fg_array = image_array * resized_mask[:, :, np.newaxis]
30
+ # Create a new ndarray with 4 channels (R, G, B, A)
31
+ result_array = np.zeros((*fg_array.shape[:2], 4), dtype=np.uint8)
32
+ # Assign RGB values from the original image
33
+ result_array[:, :, :3] = fg_array
34
+ # Assign alpha values from the resized mask
35
+ result_array[:, :, 3] = resized_mask*255
36
+ result_image = Image.fromarray(result_array, mode='RGBA')
37
+
38
+ return result_image
39
+
40
+
41
+ def overlay_mask_on_image(image: Image, mask_array: np.ndarray, alpha=0.5):
42
+ original_image = image
43
+ overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 0))
44
+
45
+ # resize the overlay mask to the original image size
46
+ overlay_mask = Image.fromarray(mask_array.astype(np.uint8)*255).resize(original_image.size, resample=Image.LANCZOS)
47
+
48
+ # dilates the mask a bit to cover the edges of the objects
49
+ dilate_image_mask(overlay_mask, dilate_siz=50)
50
+
51
+ # Apply the overlay color to the overlayed array
52
+ overlay_color = (0, 240, 0, int(255*alpha)) # RGBA
53
+ draw = ImageDraw.Draw(overlay_image)
54
+ draw.bitmap((0, 0), overlay_mask, fill=overlay_color)
55
+
56
+ result_image = Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
57
+ return result_image
58
+
59
+ def filter_segment_classes(segmentation, filter_classes, mode='filt_out') -> np.ndarray:
60
+ """ Returns a boolean mask removing the values in filter_classes from the segmentation array.
61
+ mode: 'filt_out' - filter out the classes in filter_classes
62
+ 'filt_in' - keeps only the classes in filter_classes
63
+ """
64
+ # Create a boolean mask removing the values in filter_classes
65
+ if mode=='filt_out':
66
+ overlay_mask = ~np.isin(segmentation, filter_classes)
67
+ elif mode=='filt_in':
68
+ overlay_mask = np.isin(segmentation, filter_classes)
69
+ else:
70
+ raise ValueError(f'Invalid mode: {mode}')
71
+ return overlay_mask
72
+
73
+ class Mask2FormerSegmenter:
74
+ def __init__(self):
75
+ self.processor = None
76
+ self.model = None
77
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
78
+ # TODO - train a classifier to learn this from the dataset
79
+ # - classes that appear much less frequently are good candidates
80
+ self.filter_classes = [0,1,2,3,5,6,10,11,12,13,14,15,18,19,22,24,36,38,40,45,46,47,69,105,128]
81
+
82
+ def load_models(self, checkpoint_name):
83
+ self.processor = AutoImageProcessor.from_pretrained(checkpoint_name)
84
+ self.model = Mask2FormerForUniversalSegmentation.from_pretrained(checkpoint_name)
85
+ self.model.to(self.device)
86
+
87
+ @torch.no_grad()
88
+ def run_semantic_inference(self, image, model, processor)-> torch.Tensor:
89
+ """Runs semantic segmentation inference on a single image file."""
90
+
91
+ if (model is None) or (processor is None):
92
+ raise ValueError(f'Model or Processor not loaded.')
93
+
94
+ funcstart_time = time.time()
95
+
96
+ inputs = processor(image, return_tensors="pt")
97
+ inputs = inputs.to(self.device)
98
+ #Forward pass - to segment the image
99
+ outputs = model(**inputs)
100
+ #meaures the time taken for the processing and forward pass
101
+ model_time = time.time() - funcstart_time
102
+ print(f'Model time: {model_time:.2f}')
103
+
104
+ #Post Processing - Semantic Segmentation
105
+ semantic_segmentation = processor.post_process_semantic_segmentation(outputs)[0]
106
+ return semantic_segmentation
107
+
108
+ def batch_inference_demo(self, dirpath):
109
+
110
+ # List image files in the input directory
111
+ image_files = [file for file in os.listdir(dirpath) if file.lower().endswith(('.jpg', '.jpeg', '.png'))]
112
+
113
+ for file in tqdm(image_files, desc="Processing images"):
114
+ filepath = os.path.join(dirpath, file)
115
+ image = Image.open(filepath)
116
+ semantic_segmentation = self.run_semantic_inference(image, self.model, self.processor)
117
+
118
+ labels_ids = torch.unique(semantic_segmentation).tolist()
119
+ valid_ids = [label_id for label_id in labels_ids if label_id not in self.filter_classes]
120
+ print(f'{os.path.basename(file)}: {valid_ids}')
121
+
122
+ # filter out the classes in filter_classes
123
+ binary_mask = filter_segment_classes(semantic_segmentation.numpy(), self.filter_classes)
124
+
125
+ overlaid_img = overlay_mask_on_image(image, binary_mask)
126
+ foreground_img = get_foreground_image(image, binary_mask)
127
+ mask_img = Image.fromarray(binary_mask.astype(np.uint8)*255).resize(image.size)
128
+ # dilates the mask a bit
129
+ mask_img = dilate_image_mask(mask_img, dilate_siz=50)
130
+
131
+ #saves the images in the results folder
132
+ outp_folder = 'results/mask2former_masked'
133
+ overlaid_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_overlay.png")
134
+ foreground_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_foreground.png")
135
+ mask_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_mask.png")
136
+
137
+ def retrieve_fg_image_and_mask(self, input_image: Image,
138
+ dilate_siz=50,
139
+ verbose=False
140
+ ) -> (Image, Image):
141
+ """Generetes a RGBA image with the foreground objects of the input image
142
+ and a binary mask for the given image file.
143
+ input_image: PIL image
144
+ dilate_siz: size in pixels of the dilation kernel to aply on the objects' mask
145
+ verbose: if True, prints the list of classes in the image that have not been filtered
146
+ returns: foreground_img (RGBA), mask_img (L)
147
+ """
148
+
149
+ # runs the semantic segmentation model
150
+ semantic_segmentation = self.run_semantic_inference(input_image,
151
+ self.model,
152
+ self.processor)
153
+ semantic_segmentation = semantic_segmentation.cpu()
154
+
155
+ if (verbose):
156
+ labels_ids = torch.unique(semantic_segmentation).tolist()
157
+ valid_ids = [label_id for label_id in labels_ids if label_id not in self.filter_classes]
158
+ print(f'valid classes detected: {valid_ids}')
159
+
160
+ # filter out the classes in filter_classes
161
+ binary_mask = filter_segment_classes(semantic_segmentation.numpy(),
162
+ self.filter_classes)
163
+ foreground_img = get_foreground_image(input_image, binary_mask)
164
+ mask_img = Image.fromarray(binary_mask.astype(np.uint8)*255).resize(input_image.size, resample=Image.LANCZOS)
165
+ # dilates the mask a bit to cover the edges of the objects. This helps the inpainting model
166
+ mask_img = dilate_image_mask(mask_img, dilate_siz=dilate_siz)
167
+
168
+ return foreground_img, mask_img
models/ade20k/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .base import *
models/ade20k/base.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
2
+
3
+ import os
4
+
5
+ import pandas as pd
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from scipy.io import loadmat
10
+ from torch.nn.modules import BatchNorm2d
11
+
12
+ from . import resnet
13
+ from . import mobilenet
14
+
15
+
16
+ NUM_CLASS = 150
17
+ base_path = os.path.dirname(os.path.abspath(__file__)) # current file path
18
+ colors_path = os.path.join(base_path, 'color150.mat')
19
+ classes_path = os.path.join(base_path, 'object150_info.csv')
20
+
21
+ segm_options = dict(colors=loadmat(colors_path)['colors'],
22
+ classes=pd.read_csv(classes_path),)
23
+
24
+
25
+ class NormalizeTensor:
26
+ def __init__(self, mean, std, inplace=False):
27
+ """Normalize a tensor image with mean and standard deviation.
28
+ .. note::
29
+ This transform acts out of place by default, i.e., it does not mutates the input tensor.
30
+ See :class:`~torchvision.transforms.Normalize` for more details.
31
+ Args:
32
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
33
+ mean (sequence): Sequence of means for each channel.
34
+ std (sequence): Sequence of standard deviations for each channel.
35
+ inplace(bool,optional): Bool to make this operation inplace.
36
+ Returns:
37
+ Tensor: Normalized Tensor image.
38
+ """
39
+
40
+ self.mean = mean
41
+ self.std = std
42
+ self.inplace = inplace
43
+
44
+ def __call__(self, tensor):
45
+ if not self.inplace:
46
+ tensor = tensor.clone()
47
+
48
+ dtype = tensor.dtype
49
+ mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device)
50
+ std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device)
51
+ tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
52
+ return tensor
53
+
54
+
55
+ # Model Builder
56
+ class ModelBuilder:
57
+ # custom weights initialization
58
+ @staticmethod
59
+ def weights_init(m):
60
+ classname = m.__class__.__name__
61
+ if classname.find('Conv') != -1:
62
+ nn.init.kaiming_normal_(m.weight.data)
63
+ elif classname.find('BatchNorm') != -1:
64
+ m.weight.data.fill_(1.)
65
+ m.bias.data.fill_(1e-4)
66
+
67
+ @staticmethod
68
+ def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''):
69
+ pretrained = True if len(weights) == 0 else False
70
+ arch = arch.lower()
71
+ if arch == 'mobilenetv2dilated':
72
+ orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained)
73
+ net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8)
74
+ elif arch == 'resnet18':
75
+ orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
76
+ net_encoder = Resnet(orig_resnet)
77
+ elif arch == 'resnet18dilated':
78
+ orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
79
+ net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
80
+ elif arch == 'resnet50dilated':
81
+ orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
82
+ net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
83
+ elif arch == 'resnet50':
84
+ orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
85
+ net_encoder = Resnet(orig_resnet)
86
+ else:
87
+ raise Exception('Architecture undefined!')
88
+
89
+ # encoders are usually pretrained
90
+ # net_encoder.apply(ModelBuilder.weights_init)
91
+ if len(weights) > 0:
92
+ print('Loading weights for net_encoder')
93
+ net_encoder.load_state_dict(
94
+ torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
95
+ return net_encoder
96
+
97
+ @staticmethod
98
+ def build_decoder(arch='ppm_deepsup',
99
+ fc_dim=512, num_class=NUM_CLASS,
100
+ weights='', use_softmax=False, drop_last_conv=False):
101
+ arch = arch.lower()
102
+ if arch == 'ppm_deepsup':
103
+ net_decoder = PPMDeepsup(
104
+ num_class=num_class,
105
+ fc_dim=fc_dim,
106
+ use_softmax=use_softmax,
107
+ drop_last_conv=drop_last_conv)
108
+ elif arch == 'c1_deepsup':
109
+ net_decoder = C1DeepSup(
110
+ num_class=num_class,
111
+ fc_dim=fc_dim,
112
+ use_softmax=use_softmax,
113
+ drop_last_conv=drop_last_conv)
114
+ else:
115
+ raise Exception('Architecture undefined!')
116
+
117
+ net_decoder.apply(ModelBuilder.weights_init)
118
+ if len(weights) > 0:
119
+ print('Loading weights for net_decoder')
120
+ net_decoder.load_state_dict(
121
+ torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
122
+ return net_decoder
123
+
124
+ @staticmethod
125
+ def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, drop_last_conv, *arts, **kwargs):
126
+ path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth')
127
+ return ModelBuilder.build_decoder(arch=arch_decoder, fc_dim=fc_dim, weights=path, use_softmax=True, drop_last_conv=drop_last_conv)
128
+
129
+ @staticmethod
130
+ def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, segmentation,
131
+ *arts, **kwargs):
132
+ if segmentation:
133
+ path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth')
134
+ else:
135
+ path = ''
136
+ return ModelBuilder.build_encoder(arch=arch_encoder, fc_dim=fc_dim, weights=path)
137
+
138
+
139
+ def conv3x3_bn_relu(in_planes, out_planes, stride=1):
140
+ return nn.Sequential(
141
+ nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False),
142
+ BatchNorm2d(out_planes),
143
+ nn.ReLU(inplace=True),
144
+ )
145
+
146
+
147
+ class SegmentationModule(nn.Module):
148
+ def __init__(self,
149
+ weights_path,
150
+ num_classes=150,
151
+ arch_encoder="resnet50dilated",
152
+ drop_last_conv=False,
153
+ net_enc=None, # None for Default encoder
154
+ net_dec=None, # None for Default decoder
155
+ encode=None, # {None, 'binary', 'color', 'sky'}
156
+ use_default_normalization=False,
157
+ return_feature_maps=False,
158
+ return_feature_maps_level=3, # {0, 1, 2, 3}
159
+ return_feature_maps_only=True,
160
+ **kwargs,
161
+ ):
162
+ super().__init__()
163
+ self.weights_path = weights_path
164
+ self.drop_last_conv = drop_last_conv
165
+ self.arch_encoder = arch_encoder
166
+ if self.arch_encoder == "resnet50dilated":
167
+ self.arch_decoder = "ppm_deepsup"
168
+ self.fc_dim = 2048
169
+ elif self.arch_encoder == "mobilenetv2dilated":
170
+ self.arch_decoder = "c1_deepsup"
171
+ self.fc_dim = 320
172
+ else:
173
+ raise NotImplementedError(f"No such arch_encoder={self.arch_encoder}")
174
+ model_builder_kwargs = dict(arch_encoder=self.arch_encoder,
175
+ arch_decoder=self.arch_decoder,
176
+ fc_dim=self.fc_dim,
177
+ drop_last_conv=drop_last_conv,
178
+ weights_path=self.weights_path)
179
+
180
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
181
+ self.encoder = ModelBuilder.get_encoder(**model_builder_kwargs) if net_enc is None else net_enc
182
+ self.decoder = ModelBuilder.get_decoder(**model_builder_kwargs) if net_dec is None else net_dec
183
+ self.use_default_normalization = use_default_normalization
184
+ self.default_normalization = NormalizeTensor(mean=[0.485, 0.456, 0.406],
185
+ std=[0.229, 0.224, 0.225])
186
+
187
+ self.encode = encode
188
+
189
+ self.return_feature_maps = return_feature_maps
190
+
191
+ assert 0 <= return_feature_maps_level <= 3
192
+ self.return_feature_maps_level = return_feature_maps_level
193
+
194
+ def normalize_input(self, tensor):
195
+ if tensor.min() < 0 or tensor.max() > 1:
196
+ raise ValueError("Tensor should be 0..1 before using normalize_input")
197
+ return self.default_normalization(tensor)
198
+
199
+ @property
200
+ def feature_maps_channels(self):
201
+ return 256 * 2**(self.return_feature_maps_level) # 256, 512, 1024, 2048
202
+
203
+ def forward(self, img_data, segSize=None):
204
+ if segSize is None:
205
+ raise NotImplementedError("Please pass segSize param. By default: (300, 300)")
206
+
207
+ fmaps = self.encoder(img_data, return_feature_maps=True)
208
+ pred = self.decoder(fmaps, segSize=segSize)
209
+
210
+ if self.return_feature_maps:
211
+ return pred, fmaps
212
+ # print("BINARY", img_data.shape, pred.shape)
213
+ return pred
214
+
215
+ def multi_mask_from_multiclass(self, pred, classes):
216
+ def isin(ar1, ar2):
217
+ return (ar1[..., None] == ar2).any(-1).float()
218
+ return isin(pred, torch.LongTensor(classes).to(self.device))
219
+
220
+ @staticmethod
221
+ def multi_mask_from_multiclass_probs(scores, classes):
222
+ res = None
223
+ for c in classes:
224
+ if res is None:
225
+ res = scores[:, c]
226
+ else:
227
+ res += scores[:, c]
228
+ return res
229
+
230
+ def predict(self, tensor, imgSizes=(-1,), # (300, 375, 450, 525, 600)
231
+ segSize=None):
232
+ """Entry-point for segmentation. Use this methods instead of forward
233
+ Arguments:
234
+ tensor {torch.Tensor} -- BCHW
235
+ Keyword Arguments:
236
+ imgSizes {tuple or list} -- imgSizes for segmentation input.
237
+ default: (300, 450)
238
+ original implementation: (300, 375, 450, 525, 600)
239
+
240
+ """
241
+ if segSize is None:
242
+ segSize = tensor.shape[-2:]
243
+ segSize = (tensor.shape[2], tensor.shape[3])
244
+ with torch.no_grad():
245
+ if self.use_default_normalization:
246
+ tensor = self.normalize_input(tensor)
247
+ scores = torch.zeros(1, NUM_CLASS, segSize[0], segSize[1]).to(self.device)
248
+ features = torch.zeros(1, self.feature_maps_channels, segSize[0], segSize[1]).to(self.device)
249
+
250
+ result = []
251
+ for img_size in imgSizes:
252
+ if img_size != -1:
253
+ img_data = F.interpolate(tensor.clone(), size=img_size)
254
+ else:
255
+ img_data = tensor.clone()
256
+
257
+ if self.return_feature_maps:
258
+ pred_current, fmaps = self.forward(img_data, segSize=segSize)
259
+ else:
260
+ pred_current = self.forward(img_data, segSize=segSize)
261
+
262
+
263
+ result.append(pred_current)
264
+ scores = scores + pred_current / len(imgSizes)
265
+
266
+ # Disclaimer: We use and aggregate only last fmaps: fmaps[3]
267
+ if self.return_feature_maps:
268
+ features = features + F.interpolate(fmaps[self.return_feature_maps_level], size=segSize) / len(imgSizes)
269
+
270
+ _, pred = torch.max(scores, dim=1)
271
+
272
+ if self.return_feature_maps:
273
+ return features
274
+
275
+ return pred, result
276
+
277
+ def get_edges(self, t):
278
+ edge = torch.cuda.ByteTensor(t.size()).zero_()
279
+ edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
280
+ edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
281
+ edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
282
+ edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
283
+
284
+ if True:
285
+ return edge.half()
286
+ return edge.float()
287
+
288
+
289
+ # pyramid pooling, deep supervision
290
+ class PPMDeepsup(nn.Module):
291
+ def __init__(self, num_class=NUM_CLASS, fc_dim=4096,
292
+ use_softmax=False, pool_scales=(1, 2, 3, 6),
293
+ drop_last_conv=False):
294
+ super().__init__()
295
+ self.use_softmax = use_softmax
296
+ self.drop_last_conv = drop_last_conv
297
+
298
+ self.ppm = []
299
+ for scale in pool_scales:
300
+ self.ppm.append(nn.Sequential(
301
+ nn.AdaptiveAvgPool2d(scale),
302
+ nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
303
+ BatchNorm2d(512),
304
+ nn.ReLU(inplace=True)
305
+ ))
306
+ self.ppm = nn.ModuleList(self.ppm)
307
+ self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
308
+
309
+ self.conv_last = nn.Sequential(
310
+ nn.Conv2d(fc_dim + len(pool_scales) * 512, 512,
311
+ kernel_size=3, padding=1, bias=False),
312
+ BatchNorm2d(512),
313
+ nn.ReLU(inplace=True),
314
+ nn.Dropout2d(0.1),
315
+ nn.Conv2d(512, num_class, kernel_size=1)
316
+ )
317
+ self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
318
+ self.dropout_deepsup = nn.Dropout2d(0.1)
319
+
320
+ def forward(self, conv_out, segSize=None):
321
+ conv5 = conv_out[-1]
322
+
323
+ input_size = conv5.size()
324
+ ppm_out = [conv5]
325
+ for pool_scale in self.ppm:
326
+ ppm_out.append(nn.functional.interpolate(
327
+ pool_scale(conv5),
328
+ (input_size[2], input_size[3]),
329
+ mode='bilinear', align_corners=False))
330
+ ppm_out = torch.cat(ppm_out, 1)
331
+
332
+ if self.drop_last_conv:
333
+ return ppm_out
334
+ else:
335
+ x = self.conv_last(ppm_out)
336
+
337
+ if self.use_softmax: # is True during inference
338
+ x = nn.functional.interpolate(
339
+ x, size=segSize, mode='bilinear', align_corners=False)
340
+ x = nn.functional.softmax(x, dim=1)
341
+ return x
342
+
343
+ # deep sup
344
+ conv4 = conv_out[-2]
345
+ _ = self.cbr_deepsup(conv4)
346
+ _ = self.dropout_deepsup(_)
347
+ _ = self.conv_last_deepsup(_)
348
+
349
+ x = nn.functional.log_softmax(x, dim=1)
350
+ _ = nn.functional.log_softmax(_, dim=1)
351
+
352
+ return (x, _)
353
+
354
+
355
+ class Resnet(nn.Module):
356
+ def __init__(self, orig_resnet):
357
+ super(Resnet, self).__init__()
358
+
359
+ # take pretrained resnet, except AvgPool and FC
360
+ self.conv1 = orig_resnet.conv1
361
+ self.bn1 = orig_resnet.bn1
362
+ self.relu1 = orig_resnet.relu1
363
+ self.conv2 = orig_resnet.conv2
364
+ self.bn2 = orig_resnet.bn2
365
+ self.relu2 = orig_resnet.relu2
366
+ self.conv3 = orig_resnet.conv3
367
+ self.bn3 = orig_resnet.bn3
368
+ self.relu3 = orig_resnet.relu3
369
+ self.maxpool = orig_resnet.maxpool
370
+ self.layer1 = orig_resnet.layer1
371
+ self.layer2 = orig_resnet.layer2
372
+ self.layer3 = orig_resnet.layer3
373
+ self.layer4 = orig_resnet.layer4
374
+
375
+ def forward(self, x, return_feature_maps=False):
376
+ conv_out = []
377
+
378
+ x = self.relu1(self.bn1(self.conv1(x)))
379
+ x = self.relu2(self.bn2(self.conv2(x)))
380
+ x = self.relu3(self.bn3(self.conv3(x)))
381
+ x = self.maxpool(x)
382
+
383
+ x = self.layer1(x); conv_out.append(x);
384
+ x = self.layer2(x); conv_out.append(x);
385
+ x = self.layer3(x); conv_out.append(x);
386
+ x = self.layer4(x); conv_out.append(x);
387
+
388
+ if return_feature_maps:
389
+ return conv_out
390
+ return [x]
391
+
392
+ # Resnet Dilated
393
+ class ResnetDilated(nn.Module):
394
+ def __init__(self, orig_resnet, dilate_scale=8):
395
+ super().__init__()
396
+ from functools import partial
397
+
398
+ if dilate_scale == 8:
399
+ orig_resnet.layer3.apply(
400
+ partial(self._nostride_dilate, dilate=2))
401
+ orig_resnet.layer4.apply(
402
+ partial(self._nostride_dilate, dilate=4))
403
+ elif dilate_scale == 16:
404
+ orig_resnet.layer4.apply(
405
+ partial(self._nostride_dilate, dilate=2))
406
+
407
+ # take pretrained resnet, except AvgPool and FC
408
+ self.conv1 = orig_resnet.conv1
409
+ self.bn1 = orig_resnet.bn1
410
+ self.relu1 = orig_resnet.relu1
411
+ self.conv2 = orig_resnet.conv2
412
+ self.bn2 = orig_resnet.bn2
413
+ self.relu2 = orig_resnet.relu2
414
+ self.conv3 = orig_resnet.conv3
415
+ self.bn3 = orig_resnet.bn3
416
+ self.relu3 = orig_resnet.relu3
417
+ self.maxpool = orig_resnet.maxpool
418
+ self.layer1 = orig_resnet.layer1
419
+ self.layer2 = orig_resnet.layer2
420
+ self.layer3 = orig_resnet.layer3
421
+ self.layer4 = orig_resnet.layer4
422
+
423
+ def _nostride_dilate(self, m, dilate):
424
+ classname = m.__class__.__name__
425
+ if classname.find('Conv') != -1:
426
+ # the convolution with stride
427
+ if m.stride == (2, 2):
428
+ m.stride = (1, 1)
429
+ if m.kernel_size == (3, 3):
430
+ m.dilation = (dilate // 2, dilate // 2)
431
+ m.padding = (dilate // 2, dilate // 2)
432
+ # other convoluions
433
+ else:
434
+ if m.kernel_size == (3, 3):
435
+ m.dilation = (dilate, dilate)
436
+ m.padding = (dilate, dilate)
437
+
438
+ def forward(self, x, return_feature_maps=False):
439
+ conv_out = []
440
+
441
+ x = self.relu1(self.bn1(self.conv1(x)))
442
+ x = self.relu2(self.bn2(self.conv2(x)))
443
+ x = self.relu3(self.bn3(self.conv3(x)))
444
+ x = self.maxpool(x)
445
+
446
+ x = self.layer1(x)
447
+ conv_out.append(x)
448
+ x = self.layer2(x)
449
+ conv_out.append(x)
450
+ x = self.layer3(x)
451
+ conv_out.append(x)
452
+ x = self.layer4(x)
453
+ conv_out.append(x)
454
+
455
+ if return_feature_maps:
456
+ return conv_out
457
+ return [x]
458
+
459
+ class MobileNetV2Dilated(nn.Module):
460
+ def __init__(self, orig_net, dilate_scale=8):
461
+ super(MobileNetV2Dilated, self).__init__()
462
+ from functools import partial
463
+
464
+ # take pretrained mobilenet features
465
+ self.features = orig_net.features[:-1]
466
+
467
+ self.total_idx = len(self.features)
468
+ self.down_idx = [2, 4, 7, 14]
469
+
470
+ if dilate_scale == 8:
471
+ for i in range(self.down_idx[-2], self.down_idx[-1]):
472
+ self.features[i].apply(
473
+ partial(self._nostride_dilate, dilate=2)
474
+ )
475
+ for i in range(self.down_idx[-1], self.total_idx):
476
+ self.features[i].apply(
477
+ partial(self._nostride_dilate, dilate=4)
478
+ )
479
+ elif dilate_scale == 16:
480
+ for i in range(self.down_idx[-1], self.total_idx):
481
+ self.features[i].apply(
482
+ partial(self._nostride_dilate, dilate=2)
483
+ )
484
+
485
+ def _nostride_dilate(self, m, dilate):
486
+ classname = m.__class__.__name__
487
+ if classname.find('Conv') != -1:
488
+ # the convolution with stride
489
+ if m.stride == (2, 2):
490
+ m.stride = (1, 1)
491
+ if m.kernel_size == (3, 3):
492
+ m.dilation = (dilate//2, dilate//2)
493
+ m.padding = (dilate//2, dilate//2)
494
+ # other convoluions
495
+ else:
496
+ if m.kernel_size == (3, 3):
497
+ m.dilation = (dilate, dilate)
498
+ m.padding = (dilate, dilate)
499
+
500
+ def forward(self, x, return_feature_maps=False):
501
+ if return_feature_maps:
502
+ conv_out = []
503
+ for i in range(self.total_idx):
504
+ x = self.features[i](x)
505
+ if i in self.down_idx:
506
+ conv_out.append(x)
507
+ conv_out.append(x)
508
+ return conv_out
509
+
510
+ else:
511
+ return [self.features(x)]
512
+
513
+
514
+ # last conv, deep supervision
515
+ class C1DeepSup(nn.Module):
516
+ def __init__(self, num_class=150, fc_dim=2048, use_softmax=False, drop_last_conv=False):
517
+ super(C1DeepSup, self).__init__()
518
+ self.use_softmax = use_softmax
519
+ self.drop_last_conv = drop_last_conv
520
+
521
+ self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
522
+ self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
523
+
524
+ # last conv
525
+ self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
526
+ self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
527
+
528
+ def forward(self, conv_out, segSize=None):
529
+ conv5 = conv_out[-1]
530
+
531
+ x = self.cbr(conv5)
532
+
533
+ if self.drop_last_conv:
534
+ return x
535
+ else:
536
+ x = self.conv_last(x)
537
+
538
+ if self.use_softmax: # is True during inference
539
+ x = nn.functional.interpolate(
540
+ x, size=segSize, mode='bilinear', align_corners=False)
541
+ x = nn.functional.softmax(x, dim=1)
542
+ return x
543
+
544
+ # deep sup
545
+ conv4 = conv_out[-2]
546
+ _ = self.cbr_deepsup(conv4)
547
+ _ = self.conv_last_deepsup(_)
548
+
549
+ x = nn.functional.log_softmax(x, dim=1)
550
+ _ = nn.functional.log_softmax(_, dim=1)
551
+
552
+ return (x, _)
553
+
554
+
555
+ # last conv
556
+ class C1(nn.Module):
557
+ def __init__(self, num_class=150, fc_dim=2048, use_softmax=False):
558
+ super(C1, self).__init__()
559
+ self.use_softmax = use_softmax
560
+
561
+ self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
562
+
563
+ # last conv
564
+ self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
565
+
566
+ def forward(self, conv_out, segSize=None):
567
+ conv5 = conv_out[-1]
568
+ x = self.cbr(conv5)
569
+ x = self.conv_last(x)
570
+
571
+ if self.use_softmax: # is True during inference
572
+ x = nn.functional.interpolate(
573
+ x, size=segSize, mode='bilinear', align_corners=False)
574
+ x = nn.functional.softmax(x, dim=1)
575
+ else:
576
+ x = nn.functional.log_softmax(x, dim=1)
577
+
578
+ return x
579
+
580
+
581
+ # pyramid pooling
582
+ class PPM(nn.Module):
583
+ def __init__(self, num_class=150, fc_dim=4096,
584
+ use_softmax=False, pool_scales=(1, 2, 3, 6)):
585
+ super(PPM, self).__init__()
586
+ self.use_softmax = use_softmax
587
+
588
+ self.ppm = []
589
+ for scale in pool_scales:
590
+ self.ppm.append(nn.Sequential(
591
+ nn.AdaptiveAvgPool2d(scale),
592
+ nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
593
+ BatchNorm2d(512),
594
+ nn.ReLU(inplace=True)
595
+ ))
596
+ self.ppm = nn.ModuleList(self.ppm)
597
+
598
+ self.conv_last = nn.Sequential(
599
+ nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
600
+ kernel_size=3, padding=1, bias=False),
601
+ BatchNorm2d(512),
602
+ nn.ReLU(inplace=True),
603
+ nn.Dropout2d(0.1),
604
+ nn.Conv2d(512, num_class, kernel_size=1)
605
+ )
606
+
607
+ def forward(self, conv_out, segSize=None):
608
+ conv5 = conv_out[-1]
609
+
610
+ input_size = conv5.size()
611
+ ppm_out = [conv5]
612
+ for pool_scale in self.ppm:
613
+ ppm_out.append(nn.functional.interpolate(
614
+ pool_scale(conv5),
615
+ (input_size[2], input_size[3]),
616
+ mode='bilinear', align_corners=False))
617
+ ppm_out = torch.cat(ppm_out, 1)
618
+
619
+ x = self.conv_last(ppm_out)
620
+
621
+ if self.use_softmax: # is True during inference
622
+ x = nn.functional.interpolate(
623
+ x, size=segSize, mode='bilinear', align_corners=False)
624
+ x = nn.functional.softmax(x, dim=1)
625
+ else:
626
+ x = nn.functional.log_softmax(x, dim=1)
627
+ return x
models/ade20k/color150.mat ADDED
Binary file (502 Bytes). View file
 
models/ade20k/mobilenet.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This MobileNetV2 implementation is modified from the following repository:
3
+ https://github.com/tonylins/pytorch-mobilenet-v2
4
+ """
5
+
6
+ import torch.nn as nn
7
+ import math
8
+ from .utils import load_url
9
+ from .segm_lib.nn import SynchronizedBatchNorm2d
10
+
11
+ BatchNorm2d = SynchronizedBatchNorm2d
12
+
13
+
14
+ __all__ = ['mobilenetv2']
15
+
16
+
17
+ model_urls = {
18
+ 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar',
19
+ }
20
+
21
+
22
+ def conv_bn(inp, oup, stride):
23
+ return nn.Sequential(
24
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
25
+ BatchNorm2d(oup),
26
+ nn.ReLU6(inplace=True)
27
+ )
28
+
29
+
30
+ def conv_1x1_bn(inp, oup):
31
+ return nn.Sequential(
32
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
33
+ BatchNorm2d(oup),
34
+ nn.ReLU6(inplace=True)
35
+ )
36
+
37
+
38
+ class InvertedResidual(nn.Module):
39
+ def __init__(self, inp, oup, stride, expand_ratio):
40
+ super(InvertedResidual, self).__init__()
41
+ self.stride = stride
42
+ assert stride in [1, 2]
43
+
44
+ hidden_dim = round(inp * expand_ratio)
45
+ self.use_res_connect = self.stride == 1 and inp == oup
46
+
47
+ if expand_ratio == 1:
48
+ self.conv = nn.Sequential(
49
+ # dw
50
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
51
+ BatchNorm2d(hidden_dim),
52
+ nn.ReLU6(inplace=True),
53
+ # pw-linear
54
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
55
+ BatchNorm2d(oup),
56
+ )
57
+ else:
58
+ self.conv = nn.Sequential(
59
+ # pw
60
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
61
+ BatchNorm2d(hidden_dim),
62
+ nn.ReLU6(inplace=True),
63
+ # dw
64
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
65
+ BatchNorm2d(hidden_dim),
66
+ nn.ReLU6(inplace=True),
67
+ # pw-linear
68
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
69
+ BatchNorm2d(oup),
70
+ )
71
+
72
+ def forward(self, x):
73
+ if self.use_res_connect:
74
+ return x + self.conv(x)
75
+ else:
76
+ return self.conv(x)
77
+
78
+
79
+ class MobileNetV2(nn.Module):
80
+ def __init__(self, n_class=1000, input_size=224, width_mult=1.):
81
+ super(MobileNetV2, self).__init__()
82
+ block = InvertedResidual
83
+ input_channel = 32
84
+ last_channel = 1280
85
+ interverted_residual_setting = [
86
+ # t, c, n, s
87
+ [1, 16, 1, 1],
88
+ [6, 24, 2, 2],
89
+ [6, 32, 3, 2],
90
+ [6, 64, 4, 2],
91
+ [6, 96, 3, 1],
92
+ [6, 160, 3, 2],
93
+ [6, 320, 1, 1],
94
+ ]
95
+
96
+ # building first layer
97
+ assert input_size % 32 == 0
98
+ input_channel = int(input_channel * width_mult)
99
+ self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
100
+ self.features = [conv_bn(3, input_channel, 2)]
101
+ # building inverted residual blocks
102
+ for t, c, n, s in interverted_residual_setting:
103
+ output_channel = int(c * width_mult)
104
+ for i in range(n):
105
+ if i == 0:
106
+ self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
107
+ else:
108
+ self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
109
+ input_channel = output_channel
110
+ # building last several layers
111
+ self.features.append(conv_1x1_bn(input_channel, self.last_channel))
112
+ # make it nn.Sequential
113
+ self.features = nn.Sequential(*self.features)
114
+
115
+ # building classifier
116
+ self.classifier = nn.Sequential(
117
+ nn.Dropout(0.2),
118
+ nn.Linear(self.last_channel, n_class),
119
+ )
120
+
121
+ self._initialize_weights()
122
+
123
+ def forward(self, x):
124
+ x = self.features(x)
125
+ x = x.mean(3).mean(2)
126
+ x = self.classifier(x)
127
+ return x
128
+
129
+ def _initialize_weights(self):
130
+ for m in self.modules():
131
+ if isinstance(m, nn.Conv2d):
132
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
133
+ m.weight.data.normal_(0, math.sqrt(2. / n))
134
+ if m.bias is not None:
135
+ m.bias.data.zero_()
136
+ elif isinstance(m, BatchNorm2d):
137
+ m.weight.data.fill_(1)
138
+ m.bias.data.zero_()
139
+ elif isinstance(m, nn.Linear):
140
+ n = m.weight.size(1)
141
+ m.weight.data.normal_(0, 0.01)
142
+ m.bias.data.zero_()
143
+
144
+
145
+ def mobilenetv2(pretrained=False, **kwargs):
146
+ """Constructs a MobileNet_V2 model.
147
+
148
+ Args:
149
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
150
+ """
151
+ model = MobileNetV2(n_class=1000, **kwargs)
152
+ if pretrained:
153
+ model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False)
154
+ return model
models/ade20k/object150_info.csv ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Idx,Ratio,Train,Val,Stuff,Name
2
+ 1,0.1576,11664,1172,1,wall
3
+ 2,0.1072,6046,612,1,building;edifice
4
+ 3,0.0878,8265,796,1,sky
5
+ 4,0.0621,9336,917,1,floor;flooring
6
+ 5,0.0480,6678,641,0,tree
7
+ 6,0.0450,6604,643,1,ceiling
8
+ 7,0.0398,4023,408,1,road;route
9
+ 8,0.0231,1906,199,0,bed
10
+ 9,0.0198,4688,460,0,windowpane;window
11
+ 10,0.0183,2423,225,1,grass
12
+ 11,0.0181,2874,294,0,cabinet
13
+ 12,0.0166,3068,310,1,sidewalk;pavement
14
+ 13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul
15
+ 14,0.0151,1804,190,1,earth;ground
16
+ 15,0.0118,6666,796,0,door;double;door
17
+ 16,0.0110,4269,411,0,table
18
+ 17,0.0109,1691,160,1,mountain;mount
19
+ 18,0.0104,3999,441,0,plant;flora;plant;life
20
+ 19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall
21
+ 20,0.0103,3261,318,0,chair
22
+ 21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar
23
+ 22,0.0074,709,75,1,water
24
+ 23,0.0067,3296,315,0,painting;picture
25
+ 24,0.0065,1191,106,0,sofa;couch;lounge
26
+ 25,0.0061,1516,162,0,shelf
27
+ 26,0.0060,667,69,1,house
28
+ 27,0.0053,651,57,1,sea
29
+ 28,0.0052,1847,224,0,mirror
30
+ 29,0.0046,1158,128,1,rug;carpet;carpeting
31
+ 30,0.0044,480,44,1,field
32
+ 31,0.0044,1172,98,0,armchair
33
+ 32,0.0044,1292,184,0,seat
34
+ 33,0.0033,1386,138,0,fence;fencing
35
+ 34,0.0031,698,61,0,desk
36
+ 35,0.0030,781,73,0,rock;stone
37
+ 36,0.0027,380,43,0,wardrobe;closet;press
38
+ 37,0.0026,3089,302,0,lamp
39
+ 38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub
40
+ 39,0.0024,804,99,0,railing;rail
41
+ 40,0.0023,1453,153,0,cushion
42
+ 41,0.0023,411,37,0,base;pedestal;stand
43
+ 42,0.0022,1440,162,0,box
44
+ 43,0.0022,800,77,0,column;pillar
45
+ 44,0.0020,2650,298,0,signboard;sign
46
+ 45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser
47
+ 46,0.0019,367,36,0,counter
48
+ 47,0.0018,311,30,1,sand
49
+ 48,0.0018,1181,122,0,sink
50
+ 49,0.0018,287,23,1,skyscraper
51
+ 50,0.0018,468,38,0,fireplace;hearth;open;fireplace
52
+ 51,0.0018,402,43,0,refrigerator;icebox
53
+ 52,0.0018,130,12,1,grandstand;covered;stand
54
+ 53,0.0018,561,64,1,path
55
+ 54,0.0017,880,102,0,stairs;steps
56
+ 55,0.0017,86,12,1,runway
57
+ 56,0.0017,172,11,0,case;display;case;showcase;vitrine
58
+ 57,0.0017,198,18,0,pool;table;billiard;table;snooker;table
59
+ 58,0.0017,930,109,0,pillow
60
+ 59,0.0015,139,18,0,screen;door;screen
61
+ 60,0.0015,564,52,1,stairway;staircase
62
+ 61,0.0015,320,26,1,river
63
+ 62,0.0015,261,29,1,bridge;span
64
+ 63,0.0014,275,22,0,bookcase
65
+ 64,0.0014,335,60,0,blind;screen
66
+ 65,0.0014,792,75,0,coffee;table;cocktail;table
67
+ 66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne
68
+ 67,0.0014,1309,138,0,flower
69
+ 68,0.0013,1112,113,0,book
70
+ 69,0.0013,266,27,1,hill
71
+ 70,0.0013,659,66,0,bench
72
+ 71,0.0012,331,31,0,countertop
73
+ 72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove
74
+ 73,0.0012,369,36,0,palm;palm;tree
75
+ 74,0.0012,144,9,0,kitchen;island
76
+ 75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system
77
+ 76,0.0010,324,33,0,swivel;chair
78
+ 77,0.0009,304,27,0,boat
79
+ 78,0.0009,170,20,0,bar
80
+ 79,0.0009,68,6,0,arcade;machine
81
+ 80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty
82
+ 81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle
83
+ 82,0.0008,492,49,0,towel
84
+ 83,0.0008,2510,269,0,light;light;source
85
+ 84,0.0008,440,39,0,truck;motortruck
86
+ 85,0.0008,147,18,1,tower
87
+ 86,0.0008,583,56,0,chandelier;pendant;pendent
88
+ 87,0.0007,533,61,0,awning;sunshade;sunblind
89
+ 88,0.0007,1989,239,0,streetlight;street;lamp
90
+ 89,0.0007,71,5,0,booth;cubicle;stall;kiosk
91
+ 90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box
92
+ 91,0.0007,135,12,0,airplane;aeroplane;plane
93
+ 92,0.0007,83,5,1,dirt;track
94
+ 93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes
95
+ 94,0.0006,1003,104,0,pole
96
+ 95,0.0006,182,12,1,land;ground;soil
97
+ 96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail
98
+ 97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway
99
+ 98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock
100
+ 99,0.0006,965,114,0,bottle
101
+ 100,0.0006,117,13,0,buffet;counter;sideboard
102
+ 101,0.0006,354,35,0,poster;posting;placard;notice;bill;card
103
+ 102,0.0006,108,9,1,stage
104
+ 103,0.0006,557,55,0,van
105
+ 104,0.0006,52,4,0,ship
106
+ 105,0.0005,99,5,0,fountain
107
+ 106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter
108
+ 107,0.0005,292,31,0,canopy
109
+ 108,0.0005,77,9,0,washer;automatic;washer;washing;machine
110
+ 109,0.0005,340,38,0,plaything;toy
111
+ 110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium
112
+ 111,0.0005,465,49,0,stool
113
+ 112,0.0005,50,4,0,barrel;cask
114
+ 113,0.0005,622,75,0,basket;handbasket
115
+ 114,0.0005,80,9,1,waterfall;falls
116
+ 115,0.0005,59,3,0,tent;collapsible;shelter
117
+ 116,0.0005,531,72,0,bag
118
+ 117,0.0005,282,30,0,minibike;motorbike
119
+ 118,0.0005,73,7,0,cradle
120
+ 119,0.0005,435,44,0,oven
121
+ 120,0.0005,136,25,0,ball
122
+ 121,0.0005,116,24,0,food;solid;food
123
+ 122,0.0004,266,31,0,step;stair
124
+ 123,0.0004,58,12,0,tank;storage;tank
125
+ 124,0.0004,418,83,0,trade;name;brand;name;brand;marque
126
+ 125,0.0004,319,43,0,microwave;microwave;oven
127
+ 126,0.0004,1193,139,0,pot;flowerpot
128
+ 127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna
129
+ 128,0.0004,347,36,0,bicycle;bike;wheel;cycle
130
+ 129,0.0004,52,5,1,lake
131
+ 130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine
132
+ 131,0.0004,108,13,0,screen;silver;screen;projection;screen
133
+ 132,0.0004,201,30,0,blanket;cover
134
+ 133,0.0004,285,21,0,sculpture
135
+ 134,0.0004,268,27,0,hood;exhaust;hood
136
+ 135,0.0003,1020,108,0,sconce
137
+ 136,0.0003,1282,122,0,vase
138
+ 137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight
139
+ 138,0.0003,453,57,0,tray
140
+ 139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin
141
+ 140,0.0003,397,44,0,fan
142
+ 141,0.0003,92,8,1,pier;wharf;wharfage;dock
143
+ 142,0.0003,228,18,0,crt;screen
144
+ 143,0.0003,570,59,0,plate
145
+ 144,0.0003,217,22,0,monitor;monitoring;device
146
+ 145,0.0003,206,19,0,bulletin;board;notice;board
147
+ 146,0.0003,130,14,0,shower
148
+ 147,0.0003,178,28,0,radiator
149
+ 148,0.0002,504,57,0,glass;drinking;glass
150
+ 149,0.0002,775,96,0,clock
151
+ 150,0.0002,421,56,0,flag
models/ade20k/resnet.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
2
+
3
+ import math
4
+
5
+ import torch.nn as nn
6
+ from torch.nn import BatchNorm2d
7
+
8
+ from .utils import load_url
9
+
10
+ __all__ = ['ResNet', 'resnet50']
11
+
12
+
13
+ model_urls = {
14
+ 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
15
+ }
16
+
17
+
18
+ def conv3x3(in_planes, out_planes, stride=1):
19
+ "3x3 convolution with padding"
20
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21
+ padding=1, bias=False)
22
+
23
+
24
+ class BasicBlock(nn.Module):
25
+ expansion = 1
26
+
27
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = conv3x3(inplanes, planes, stride)
30
+ self.bn1 = BatchNorm2d(planes)
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.conv2 = conv3x3(planes, planes)
33
+ self.bn2 = BatchNorm2d(planes)
34
+ self.downsample = downsample
35
+ self.stride = stride
36
+
37
+ def forward(self, x):
38
+ residual = x
39
+
40
+ out = self.conv1(x)
41
+ out = self.bn1(out)
42
+ out = self.relu(out)
43
+
44
+ out = self.conv2(out)
45
+ out = self.bn2(out)
46
+
47
+ if self.downsample is not None:
48
+ residual = self.downsample(x)
49
+
50
+ out += residual
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class Bottleneck(nn.Module):
57
+ expansion = 4
58
+
59
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
60
+ super(Bottleneck, self).__init__()
61
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
62
+ self.bn1 = BatchNorm2d(planes)
63
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
64
+ padding=1, bias=False)
65
+ self.bn2 = BatchNorm2d(planes)
66
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
67
+ self.bn3 = BatchNorm2d(planes * 4)
68
+ self.relu = nn.ReLU(inplace=True)
69
+ self.downsample = downsample
70
+ self.stride = stride
71
+
72
+ def forward(self, x):
73
+ residual = x
74
+
75
+ out = self.conv1(x)
76
+ out = self.bn1(out)
77
+ out = self.relu(out)
78
+
79
+ out = self.conv2(out)
80
+ out = self.bn2(out)
81
+ out = self.relu(out)
82
+
83
+ out = self.conv3(out)
84
+ out = self.bn3(out)
85
+
86
+ if self.downsample is not None:
87
+ residual = self.downsample(x)
88
+
89
+ out += residual
90
+ out = self.relu(out)
91
+
92
+ return out
93
+
94
+
95
+ class ResNet(nn.Module):
96
+
97
+ def __init__(self, block, layers, num_classes=1000):
98
+ self.inplanes = 128
99
+ super(ResNet, self).__init__()
100
+ self.conv1 = conv3x3(3, 64, stride=2)
101
+ self.bn1 = BatchNorm2d(64)
102
+ self.relu1 = nn.ReLU(inplace=True)
103
+ self.conv2 = conv3x3(64, 64)
104
+ self.bn2 = BatchNorm2d(64)
105
+ self.relu2 = nn.ReLU(inplace=True)
106
+ self.conv3 = conv3x3(64, 128)
107
+ self.bn3 = BatchNorm2d(128)
108
+ self.relu3 = nn.ReLU(inplace=True)
109
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
110
+
111
+ self.layer1 = self._make_layer(block, 64, layers[0])
112
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
113
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
114
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
115
+ self.avgpool = nn.AvgPool2d(7, stride=1)
116
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
117
+
118
+ for m in self.modules():
119
+ if isinstance(m, nn.Conv2d):
120
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
121
+ m.weight.data.normal_(0, math.sqrt(2. / n))
122
+ elif isinstance(m, BatchNorm2d):
123
+ m.weight.data.fill_(1)
124
+ m.bias.data.zero_()
125
+
126
+ def _make_layer(self, block, planes, blocks, stride=1):
127
+ downsample = None
128
+ if stride != 1 or self.inplanes != planes * block.expansion:
129
+ downsample = nn.Sequential(
130
+ nn.Conv2d(self.inplanes, planes * block.expansion,
131
+ kernel_size=1, stride=stride, bias=False),
132
+ BatchNorm2d(planes * block.expansion),
133
+ )
134
+
135
+ layers = []
136
+ layers.append(block(self.inplanes, planes, stride, downsample))
137
+ self.inplanes = planes * block.expansion
138
+ for i in range(1, blocks):
139
+ layers.append(block(self.inplanes, planes))
140
+
141
+ return nn.Sequential(*layers)
142
+
143
+ def forward(self, x):
144
+ x = self.relu1(self.bn1(self.conv1(x)))
145
+ x = self.relu2(self.bn2(self.conv2(x)))
146
+ x = self.relu3(self.bn3(self.conv3(x)))
147
+ x = self.maxpool(x)
148
+
149
+ x = self.layer1(x)
150
+ x = self.layer2(x)
151
+ x = self.layer3(x)
152
+ x = self.layer4(x)
153
+
154
+ x = self.avgpool(x)
155
+ x = x.view(x.size(0), -1)
156
+ x = self.fc(x)
157
+
158
+ return x
159
+
160
+
161
+ def resnet50(pretrained=False, **kwargs):
162
+ """Constructs a ResNet-50 model.
163
+
164
+ Args:
165
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
166
+ """
167
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
168
+ if pretrained:
169
+ model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
170
+ return model
171
+
172
+
173
+ def resnet18(pretrained=False, **kwargs):
174
+ """Constructs a ResNet-18 model.
175
+ Args:
176
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
177
+ """
178
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
179
+ if pretrained:
180
+ model.load_state_dict(load_url(model_urls['resnet18']))
181
+ return model
models/ade20k/segm_lib/nn/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .modules import *
2
+ from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
models/ade20k/segm_lib/nn/modules/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12
+ from .replicate import DataParallelWithCallback, patch_replication_callback
models/ade20k/segm_lib/nn/modules/batchnorm.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from torch.nn.modules.batchnorm import _BatchNorm
17
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18
+
19
+ from .comm import SyncMaster
20
+
21
+ __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22
+
23
+
24
+ def _sum_ft(tensor):
25
+ """sum over the first and last dimention"""
26
+ return tensor.sum(dim=0).sum(dim=-1)
27
+
28
+
29
+ def _unsqueeze_ft(tensor):
30
+ """add new dementions at the front and the tail"""
31
+ return tensor.unsqueeze(0).unsqueeze(-1)
32
+
33
+
34
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36
+
37
+
38
+ class _SynchronizedBatchNorm(_BatchNorm):
39
+ def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
40
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41
+
42
+ self._sync_master = SyncMaster(self._data_parallel_master)
43
+
44
+ self._is_parallel = False
45
+ self._parallel_id = None
46
+ self._slave_pipe = None
47
+
48
+ # customed batch norm statistics
49
+ self._moving_average_fraction = 1. - momentum
50
+ self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
51
+ self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
52
+ self.register_buffer('_running_iter', torch.ones(1))
53
+ self._tmp_running_mean = self.running_mean.clone() * self._running_iter
54
+ self._tmp_running_var = self.running_var.clone() * self._running_iter
55
+
56
+ def forward(self, input):
57
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
58
+ if not (self._is_parallel and self.training):
59
+ return F.batch_norm(
60
+ input, self.running_mean, self.running_var, self.weight, self.bias,
61
+ self.training, self.momentum, self.eps)
62
+
63
+ # Resize the input to (B, C, -1).
64
+ input_shape = input.size()
65
+ input = input.view(input.size(0), self.num_features, -1)
66
+
67
+ # Compute the sum and square-sum.
68
+ sum_size = input.size(0) * input.size(2)
69
+ input_sum = _sum_ft(input)
70
+ input_ssum = _sum_ft(input ** 2)
71
+
72
+ # Reduce-and-broadcast the statistics.
73
+ if self._parallel_id == 0:
74
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
75
+ else:
76
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
77
+
78
+ # Compute the output.
79
+ if self.affine:
80
+ # MJY:: Fuse the multiplication for speed.
81
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
82
+ else:
83
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
84
+
85
+ # Reshape it.
86
+ return output.view(input_shape)
87
+
88
+ def __data_parallel_replicate__(self, ctx, copy_id):
89
+ self._is_parallel = True
90
+ self._parallel_id = copy_id
91
+
92
+ # parallel_id == 0 means master device.
93
+ if self._parallel_id == 0:
94
+ ctx.sync_master = self._sync_master
95
+ else:
96
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
97
+
98
+ def _data_parallel_master(self, intermediates):
99
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
100
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
101
+
102
+ to_reduce = [i[1][:2] for i in intermediates]
103
+ to_reduce = [j for i in to_reduce for j in i] # flatten
104
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
105
+
106
+ sum_size = sum([i[1].sum_size for i in intermediates])
107
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
108
+
109
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
110
+
111
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
112
+
113
+ outputs = []
114
+ for i, rec in enumerate(intermediates):
115
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
116
+
117
+ return outputs
118
+
119
+ def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
120
+ """return *dest* by `dest := dest*alpha + delta*beta + bias`"""
121
+ return dest * alpha + delta * beta + bias
122
+
123
+ def _compute_mean_std(self, sum_, ssum, size):
124
+ """Compute the mean and standard-deviation with sum and square-sum. This method
125
+ also maintains the moving average on the master device."""
126
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
127
+ mean = sum_ / size
128
+ sumvar = ssum - sum_ * mean
129
+ unbias_var = sumvar / (size - 1)
130
+ bias_var = sumvar / size
131
+
132
+ self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
133
+ self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
134
+ self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction)
135
+
136
+ self.running_mean = self._tmp_running_mean / self._running_iter
137
+ self.running_var = self._tmp_running_var / self._running_iter
138
+
139
+ return mean, bias_var.clamp(self.eps) ** -0.5
140
+
141
+
142
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
143
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
144
+ mini-batch.
145
+
146
+ .. math::
147
+
148
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
149
+
150
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
151
+ standard-deviation are reduced across all devices during training.
152
+
153
+ For example, when one uses `nn.DataParallel` to wrap the network during
154
+ training, PyTorch's implementation normalize the tensor on each device using
155
+ the statistics only on that device, which accelerated the computation and
156
+ is also easy to implement, but the statistics might be inaccurate.
157
+ Instead, in this synchronized version, the statistics will be computed
158
+ over all training samples distributed on multiple devices.
159
+
160
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
161
+ as the built-in PyTorch implementation.
162
+
163
+ The mean and standard-deviation are calculated per-dimension over
164
+ the mini-batches and gamma and beta are learnable parameter vectors
165
+ of size C (where C is the input size).
166
+
167
+ During training, this layer keeps a running estimate of its computed mean
168
+ and variance. The running sum is kept with a default momentum of 0.1.
169
+
170
+ During evaluation, this running mean/variance is used for normalization.
171
+
172
+ Because the BatchNorm is done over the `C` dimension, computing statistics
173
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
174
+
175
+ Args:
176
+ num_features: num_features from an expected input of size
177
+ `batch_size x num_features [x width]`
178
+ eps: a value added to the denominator for numerical stability.
179
+ Default: 1e-5
180
+ momentum: the value used for the running_mean and running_var
181
+ computation. Default: 0.1
182
+ affine: a boolean value that when set to ``True``, gives the layer learnable
183
+ affine parameters. Default: ``True``
184
+
185
+ Shape:
186
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
187
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
188
+
189
+ Examples:
190
+ >>> # With Learnable Parameters
191
+ >>> m = SynchronizedBatchNorm1d(100)
192
+ >>> # Without Learnable Parameters
193
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
194
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
195
+ >>> output = m(input)
196
+ """
197
+
198
+ def _check_input_dim(self, input):
199
+ if input.dim() != 2 and input.dim() != 3:
200
+ raise ValueError('expected 2D or 3D input (got {}D input)'
201
+ .format(input.dim()))
202
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
203
+
204
+
205
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
206
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
207
+ of 3d inputs
208
+
209
+ .. math::
210
+
211
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
212
+
213
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
214
+ standard-deviation are reduced across all devices during training.
215
+
216
+ For example, when one uses `nn.DataParallel` to wrap the network during
217
+ training, PyTorch's implementation normalize the tensor on each device using
218
+ the statistics only on that device, which accelerated the computation and
219
+ is also easy to implement, but the statistics might be inaccurate.
220
+ Instead, in this synchronized version, the statistics will be computed
221
+ over all training samples distributed on multiple devices.
222
+
223
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
224
+ as the built-in PyTorch implementation.
225
+
226
+ The mean and standard-deviation are calculated per-dimension over
227
+ the mini-batches and gamma and beta are learnable parameter vectors
228
+ of size C (where C is the input size).
229
+
230
+ During training, this layer keeps a running estimate of its computed mean
231
+ and variance. The running sum is kept with a default momentum of 0.1.
232
+
233
+ During evaluation, this running mean/variance is used for normalization.
234
+
235
+ Because the BatchNorm is done over the `C` dimension, computing statistics
236
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
237
+
238
+ Args:
239
+ num_features: num_features from an expected input of
240
+ size batch_size x num_features x height x width
241
+ eps: a value added to the denominator for numerical stability.
242
+ Default: 1e-5
243
+ momentum: the value used for the running_mean and running_var
244
+ computation. Default: 0.1
245
+ affine: a boolean value that when set to ``True``, gives the layer learnable
246
+ affine parameters. Default: ``True``
247
+
248
+ Shape:
249
+ - Input: :math:`(N, C, H, W)`
250
+ - Output: :math:`(N, C, H, W)` (same shape as input)
251
+
252
+ Examples:
253
+ >>> # With Learnable Parameters
254
+ >>> m = SynchronizedBatchNorm2d(100)
255
+ >>> # Without Learnable Parameters
256
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
257
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
258
+ >>> output = m(input)
259
+ """
260
+
261
+ def _check_input_dim(self, input):
262
+ if input.dim() != 4:
263
+ raise ValueError('expected 4D input (got {}D input)'
264
+ .format(input.dim()))
265
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
266
+
267
+
268
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
269
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
270
+ of 4d inputs
271
+
272
+ .. math::
273
+
274
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
275
+
276
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
277
+ standard-deviation are reduced across all devices during training.
278
+
279
+ For example, when one uses `nn.DataParallel` to wrap the network during
280
+ training, PyTorch's implementation normalize the tensor on each device using
281
+ the statistics only on that device, which accelerated the computation and
282
+ is also easy to implement, but the statistics might be inaccurate.
283
+ Instead, in this synchronized version, the statistics will be computed
284
+ over all training samples distributed on multiple devices.
285
+
286
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
287
+ as the built-in PyTorch implementation.
288
+
289
+ The mean and standard-deviation are calculated per-dimension over
290
+ the mini-batches and gamma and beta are learnable parameter vectors
291
+ of size C (where C is the input size).
292
+
293
+ During training, this layer keeps a running estimate of its computed mean
294
+ and variance. The running sum is kept with a default momentum of 0.1.
295
+
296
+ During evaluation, this running mean/variance is used for normalization.
297
+
298
+ Because the BatchNorm is done over the `C` dimension, computing statistics
299
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
300
+ or Spatio-temporal BatchNorm
301
+
302
+ Args:
303
+ num_features: num_features from an expected input of
304
+ size batch_size x num_features x depth x height x width
305
+ eps: a value added to the denominator for numerical stability.
306
+ Default: 1e-5
307
+ momentum: the value used for the running_mean and running_var
308
+ computation. Default: 0.1
309
+ affine: a boolean value that when set to ``True``, gives the layer learnable
310
+ affine parameters. Default: ``True``
311
+
312
+ Shape:
313
+ - Input: :math:`(N, C, D, H, W)`
314
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
315
+
316
+ Examples:
317
+ >>> # With Learnable Parameters
318
+ >>> m = SynchronizedBatchNorm3d(100)
319
+ >>> # Without Learnable Parameters
320
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
321
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
322
+ >>> output = m(input)
323
+ """
324
+
325
+ def _check_input_dim(self, input):
326
+ if input.dim() != 5:
327
+ raise ValueError('expected 5D input (got {}D input)'
328
+ .format(input.dim()))
329
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
models/ade20k/segm_lib/nn/modules/comm.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+
59
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
+ and passed to a registered callback.
63
+ - After receiving the messages, the master device should gather the information and determine to message passed
64
+ back to each slave devices.
65
+ """
66
+
67
+ def __init__(self, master_callback):
68
+ """
69
+
70
+ Args:
71
+ master_callback: a callback to be invoked after having collected messages from slave devices.
72
+ """
73
+ self._master_callback = master_callback
74
+ self._queue = queue.Queue()
75
+ self._registry = collections.OrderedDict()
76
+ self._activated = False
77
+
78
+ def register_slave(self, identifier):
79
+ """
80
+ Register an slave device.
81
+
82
+ Args:
83
+ identifier: an identifier, usually is the device id.
84
+
85
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
86
+
87
+ """
88
+ if self._activated:
89
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
90
+ self._activated = False
91
+ self._registry.clear()
92
+ future = FutureResult()
93
+ self._registry[identifier] = _MasterRegistry(future)
94
+ return SlavePipe(identifier, self._queue, future)
95
+
96
+ def run_master(self, master_msg):
97
+ """
98
+ Main entry for the master device in each forward pass.
99
+ The messages were first collected from each devices (including the master device), and then
100
+ an callback will be invoked to compute the message to be sent back to each devices
101
+ (including the master device).
102
+
103
+ Args:
104
+ master_msg: the message that the master want to send to itself. This will be placed as the first
105
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106
+
107
+ Returns: the message to be sent back to the master device.
108
+
109
+ """
110
+ self._activated = True
111
+
112
+ intermediates = [(0, master_msg)]
113
+ for i in range(self.nr_slaves):
114
+ intermediates.append(self._queue.get())
115
+
116
+ results = self._master_callback(intermediates)
117
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
118
+
119
+ for i, res in results:
120
+ if i == 0:
121
+ continue
122
+ self._registry[i].result.put(res)
123
+
124
+ for i in range(self.nr_slaves):
125
+ assert self._queue.get() is True
126
+
127
+ return results[0][1]
128
+
129
+ @property
130
+ def nr_slaves(self):
131
+ return len(self._registry)
models/ade20k/segm_lib/nn/modules/replicate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+
31
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32
+
33
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
34
+ (shared among multiple copies of this module on different devices).
35
+ Through this context, different copies can share some information.
36
+
37
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38
+ of any slave copies.
39
+ """
40
+ master_copy = modules[0]
41
+ nr_modules = len(list(master_copy.modules()))
42
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
43
+
44
+ for i, module in enumerate(modules):
45
+ for j, m in enumerate(module.modules()):
46
+ if hasattr(m, '__data_parallel_replicate__'):
47
+ m.__data_parallel_replicate__(ctxs[j], i)
48
+
49
+
50
+ class DataParallelWithCallback(DataParallel):
51
+ """
52
+ Data Parallel with a replication callback.
53
+
54
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55
+ original `replicate` function.
56
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57
+
58
+ Examples:
59
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61
+ # sync_bn.__data_parallel_replicate__ will be invoked.
62
+ """
63
+
64
+ def replicate(self, module, device_ids):
65
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66
+ execute_replication_callbacks(modules)
67
+ return modules
68
+
69
+
70
+ def patch_replication_callback(data_parallel):
71
+ """
72
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
73
+ Useful when you have customized `DataParallel` implementation.
74
+
75
+ Examples:
76
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78
+ > patch_replication_callback(sync_bn)
79
+ # this is equivalent to
80
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82
+ """
83
+
84
+ assert isinstance(data_parallel, DataParallel)
85
+
86
+ old_replicate = data_parallel.replicate
87
+
88
+ @functools.wraps(old_replicate)
89
+ def new_replicate(module, device_ids):
90
+ modules = old_replicate(module, device_ids)
91
+ execute_replication_callbacks(modules)
92
+ return modules
93
+
94
+ data_parallel.replicate = new_replicate
models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : test_numeric_batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+
9
+ import unittest
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.autograd import Variable
14
+
15
+ from sync_batchnorm.unittest import TorchTestCase
16
+
17
+
18
+ def handy_var(a, unbias=True):
19
+ n = a.size(0)
20
+ asum = a.sum(dim=0)
21
+ as_sum = (a ** 2).sum(dim=0) # a square sum
22
+ sumvar = as_sum - asum * asum / n
23
+ if unbias:
24
+ return sumvar / (n - 1)
25
+ else:
26
+ return sumvar / n
27
+
28
+
29
+ class NumericTestCase(TorchTestCase):
30
+ def testNumericBatchNorm(self):
31
+ a = torch.rand(16, 10)
32
+ bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
33
+ bn.train()
34
+
35
+ a_var1 = Variable(a, requires_grad=True)
36
+ b_var1 = bn(a_var1)
37
+ loss1 = b_var1.sum()
38
+ loss1.backward()
39
+
40
+ a_var2 = Variable(a, requires_grad=True)
41
+ a_mean2 = a_var2.mean(dim=0, keepdim=True)
42
+ a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
43
+ # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
44
+ b_var2 = (a_var2 - a_mean2) / a_std2
45
+ loss2 = b_var2.sum()
46
+ loss2.backward()
47
+
48
+ self.assertTensorClose(bn.running_mean, a.mean(dim=0))
49
+ self.assertTensorClose(bn.running_var, handy_var(a))
50
+ self.assertTensorClose(a_var1.data, a_var2.data)
51
+ self.assertTensorClose(b_var1.data, b_var2.data)
52
+ self.assertTensorClose(a_var1.grad, a_var2.grad)
53
+
54
+
55
+ if __name__ == '__main__':
56
+ unittest.main()
models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : test_sync_batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+
9
+ import unittest
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.autograd import Variable
14
+
15
+ from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
16
+ from sync_batchnorm.unittest import TorchTestCase
17
+
18
+
19
+ def handy_var(a, unbias=True):
20
+ n = a.size(0)
21
+ asum = a.sum(dim=0)
22
+ as_sum = (a ** 2).sum(dim=0) # a square sum
23
+ sumvar = as_sum - asum * asum / n
24
+ if unbias:
25
+ return sumvar / (n - 1)
26
+ else:
27
+ return sumvar / n
28
+
29
+
30
+ def _find_bn(module):
31
+ for m in module.modules():
32
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
33
+ return m
34
+
35
+
36
+ class SyncTestCase(TorchTestCase):
37
+ def _syncParameters(self, bn1, bn2):
38
+ bn1.reset_parameters()
39
+ bn2.reset_parameters()
40
+ if bn1.affine and bn2.affine:
41
+ bn2.weight.data.copy_(bn1.weight.data)
42
+ bn2.bias.data.copy_(bn1.bias.data)
43
+
44
+ def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
45
+ """Check the forward and backward for the customized batch normalization."""
46
+ bn1.train(mode=is_train)
47
+ bn2.train(mode=is_train)
48
+
49
+ if cuda:
50
+ input = input.cuda()
51
+
52
+ self._syncParameters(_find_bn(bn1), _find_bn(bn2))
53
+
54
+ input1 = Variable(input, requires_grad=True)
55
+ output1 = bn1(input1)
56
+ output1.sum().backward()
57
+ input2 = Variable(input, requires_grad=True)
58
+ output2 = bn2(input2)
59
+ output2.sum().backward()
60
+
61
+ self.assertTensorClose(input1.data, input2.data)
62
+ self.assertTensorClose(output1.data, output2.data)
63
+ self.assertTensorClose(input1.grad, input2.grad)
64
+ self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
65
+ self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
66
+
67
+ def testSyncBatchNormNormalTrain(self):
68
+ bn = nn.BatchNorm1d(10)
69
+ sync_bn = SynchronizedBatchNorm1d(10)
70
+
71
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
72
+
73
+ def testSyncBatchNormNormalEval(self):
74
+ bn = nn.BatchNorm1d(10)
75
+ sync_bn = SynchronizedBatchNorm1d(10)
76
+
77
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
78
+
79
+ def testSyncBatchNormSyncTrain(self):
80
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
81
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
82
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
83
+
84
+ bn.cuda()
85
+ sync_bn.cuda()
86
+
87
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
88
+
89
+ def testSyncBatchNormSyncEval(self):
90
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
91
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
92
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
93
+
94
+ bn.cuda()
95
+ sync_bn.cuda()
96
+
97
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
98
+
99
+ def testSyncBatchNorm2DSyncTrain(self):
100
+ bn = nn.BatchNorm2d(10)
101
+ sync_bn = SynchronizedBatchNorm2d(10)
102
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
103
+
104
+ bn.cuda()
105
+ sync_bn.cuda()
106
+
107
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
108
+
109
+
110
+ if __name__ == '__main__':
111
+ unittest.main()
models/ade20k/segm_lib/nn/modules/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+
13
+ import numpy as np
14
+ from torch.autograd import Variable
15
+
16
+
17
+ def as_numpy(v):
18
+ if isinstance(v, Variable):
19
+ v = v.data
20
+ return v.cpu().numpy()
21
+
22
+
23
+ class TorchTestCase(unittest.TestCase):
24
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25
+ npa, npb = as_numpy(a), as_numpy(b)
26
+ self.assertTrue(
27
+ np.allclose(npa, npb, atol=atol),
28
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29
+ )
models/ade20k/segm_lib/nn/parallel/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
models/ade20k/segm_lib/nn/parallel/data_parallel.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf8 -*-
2
+
3
+ import torch.cuda as cuda
4
+ import torch.nn as nn
5
+ import torch
6
+ import collections
7
+ from torch.nn.parallel._functions import Gather
8
+
9
+
10
+ __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
11
+
12
+
13
+ def async_copy_to(obj, dev, main_stream=None):
14
+ if torch.is_tensor(obj):
15
+ v = obj.cuda(dev, non_blocking=True)
16
+ if main_stream is not None:
17
+ v.data.record_stream(main_stream)
18
+ return v
19
+ elif isinstance(obj, collections.Mapping):
20
+ return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
21
+ elif isinstance(obj, collections.Sequence):
22
+ return [async_copy_to(o, dev, main_stream) for o in obj]
23
+ else:
24
+ return obj
25
+
26
+
27
+ def dict_gather(outputs, target_device, dim=0):
28
+ """
29
+ Gathers variables from different GPUs on a specified device
30
+ (-1 means the CPU), with dictionary support.
31
+ """
32
+ def gather_map(outputs):
33
+ out = outputs[0]
34
+ if torch.is_tensor(out):
35
+ # MJY(20180330) HACK:: force nr_dims > 0
36
+ if out.dim() == 0:
37
+ outputs = [o.unsqueeze(0) for o in outputs]
38
+ return Gather.apply(target_device, dim, *outputs)
39
+ elif out is None:
40
+ return None
41
+ elif isinstance(out, collections.Mapping):
42
+ return {k: gather_map([o[k] for o in outputs]) for k in out}
43
+ elif isinstance(out, collections.Sequence):
44
+ return type(out)(map(gather_map, zip(*outputs)))
45
+ return gather_map(outputs)
46
+
47
+
48
+ class DictGatherDataParallel(nn.DataParallel):
49
+ def gather(self, outputs, output_device):
50
+ return dict_gather(outputs, output_device, dim=self.dim)
51
+
52
+
53
+ class UserScatteredDataParallel(DictGatherDataParallel):
54
+ def scatter(self, inputs, kwargs, device_ids):
55
+ assert len(inputs) == 1
56
+ inputs = inputs[0]
57
+ inputs = _async_copy_stream(inputs, device_ids)
58
+ inputs = [[i] for i in inputs]
59
+ assert len(kwargs) == 0
60
+ kwargs = [{} for _ in range(len(inputs))]
61
+
62
+ return inputs, kwargs
63
+
64
+
65
+ def user_scattered_collate(batch):
66
+ return batch
67
+
68
+
69
+ def _async_copy(inputs, device_ids):
70
+ nr_devs = len(device_ids)
71
+ assert type(inputs) in (tuple, list)
72
+ assert len(inputs) == nr_devs
73
+
74
+ outputs = []
75
+ for i, dev in zip(inputs, device_ids):
76
+ with cuda.device(dev):
77
+ outputs.append(async_copy_to(i, dev))
78
+
79
+ return tuple(outputs)
80
+
81
+
82
+ def _async_copy_stream(inputs, device_ids):
83
+ nr_devs = len(device_ids)
84
+ assert type(inputs) in (tuple, list)
85
+ assert len(inputs) == nr_devs
86
+
87
+ outputs = []
88
+ streams = [_get_stream(d) for d in device_ids]
89
+ for i, dev, stream in zip(inputs, device_ids, streams):
90
+ with cuda.device(dev):
91
+ main_stream = cuda.current_stream()
92
+ with cuda.stream(stream):
93
+ outputs.append(async_copy_to(i, dev, main_stream=main_stream))
94
+ main_stream.wait_stream(stream)
95
+
96
+ return outputs
97
+
98
+
99
+ """Adapted from: torch/nn/parallel/_functions.py"""
100
+ # background streams used for copying
101
+ _streams = None
102
+
103
+
104
+ def _get_stream(device):
105
+ """Gets a background stream for copying between CPU and GPU"""
106
+ global _streams
107
+ if device == -1:
108
+ return None
109
+ if _streams is None:
110
+ _streams = [None] * cuda.device_count()
111
+ if _streams[device] is None: _streams[device] = cuda.Stream(device)
112
+ return _streams[device]
models/ade20k/segm_lib/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .th import *
models/ade20k/segm_lib/utils/data/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ from .dataset import Dataset, TensorDataset, ConcatDataset
3
+ from .dataloader import DataLoader
models/ade20k/segm_lib/utils/data/dataloader.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.multiprocessing as multiprocessing
3
+ from torch._C import _set_worker_signal_handlers, \
4
+ _remove_worker_pids, _error_if_any_worker_fails
5
+ try:
6
+ from torch._C import _set_worker_pids
7
+ except:
8
+ from torch._C import _update_worker_pids as _set_worker_pids
9
+ from .sampler import SequentialSampler, RandomSampler, BatchSampler
10
+ import signal
11
+ import collections
12
+ import re
13
+ import sys
14
+ import threading
15
+ import traceback
16
+ from torch._six import string_classes, int_classes
17
+ import numpy as np
18
+
19
+ if sys.version_info[0] == 2:
20
+ import Queue as queue
21
+ else:
22
+ import queue
23
+
24
+
25
+ class ExceptionWrapper(object):
26
+ r"Wraps an exception plus traceback to communicate across threads"
27
+
28
+ def __init__(self, exc_info):
29
+ self.exc_type = exc_info[0]
30
+ self.exc_msg = "".join(traceback.format_exception(*exc_info))
31
+
32
+
33
+ _use_shared_memory = False
34
+ """Whether to use shared memory in default_collate"""
35
+
36
+
37
+ def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
38
+ global _use_shared_memory
39
+ _use_shared_memory = True
40
+
41
+ # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
42
+ # module's handlers are executed after Python returns from C low-level
43
+ # handlers, likely when the same fatal signal happened again already.
44
+ # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
45
+ _set_worker_signal_handlers()
46
+
47
+ torch.set_num_threads(1)
48
+ torch.manual_seed(seed)
49
+ np.random.seed(seed)
50
+
51
+ if init_fn is not None:
52
+ init_fn(worker_id)
53
+
54
+ while True:
55
+ r = index_queue.get()
56
+ if r is None:
57
+ break
58
+ idx, batch_indices = r
59
+ try:
60
+ samples = collate_fn([dataset[i] for i in batch_indices])
61
+ except Exception:
62
+ data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
63
+ else:
64
+ data_queue.put((idx, samples))
65
+
66
+
67
+ def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
68
+ if pin_memory:
69
+ torch.cuda.set_device(device_id)
70
+
71
+ while True:
72
+ try:
73
+ r = in_queue.get()
74
+ except Exception:
75
+ if done_event.is_set():
76
+ return
77
+ raise
78
+ if r is None:
79
+ break
80
+ if isinstance(r[1], ExceptionWrapper):
81
+ out_queue.put(r)
82
+ continue
83
+ idx, batch = r
84
+ try:
85
+ if pin_memory:
86
+ batch = pin_memory_batch(batch)
87
+ except Exception:
88
+ out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
89
+ else:
90
+ out_queue.put((idx, batch))
91
+
92
+ numpy_type_map = {
93
+ 'float64': torch.DoubleTensor,
94
+ 'float32': torch.FloatTensor,
95
+ 'float16': torch.HalfTensor,
96
+ 'int64': torch.LongTensor,
97
+ 'int32': torch.IntTensor,
98
+ 'int16': torch.ShortTensor,
99
+ 'int8': torch.CharTensor,
100
+ 'uint8': torch.ByteTensor,
101
+ }
102
+
103
+
104
+ def default_collate(batch):
105
+ "Puts each data field into a tensor with outer dimension batch size"
106
+
107
+ error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
108
+ elem_type = type(batch[0])
109
+ if torch.is_tensor(batch[0]):
110
+ out = None
111
+ if _use_shared_memory:
112
+ # If we're in a background process, concatenate directly into a
113
+ # shared memory tensor to avoid an extra copy
114
+ numel = sum([x.numel() for x in batch])
115
+ storage = batch[0].storage()._new_shared(numel)
116
+ out = batch[0].new(storage)
117
+ return torch.stack(batch, 0, out=out)
118
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
119
+ and elem_type.__name__ != 'string_':
120
+ elem = batch[0]
121
+ if elem_type.__name__ == 'ndarray':
122
+ # array of string classes and object
123
+ if re.search('[SaUO]', elem.dtype.str) is not None:
124
+ raise TypeError(error_msg.format(elem.dtype))
125
+
126
+ return torch.stack([torch.from_numpy(b) for b in batch], 0)
127
+ if elem.shape == (): # scalars
128
+ py_type = float if elem.dtype.name.startswith('float') else int
129
+ return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
130
+ elif isinstance(batch[0], int_classes):
131
+ return torch.LongTensor(batch)
132
+ elif isinstance(batch[0], float):
133
+ return torch.DoubleTensor(batch)
134
+ elif isinstance(batch[0], string_classes):
135
+ return batch
136
+ elif isinstance(batch[0], collections.Mapping):
137
+ return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
138
+ elif isinstance(batch[0], collections.Sequence):
139
+ transposed = zip(*batch)
140
+ return [default_collate(samples) for samples in transposed]
141
+
142
+ raise TypeError((error_msg.format(type(batch[0]))))
143
+
144
+
145
+ def pin_memory_batch(batch):
146
+ if torch.is_tensor(batch):
147
+ return batch.pin_memory()
148
+ elif isinstance(batch, string_classes):
149
+ return batch
150
+ elif isinstance(batch, collections.Mapping):
151
+ return {k: pin_memory_batch(sample) for k, sample in batch.items()}
152
+ elif isinstance(batch, collections.Sequence):
153
+ return [pin_memory_batch(sample) for sample in batch]
154
+ else:
155
+ return batch
156
+
157
+
158
+ _SIGCHLD_handler_set = False
159
+ """Whether SIGCHLD handler is set for DataLoader worker failures. Only one
160
+ handler needs to be set for all DataLoaders in a process."""
161
+
162
+
163
+ def _set_SIGCHLD_handler():
164
+ # Windows doesn't support SIGCHLD handler
165
+ if sys.platform == 'win32':
166
+ return
167
+ # can't set signal in child threads
168
+ if not isinstance(threading.current_thread(), threading._MainThread):
169
+ return
170
+ global _SIGCHLD_handler_set
171
+ if _SIGCHLD_handler_set:
172
+ return
173
+ previous_handler = signal.getsignal(signal.SIGCHLD)
174
+ if not callable(previous_handler):
175
+ previous_handler = None
176
+
177
+ def handler(signum, frame):
178
+ # This following call uses `waitid` with WNOHANG from C side. Therefore,
179
+ # Python can still get and update the process status successfully.
180
+ _error_if_any_worker_fails()
181
+ if previous_handler is not None:
182
+ previous_handler(signum, frame)
183
+
184
+ signal.signal(signal.SIGCHLD, handler)
185
+ _SIGCHLD_handler_set = True
186
+
187
+
188
+ class DataLoaderIter(object):
189
+ "Iterates once over the DataLoader's dataset, as specified by the sampler"
190
+
191
+ def __init__(self, loader):
192
+ self.dataset = loader.dataset
193
+ self.collate_fn = loader.collate_fn
194
+ self.batch_sampler = loader.batch_sampler
195
+ self.num_workers = loader.num_workers
196
+ self.pin_memory = loader.pin_memory and torch.cuda.is_available()
197
+ self.timeout = loader.timeout
198
+ self.done_event = threading.Event()
199
+
200
+ self.sample_iter = iter(self.batch_sampler)
201
+
202
+ if self.num_workers > 0:
203
+ self.worker_init_fn = loader.worker_init_fn
204
+ self.index_queue = multiprocessing.SimpleQueue()
205
+ self.worker_result_queue = multiprocessing.SimpleQueue()
206
+ self.batches_outstanding = 0
207
+ self.worker_pids_set = False
208
+ self.shutdown = False
209
+ self.send_idx = 0
210
+ self.rcvd_idx = 0
211
+ self.reorder_dict = {}
212
+
213
+ base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0]
214
+ self.workers = [
215
+ multiprocessing.Process(
216
+ target=_worker_loop,
217
+ args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
218
+ base_seed + i, self.worker_init_fn, i))
219
+ for i in range(self.num_workers)]
220
+
221
+ if self.pin_memory or self.timeout > 0:
222
+ self.data_queue = queue.Queue()
223
+ if self.pin_memory:
224
+ maybe_device_id = torch.cuda.current_device()
225
+ else:
226
+ # do not initialize cuda context if not necessary
227
+ maybe_device_id = None
228
+ self.worker_manager_thread = threading.Thread(
229
+ target=_worker_manager_loop,
230
+ args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
231
+ maybe_device_id))
232
+ self.worker_manager_thread.daemon = True
233
+ self.worker_manager_thread.start()
234
+ else:
235
+ self.data_queue = self.worker_result_queue
236
+
237
+ for w in self.workers:
238
+ w.daemon = True # ensure that the worker exits on process exit
239
+ w.start()
240
+
241
+ _set_worker_pids(id(self), tuple(w.pid for w in self.workers))
242
+ _set_SIGCHLD_handler()
243
+ self.worker_pids_set = True
244
+
245
+ # prime the prefetch loop
246
+ for _ in range(2 * self.num_workers):
247
+ self._put_indices()
248
+
249
+ def __len__(self):
250
+ return len(self.batch_sampler)
251
+
252
+ def _get_batch(self):
253
+ if self.timeout > 0:
254
+ try:
255
+ return self.data_queue.get(timeout=self.timeout)
256
+ except queue.Empty:
257
+ raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
258
+ else:
259
+ return self.data_queue.get()
260
+
261
+ def __next__(self):
262
+ if self.num_workers == 0: # same-process loading
263
+ indices = next(self.sample_iter) # may raise StopIteration
264
+ batch = self.collate_fn([self.dataset[i] for i in indices])
265
+ if self.pin_memory:
266
+ batch = pin_memory_batch(batch)
267
+ return batch
268
+
269
+ # check if the next sample has already been generated
270
+ if self.rcvd_idx in self.reorder_dict:
271
+ batch = self.reorder_dict.pop(self.rcvd_idx)
272
+ return self._process_next_batch(batch)
273
+
274
+ if self.batches_outstanding == 0:
275
+ self._shutdown_workers()
276
+ raise StopIteration
277
+
278
+ while True:
279
+ assert (not self.shutdown and self.batches_outstanding > 0)
280
+ idx, batch = self._get_batch()
281
+ self.batches_outstanding -= 1
282
+ if idx != self.rcvd_idx:
283
+ # store out-of-order samples
284
+ self.reorder_dict[idx] = batch
285
+ continue
286
+ return self._process_next_batch(batch)
287
+
288
+ next = __next__ # Python 2 compatibility
289
+
290
+ def __iter__(self):
291
+ return self
292
+
293
+ def _put_indices(self):
294
+ assert self.batches_outstanding < 2 * self.num_workers
295
+ indices = next(self.sample_iter, None)
296
+ if indices is None:
297
+ return
298
+ self.index_queue.put((self.send_idx, indices))
299
+ self.batches_outstanding += 1
300
+ self.send_idx += 1
301
+
302
+ def _process_next_batch(self, batch):
303
+ self.rcvd_idx += 1
304
+ self._put_indices()
305
+ if isinstance(batch, ExceptionWrapper):
306
+ raise batch.exc_type(batch.exc_msg)
307
+ return batch
308
+
309
+ def __getstate__(self):
310
+ # TODO: add limited pickling support for sharing an iterator
311
+ # across multiple threads for HOGWILD.
312
+ # Probably the best way to do this is by moving the sample pushing
313
+ # to a separate thread and then just sharing the data queue
314
+ # but signalling the end is tricky without a non-blocking API
315
+ raise NotImplementedError("DataLoaderIterator cannot be pickled")
316
+
317
+ def _shutdown_workers(self):
318
+ try:
319
+ if not self.shutdown:
320
+ self.shutdown = True
321
+ self.done_event.set()
322
+ # if worker_manager_thread is waiting to put
323
+ while not self.data_queue.empty():
324
+ self.data_queue.get()
325
+ for _ in self.workers:
326
+ self.index_queue.put(None)
327
+ # done_event should be sufficient to exit worker_manager_thread,
328
+ # but be safe here and put another None
329
+ self.worker_result_queue.put(None)
330
+ finally:
331
+ # removes pids no matter what
332
+ if self.worker_pids_set:
333
+ _remove_worker_pids(id(self))
334
+ self.worker_pids_set = False
335
+
336
+ def __del__(self):
337
+ if self.num_workers > 0:
338
+ self._shutdown_workers()
339
+
340
+
341
+ class DataLoader(object):
342
+ """
343
+ Data loader. Combines a dataset and a sampler, and provides
344
+ single- or multi-process iterators over the dataset.
345
+
346
+ Arguments:
347
+ dataset (Dataset): dataset from which to load the data.
348
+ batch_size (int, optional): how many samples per batch to load
349
+ (default: 1).
350
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
351
+ at every epoch (default: False).
352
+ sampler (Sampler, optional): defines the strategy to draw samples from
353
+ the dataset. If specified, ``shuffle`` must be False.
354
+ batch_sampler (Sampler, optional): like sampler, but returns a batch of
355
+ indices at a time. Mutually exclusive with batch_size, shuffle,
356
+ sampler, and drop_last.
357
+ num_workers (int, optional): how many subprocesses to use for data
358
+ loading. 0 means that the data will be loaded in the main process.
359
+ (default: 0)
360
+ collate_fn (callable, optional): merges a list of samples to form a mini-batch.
361
+ pin_memory (bool, optional): If ``True``, the data loader will copy tensors
362
+ into CUDA pinned memory before returning them.
363
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
364
+ if the dataset size is not divisible by the batch size. If ``False`` and
365
+ the size of dataset is not divisible by the batch size, then the last batch
366
+ will be smaller. (default: False)
367
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
368
+ from workers. Should always be non-negative. (default: 0)
369
+ worker_init_fn (callable, optional): If not None, this will be called on each
370
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
371
+ input, after seeding and before data loading. (default: None)
372
+
373
+ .. note:: By default, each worker will have its PyTorch seed set to
374
+ ``base_seed + worker_id``, where ``base_seed`` is a long generated
375
+ by main process using its RNG. You may use ``torch.initial_seed()`` to access
376
+ this value in :attr:`worker_init_fn`, which can be used to set other seeds
377
+ (e.g. NumPy) before data loading.
378
+
379
+ .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an
380
+ unpicklable object, e.g., a lambda function.
381
+ """
382
+
383
+ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
384
+ num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
385
+ timeout=0, worker_init_fn=None):
386
+ self.dataset = dataset
387
+ self.batch_size = batch_size
388
+ self.num_workers = num_workers
389
+ self.collate_fn = collate_fn
390
+ self.pin_memory = pin_memory
391
+ self.drop_last = drop_last
392
+ self.timeout = timeout
393
+ self.worker_init_fn = worker_init_fn
394
+
395
+ if timeout < 0:
396
+ raise ValueError('timeout option should be non-negative')
397
+
398
+ if batch_sampler is not None:
399
+ if batch_size > 1 or shuffle or sampler is not None or drop_last:
400
+ raise ValueError('batch_sampler is mutually exclusive with '
401
+ 'batch_size, shuffle, sampler, and drop_last')
402
+
403
+ if sampler is not None and shuffle:
404
+ raise ValueError('sampler is mutually exclusive with shuffle')
405
+
406
+ if self.num_workers < 0:
407
+ raise ValueError('num_workers cannot be negative; '
408
+ 'use num_workers=0 to disable multiprocessing.')
409
+
410
+ if batch_sampler is None:
411
+ if sampler is None:
412
+ if shuffle:
413
+ sampler = RandomSampler(dataset)
414
+ else:
415
+ sampler = SequentialSampler(dataset)
416
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
417
+
418
+ self.sampler = sampler
419
+ self.batch_sampler = batch_sampler
420
+
421
+ def __iter__(self):
422
+ return DataLoaderIter(self)
423
+
424
+ def __len__(self):
425
+ return len(self.batch_sampler)
models/ade20k/segm_lib/utils/data/dataset.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import warnings
3
+
4
+ from torch._utils import _accumulate
5
+ from torch import randperm
6
+
7
+
8
+ class Dataset(object):
9
+ """An abstract class representing a Dataset.
10
+
11
+ All other datasets should subclass it. All subclasses should override
12
+ ``__len__``, that provides the size of the dataset, and ``__getitem__``,
13
+ supporting integer indexing in range from 0 to len(self) exclusive.
14
+ """
15
+
16
+ def __getitem__(self, index):
17
+ raise NotImplementedError
18
+
19
+ def __len__(self):
20
+ raise NotImplementedError
21
+
22
+ def __add__(self, other):
23
+ return ConcatDataset([self, other])
24
+
25
+
26
+ class TensorDataset(Dataset):
27
+ """Dataset wrapping data and target tensors.
28
+
29
+ Each sample will be retrieved by indexing both tensors along the first
30
+ dimension.
31
+
32
+ Arguments:
33
+ data_tensor (Tensor): contains sample data.
34
+ target_tensor (Tensor): contains sample targets (labels).
35
+ """
36
+
37
+ def __init__(self, data_tensor, target_tensor):
38
+ assert data_tensor.size(0) == target_tensor.size(0)
39
+ self.data_tensor = data_tensor
40
+ self.target_tensor = target_tensor
41
+
42
+ def __getitem__(self, index):
43
+ return self.data_tensor[index], self.target_tensor[index]
44
+
45
+ def __len__(self):
46
+ return self.data_tensor.size(0)
47
+
48
+
49
+ class ConcatDataset(Dataset):
50
+ """
51
+ Dataset to concatenate multiple datasets.
52
+ Purpose: useful to assemble different existing datasets, possibly
53
+ large-scale datasets as the concatenation operation is done in an
54
+ on-the-fly manner.
55
+
56
+ Arguments:
57
+ datasets (iterable): List of datasets to be concatenated
58
+ """
59
+
60
+ @staticmethod
61
+ def cumsum(sequence):
62
+ r, s = [], 0
63
+ for e in sequence:
64
+ l = len(e)
65
+ r.append(l + s)
66
+ s += l
67
+ return r
68
+
69
+ def __init__(self, datasets):
70
+ super(ConcatDataset, self).__init__()
71
+ assert len(datasets) > 0, 'datasets should not be an empty iterable'
72
+ self.datasets = list(datasets)
73
+ self.cumulative_sizes = self.cumsum(self.datasets)
74
+
75
+ def __len__(self):
76
+ return self.cumulative_sizes[-1]
77
+
78
+ def __getitem__(self, idx):
79
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
80
+ if dataset_idx == 0:
81
+ sample_idx = idx
82
+ else:
83
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
84
+ return self.datasets[dataset_idx][sample_idx]
85
+
86
+ @property
87
+ def cummulative_sizes(self):
88
+ warnings.warn("cummulative_sizes attribute is renamed to "
89
+ "cumulative_sizes", DeprecationWarning, stacklevel=2)
90
+ return self.cumulative_sizes
91
+
92
+
93
+ class Subset(Dataset):
94
+ def __init__(self, dataset, indices):
95
+ self.dataset = dataset
96
+ self.indices = indices
97
+
98
+ def __getitem__(self, idx):
99
+ return self.dataset[self.indices[idx]]
100
+
101
+ def __len__(self):
102
+ return len(self.indices)
103
+
104
+
105
+ def random_split(dataset, lengths):
106
+ """
107
+ Randomly split a dataset into non-overlapping new datasets of given lengths
108
+ ds
109
+
110
+ Arguments:
111
+ dataset (Dataset): Dataset to be split
112
+ lengths (iterable): lengths of splits to be produced
113
+ """
114
+ if sum(lengths) != len(dataset):
115
+ raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
116
+
117
+ indices = randperm(sum(lengths))
118
+ return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
models/ade20k/segm_lib/utils/data/distributed.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from .sampler import Sampler
4
+ from torch.distributed import get_world_size, get_rank
5
+
6
+
7
+ class DistributedSampler(Sampler):
8
+ """Sampler that restricts data loading to a subset of the dataset.
9
+
10
+ It is especially useful in conjunction with
11
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
12
+ process can pass a DistributedSampler instance as a DataLoader sampler,
13
+ and load a subset of the original dataset that is exclusive to it.
14
+
15
+ .. note::
16
+ Dataset is assumed to be of constant size.
17
+
18
+ Arguments:
19
+ dataset: Dataset used for sampling.
20
+ num_replicas (optional): Number of processes participating in
21
+ distributed training.
22
+ rank (optional): Rank of the current process within num_replicas.
23
+ """
24
+
25
+ def __init__(self, dataset, num_replicas=None, rank=None):
26
+ if num_replicas is None:
27
+ num_replicas = get_world_size()
28
+ if rank is None:
29
+ rank = get_rank()
30
+ self.dataset = dataset
31
+ self.num_replicas = num_replicas
32
+ self.rank = rank
33
+ self.epoch = 0
34
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
35
+ self.total_size = self.num_samples * self.num_replicas
36
+
37
+ def __iter__(self):
38
+ # deterministically shuffle based on epoch
39
+ g = torch.Generator()
40
+ g.manual_seed(self.epoch)
41
+ indices = list(torch.randperm(len(self.dataset), generator=g))
42
+
43
+ # add extra samples to make it evenly divisible
44
+ indices += indices[:(self.total_size - len(indices))]
45
+ assert len(indices) == self.total_size
46
+
47
+ # subsample
48
+ offset = self.num_samples * self.rank
49
+ indices = indices[offset:offset + self.num_samples]
50
+ assert len(indices) == self.num_samples
51
+
52
+ return iter(indices)
53
+
54
+ def __len__(self):
55
+ return self.num_samples
56
+
57
+ def set_epoch(self, epoch):
58
+ self.epoch = epoch
models/ade20k/segm_lib/utils/data/sampler.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class Sampler(object):
5
+ """Base class for all Samplers.
6
+
7
+ Every Sampler subclass has to provide an __iter__ method, providing a way
8
+ to iterate over indices of dataset elements, and a __len__ method that
9
+ returns the length of the returned iterators.
10
+ """
11
+
12
+ def __init__(self, data_source):
13
+ pass
14
+
15
+ def __iter__(self):
16
+ raise NotImplementedError
17
+
18
+ def __len__(self):
19
+ raise NotImplementedError
20
+
21
+
22
+ class SequentialSampler(Sampler):
23
+ """Samples elements sequentially, always in the same order.
24
+
25
+ Arguments:
26
+ data_source (Dataset): dataset to sample from
27
+ """
28
+
29
+ def __init__(self, data_source):
30
+ self.data_source = data_source
31
+
32
+ def __iter__(self):
33
+ return iter(range(len(self.data_source)))
34
+
35
+ def __len__(self):
36
+ return len(self.data_source)
37
+
38
+
39
+ class RandomSampler(Sampler):
40
+ """Samples elements randomly, without replacement.
41
+
42
+ Arguments:
43
+ data_source (Dataset): dataset to sample from
44
+ """
45
+
46
+ def __init__(self, data_source):
47
+ self.data_source = data_source
48
+
49
+ def __iter__(self):
50
+ return iter(torch.randperm(len(self.data_source)).long())
51
+
52
+ def __len__(self):
53
+ return len(self.data_source)
54
+
55
+
56
+ class SubsetRandomSampler(Sampler):
57
+ """Samples elements randomly from a given list of indices, without replacement.
58
+
59
+ Arguments:
60
+ indices (list): a list of indices
61
+ """
62
+
63
+ def __init__(self, indices):
64
+ self.indices = indices
65
+
66
+ def __iter__(self):
67
+ return (self.indices[i] for i in torch.randperm(len(self.indices)))
68
+
69
+ def __len__(self):
70
+ return len(self.indices)
71
+
72
+
73
+ class WeightedRandomSampler(Sampler):
74
+ """Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
75
+
76
+ Arguments:
77
+ weights (list) : a list of weights, not necessary summing up to one
78
+ num_samples (int): number of samples to draw
79
+ replacement (bool): if ``True``, samples are drawn with replacement.
80
+ If not, they are drawn without replacement, which means that when a
81
+ sample index is drawn for a row, it cannot be drawn again for that row.
82
+ """
83
+
84
+ def __init__(self, weights, num_samples, replacement=True):
85
+ self.weights = torch.DoubleTensor(weights)
86
+ self.num_samples = num_samples
87
+ self.replacement = replacement
88
+
89
+ def __iter__(self):
90
+ return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
91
+
92
+ def __len__(self):
93
+ return self.num_samples
94
+
95
+
96
+ class BatchSampler(object):
97
+ """Wraps another sampler to yield a mini-batch of indices.
98
+
99
+ Args:
100
+ sampler (Sampler): Base sampler.
101
+ batch_size (int): Size of mini-batch.
102
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
103
+ its size would be less than ``batch_size``
104
+
105
+ Example:
106
+ >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
107
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
108
+ >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
109
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
110
+ """
111
+
112
+ def __init__(self, sampler, batch_size, drop_last):
113
+ self.sampler = sampler
114
+ self.batch_size = batch_size
115
+ self.drop_last = drop_last
116
+
117
+ def __iter__(self):
118
+ batch = []
119
+ for idx in self.sampler:
120
+ batch.append(idx)
121
+ if len(batch) == self.batch_size:
122
+ yield batch
123
+ batch = []
124
+ if len(batch) > 0 and not self.drop_last:
125
+ yield batch
126
+
127
+ def __len__(self):
128
+ if self.drop_last:
129
+ return len(self.sampler) // self.batch_size
130
+ else:
131
+ return (len(self.sampler) + self.batch_size - 1) // self.batch_size
models/ade20k/segm_lib/utils/th.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ import numpy as np
4
+ import collections
5
+
6
+ __all__ = ['as_variable', 'as_numpy', 'mark_volatile']
7
+
8
+ def as_variable(obj):
9
+ if isinstance(obj, Variable):
10
+ return obj
11
+ if isinstance(obj, collections.Sequence):
12
+ return [as_variable(v) for v in obj]
13
+ elif isinstance(obj, collections.Mapping):
14
+ return {k: as_variable(v) for k, v in obj.items()}
15
+ else:
16
+ return Variable(obj)
17
+
18
+ def as_numpy(obj):
19
+ if isinstance(obj, collections.Sequence):
20
+ return [as_numpy(v) for v in obj]
21
+ elif isinstance(obj, collections.Mapping):
22
+ return {k: as_numpy(v) for k, v in obj.items()}
23
+ elif isinstance(obj, Variable):
24
+ return obj.data.cpu().numpy()
25
+ elif torch.is_tensor(obj):
26
+ return obj.cpu().numpy()
27
+ else:
28
+ return np.array(obj)
29
+
30
+ def mark_volatile(obj):
31
+ if torch.is_tensor(obj):
32
+ obj = Variable(obj)
33
+ if isinstance(obj, Variable):
34
+ obj.no_grad = True
35
+ return obj
36
+ elif isinstance(obj, collections.Mapping):
37
+ return {k: mark_volatile(o) for k, o in obj.items()}
38
+ elif isinstance(obj, collections.Sequence):
39
+ return [mark_volatile(o) for o in obj]
40
+ else:
41
+ return obj
models/ade20k/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
2
+
3
+ import os
4
+ import sys
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ try:
10
+ from urllib import urlretrieve
11
+ except ImportError:
12
+ from urllib.request import urlretrieve
13
+
14
+
15
+ def load_url(url, model_dir='./pretrained', map_location=None):
16
+ if not os.path.exists(model_dir):
17
+ os.makedirs(model_dir)
18
+ filename = url.split('/')[-1]
19
+ cached_file = os.path.join(model_dir, filename)
20
+ if not os.path.exists(cached_file):
21
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
22
+ urlretrieve(url, cached_file)
23
+ return torch.load(cached_file, map_location=map_location)
24
+
25
+
26
+ def color_encode(labelmap, colors, mode='RGB'):
27
+ labelmap = labelmap.astype('int')
28
+ labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
29
+ dtype=np.uint8)
30
+ for label in np.unique(labelmap):
31
+ if label < 0:
32
+ continue
33
+ labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
34
+ np.tile(colors[label],
35
+ (labelmap.shape[0], labelmap.shape[1], 1))
36
+
37
+ if mode == 'BGR':
38
+ return labelmap_rgb[:, :, ::-1]
39
+ else:
40
+ return labelmap_rgb
requirements.txt ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.5
2
+ aiosignal==1.3.1
3
+ albumentations==1.3.1
4
+ antlr4-python3-runtime==4.9.3
5
+ async-timeout==4.0.3
6
+ attrs==23.1.0
7
+ braceexpand==0.1.7
8
+ certifi==2023.7.22
9
+ charset-normalizer==3.2.0
10
+ colorama==0.4.6
11
+ contourpy==1.1.0
12
+ cycler==0.11.0
13
+ easydict==1.10
14
+ filelock==3.12.2
15
+ fonttools==4.42.0
16
+ frozenlist==1.4.0
17
+ fsspec==2023.6.0
18
+ huggingface-hub==0.16.4
19
+ idna==3.4
20
+ imageio==2.31.1
21
+ imgaug==0.4.0
22
+ Jinja2==3.1.2
23
+ joblib==1.3.2
24
+ kiwisolver==1.4.4
25
+ kornia==0.5.0
26
+ lazy_loader==0.3
27
+ lightning-utilities==0.9.0
28
+ MarkupSafe==2.1.3
29
+ matplotlib==3.7.2
30
+ mpmath==1.3.0
31
+ multidict==6.0.4
32
+ networkx==3.1
33
+ numpy==1.25.2
34
+ omegaconf==2.3.0
35
+ opencv-python==4.8.0.76
36
+ opencv-python-headless==4.8.0.76
37
+ packaging==23.1
38
+ pandas==2.0.3
39
+ Pillow==10.0.0
40
+ pip==23.2.1
41
+ pyparsing==3.0.9
42
+ python-dateutil==2.8.2
43
+ pytorch-lightning==2.0.7
44
+ pytz==2023.3
45
+ PyWavelets==1.4.1
46
+ PyYAML==6.0.1
47
+ qudida==0.0.4
48
+ regex==2023.8.8
49
+ requests==2.31.0
50
+ safetensors==0.3.2
51
+ scikit-image==0.21.0
52
+ scikit-learn==1.3.0
53
+ scipy==1.11.2
54
+ setuptools==68.0.0
55
+ shapely==2.0.1
56
+ six==1.16.0
57
+ sympy==1.12
58
+ threadpoolctl==3.2.0
59
+ tifffile==2023.8.12
60
+ tokenizers==0.13.3
61
+ torch==2.0.1
62
+ torchmetrics==1.0.3
63
+ torchvision==0.15.2
64
+ tqdm==4.66.1
65
+ transformers==4.31.0
66
+ typing_extensions==4.7.1
67
+ tzdata==2023.3
68
+ urllib3==2.0.4
69
+ webdataset==0.2.48
70
+ wheel==0.38.4
71
+ yarl==1.9.2
saicinpainting/__init__.py ADDED
File without changes
saicinpainting/evaluation/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+
5
+ from saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1
6
+ from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
7
+
8
+
9
+ def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs):
10
+ logging.info(f'Make evaluator {kind}')
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ metrics = {}
13
+ if ssim:
14
+ metrics['ssim'] = SSIMScore()
15
+ if lpips:
16
+ metrics['lpips'] = LPIPSScore()
17
+ if fid:
18
+ metrics['fid'] = FIDScore().to(device)
19
+
20
+ if integral_kind is None:
21
+ integral_func = None
22
+ elif integral_kind == 'ssim_fid100_f1':
23
+ integral_func = ssim_fid100_f1
24
+ elif integral_kind == 'lpips_fid100_f1':
25
+ integral_func = lpips_fid100_f1
26
+ else:
27
+ raise ValueError(f'Unexpected integral_kind={integral_kind}')
28
+
29
+ if kind == 'default':
30
+ return InpaintingEvaluatorOnline(scores=metrics,
31
+ integral_func=integral_func,
32
+ integral_title=integral_kind,
33
+ **kwargs)
saicinpainting/evaluation/data.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import cv2
5
+ import PIL.Image as Image
6
+ import numpy as np
7
+
8
+ from torch.utils.data import Dataset
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def load_image(fname, mode='RGB', return_orig=False):
13
+ img = np.array(Image.open(fname).convert(mode))
14
+ if img.ndim == 3:
15
+ img = np.transpose(img, (2, 0, 1))
16
+ out_img = img.astype('float32') / 255
17
+ if return_orig:
18
+ return out_img, img
19
+ else:
20
+ return out_img
21
+
22
+
23
+ def ceil_modulo(x, mod):
24
+ if x % mod == 0:
25
+ return x
26
+ return (x // mod + 1) * mod
27
+
28
+
29
+ def pad_img_to_modulo(img, mod):
30
+ channels, height, width = img.shape
31
+ out_height = ceil_modulo(height, mod)
32
+ out_width = ceil_modulo(width, mod)
33
+ return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')
34
+
35
+
36
+ def pad_tensor_to_modulo(img, mod):
37
+ batch_size, channels, height, width = img.shape
38
+ out_height = ceil_modulo(height, mod)
39
+ out_width = ceil_modulo(width, mod)
40
+ return F.pad(img, pad=(0, out_width - width, 0, out_height - height), mode='reflect')
41
+
42
+
43
+ def scale_image(img, factor, interpolation=cv2.INTER_AREA):
44
+ if img.shape[0] == 1:
45
+ img = img[0]
46
+ else:
47
+ img = np.transpose(img, (1, 2, 0))
48
+
49
+ img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
50
+
51
+ if img.ndim == 2:
52
+ img = img[None, ...]
53
+ else:
54
+ img = np.transpose(img, (2, 0, 1))
55
+ return img
56
+
57
+
58
+ class InpaintingDataset(Dataset):
59
+ def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
60
+ self.datadir = datadir
61
+ self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, '**', '*mask*.png'), recursive=True)))
62
+ self.img_filenames = [fname.rsplit('_mask', 1)[0] + img_suffix for fname in self.mask_filenames]
63
+ self.pad_out_to_modulo = pad_out_to_modulo
64
+ self.scale_factor = scale_factor
65
+
66
+ def __len__(self):
67
+ return len(self.mask_filenames)
68
+
69
+ def __getitem__(self, i):
70
+ image = load_image(self.img_filenames[i], mode='RGB')
71
+ mask = load_image(self.mask_filenames[i], mode='L')
72
+ result = dict(image=image, mask=mask[None, ...])
73
+
74
+ if self.scale_factor is not None:
75
+ result['image'] = scale_image(result['image'], self.scale_factor)
76
+ result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
77
+
78
+ if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
79
+ result['unpad_to_size'] = result['image'].shape[1:]
80
+ result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
81
+ result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
82
+
83
+ return result
84
+
85
+ class OurInpaintingDataset(Dataset):
86
+ def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
87
+ self.datadir = datadir
88
+ self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, 'mask', '**', '*mask*.png'), recursive=True)))
89
+ self.img_filenames = [os.path.join(self.datadir, 'img', os.path.basename(fname.rsplit('-', 1)[0].rsplit('_', 1)[0]) + '.png') for fname in self.mask_filenames]
90
+ self.pad_out_to_modulo = pad_out_to_modulo
91
+ self.scale_factor = scale_factor
92
+
93
+ def __len__(self):
94
+ return len(self.mask_filenames)
95
+
96
+ def __getitem__(self, i):
97
+ result = dict(image=load_image(self.img_filenames[i], mode='RGB'),
98
+ mask=load_image(self.mask_filenames[i], mode='L')[None, ...])
99
+
100
+ if self.scale_factor is not None:
101
+ result['image'] = scale_image(result['image'], self.scale_factor)
102
+ result['mask'] = scale_image(result['mask'], self.scale_factor)
103
+
104
+ if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
105
+ result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
106
+ result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
107
+
108
+ return result
109
+
110
+ class PrecomputedInpaintingResultsDataset(InpaintingDataset):
111
+ def __init__(self, datadir, predictdir, inpainted_suffix='_inpainted.jpg', **kwargs):
112
+ super().__init__(datadir, **kwargs)
113
+ if not datadir.endswith('/'):
114
+ datadir += '/'
115
+ self.predictdir = predictdir
116
+ self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
117
+ for fname in self.mask_filenames]
118
+
119
+ def __getitem__(self, i):
120
+ result = super().__getitem__(i)
121
+ result['inpainted'] = load_image(self.pred_filenames[i])
122
+ if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
123
+ result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
124
+ return result
125
+
126
+ class OurPrecomputedInpaintingResultsDataset(OurInpaintingDataset):
127
+ def __init__(self, datadir, predictdir, inpainted_suffix="png", **kwargs):
128
+ super().__init__(datadir, **kwargs)
129
+ if not datadir.endswith('/'):
130
+ datadir += '/'
131
+ self.predictdir = predictdir
132
+ self.pred_filenames = [os.path.join(predictdir, os.path.basename(os.path.splitext(fname)[0]) + f'_inpainted.{inpainted_suffix}')
133
+ for fname in self.mask_filenames]
134
+ # self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
135
+ # for fname in self.mask_filenames]
136
+
137
+ def __getitem__(self, i):
138
+ result = super().__getitem__(i)
139
+ result['inpainted'] = self.file_loader(self.pred_filenames[i])
140
+
141
+ if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
142
+ result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
143
+ return result
144
+
145
+ class InpaintingEvalOnlineDataset(Dataset):
146
+ def __init__(self, indir, mask_generator, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None, **kwargs):
147
+ self.indir = indir
148
+ self.mask_generator = mask_generator
149
+ self.img_filenames = sorted(list(glob.glob(os.path.join(self.indir, '**', f'*{img_suffix}' ), recursive=True)))
150
+ self.pad_out_to_modulo = pad_out_to_modulo
151
+ self.scale_factor = scale_factor
152
+
153
+ def __len__(self):
154
+ return len(self.img_filenames)
155
+
156
+ def __getitem__(self, i):
157
+ img, raw_image = load_image(self.img_filenames[i], mode='RGB', return_orig=True)
158
+ mask = self.mask_generator(img, raw_image=raw_image)
159
+ result = dict(image=img, mask=mask)
160
+
161
+ if self.scale_factor is not None:
162
+ result['image'] = scale_image(result['image'], self.scale_factor)
163
+ result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
164
+
165
+ if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
166
+ result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
167
+ result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
168
+ return result
saicinpainting/evaluation/evaluator.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Dict
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import tqdm
9
+ from torch.utils.data import DataLoader
10
+
11
+ from saicinpainting.evaluation.utils import move_to_device
12
+
13
+ LOGGER = logging.getLogger(__name__)
14
+
15
+
16
+ class InpaintingEvaluator():
17
+ def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda',
18
+ integral_func=None, integral_title=None, clamp_image_range=None):
19
+ """
20
+ :param dataset: torch.utils.data.Dataset which contains images and masks
21
+ :param scores: dict {score_name: EvaluatorScore object}
22
+ :param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples
23
+ which are defined by share of area occluded by mask
24
+ :param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
25
+ :param batch_size: batch_size for the dataloader
26
+ :param device: device to use
27
+ """
28
+ self.scores = scores
29
+ self.dataset = dataset
30
+
31
+ self.area_grouping = area_grouping
32
+ self.bins = bins
33
+
34
+ self.device = torch.device(device)
35
+
36
+ self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size)
37
+
38
+ self.integral_func = integral_func
39
+ self.integral_title = integral_title
40
+ self.clamp_image_range = clamp_image_range
41
+
42
+ def _get_bin_edges(self):
43
+ bin_edges = np.linspace(0, 1, self.bins + 1)
44
+
45
+ num_digits = max(0, math.ceil(math.log10(self.bins)) - 1)
46
+ interval_names = []
47
+ for idx_bin in range(self.bins):
48
+ start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \
49
+ round(100 * bin_edges[idx_bin + 1], num_digits)
50
+ start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
51
+ end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
52
+ interval_names.append("{0}-{1}%".format(start_percent, end_percent))
53
+
54
+ groups = []
55
+ for batch in self.dataloader:
56
+ mask = batch['mask']
57
+ batch_size = mask.shape[0]
58
+ area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1)
59
+ bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 1
60
+ # corner case: when area is equal to 1, bin_indices should return bins - 1, not bins for that element
61
+ bin_indices[bin_indices == self.bins] = self.bins - 1
62
+ groups.append(bin_indices)
63
+ groups = np.hstack(groups)
64
+
65
+ return groups, interval_names
66
+
67
+ def evaluate(self, model=None):
68
+ """
69
+ :param model: callable with signature (image_batch, mask_batch); should return inpainted_batch
70
+ :return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
71
+ name of the particular group arranged by area of mask (e.g. '10-20%')
72
+ and score statistics for the group as values.
73
+ """
74
+ results = dict()
75
+ if self.area_grouping:
76
+ groups, interval_names = self._get_bin_edges()
77
+ else:
78
+ groups = None
79
+
80
+ for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'):
81
+ score.to(self.device)
82
+ with torch.no_grad():
83
+ score.reset()
84
+ for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False):
85
+ batch = move_to_device(batch, self.device)
86
+ image_batch, mask_batch = batch['image'], batch['mask']
87
+ if self.clamp_image_range is not None:
88
+ image_batch = torch.clamp(image_batch,
89
+ min=self.clamp_image_range[0],
90
+ max=self.clamp_image_range[1])
91
+ if model is None:
92
+ assert 'inpainted' in batch, \
93
+ 'Model is None, so we expected precomputed inpainting results at key "inpainted"'
94
+ inpainted_batch = batch['inpainted']
95
+ else:
96
+ inpainted_batch = model(image_batch, mask_batch)
97
+ score(inpainted_batch, image_batch, mask_batch)
98
+ total_results, group_results = score.get_value(groups=groups)
99
+
100
+ results[(score_name, 'total')] = total_results
101
+ if groups is not None:
102
+ for group_index, group_values in group_results.items():
103
+ group_name = interval_names[group_index]
104
+ results[(score_name, group_name)] = group_values
105
+
106
+ if self.integral_func is not None:
107
+ results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
108
+
109
+ return results
110
+
111
+
112
+ def ssim_fid100_f1(metrics, fid_scale=100):
113
+ ssim = metrics[('ssim', 'total')]['mean']
114
+ fid = metrics[('fid', 'total')]['mean']
115
+ fid_rel = max(0, fid_scale - fid) / fid_scale
116
+ f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
117
+ return f1
118
+
119
+
120
+ def lpips_fid100_f1(metrics, fid_scale=100):
121
+ neg_lpips = 1 - metrics[('lpips', 'total')]['mean'] # invert, so bigger is better
122
+ fid = metrics[('fid', 'total')]['mean']
123
+ fid_rel = max(0, fid_scale - fid) / fid_scale
124
+ f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3)
125
+ return f1
126
+
127
+
128
+
129
+ class InpaintingEvaluatorOnline(nn.Module):
130
+ def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted',
131
+ integral_func=None, integral_title=None, clamp_image_range=None):
132
+ """
133
+ :param scores: dict {score_name: EvaluatorScore object}
134
+ :param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
135
+ :param device: device to use
136
+ """
137
+ super().__init__()
138
+ LOGGER.info(f'{type(self)} init called')
139
+ self.scores = nn.ModuleDict(scores)
140
+ self.image_key = image_key
141
+ self.inpainted_key = inpainted_key
142
+ self.bins_num = bins
143
+ self.bin_edges = np.linspace(0, 1, self.bins_num + 1)
144
+
145
+ num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1)
146
+ self.interval_names = []
147
+ for idx_bin in range(self.bins_num):
148
+ start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \
149
+ round(100 * self.bin_edges[idx_bin + 1], num_digits)
150
+ start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
151
+ end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
152
+ self.interval_names.append("{0}-{1}%".format(start_percent, end_percent))
153
+
154
+ self.groups = []
155
+
156
+ self.integral_func = integral_func
157
+ self.integral_title = integral_title
158
+ self.clamp_image_range = clamp_image_range
159
+
160
+ LOGGER.info(f'{type(self)} init done')
161
+
162
+ def _get_bins(self, mask_batch):
163
+ batch_size = mask_batch.shape[0]
164
+ area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy()
165
+ bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1)
166
+ return bin_indices
167
+
168
+ def forward(self, batch: Dict[str, torch.Tensor]):
169
+ """
170
+ Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end
171
+ :param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key)
172
+ """
173
+ result = {}
174
+ with torch.no_grad():
175
+ image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key]
176
+ if self.clamp_image_range is not None:
177
+ image_batch = torch.clamp(image_batch,
178
+ min=self.clamp_image_range[0],
179
+ max=self.clamp_image_range[1])
180
+ self.groups.extend(self._get_bins(mask_batch))
181
+
182
+ for score_name, score in self.scores.items():
183
+ result[score_name] = score(inpainted_batch, image_batch, mask_batch)
184
+ return result
185
+
186
+ def process_batch(self, batch: Dict[str, torch.Tensor]):
187
+ return self(batch)
188
+
189
+ def evaluation_end(self, states=None):
190
+ """:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
191
+ name of the particular group arranged by area of mask (e.g. '10-20%')
192
+ and score statistics for the group as values.
193
+ """
194
+ LOGGER.info(f'{type(self)}: evaluation_end called')
195
+
196
+ self.groups = np.array(self.groups)
197
+
198
+ results = {}
199
+ for score_name, score in self.scores.items():
200
+ LOGGER.info(f'Getting value of {score_name}')
201
+ cur_states = [s[score_name] for s in states] if states is not None else None
202
+ total_results, group_results = score.get_value(groups=self.groups, states=cur_states)
203
+ LOGGER.info(f'Getting value of {score_name} done')
204
+ results[(score_name, 'total')] = total_results
205
+
206
+ for group_index, group_values in group_results.items():
207
+ group_name = self.interval_names[group_index]
208
+ results[(score_name, group_name)] = group_values
209
+
210
+ if self.integral_func is not None:
211
+ results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
212
+
213
+ LOGGER.info(f'{type(self)}: reset scores')
214
+ self.groups = []
215
+ for sc in self.scores.values():
216
+ sc.reset()
217
+ LOGGER.info(f'{type(self)}: reset scores done')
218
+
219
+ LOGGER.info(f'{type(self)}: evaluation_end done')
220
+ return results
saicinpainting/evaluation/losses/__init__.py ADDED
File without changes
saicinpainting/evaluation/losses/base_loss.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from abc import abstractmethod, ABC
3
+
4
+ import numpy as np
5
+ import sklearn
6
+ import sklearn.svm
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from joblib import Parallel, delayed
11
+ from scipy import linalg
12
+
13
+ from models.ade20k import SegmentationModule, NUM_CLASS, segm_options
14
+ from .fid.inception import InceptionV3
15
+ from .lpips import PerceptualLoss
16
+ from .ssim import SSIM
17
+
18
+ LOGGER = logging.getLogger(__name__)
19
+
20
+
21
+ def get_groupings(groups):
22
+ """
23
+ :param groups: group numbers for respective elements
24
+ :return: dict of kind {group_idx: indices of the corresponding group elements}
25
+ """
26
+ label_groups, count_groups = np.unique(groups, return_counts=True)
27
+
28
+ indices = np.argsort(groups)
29
+
30
+ grouping = dict()
31
+ cur_start = 0
32
+ for label, count in zip(label_groups, count_groups):
33
+ cur_end = cur_start + count
34
+ cur_indices = indices[cur_start:cur_end]
35
+ grouping[label] = cur_indices
36
+ cur_start = cur_end
37
+ return grouping
38
+
39
+
40
+ class EvaluatorScore(nn.Module):
41
+ @abstractmethod
42
+ def forward(self, pred_batch, target_batch, mask):
43
+ pass
44
+
45
+ @abstractmethod
46
+ def get_value(self, groups=None, states=None):
47
+ pass
48
+
49
+ @abstractmethod
50
+ def reset(self):
51
+ pass
52
+
53
+
54
+ class PairwiseScore(EvaluatorScore, ABC):
55
+ def __init__(self):
56
+ super().__init__()
57
+ self.individual_values = None
58
+
59
+ def get_value(self, groups=None, states=None):
60
+ """
61
+ :param groups:
62
+ :return:
63
+ total_results: dict of kind {'mean': score mean, 'std': score std}
64
+ group_results: None, if groups is None;
65
+ else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
66
+ """
67
+ individual_values = torch.cat(states, dim=-1).reshape(-1).cpu().numpy() if states is not None \
68
+ else self.individual_values
69
+
70
+ total_results = {
71
+ 'mean': individual_values.mean(),
72
+ 'std': individual_values.std()
73
+ }
74
+
75
+ if groups is None:
76
+ return total_results, None
77
+
78
+ group_results = dict()
79
+ grouping = get_groupings(groups)
80
+ for label, index in grouping.items():
81
+ group_scores = individual_values[index]
82
+ group_results[label] = {
83
+ 'mean': group_scores.mean(),
84
+ 'std': group_scores.std()
85
+ }
86
+ return total_results, group_results
87
+
88
+ def reset(self):
89
+ self.individual_values = []
90
+
91
+
92
+ class SSIMScore(PairwiseScore):
93
+ def __init__(self, window_size=11):
94
+ super().__init__()
95
+ self.score = SSIM(window_size=window_size, size_average=False).eval()
96
+ self.reset()
97
+
98
+ def forward(self, pred_batch, target_batch, mask=None):
99
+ batch_values = self.score(pred_batch, target_batch)
100
+ self.individual_values = np.hstack([
101
+ self.individual_values, batch_values.detach().cpu().numpy()
102
+ ])
103
+ return batch_values
104
+
105
+
106
+ class LPIPSScore(PairwiseScore):
107
+ def __init__(self, model='net-lin', net='vgg', model_path=None, use_gpu=True):
108
+ super().__init__()
109
+ self.score = PerceptualLoss(model=model, net=net, model_path=model_path,
110
+ use_gpu=use_gpu, spatial=False).eval()
111
+ self.reset()
112
+
113
+ def forward(self, pred_batch, target_batch, mask=None):
114
+ batch_values = self.score(pred_batch, target_batch).flatten()
115
+ self.individual_values = np.hstack([
116
+ self.individual_values, batch_values.detach().cpu().numpy()
117
+ ])
118
+ return batch_values
119
+
120
+
121
+ def fid_calculate_activation_statistics(act):
122
+ mu = np.mean(act, axis=0)
123
+ sigma = np.cov(act, rowvar=False)
124
+ return mu, sigma
125
+
126
+
127
+ def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6):
128
+ mu1, sigma1 = fid_calculate_activation_statistics(activations_pred)
129
+ mu2, sigma2 = fid_calculate_activation_statistics(activations_target)
130
+
131
+ diff = mu1 - mu2
132
+
133
+ # Product might be almost singular
134
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
135
+ if not np.isfinite(covmean).all():
136
+ msg = ('fid calculation produces singular product; '
137
+ 'adding %s to diagonal of cov estimates') % eps
138
+ LOGGER.warning(msg)
139
+ offset = np.eye(sigma1.shape[0]) * eps
140
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
141
+
142
+ # Numerical error might give slight imaginary component
143
+ if np.iscomplexobj(covmean):
144
+ # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
145
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2):
146
+ m = np.max(np.abs(covmean.imag))
147
+ raise ValueError('Imaginary component {}'.format(m))
148
+ covmean = covmean.real
149
+
150
+ tr_covmean = np.trace(covmean)
151
+
152
+ return (diff.dot(diff) + np.trace(sigma1) +
153
+ np.trace(sigma2) - 2 * tr_covmean)
154
+
155
+
156
+ class FIDScore(EvaluatorScore):
157
+ def __init__(self, dims=2048, eps=1e-6):
158
+ LOGGER.info("FIDscore init called")
159
+ super().__init__()
160
+ if getattr(FIDScore, '_MODEL', None) is None:
161
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
162
+ FIDScore._MODEL = InceptionV3([block_idx]).eval()
163
+ self.model = FIDScore._MODEL
164
+ self.eps = eps
165
+ self.reset()
166
+ LOGGER.info("FIDscore init done")
167
+
168
+ def forward(self, pred_batch, target_batch, mask=None):
169
+ activations_pred = self._get_activations(pred_batch)
170
+ activations_target = self._get_activations(target_batch)
171
+
172
+ self.activations_pred.append(activations_pred.detach().cpu())
173
+ self.activations_target.append(activations_target.detach().cpu())
174
+
175
+ return activations_pred, activations_target
176
+
177
+ def get_value(self, groups=None, states=None):
178
+ LOGGER.info("FIDscore get_value called")
179
+ activations_pred, activations_target = zip(*states) if states is not None \
180
+ else (self.activations_pred, self.activations_target)
181
+ activations_pred = torch.cat(activations_pred).cpu().numpy()
182
+ activations_target = torch.cat(activations_target).cpu().numpy()
183
+
184
+ total_distance = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps)
185
+ total_results = dict(mean=total_distance)
186
+
187
+ if groups is None:
188
+ group_results = None
189
+ else:
190
+ group_results = dict()
191
+ grouping = get_groupings(groups)
192
+ for label, index in grouping.items():
193
+ if len(index) > 1:
194
+ group_distance = calculate_frechet_distance(activations_pred[index], activations_target[index],
195
+ eps=self.eps)
196
+ group_results[label] = dict(mean=group_distance)
197
+
198
+ else:
199
+ group_results[label] = dict(mean=float('nan'))
200
+
201
+ self.reset()
202
+
203
+ LOGGER.info("FIDscore get_value done")
204
+
205
+ return total_results, group_results
206
+
207
+ def reset(self):
208
+ self.activations_pred = []
209
+ self.activations_target = []
210
+
211
+ def _get_activations(self, batch):
212
+ activations = self.model(batch)[0]
213
+ if activations.shape[2] != 1 or activations.shape[3] != 1:
214
+ assert False, \
215
+ 'We should not have got here, because Inception always scales inputs to 299x299'
216
+ # activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1))
217
+ activations = activations.squeeze(-1).squeeze(-1)
218
+ return activations
219
+
220
+
221
+ class SegmentationAwareScore(EvaluatorScore):
222
+ def __init__(self, weights_path):
223
+ super().__init__()
224
+ self.segm_network = SegmentationModule(weights_path=weights_path, use_default_normalization=True).eval()
225
+ self.target_class_freq_by_image_total = []
226
+ self.target_class_freq_by_image_mask = []
227
+ self.pred_class_freq_by_image_mask = []
228
+
229
+ def forward(self, pred_batch, target_batch, mask):
230
+ pred_segm_flat = self.segm_network.predict(pred_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy()
231
+ target_segm_flat = self.segm_network.predict(target_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy()
232
+ mask_flat = (mask.view(mask.shape[0], -1) > 0.5).detach().cpu().numpy()
233
+
234
+ batch_target_class_freq_total = []
235
+ batch_target_class_freq_mask = []
236
+ batch_pred_class_freq_mask = []
237
+
238
+ for cur_pred_segm, cur_target_segm, cur_mask in zip(pred_segm_flat, target_segm_flat, mask_flat):
239
+ cur_target_class_freq_total = np.bincount(cur_target_segm, minlength=NUM_CLASS)[None, ...]
240
+ cur_target_class_freq_mask = np.bincount(cur_target_segm[cur_mask], minlength=NUM_CLASS)[None, ...]
241
+ cur_pred_class_freq_mask = np.bincount(cur_pred_segm[cur_mask], minlength=NUM_CLASS)[None, ...]
242
+
243
+ self.target_class_freq_by_image_total.append(cur_target_class_freq_total)
244
+ self.target_class_freq_by_image_mask.append(cur_target_class_freq_mask)
245
+ self.pred_class_freq_by_image_mask.append(cur_pred_class_freq_mask)
246
+
247
+ batch_target_class_freq_total.append(cur_target_class_freq_total)
248
+ batch_target_class_freq_mask.append(cur_target_class_freq_mask)
249
+ batch_pred_class_freq_mask.append(cur_pred_class_freq_mask)
250
+
251
+ batch_target_class_freq_total = np.concatenate(batch_target_class_freq_total, axis=0)
252
+ batch_target_class_freq_mask = np.concatenate(batch_target_class_freq_mask, axis=0)
253
+ batch_pred_class_freq_mask = np.concatenate(batch_pred_class_freq_mask, axis=0)
254
+ return batch_target_class_freq_total, batch_target_class_freq_mask, batch_pred_class_freq_mask
255
+
256
+ def reset(self):
257
+ super().reset()
258
+ self.target_class_freq_by_image_total = []
259
+ self.target_class_freq_by_image_mask = []
260
+ self.pred_class_freq_by_image_mask = []
261
+
262
+
263
+ def distribute_values_to_classes(target_class_freq_by_image_mask, values, idx2name):
264
+ assert target_class_freq_by_image_mask.ndim == 2 and target_class_freq_by_image_mask.shape[0] == values.shape[0]
265
+ total_class_freq = target_class_freq_by_image_mask.sum(0)
266
+ distr_values = (target_class_freq_by_image_mask * values[..., None]).sum(0)
267
+ result = distr_values / (total_class_freq + 1e-3)
268
+ return {idx2name[i]: val for i, val in enumerate(result) if total_class_freq[i] > 0}
269
+
270
+
271
+ def get_segmentation_idx2name():
272
+ return {i - 1: name for i, name in segm_options['classes'].set_index('Idx', drop=True)['Name'].to_dict().items()}
273
+
274
+
275
+ class SegmentationAwarePairwiseScore(SegmentationAwareScore):
276
+ def __init__(self, *args, **kwargs):
277
+ super().__init__(*args, **kwargs)
278
+ self.individual_values = []
279
+ self.segm_idx2name = get_segmentation_idx2name()
280
+
281
+ def forward(self, pred_batch, target_batch, mask):
282
+ cur_class_stats = super().forward(pred_batch, target_batch, mask)
283
+ score_values = self.calc_score(pred_batch, target_batch, mask)
284
+ self.individual_values.append(score_values)
285
+ return cur_class_stats + (score_values,)
286
+
287
+ @abstractmethod
288
+ def calc_score(self, pred_batch, target_batch, mask):
289
+ raise NotImplementedError()
290
+
291
+ def get_value(self, groups=None, states=None):
292
+ """
293
+ :param groups:
294
+ :return:
295
+ total_results: dict of kind {'mean': score mean, 'std': score std}
296
+ group_results: None, if groups is None;
297
+ else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
298
+ """
299
+ if states is not None:
300
+ (target_class_freq_by_image_total,
301
+ target_class_freq_by_image_mask,
302
+ pred_class_freq_by_image_mask,
303
+ individual_values) = states
304
+ else:
305
+ target_class_freq_by_image_total = self.target_class_freq_by_image_total
306
+ target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
307
+ pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
308
+ individual_values = self.individual_values
309
+
310
+ target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
311
+ target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
312
+ pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
313
+ individual_values = np.concatenate(individual_values, axis=0)
314
+
315
+ total_results = {
316
+ 'mean': individual_values.mean(),
317
+ 'std': individual_values.std(),
318
+ **distribute_values_to_classes(target_class_freq_by_image_mask, individual_values, self.segm_idx2name)
319
+ }
320
+
321
+ if groups is None:
322
+ return total_results, None
323
+
324
+ group_results = dict()
325
+ grouping = get_groupings(groups)
326
+ for label, index in grouping.items():
327
+ group_class_freq = target_class_freq_by_image_mask[index]
328
+ group_scores = individual_values[index]
329
+ group_results[label] = {
330
+ 'mean': group_scores.mean(),
331
+ 'std': group_scores.std(),
332
+ ** distribute_values_to_classes(group_class_freq, group_scores, self.segm_idx2name)
333
+ }
334
+ return total_results, group_results
335
+
336
+ def reset(self):
337
+ super().reset()
338
+ self.individual_values = []
339
+
340
+
341
+ class SegmentationClassStats(SegmentationAwarePairwiseScore):
342
+ def calc_score(self, pred_batch, target_batch, mask):
343
+ return 0
344
+
345
+ def get_value(self, groups=None, states=None):
346
+ """
347
+ :param groups:
348
+ :return:
349
+ total_results: dict of kind {'mean': score mean, 'std': score std}
350
+ group_results: None, if groups is None;
351
+ else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
352
+ """
353
+ if states is not None:
354
+ (target_class_freq_by_image_total,
355
+ target_class_freq_by_image_mask,
356
+ pred_class_freq_by_image_mask,
357
+ _) = states
358
+ else:
359
+ target_class_freq_by_image_total = self.target_class_freq_by_image_total
360
+ target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
361
+ pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
362
+
363
+ target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
364
+ target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
365
+ pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
366
+
367
+ target_class_freq_by_image_total_marginal = target_class_freq_by_image_total.sum(0).astype('float32')
368
+ target_class_freq_by_image_total_marginal /= target_class_freq_by_image_total_marginal.sum()
369
+
370
+ target_class_freq_by_image_mask_marginal = target_class_freq_by_image_mask.sum(0).astype('float32')
371
+ target_class_freq_by_image_mask_marginal /= target_class_freq_by_image_mask_marginal.sum()
372
+
373
+ pred_class_freq_diff = (pred_class_freq_by_image_mask - target_class_freq_by_image_mask).sum(0) / (target_class_freq_by_image_mask.sum(0) + 1e-3)
374
+
375
+ total_results = dict()
376
+ total_results.update({f'total_freq/{self.segm_idx2name[i]}': v
377
+ for i, v in enumerate(target_class_freq_by_image_total_marginal)
378
+ if v > 0})
379
+ total_results.update({f'mask_freq/{self.segm_idx2name[i]}': v
380
+ for i, v in enumerate(target_class_freq_by_image_mask_marginal)
381
+ if v > 0})
382
+ total_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v
383
+ for i, v in enumerate(pred_class_freq_diff)
384
+ if target_class_freq_by_image_total_marginal[i] > 0})
385
+
386
+ if groups is None:
387
+ return total_results, None
388
+
389
+ group_results = dict()
390
+ grouping = get_groupings(groups)
391
+ for label, index in grouping.items():
392
+ group_target_class_freq_by_image_total = target_class_freq_by_image_total[index]
393
+ group_target_class_freq_by_image_mask = target_class_freq_by_image_mask[index]
394
+ group_pred_class_freq_by_image_mask = pred_class_freq_by_image_mask[index]
395
+
396
+ group_target_class_freq_by_image_total_marginal = group_target_class_freq_by_image_total.sum(0).astype('float32')
397
+ group_target_class_freq_by_image_total_marginal /= group_target_class_freq_by_image_total_marginal.sum()
398
+
399
+ group_target_class_freq_by_image_mask_marginal = group_target_class_freq_by_image_mask.sum(0).astype('float32')
400
+ group_target_class_freq_by_image_mask_marginal /= group_target_class_freq_by_image_mask_marginal.sum()
401
+
402
+ group_pred_class_freq_diff = (group_pred_class_freq_by_image_mask - group_target_class_freq_by_image_mask).sum(0) / (
403
+ group_target_class_freq_by_image_mask.sum(0) + 1e-3)
404
+
405
+ cur_group_results = dict()
406
+ cur_group_results.update({f'total_freq/{self.segm_idx2name[i]}': v
407
+ for i, v in enumerate(group_target_class_freq_by_image_total_marginal)
408
+ if v > 0})
409
+ cur_group_results.update({f'mask_freq/{self.segm_idx2name[i]}': v
410
+ for i, v in enumerate(group_target_class_freq_by_image_mask_marginal)
411
+ if v > 0})
412
+ cur_group_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v
413
+ for i, v in enumerate(group_pred_class_freq_diff)
414
+ if group_target_class_freq_by_image_total_marginal[i] > 0})
415
+
416
+ group_results[label] = cur_group_results
417
+ return total_results, group_results
418
+
419
+
420
+ class SegmentationAwareSSIM(SegmentationAwarePairwiseScore):
421
+ def __init__(self, *args, window_size=11, **kwargs):
422
+ super().__init__(*args, **kwargs)
423
+ self.score_impl = SSIM(window_size=window_size, size_average=False).eval()
424
+
425
+ def calc_score(self, pred_batch, target_batch, mask):
426
+ return self.score_impl(pred_batch, target_batch).detach().cpu().numpy()
427
+
428
+
429
+ class SegmentationAwareLPIPS(SegmentationAwarePairwiseScore):
430
+ def __init__(self, *args, model='net-lin', net='vgg', model_path=None, use_gpu=True, **kwargs):
431
+ super().__init__(*args, **kwargs)
432
+ self.score_impl = PerceptualLoss(model=model, net=net, model_path=model_path,
433
+ use_gpu=use_gpu, spatial=False).eval()
434
+
435
+ def calc_score(self, pred_batch, target_batch, mask):
436
+ return self.score_impl(pred_batch, target_batch).flatten().detach().cpu().numpy()
437
+
438
+
439
+ def calculade_fid_no_img(img_i, activations_pred, activations_target, eps=1e-6):
440
+ activations_pred = activations_pred.copy()
441
+ activations_pred[img_i] = activations_target[img_i]
442
+ return calculate_frechet_distance(activations_pred, activations_target, eps=eps)
443
+
444
+
445
+ class SegmentationAwareFID(SegmentationAwarePairwiseScore):
446
+ def __init__(self, *args, dims=2048, eps=1e-6, n_jobs=-1, **kwargs):
447
+ super().__init__(*args, **kwargs)
448
+ if getattr(FIDScore, '_MODEL', None) is None:
449
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
450
+ FIDScore._MODEL = InceptionV3([block_idx]).eval()
451
+ self.model = FIDScore._MODEL
452
+ self.eps = eps
453
+ self.n_jobs = n_jobs
454
+
455
+ def calc_score(self, pred_batch, target_batch, mask):
456
+ activations_pred = self._get_activations(pred_batch)
457
+ activations_target = self._get_activations(target_batch)
458
+ return activations_pred, activations_target
459
+
460
+ def get_value(self, groups=None, states=None):
461
+ """
462
+ :param groups:
463
+ :return:
464
+ total_results: dict of kind {'mean': score mean, 'std': score std}
465
+ group_results: None, if groups is None;
466
+ else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
467
+ """
468
+ if states is not None:
469
+ (target_class_freq_by_image_total,
470
+ target_class_freq_by_image_mask,
471
+ pred_class_freq_by_image_mask,
472
+ activation_pairs) = states
473
+ else:
474
+ target_class_freq_by_image_total = self.target_class_freq_by_image_total
475
+ target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
476
+ pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
477
+ activation_pairs = self.individual_values
478
+
479
+ target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
480
+ target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
481
+ pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
482
+ activations_pred, activations_target = zip(*activation_pairs)
483
+ activations_pred = np.concatenate(activations_pred, axis=0)
484
+ activations_target = np.concatenate(activations_target, axis=0)
485
+
486
+ total_results = {
487
+ 'mean': calculate_frechet_distance(activations_pred, activations_target, eps=self.eps),
488
+ 'std': 0,
489
+ **self.distribute_fid_to_classes(target_class_freq_by_image_mask, activations_pred, activations_target)
490
+ }
491
+
492
+ if groups is None:
493
+ return total_results, None
494
+
495
+ group_results = dict()
496
+ grouping = get_groupings(groups)
497
+ for label, index in grouping.items():
498
+ if len(index) > 1:
499
+ group_activations_pred = activations_pred[index]
500
+ group_activations_target = activations_target[index]
501
+ group_class_freq = target_class_freq_by_image_mask[index]
502
+ group_results[label] = {
503
+ 'mean': calculate_frechet_distance(group_activations_pred, group_activations_target, eps=self.eps),
504
+ 'std': 0,
505
+ **self.distribute_fid_to_classes(group_class_freq,
506
+ group_activations_pred,
507
+ group_activations_target)
508
+ }
509
+ else:
510
+ group_results[label] = dict(mean=float('nan'), std=0)
511
+ return total_results, group_results
512
+
513
+ def distribute_fid_to_classes(self, class_freq, activations_pred, activations_target):
514
+ real_fid = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps)
515
+
516
+ fid_no_images = Parallel(n_jobs=self.n_jobs)(
517
+ delayed(calculade_fid_no_img)(img_i, activations_pred, activations_target, eps=self.eps)
518
+ for img_i in range(activations_pred.shape[0])
519
+ )
520
+ errors = real_fid - fid_no_images
521
+ return distribute_values_to_classes(class_freq, errors, self.segm_idx2name)
522
+
523
+ def _get_activations(self, batch):
524
+ activations = self.model(batch)[0]
525
+ if activations.shape[2] != 1 or activations.shape[3] != 1:
526
+ activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1))
527
+ activations = activations.squeeze(-1).squeeze(-1).detach().cpu().numpy()
528
+ return activations
saicinpainting/evaluation/losses/fid/__init__.py ADDED
File without changes
saicinpainting/evaluation/losses/fid/fid_score.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Calculates the Frechet Inception Distance (FID) to evalulate GANs
3
+
4
+ The FID metric calculates the distance between two distributions of images.
5
+ Typically, we have summary statistics (mean & covariance matrix) of one
6
+ of these distributions, while the 2nd distribution is given by a GAN.
7
+
8
+ When run as a stand-alone program, it compares the distribution of
9
+ images that are stored as PNG/JPEG at a specified location with a
10
+ distribution given by summary statistics (in pickle format).
11
+
12
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
13
+ the pool_3 layer of the inception net for generated samples and real world
14
+ samples respectively.
15
+
16
+ See --help to see further details.
17
+
18
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
19
+ of Tensorflow
20
+
21
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
22
+
23
+ Licensed under the Apache License, Version 2.0 (the "License");
24
+ you may not use this file except in compliance with the License.
25
+ You may obtain a copy of the License at
26
+
27
+ http://www.apache.org/licenses/LICENSE-2.0
28
+
29
+ Unless required by applicable law or agreed to in writing, software
30
+ distributed under the License is distributed on an "AS IS" BASIS,
31
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32
+ See the License for the specific language governing permissions and
33
+ limitations under the License.
34
+ """
35
+ import os
36
+ import pathlib
37
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
38
+
39
+ import numpy as np
40
+ import torch
41
+ # from scipy.misc import imread
42
+ from imageio import imread
43
+ from PIL import Image, JpegImagePlugin
44
+ from scipy import linalg
45
+ from torch.nn.functional import adaptive_avg_pool2d
46
+ from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor
47
+
48
+ try:
49
+ from tqdm import tqdm
50
+ except ImportError:
51
+ # If not tqdm is not available, provide a mock version of it
52
+ def tqdm(x): return x
53
+
54
+ try:
55
+ from .inception import InceptionV3
56
+ except ModuleNotFoundError:
57
+ from inception import InceptionV3
58
+
59
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
60
+ parser.add_argument('path', type=str, nargs=2,
61
+ help=('Path to the generated images or '
62
+ 'to .npz statistic files'))
63
+ parser.add_argument('--batch-size', type=int, default=50,
64
+ help='Batch size to use')
65
+ parser.add_argument('--dims', type=int, default=2048,
66
+ choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
67
+ help=('Dimensionality of Inception features to use. '
68
+ 'By default, uses pool3 features'))
69
+ parser.add_argument('-c', '--gpu', default='', type=str,
70
+ help='GPU to use (leave blank for CPU only)')
71
+ parser.add_argument('--resize', default=256)
72
+
73
+ transform = Compose([Resize(256), CenterCrop(256), ToTensor()])
74
+
75
+
76
+ def get_activations(files, model, batch_size=50, dims=2048,
77
+ cuda=False, verbose=False, keep_size=False):
78
+ """Calculates the activations of the pool_3 layer for all images.
79
+
80
+ Params:
81
+ -- files : List of image files paths
82
+ -- model : Instance of inception model
83
+ -- batch_size : Batch size of images for the model to process at once.
84
+ Make sure that the number of samples is a multiple of
85
+ the batch size, otherwise some samples are ignored. This
86
+ behavior is retained to match the original FID score
87
+ implementation.
88
+ -- dims : Dimensionality of features returned by Inception
89
+ -- cuda : If set to True, use GPU
90
+ -- verbose : If set to True and parameter out_step is given, the number
91
+ of calculated batches is reported.
92
+ Returns:
93
+ -- A numpy array of dimension (num images, dims) that contains the
94
+ activations of the given tensor when feeding inception with the
95
+ query tensor.
96
+ """
97
+ model.eval()
98
+
99
+ if len(files) % batch_size != 0:
100
+ print(('Warning: number of images is not a multiple of the '
101
+ 'batch size. Some samples are going to be ignored.'))
102
+ if batch_size > len(files):
103
+ print(('Warning: batch size is bigger than the data size. '
104
+ 'Setting batch size to data size'))
105
+ batch_size = len(files)
106
+
107
+ n_batches = len(files) // batch_size
108
+ n_used_imgs = n_batches * batch_size
109
+
110
+ pred_arr = np.empty((n_used_imgs, dims))
111
+
112
+ for i in tqdm(range(n_batches)):
113
+ if verbose:
114
+ print('\rPropagating batch %d/%d' % (i + 1, n_batches),
115
+ end='', flush=True)
116
+ start = i * batch_size
117
+ end = start + batch_size
118
+
119
+ # # Official code goes below
120
+ # images = np.array([imread(str(f)).astype(np.float32)
121
+ # for f in files[start:end]])
122
+
123
+ # # Reshape to (n_images, 3, height, width)
124
+ # images = images.transpose((0, 3, 1, 2))
125
+ # images /= 255
126
+ # batch = torch.from_numpy(images).type(torch.FloatTensor)
127
+ # #
128
+
129
+ t = transform if not keep_size else ToTensor()
130
+
131
+ if isinstance(files[0], pathlib.PosixPath):
132
+ images = [t(Image.open(str(f))) for f in files[start:end]]
133
+
134
+ elif isinstance(files[0], Image.Image):
135
+ images = [t(f) for f in files[start:end]]
136
+
137
+ else:
138
+ raise ValueError(f"Unknown data type for image: {type(files[0])}")
139
+
140
+ batch = torch.stack(images)
141
+
142
+ if cuda:
143
+ batch = batch.cuda()
144
+
145
+ pred = model(batch)[0]
146
+
147
+ # If model output is not scalar, apply global spatial average pooling.
148
+ # This happens if you choose a dimensionality not equal 2048.
149
+ if pred.shape[2] != 1 or pred.shape[3] != 1:
150
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
151
+
152
+ pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
153
+
154
+ if verbose:
155
+ print(' done')
156
+
157
+ return pred_arr
158
+
159
+
160
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
161
+ """Numpy implementation of the Frechet Distance.
162
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
163
+ and X_2 ~ N(mu_2, C_2) is
164
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
165
+
166
+ Stable version by Dougal J. Sutherland.
167
+
168
+ Params:
169
+ -- mu1 : Numpy array containing the activations of a layer of the
170
+ inception net (like returned by the function 'get_predictions')
171
+ for generated samples.
172
+ -- mu2 : The sample mean over activations, precalculated on an
173
+ representative data set.
174
+ -- sigma1: The covariance matrix over activations for generated samples.
175
+ -- sigma2: The covariance matrix over activations, precalculated on an
176
+ representative data set.
177
+
178
+ Returns:
179
+ -- : The Frechet Distance.
180
+ """
181
+
182
+ mu1 = np.atleast_1d(mu1)
183
+ mu2 = np.atleast_1d(mu2)
184
+
185
+ sigma1 = np.atleast_2d(sigma1)
186
+ sigma2 = np.atleast_2d(sigma2)
187
+
188
+ assert mu1.shape == mu2.shape, \
189
+ 'Training and test mean vectors have different lengths'
190
+ assert sigma1.shape == sigma2.shape, \
191
+ 'Training and test covariances have different dimensions'
192
+
193
+ diff = mu1 - mu2
194
+
195
+ # Product might be almost singular
196
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
197
+ if not np.isfinite(covmean).all():
198
+ msg = ('fid calculation produces singular product; '
199
+ 'adding %s to diagonal of cov estimates') % eps
200
+ print(msg)
201
+ offset = np.eye(sigma1.shape[0]) * eps
202
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
203
+
204
+ # Numerical error might give slight imaginary component
205
+ if np.iscomplexobj(covmean):
206
+ # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
207
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2):
208
+ m = np.max(np.abs(covmean.imag))
209
+ raise ValueError('Imaginary component {}'.format(m))
210
+ covmean = covmean.real
211
+
212
+ tr_covmean = np.trace(covmean)
213
+
214
+ return (diff.dot(diff) + np.trace(sigma1) +
215
+ np.trace(sigma2) - 2 * tr_covmean)
216
+
217
+
218
+ def calculate_activation_statistics(files, model, batch_size=50,
219
+ dims=2048, cuda=False, verbose=False, keep_size=False):
220
+ """Calculation of the statistics used by the FID.
221
+ Params:
222
+ -- files : List of image files paths
223
+ -- model : Instance of inception model
224
+ -- batch_size : The images numpy array is split into batches with
225
+ batch size batch_size. A reasonable batch size
226
+ depends on the hardware.
227
+ -- dims : Dimensionality of features returned by Inception
228
+ -- cuda : If set to True, use GPU
229
+ -- verbose : If set to True and parameter out_step is given, the
230
+ number of calculated batches is reported.
231
+ Returns:
232
+ -- mu : The mean over samples of the activations of the pool_3 layer of
233
+ the inception model.
234
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
235
+ the inception model.
236
+ """
237
+ act = get_activations(files, model, batch_size, dims, cuda, verbose, keep_size=keep_size)
238
+ mu = np.mean(act, axis=0)
239
+ sigma = np.cov(act, rowvar=False)
240
+ return mu, sigma
241
+
242
+
243
+ def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
244
+ if path.endswith('.npz'):
245
+ f = np.load(path)
246
+ m, s = f['mu'][:], f['sigma'][:]
247
+ f.close()
248
+ else:
249
+ path = pathlib.Path(path)
250
+ files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
251
+ m, s = calculate_activation_statistics(files, model, batch_size,
252
+ dims, cuda)
253
+
254
+ return m, s
255
+
256
+
257
+ def _compute_statistics_of_images(images, model, batch_size, dims, cuda, keep_size=False):
258
+ if isinstance(images, list): # exact paths to files are provided
259
+ m, s = calculate_activation_statistics(images, model, batch_size,
260
+ dims, cuda, keep_size=keep_size)
261
+
262
+ return m, s
263
+
264
+ else:
265
+ raise ValueError
266
+
267
+
268
+ def calculate_fid_given_paths(paths, batch_size, cuda, dims):
269
+ """Calculates the FID of two paths"""
270
+ for p in paths:
271
+ if not os.path.exists(p):
272
+ raise RuntimeError('Invalid path: %s' % p)
273
+
274
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
275
+
276
+ model = InceptionV3([block_idx])
277
+ if cuda:
278
+ model.cuda()
279
+
280
+ m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size,
281
+ dims, cuda)
282
+ m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size,
283
+ dims, cuda)
284
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
285
+
286
+ return fid_value
287
+
288
+
289
+ def calculate_fid_given_images(images, batch_size, cuda, dims, use_globals=False, keep_size=False):
290
+ if use_globals:
291
+ global FID_MODEL # for multiprocessing
292
+
293
+ for imgs in images:
294
+ if isinstance(imgs, list) and isinstance(imgs[0], (Image.Image, JpegImagePlugin.JpegImageFile)):
295
+ pass
296
+ else:
297
+ raise RuntimeError('Invalid images')
298
+
299
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
300
+
301
+ if 'FID_MODEL' not in globals() or not use_globals:
302
+ model = InceptionV3([block_idx])
303
+ if cuda:
304
+ model.cuda()
305
+
306
+ if use_globals:
307
+ FID_MODEL = model
308
+
309
+ else:
310
+ model = FID_MODEL
311
+
312
+ m1, s1 = _compute_statistics_of_images(images[0], model, batch_size,
313
+ dims, cuda, keep_size=False)
314
+ m2, s2 = _compute_statistics_of_images(images[1], model, batch_size,
315
+ dims, cuda, keep_size=False)
316
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
317
+ return fid_value
318
+
319
+
320
+ if __name__ == '__main__':
321
+ args = parser.parse_args()
322
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
323
+
324
+ fid_value = calculate_fid_given_paths(args.path,
325
+ args.batch_size,
326
+ args.gpu != '',
327
+ args.dims)
328
+ print('FID: ', fid_value)
saicinpainting/evaluation/losses/fid/inception.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torchvision import models
7
+
8
+ try:
9
+ from torchvision.models.utils import load_state_dict_from_url
10
+ except ImportError:
11
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
12
+
13
+ # Inception weights ported to Pytorch from
14
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
15
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
16
+
17
+
18
+ LOGGER = logging.getLogger(__name__)
19
+
20
+
21
+ class InceptionV3(nn.Module):
22
+ """Pretrained InceptionV3 network returning feature maps"""
23
+
24
+ # Index of default block of inception to return,
25
+ # corresponds to output of final average pooling
26
+ DEFAULT_BLOCK_INDEX = 3
27
+
28
+ # Maps feature dimensionality to their output blocks indices
29
+ BLOCK_INDEX_BY_DIM = {
30
+ 64: 0, # First max pooling features
31
+ 192: 1, # Second max pooling featurs
32
+ 768: 2, # Pre-aux classifier features
33
+ 2048: 3 # Final average pooling features
34
+ }
35
+
36
+ def __init__(self,
37
+ output_blocks=[DEFAULT_BLOCK_INDEX],
38
+ resize_input=True,
39
+ normalize_input=True,
40
+ requires_grad=False,
41
+ use_fid_inception=True):
42
+ """Build pretrained InceptionV3
43
+
44
+ Parameters
45
+ ----------
46
+ output_blocks : list of int
47
+ Indices of blocks to return features of. Possible values are:
48
+ - 0: corresponds to output of first max pooling
49
+ - 1: corresponds to output of second max pooling
50
+ - 2: corresponds to output which is fed to aux classifier
51
+ - 3: corresponds to output of final average pooling
52
+ resize_input : bool
53
+ If true, bilinearly resizes input to width and height 299 before
54
+ feeding input to model. As the network without fully connected
55
+ layers is fully convolutional, it should be able to handle inputs
56
+ of arbitrary size, so resizing might not be strictly needed
57
+ normalize_input : bool
58
+ If true, scales the input from range (0, 1) to the range the
59
+ pretrained Inception network expects, namely (-1, 1)
60
+ requires_grad : bool
61
+ If true, parameters of the model require gradients. Possibly useful
62
+ for finetuning the network
63
+ use_fid_inception : bool
64
+ If true, uses the pretrained Inception model used in Tensorflow's
65
+ FID implementation. If false, uses the pretrained Inception model
66
+ available in torchvision. The FID Inception model has different
67
+ weights and a slightly different structure from torchvision's
68
+ Inception model. If you want to compute FID scores, you are
69
+ strongly advised to set this parameter to true to get comparable
70
+ results.
71
+ """
72
+ super(InceptionV3, self).__init__()
73
+
74
+ self.resize_input = resize_input
75
+ self.normalize_input = normalize_input
76
+ self.output_blocks = sorted(output_blocks)
77
+ self.last_needed_block = max(output_blocks)
78
+
79
+ assert self.last_needed_block <= 3, \
80
+ 'Last possible output block index is 3'
81
+
82
+ self.blocks = nn.ModuleList()
83
+
84
+ if use_fid_inception:
85
+ inception = fid_inception_v3()
86
+ else:
87
+ inception = models.inception_v3(pretrained=True)
88
+
89
+ # Block 0: input to maxpool1
90
+ block0 = [
91
+ inception.Conv2d_1a_3x3,
92
+ inception.Conv2d_2a_3x3,
93
+ inception.Conv2d_2b_3x3,
94
+ nn.MaxPool2d(kernel_size=3, stride=2)
95
+ ]
96
+ self.blocks.append(nn.Sequential(*block0))
97
+
98
+ # Block 1: maxpool1 to maxpool2
99
+ if self.last_needed_block >= 1:
100
+ block1 = [
101
+ inception.Conv2d_3b_1x1,
102
+ inception.Conv2d_4a_3x3,
103
+ nn.MaxPool2d(kernel_size=3, stride=2)
104
+ ]
105
+ self.blocks.append(nn.Sequential(*block1))
106
+
107
+ # Block 2: maxpool2 to aux classifier
108
+ if self.last_needed_block >= 2:
109
+ block2 = [
110
+ inception.Mixed_5b,
111
+ inception.Mixed_5c,
112
+ inception.Mixed_5d,
113
+ inception.Mixed_6a,
114
+ inception.Mixed_6b,
115
+ inception.Mixed_6c,
116
+ inception.Mixed_6d,
117
+ inception.Mixed_6e,
118
+ ]
119
+ self.blocks.append(nn.Sequential(*block2))
120
+
121
+ # Block 3: aux classifier to final avgpool
122
+ if self.last_needed_block >= 3:
123
+ block3 = [
124
+ inception.Mixed_7a,
125
+ inception.Mixed_7b,
126
+ inception.Mixed_7c,
127
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
128
+ ]
129
+ self.blocks.append(nn.Sequential(*block3))
130
+
131
+ for param in self.parameters():
132
+ param.requires_grad = requires_grad
133
+
134
+ def forward(self, inp):
135
+ """Get Inception feature maps
136
+
137
+ Parameters
138
+ ----------
139
+ inp : torch.autograd.Variable
140
+ Input tensor of shape Bx3xHxW. Values are expected to be in
141
+ range (0, 1)
142
+
143
+ Returns
144
+ -------
145
+ List of torch.autograd.Variable, corresponding to the selected output
146
+ block, sorted ascending by index
147
+ """
148
+ outp = []
149
+ x = inp
150
+
151
+ if self.resize_input:
152
+ x = F.interpolate(x,
153
+ size=(299, 299),
154
+ mode='bilinear',
155
+ align_corners=False)
156
+
157
+ if self.normalize_input:
158
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
159
+
160
+ for idx, block in enumerate(self.blocks):
161
+ x = block(x)
162
+ if idx in self.output_blocks:
163
+ outp.append(x)
164
+
165
+ if idx == self.last_needed_block:
166
+ break
167
+
168
+ return outp
169
+
170
+
171
+ def fid_inception_v3():
172
+ """Build pretrained Inception model for FID computation
173
+
174
+ The Inception model for FID computation uses a different set of weights
175
+ and has a slightly different structure than torchvision's Inception.
176
+
177
+ This method first constructs torchvision's Inception and then patches the
178
+ necessary parts that are different in the FID Inception model.
179
+ """
180
+ LOGGER.info('fid_inception_v3 called')
181
+ inception = models.inception_v3(num_classes=1008,
182
+ aux_logits=False,
183
+ pretrained=False)
184
+ LOGGER.info('models.inception_v3 done')
185
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
186
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
187
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
188
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
189
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
190
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
191
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
192
+ inception.Mixed_7b = FIDInceptionE_1(1280)
193
+ inception.Mixed_7c = FIDInceptionE_2(2048)
194
+
195
+ LOGGER.info('fid_inception_v3 patching done')
196
+
197
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
198
+ LOGGER.info('fid_inception_v3 weights downloaded')
199
+
200
+ inception.load_state_dict(state_dict)
201
+ LOGGER.info('fid_inception_v3 weights loaded into model')
202
+
203
+ return inception
204
+
205
+
206
+ class FIDInceptionA(models.inception.InceptionA):
207
+ """InceptionA block patched for FID computation"""
208
+ def __init__(self, in_channels, pool_features):
209
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
210
+
211
+ def forward(self, x):
212
+ branch1x1 = self.branch1x1(x)
213
+
214
+ branch5x5 = self.branch5x5_1(x)
215
+ branch5x5 = self.branch5x5_2(branch5x5)
216
+
217
+ branch3x3dbl = self.branch3x3dbl_1(x)
218
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
219
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
220
+
221
+ # Patch: Tensorflow's average pool does not use the padded zero's in
222
+ # its average calculation
223
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
224
+ count_include_pad=False)
225
+ branch_pool = self.branch_pool(branch_pool)
226
+
227
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
228
+ return torch.cat(outputs, 1)
229
+
230
+
231
+ class FIDInceptionC(models.inception.InceptionC):
232
+ """InceptionC block patched for FID computation"""
233
+ def __init__(self, in_channels, channels_7x7):
234
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
235
+
236
+ def forward(self, x):
237
+ branch1x1 = self.branch1x1(x)
238
+
239
+ branch7x7 = self.branch7x7_1(x)
240
+ branch7x7 = self.branch7x7_2(branch7x7)
241
+ branch7x7 = self.branch7x7_3(branch7x7)
242
+
243
+ branch7x7dbl = self.branch7x7dbl_1(x)
244
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
245
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
246
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
247
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
248
+
249
+ # Patch: Tensorflow's average pool does not use the padded zero's in
250
+ # its average calculation
251
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
252
+ count_include_pad=False)
253
+ branch_pool = self.branch_pool(branch_pool)
254
+
255
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
256
+ return torch.cat(outputs, 1)
257
+
258
+
259
+ class FIDInceptionE_1(models.inception.InceptionE):
260
+ """First InceptionE block patched for FID computation"""
261
+ def __init__(self, in_channels):
262
+ super(FIDInceptionE_1, self).__init__(in_channels)
263
+
264
+ def forward(self, x):
265
+ branch1x1 = self.branch1x1(x)
266
+
267
+ branch3x3 = self.branch3x3_1(x)
268
+ branch3x3 = [
269
+ self.branch3x3_2a(branch3x3),
270
+ self.branch3x3_2b(branch3x3),
271
+ ]
272
+ branch3x3 = torch.cat(branch3x3, 1)
273
+
274
+ branch3x3dbl = self.branch3x3dbl_1(x)
275
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
276
+ branch3x3dbl = [
277
+ self.branch3x3dbl_3a(branch3x3dbl),
278
+ self.branch3x3dbl_3b(branch3x3dbl),
279
+ ]
280
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
281
+
282
+ # Patch: Tensorflow's average pool does not use the padded zero's in
283
+ # its average calculation
284
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
285
+ count_include_pad=False)
286
+ branch_pool = self.branch_pool(branch_pool)
287
+
288
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
289
+ return torch.cat(outputs, 1)
290
+
291
+
292
+ class FIDInceptionE_2(models.inception.InceptionE):
293
+ """Second InceptionE block patched for FID computation"""
294
+ def __init__(self, in_channels):
295
+ super(FIDInceptionE_2, self).__init__(in_channels)
296
+
297
+ def forward(self, x):
298
+ branch1x1 = self.branch1x1(x)
299
+
300
+ branch3x3 = self.branch3x3_1(x)
301
+ branch3x3 = [
302
+ self.branch3x3_2a(branch3x3),
303
+ self.branch3x3_2b(branch3x3),
304
+ ]
305
+ branch3x3 = torch.cat(branch3x3, 1)
306
+
307
+ branch3x3dbl = self.branch3x3dbl_1(x)
308
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
309
+ branch3x3dbl = [
310
+ self.branch3x3dbl_3a(branch3x3dbl),
311
+ self.branch3x3dbl_3b(branch3x3dbl),
312
+ ]
313
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
314
+
315
+ # Patch: The FID Inception model uses max pooling instead of average
316
+ # pooling. This is likely an error in this specific Inception
317
+ # implementation, as other Inception models use average pooling here
318
+ # (which matches the description in the paper).
319
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
320
+ branch_pool = self.branch_pool(branch_pool)
321
+
322
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
323
+ return torch.cat(outputs, 1)
saicinpainting/evaluation/losses/lpips.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ############################################################
2
+ # The contents below have been combined using files in the #
3
+ # following repository: #
4
+ # https://github.com/richzhang/PerceptualSimilarity #
5
+ ############################################################
6
+
7
+ ############################################################
8
+ # __init__.py #
9
+ ############################################################
10
+
11
+ import numpy as np
12
+ from skimage.metrics import structural_similarity
13
+ import torch
14
+
15
+ from saicinpainting.utils import get_shape
16
+
17
+
18
+ class PerceptualLoss(torch.nn.Module):
19
+ def __init__(self, model='net-lin', net='alex', colorspace='rgb', model_path=None, spatial=False, use_gpu=True):
20
+ # VGG using our perceptually-learned weights (LPIPS metric)
21
+ # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
22
+ super(PerceptualLoss, self).__init__()
23
+ self.use_gpu = use_gpu
24
+ self.spatial = spatial
25
+ self.model = DistModel()
26
+ self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace,
27
+ model_path=model_path, spatial=self.spatial)
28
+
29
+ def forward(self, pred, target, normalize=True):
30
+ """
31
+ Pred and target are Variables.
32
+ If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
33
+ If normalize is False, assumes the images are already between [-1,+1]
34
+ Inputs pred and target are Nx3xHxW
35
+ Output pytorch Variable N long
36
+ """
37
+
38
+ if normalize:
39
+ target = 2 * target - 1
40
+ pred = 2 * pred - 1
41
+
42
+ return self.model(target, pred)
43
+
44
+
45
+ def normalize_tensor(in_feat, eps=1e-10):
46
+ norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))
47
+ return in_feat / (norm_factor + eps)
48
+
49
+
50
+ def l2(p0, p1, range=255.):
51
+ return .5 * np.mean((p0 / range - p1 / range) ** 2)
52
+
53
+
54
+ def psnr(p0, p1, peak=255.):
55
+ return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2))
56
+
57
+
58
+ def dssim(p0, p1, range=255.):
59
+ return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
60
+
61
+
62
+ def rgb2lab(in_img, mean_cent=False):
63
+ from skimage import color
64
+ img_lab = color.rgb2lab(in_img)
65
+ if (mean_cent):
66
+ img_lab[:, :, 0] = img_lab[:, :, 0] - 50
67
+ return img_lab
68
+
69
+
70
+ def tensor2np(tensor_obj):
71
+ # change dimension of a tensor object into a numpy array
72
+ return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))
73
+
74
+
75
+ def np2tensor(np_obj):
76
+ # change dimenion of np array into tensor array
77
+ return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
78
+
79
+
80
+ def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
81
+ # image tensor to lab tensor
82
+ from skimage import color
83
+
84
+ img = tensor2im(image_tensor)
85
+ img_lab = color.rgb2lab(img)
86
+ if (mc_only):
87
+ img_lab[:, :, 0] = img_lab[:, :, 0] - 50
88
+ if (to_norm and not mc_only):
89
+ img_lab[:, :, 0] = img_lab[:, :, 0] - 50
90
+ img_lab = img_lab / 100.
91
+
92
+ return np2tensor(img_lab)
93
+
94
+
95
+ def tensorlab2tensor(lab_tensor, return_inbnd=False):
96
+ from skimage import color
97
+ import warnings
98
+ warnings.filterwarnings("ignore")
99
+
100
+ lab = tensor2np(lab_tensor) * 100.
101
+ lab[:, :, 0] = lab[:, :, 0] + 50
102
+
103
+ rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1)
104
+ if (return_inbnd):
105
+ # convert back to lab, see if we match
106
+ lab_back = color.rgb2lab(rgb_back.astype('uint8'))
107
+ mask = 1. * np.isclose(lab_back, lab, atol=2.)
108
+ mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])
109
+ return (im2tensor(rgb_back), mask)
110
+ else:
111
+ return im2tensor(rgb_back)
112
+
113
+
114
+ def rgb2lab(input):
115
+ from skimage import color
116
+ return color.rgb2lab(input / 255.)
117
+
118
+
119
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
120
+ image_numpy = image_tensor[0].cpu().float().numpy()
121
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
122
+ return image_numpy.astype(imtype)
123
+
124
+
125
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
126
+ return torch.Tensor((image / factor - cent)
127
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
128
+
129
+
130
+ def tensor2vec(vector_tensor):
131
+ return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
132
+
133
+
134
+ def voc_ap(rec, prec, use_07_metric=False):
135
+ """ ap = voc_ap(rec, prec, [use_07_metric])
136
+ Compute VOC AP given precision and recall.
137
+ If use_07_metric is true, uses the
138
+ VOC 07 11 point method (default:False).
139
+ """
140
+ if use_07_metric:
141
+ # 11 point metric
142
+ ap = 0.
143
+ for t in np.arange(0., 1.1, 0.1):
144
+ if np.sum(rec >= t) == 0:
145
+ p = 0
146
+ else:
147
+ p = np.max(prec[rec >= t])
148
+ ap = ap + p / 11.
149
+ else:
150
+ # correct AP calculation
151
+ # first append sentinel values at the end
152
+ mrec = np.concatenate(([0.], rec, [1.]))
153
+ mpre = np.concatenate(([0.], prec, [0.]))
154
+
155
+ # compute the precision envelope
156
+ for i in range(mpre.size - 1, 0, -1):
157
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
158
+
159
+ # to calculate area under PR curve, look for points
160
+ # where X axis (recall) changes value
161
+ i = np.where(mrec[1:] != mrec[:-1])[0]
162
+
163
+ # and sum (\Delta recall) * prec
164
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
165
+ return ap
166
+
167
+
168
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
169
+ # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
170
+ image_numpy = image_tensor[0].cpu().float().numpy()
171
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
172
+ return image_numpy.astype(imtype)
173
+
174
+
175
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
176
+ # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
177
+ return torch.Tensor((image / factor - cent)
178
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
179
+
180
+
181
+ ############################################################
182
+ # base_model.py #
183
+ ############################################################
184
+
185
+
186
+ class BaseModel(torch.nn.Module):
187
+ def __init__(self):
188
+ super().__init__()
189
+
190
+ def name(self):
191
+ return 'BaseModel'
192
+
193
+ def initialize(self, use_gpu=True):
194
+ self.use_gpu = use_gpu
195
+
196
+ def forward(self):
197
+ pass
198
+
199
+ def get_image_paths(self):
200
+ pass
201
+
202
+ def optimize_parameters(self):
203
+ pass
204
+
205
+ def get_current_visuals(self):
206
+ return self.input
207
+
208
+ def get_current_errors(self):
209
+ return {}
210
+
211
+ def save(self, label):
212
+ pass
213
+
214
+ # helper saving function that can be used by subclasses
215
+ def save_network(self, network, path, network_label, epoch_label):
216
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
217
+ save_path = os.path.join(path, save_filename)
218
+ torch.save(network.state_dict(), save_path)
219
+
220
+ # helper loading function that can be used by subclasses
221
+ def load_network(self, network, network_label, epoch_label):
222
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
223
+ save_path = os.path.join(self.save_dir, save_filename)
224
+ print('Loading network from %s' % save_path)
225
+ network.load_state_dict(torch.load(save_path, map_location='cpu'))
226
+
227
+ def update_learning_rate():
228
+ pass
229
+
230
+ def get_image_paths(self):
231
+ return self.image_paths
232
+
233
+ def save_done(self, flag=False):
234
+ np.save(os.path.join(self.save_dir, 'done_flag'), flag)
235
+ np.savetxt(os.path.join(self.save_dir, 'done_flag'), [flag, ], fmt='%i')
236
+
237
+
238
+ ############################################################
239
+ # dist_model.py #
240
+ ############################################################
241
+
242
+ import os
243
+ from collections import OrderedDict
244
+ from scipy.ndimage import zoom
245
+ from tqdm import tqdm
246
+
247
+
248
+ class DistModel(BaseModel):
249
+ def name(self):
250
+ return self.model_name
251
+
252
+ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False,
253
+ model_path=None,
254
+ use_gpu=True, printNet=False, spatial=False,
255
+ is_train=False, lr=.0001, beta1=0.5, version='0.1'):
256
+ '''
257
+ INPUTS
258
+ model - ['net-lin'] for linearly calibrated network
259
+ ['net'] for off-the-shelf network
260
+ ['L2'] for L2 distance in Lab colorspace
261
+ ['SSIM'] for ssim in RGB colorspace
262
+ net - ['squeeze','alex','vgg']
263
+ model_path - if None, will look in weights/[NET_NAME].pth
264
+ colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
265
+ use_gpu - bool - whether or not to use a GPU
266
+ printNet - bool - whether or not to print network architecture out
267
+ spatial - bool - whether to output an array containing varying distances across spatial dimensions
268
+ spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
269
+ spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
270
+ spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
271
+ is_train - bool - [True] for training mode
272
+ lr - float - initial learning rate
273
+ beta1 - float - initial momentum term for adam
274
+ version - 0.1 for latest, 0.0 was original (with a bug)
275
+ '''
276
+ BaseModel.initialize(self, use_gpu=use_gpu)
277
+
278
+ self.model = model
279
+ self.net = net
280
+ self.is_train = is_train
281
+ self.spatial = spatial
282
+ self.model_name = '%s [%s]' % (model, net)
283
+
284
+ if (self.model == 'net-lin'): # pretrained net + linear layer
285
+ self.net = PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
286
+ use_dropout=True, spatial=spatial, version=version, lpips=True)
287
+ kw = dict(map_location='cpu')
288
+ if (model_path is None):
289
+ import inspect
290
+ model_path = os.path.abspath(
291
+ os.path.join(os.path.dirname(__file__), '..', '..', '..', 'models', 'lpips_models', f'{net}.pth'))
292
+
293
+ if (not is_train):
294
+ self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
295
+
296
+ elif (self.model == 'net'): # pretrained network
297
+ self.net = PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
298
+ elif (self.model in ['L2', 'l2']):
299
+ self.net = L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing
300
+ self.model_name = 'L2'
301
+ elif (self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']):
302
+ self.net = DSSIM(use_gpu=use_gpu, colorspace=colorspace)
303
+ self.model_name = 'SSIM'
304
+ else:
305
+ raise ValueError("Model [%s] not recognized." % self.model)
306
+
307
+ self.trainable_parameters = list(self.net.parameters())
308
+
309
+ if self.is_train: # training mode
310
+ # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
311
+ self.rankLoss = BCERankingLoss()
312
+ self.trainable_parameters += list(self.rankLoss.net.parameters())
313
+ self.lr = lr
314
+ self.old_lr = lr
315
+ self.optimizer_net = torch.optim.Adam(self.trainable_parameters, lr=lr, betas=(beta1, 0.999))
316
+ else: # test mode
317
+ self.net.eval()
318
+
319
+ # if (use_gpu):
320
+ # self.net.to(gpu_ids[0])
321
+ # self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
322
+ # if (self.is_train):
323
+ # self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
324
+
325
+ if (printNet):
326
+ print('---------- Networks initialized -------------')
327
+ print_network(self.net)
328
+ print('-----------------------------------------------')
329
+
330
+ def forward(self, in0, in1, retPerLayer=False):
331
+ ''' Function computes the distance between image patches in0 and in1
332
+ INPUTS
333
+ in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
334
+ OUTPUT
335
+ computed distances between in0 and in1
336
+ '''
337
+
338
+ return self.net(in0, in1, retPerLayer=retPerLayer)
339
+
340
+ # ***** TRAINING FUNCTIONS *****
341
+ def optimize_parameters(self):
342
+ self.forward_train()
343
+ self.optimizer_net.zero_grad()
344
+ self.backward_train()
345
+ self.optimizer_net.step()
346
+ self.clamp_weights()
347
+
348
+ def clamp_weights(self):
349
+ for module in self.net.modules():
350
+ if (hasattr(module, 'weight') and module.kernel_size == (1, 1)):
351
+ module.weight.data = torch.clamp(module.weight.data, min=0)
352
+
353
+ def set_input(self, data):
354
+ self.input_ref = data['ref']
355
+ self.input_p0 = data['p0']
356
+ self.input_p1 = data['p1']
357
+ self.input_judge = data['judge']
358
+
359
+ # if (self.use_gpu):
360
+ # self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
361
+ # self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
362
+ # self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
363
+ # self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
364
+
365
+ # self.var_ref = Variable(self.input_ref, requires_grad=True)
366
+ # self.var_p0 = Variable(self.input_p0, requires_grad=True)
367
+ # self.var_p1 = Variable(self.input_p1, requires_grad=True)
368
+
369
+ def forward_train(self): # run forward pass
370
+ # print(self.net.module.scaling_layer.shift)
371
+ # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
372
+
373
+ assert False, "We shoud've not get here when using LPIPS as a metric"
374
+
375
+ self.d0 = self(self.var_ref, self.var_p0)
376
+ self.d1 = self(self.var_ref, self.var_p1)
377
+ self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
378
+
379
+ self.var_judge = Variable(1. * self.input_judge).view(self.d0.size())
380
+
381
+ self.loss_total = self.rankLoss(self.d0, self.d1, self.var_judge * 2. - 1.)
382
+
383
+ return self.loss_total
384
+
385
+ def backward_train(self):
386
+ torch.mean(self.loss_total).backward()
387
+
388
+ def compute_accuracy(self, d0, d1, judge):
389
+ ''' d0, d1 are Variables, judge is a Tensor '''
390
+ d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
391
+ judge_per = judge.cpu().numpy().flatten()
392
+ return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
393
+
394
+ def get_current_errors(self):
395
+ retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
396
+ ('acc_r', self.acc_r)])
397
+
398
+ for key in retDict.keys():
399
+ retDict[key] = np.mean(retDict[key])
400
+
401
+ return retDict
402
+
403
+ def get_current_visuals(self):
404
+ zoom_factor = 256 / self.var_ref.data.size()[2]
405
+
406
+ ref_img = tensor2im(self.var_ref.data)
407
+ p0_img = tensor2im(self.var_p0.data)
408
+ p1_img = tensor2im(self.var_p1.data)
409
+
410
+ ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
411
+ p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
412
+ p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
413
+
414
+ return OrderedDict([('ref', ref_img_vis),
415
+ ('p0', p0_img_vis),
416
+ ('p1', p1_img_vis)])
417
+
418
+ def save(self, path, label):
419
+ if (self.use_gpu):
420
+ self.save_network(self.net.module, path, '', label)
421
+ else:
422
+ self.save_network(self.net, path, '', label)
423
+ self.save_network(self.rankLoss.net, path, 'rank', label)
424
+
425
+ def update_learning_rate(self, nepoch_decay):
426
+ lrd = self.lr / nepoch_decay
427
+ lr = self.old_lr - lrd
428
+
429
+ for param_group in self.optimizer_net.param_groups:
430
+ param_group['lr'] = lr
431
+
432
+ print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr))
433
+ self.old_lr = lr
434
+
435
+
436
+ def score_2afc_dataset(data_loader, func, name=''):
437
+ ''' Function computes Two Alternative Forced Choice (2AFC) score using
438
+ distance function 'func' in dataset 'data_loader'
439
+ INPUTS
440
+ data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
441
+ func - callable distance function - calling d=func(in0,in1) should take 2
442
+ pytorch tensors with shape Nx3xXxY, and return numpy array of length N
443
+ OUTPUTS
444
+ [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
445
+ [1] - dictionary with following elements
446
+ d0s,d1s - N arrays containing distances between reference patch to perturbed patches
447
+ gts - N array in [0,1], preferred patch selected by human evaluators
448
+ (closer to "0" for left patch p0, "1" for right patch p1,
449
+ "0.6" means 60pct people preferred right patch, 40pct preferred left)
450
+ scores - N array in [0,1], corresponding to what percentage function agreed with humans
451
+ CONSTS
452
+ N - number of test triplets in data_loader
453
+ '''
454
+
455
+ d0s = []
456
+ d1s = []
457
+ gts = []
458
+
459
+ for data in tqdm(data_loader.load_data(), desc=name):
460
+ d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist()
461
+ d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist()
462
+ gts += data['judge'].cpu().numpy().flatten().tolist()
463
+
464
+ d0s = np.array(d0s)
465
+ d1s = np.array(d1s)
466
+ gts = np.array(gts)
467
+ scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5
468
+
469
+ return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
470
+
471
+
472
+ def score_jnd_dataset(data_loader, func, name=''):
473
+ ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
474
+ INPUTS
475
+ data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
476
+ func - callable distance function - calling d=func(in0,in1) should take 2
477
+ pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
478
+ OUTPUTS
479
+ [0] - JND score in [0,1], mAP score (area under precision-recall curve)
480
+ [1] - dictionary with following elements
481
+ ds - N array containing distances between two patches shown to human evaluator
482
+ sames - N array containing fraction of people who thought the two patches were identical
483
+ CONSTS
484
+ N - number of test triplets in data_loader
485
+ '''
486
+
487
+ ds = []
488
+ gts = []
489
+
490
+ for data in tqdm(data_loader.load_data(), desc=name):
491
+ ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist()
492
+ gts += data['same'].cpu().numpy().flatten().tolist()
493
+
494
+ sames = np.array(gts)
495
+ ds = np.array(ds)
496
+
497
+ sorted_inds = np.argsort(ds)
498
+ ds_sorted = ds[sorted_inds]
499
+ sames_sorted = sames[sorted_inds]
500
+
501
+ TPs = np.cumsum(sames_sorted)
502
+ FPs = np.cumsum(1 - sames_sorted)
503
+ FNs = np.sum(sames_sorted) - TPs
504
+
505
+ precs = TPs / (TPs + FPs)
506
+ recs = TPs / (TPs + FNs)
507
+ score = voc_ap(recs, precs)
508
+
509
+ return (score, dict(ds=ds, sames=sames))
510
+
511
+
512
+ ############################################################
513
+ # networks_basic.py #
514
+ ############################################################
515
+
516
+ import torch.nn as nn
517
+ from torch.autograd import Variable
518
+ import numpy as np
519
+
520
+
521
+ def spatial_average(in_tens, keepdim=True):
522
+ return in_tens.mean([2, 3], keepdim=keepdim)
523
+
524
+
525
+ def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
526
+ in_H = in_tens.shape[2]
527
+ scale_factor = 1. * out_H / in_H
528
+
529
+ return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
530
+
531
+
532
+ # Learned perceptual metric
533
+ class PNetLin(nn.Module):
534
+ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False,
535
+ version='0.1', lpips=True):
536
+ super(PNetLin, self).__init__()
537
+
538
+ self.pnet_type = pnet_type
539
+ self.pnet_tune = pnet_tune
540
+ self.pnet_rand = pnet_rand
541
+ self.spatial = spatial
542
+ self.lpips = lpips
543
+ self.version = version
544
+ self.scaling_layer = ScalingLayer()
545
+
546
+ if (self.pnet_type in ['vgg', 'vgg16']):
547
+ net_type = vgg16
548
+ self.chns = [64, 128, 256, 512, 512]
549
+ elif (self.pnet_type == 'alex'):
550
+ net_type = alexnet
551
+ self.chns = [64, 192, 384, 256, 256]
552
+ elif (self.pnet_type == 'squeeze'):
553
+ net_type = squeezenet
554
+ self.chns = [64, 128, 256, 384, 384, 512, 512]
555
+ self.L = len(self.chns)
556
+
557
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
558
+
559
+ if (lpips):
560
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
561
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
562
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
563
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
564
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
565
+ self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
566
+ if (self.pnet_type == 'squeeze'): # 7 layers for squeezenet
567
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
568
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
569
+ self.lins += [self.lin5, self.lin6]
570
+
571
+ def forward(self, in0, in1, retPerLayer=False):
572
+ # v0.0 - original release had a bug, where input was not scaled
573
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else (
574
+ in0, in1)
575
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
576
+ feats0, feats1, diffs = {}, {}, {}
577
+
578
+ for kk in range(self.L):
579
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
580
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
581
+
582
+ if (self.lpips):
583
+ if (self.spatial):
584
+ res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
585
+ else:
586
+ res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
587
+ else:
588
+ if (self.spatial):
589
+ res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
590
+ else:
591
+ res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]
592
+
593
+ val = res[0]
594
+ for l in range(1, self.L):
595
+ val += res[l]
596
+
597
+ if (retPerLayer):
598
+ return (val, res)
599
+ else:
600
+ return val
601
+
602
+
603
+ class ScalingLayer(nn.Module):
604
+ def __init__(self):
605
+ super(ScalingLayer, self).__init__()
606
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
607
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
608
+
609
+ def forward(self, inp):
610
+ return (inp - self.shift) / self.scale
611
+
612
+
613
+ class NetLinLayer(nn.Module):
614
+ ''' A single linear layer which does a 1x1 conv '''
615
+
616
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
617
+ super(NetLinLayer, self).__init__()
618
+
619
+ layers = [nn.Dropout(), ] if (use_dropout) else []
620
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
621
+ self.model = nn.Sequential(*layers)
622
+
623
+
624
+ class Dist2LogitLayer(nn.Module):
625
+ ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
626
+
627
+ def __init__(self, chn_mid=32, use_sigmoid=True):
628
+ super(Dist2LogitLayer, self).__init__()
629
+
630
+ layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ]
631
+ layers += [nn.LeakyReLU(0.2, True), ]
632
+ layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ]
633
+ layers += [nn.LeakyReLU(0.2, True), ]
634
+ layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ]
635
+ if (use_sigmoid):
636
+ layers += [nn.Sigmoid(), ]
637
+ self.model = nn.Sequential(*layers)
638
+
639
+ def forward(self, d0, d1, eps=0.1):
640
+ return self.model(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1))
641
+
642
+
643
+ class BCERankingLoss(nn.Module):
644
+ def __init__(self, chn_mid=32):
645
+ super(BCERankingLoss, self).__init__()
646
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
647
+ # self.parameters = list(self.net.parameters())
648
+ self.loss = torch.nn.BCELoss()
649
+
650
+ def forward(self, d0, d1, judge):
651
+ per = (judge + 1.) / 2.
652
+ self.logit = self.net(d0, d1)
653
+ return self.loss(self.logit, per)
654
+
655
+
656
+ # L2, DSSIM metrics
657
+ class FakeNet(nn.Module):
658
+ def __init__(self, use_gpu=True, colorspace='Lab'):
659
+ super(FakeNet, self).__init__()
660
+ self.use_gpu = use_gpu
661
+ self.colorspace = colorspace
662
+
663
+
664
+ class L2(FakeNet):
665
+
666
+ def forward(self, in0, in1, retPerLayer=None):
667
+ assert (in0.size()[0] == 1) # currently only supports batchSize 1
668
+
669
+ if (self.colorspace == 'RGB'):
670
+ (N, C, X, Y) = in0.size()
671
+ value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y),
672
+ dim=3).view(N)
673
+ return value
674
+ elif (self.colorspace == 'Lab'):
675
+ value = l2(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
676
+ tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float')
677
+ ret_var = Variable(torch.Tensor((value,)))
678
+ # if (self.use_gpu):
679
+ # ret_var = ret_var.cuda()
680
+ return ret_var
681
+
682
+
683
+ class DSSIM(FakeNet):
684
+
685
+ def forward(self, in0, in1, retPerLayer=None):
686
+ assert (in0.size()[0] == 1) # currently only supports batchSize 1
687
+
688
+ if (self.colorspace == 'RGB'):
689
+ value = dssim(1. * tensor2im(in0.data), 1. * tensor2im(in1.data), range=255.).astype('float')
690
+ elif (self.colorspace == 'Lab'):
691
+ value = dssim(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
692
+ tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float')
693
+ ret_var = Variable(torch.Tensor((value,)))
694
+ # if (self.use_gpu):
695
+ # ret_var = ret_var.cuda()
696
+ return ret_var
697
+
698
+
699
+ def print_network(net):
700
+ num_params = 0
701
+ for param in net.parameters():
702
+ num_params += param.numel()
703
+ print('Network', net)
704
+ print('Total number of parameters: %d' % num_params)
705
+
706
+
707
+ ############################################################
708
+ # pretrained_networks.py #
709
+ ############################################################
710
+
711
+ from collections import namedtuple
712
+ import torch
713
+ from torchvision import models as tv
714
+
715
+
716
+ class squeezenet(torch.nn.Module):
717
+ def __init__(self, requires_grad=False, pretrained=True):
718
+ super(squeezenet, self).__init__()
719
+ pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
720
+ self.slice1 = torch.nn.Sequential()
721
+ self.slice2 = torch.nn.Sequential()
722
+ self.slice3 = torch.nn.Sequential()
723
+ self.slice4 = torch.nn.Sequential()
724
+ self.slice5 = torch.nn.Sequential()
725
+ self.slice6 = torch.nn.Sequential()
726
+ self.slice7 = torch.nn.Sequential()
727
+ self.N_slices = 7
728
+ for x in range(2):
729
+ self.slice1.add_module(str(x), pretrained_features[x])
730
+ for x in range(2, 5):
731
+ self.slice2.add_module(str(x), pretrained_features[x])
732
+ for x in range(5, 8):
733
+ self.slice3.add_module(str(x), pretrained_features[x])
734
+ for x in range(8, 10):
735
+ self.slice4.add_module(str(x), pretrained_features[x])
736
+ for x in range(10, 11):
737
+ self.slice5.add_module(str(x), pretrained_features[x])
738
+ for x in range(11, 12):
739
+ self.slice6.add_module(str(x), pretrained_features[x])
740
+ for x in range(12, 13):
741
+ self.slice7.add_module(str(x), pretrained_features[x])
742
+ if not requires_grad:
743
+ for param in self.parameters():
744
+ param.requires_grad = False
745
+
746
+ def forward(self, X):
747
+ h = self.slice1(X)
748
+ h_relu1 = h
749
+ h = self.slice2(h)
750
+ h_relu2 = h
751
+ h = self.slice3(h)
752
+ h_relu3 = h
753
+ h = self.slice4(h)
754
+ h_relu4 = h
755
+ h = self.slice5(h)
756
+ h_relu5 = h
757
+ h = self.slice6(h)
758
+ h_relu6 = h
759
+ h = self.slice7(h)
760
+ h_relu7 = h
761
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'])
762
+ out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
763
+
764
+ return out
765
+
766
+
767
+ class alexnet(torch.nn.Module):
768
+ def __init__(self, requires_grad=False, pretrained=True):
769
+ super(alexnet, self).__init__()
770
+ alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
771
+ self.slice1 = torch.nn.Sequential()
772
+ self.slice2 = torch.nn.Sequential()
773
+ self.slice3 = torch.nn.Sequential()
774
+ self.slice4 = torch.nn.Sequential()
775
+ self.slice5 = torch.nn.Sequential()
776
+ self.N_slices = 5
777
+ for x in range(2):
778
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
779
+ for x in range(2, 5):
780
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
781
+ for x in range(5, 8):
782
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
783
+ for x in range(8, 10):
784
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
785
+ for x in range(10, 12):
786
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
787
+ if not requires_grad:
788
+ for param in self.parameters():
789
+ param.requires_grad = False
790
+
791
+ def forward(self, X):
792
+ h = self.slice1(X)
793
+ h_relu1 = h
794
+ h = self.slice2(h)
795
+ h_relu2 = h
796
+ h = self.slice3(h)
797
+ h_relu3 = h
798
+ h = self.slice4(h)
799
+ h_relu4 = h
800
+ h = self.slice5(h)
801
+ h_relu5 = h
802
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
803
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
804
+
805
+ return out
806
+
807
+
808
+ class vgg16(torch.nn.Module):
809
+ def __init__(self, requires_grad=False, pretrained=True):
810
+ super(vgg16, self).__init__()
811
+ vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
812
+ self.slice1 = torch.nn.Sequential()
813
+ self.slice2 = torch.nn.Sequential()
814
+ self.slice3 = torch.nn.Sequential()
815
+ self.slice4 = torch.nn.Sequential()
816
+ self.slice5 = torch.nn.Sequential()
817
+ self.N_slices = 5
818
+ for x in range(4):
819
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
820
+ for x in range(4, 9):
821
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
822
+ for x in range(9, 16):
823
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
824
+ for x in range(16, 23):
825
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
826
+ for x in range(23, 30):
827
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
828
+ if not requires_grad:
829
+ for param in self.parameters():
830
+ param.requires_grad = False
831
+
832
+ def forward(self, X):
833
+ h = self.slice1(X)
834
+ h_relu1_2 = h
835
+ h = self.slice2(h)
836
+ h_relu2_2 = h
837
+ h = self.slice3(h)
838
+ h_relu3_3 = h
839
+ h = self.slice4(h)
840
+ h_relu4_3 = h
841
+ h = self.slice5(h)
842
+ h_relu5_3 = h
843
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
844
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
845
+
846
+ return out
847
+
848
+
849
+ class resnet(torch.nn.Module):
850
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
851
+ super(resnet, self).__init__()
852
+ if (num == 18):
853
+ self.net = tv.resnet18(pretrained=pretrained)
854
+ elif (num == 34):
855
+ self.net = tv.resnet34(pretrained=pretrained)
856
+ elif (num == 50):
857
+ self.net = tv.resnet50(pretrained=pretrained)
858
+ elif (num == 101):
859
+ self.net = tv.resnet101(pretrained=pretrained)
860
+ elif (num == 152):
861
+ self.net = tv.resnet152(pretrained=pretrained)
862
+ self.N_slices = 5
863
+
864
+ self.conv1 = self.net.conv1
865
+ self.bn1 = self.net.bn1
866
+ self.relu = self.net.relu
867
+ self.maxpool = self.net.maxpool
868
+ self.layer1 = self.net.layer1
869
+ self.layer2 = self.net.layer2
870
+ self.layer3 = self.net.layer3
871
+ self.layer4 = self.net.layer4
872
+
873
+ def forward(self, X):
874
+ h = self.conv1(X)
875
+ h = self.bn1(h)
876
+ h = self.relu(h)
877
+ h_relu1 = h
878
+ h = self.maxpool(h)
879
+ h = self.layer1(h)
880
+ h_conv2 = h
881
+ h = self.layer2(h)
882
+ h_conv3 = h
883
+ h = self.layer3(h)
884
+ h_conv4 = h
885
+ h = self.layer4(h)
886
+ h_conv5 = h
887
+
888
+ outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5'])
889
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
890
+
891
+ return out
saicinpainting/evaluation/losses/ssim.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SSIM(torch.nn.Module):
7
+ """SSIM. Modified from:
8
+ https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
9
+ """
10
+
11
+ def __init__(self, window_size=11, size_average=True):
12
+ super().__init__()
13
+ self.window_size = window_size
14
+ self.size_average = size_average
15
+ self.channel = 1
16
+ self.register_buffer('window', self._create_window(window_size, self.channel))
17
+
18
+ def forward(self, img1, img2):
19
+ assert len(img1.shape) == 4
20
+
21
+ channel = img1.size()[1]
22
+
23
+ if channel == self.channel and self.window.data.type() == img1.data.type():
24
+ window = self.window
25
+ else:
26
+ window = self._create_window(self.window_size, channel)
27
+
28
+ # window = window.to(img1.get_device())
29
+ window = window.type_as(img1)
30
+
31
+ self.window = window
32
+ self.channel = channel
33
+
34
+ return self._ssim(img1, img2, window, self.window_size, channel, self.size_average)
35
+
36
+ def _gaussian(self, window_size, sigma):
37
+ gauss = torch.Tensor([
38
+ np.exp(-(x - (window_size // 2)) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)
39
+ ])
40
+ return gauss / gauss.sum()
41
+
42
+ def _create_window(self, window_size, channel):
43
+ _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1)
44
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
45
+ return _2D_window.expand(channel, 1, window_size, window_size).contiguous()
46
+
47
+ def _ssim(self, img1, img2, window, window_size, channel, size_average=True):
48
+ mu1 = F.conv2d(img1, window, padding=(window_size // 2), groups=channel)
49
+ mu2 = F.conv2d(img2, window, padding=(window_size // 2), groups=channel)
50
+
51
+ mu1_sq = mu1.pow(2)
52
+ mu2_sq = mu2.pow(2)
53
+ mu1_mu2 = mu1 * mu2
54
+
55
+ sigma1_sq = F.conv2d(
56
+ img1 * img1, window, padding=(window_size // 2), groups=channel) - mu1_sq
57
+ sigma2_sq = F.conv2d(
58
+ img2 * img2, window, padding=(window_size // 2), groups=channel) - mu2_sq
59
+ sigma12 = F.conv2d(
60
+ img1 * img2, window, padding=(window_size // 2), groups=channel) - mu1_mu2
61
+
62
+ C1 = 0.01 ** 2
63
+ C2 = 0.03 ** 2
64
+
65
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
66
+ ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
67
+
68
+ if size_average:
69
+ return ssim_map.mean()
70
+
71
+ return ssim_map.mean(1).mean(1).mean(1)
72
+
73
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
74
+ return
saicinpainting/evaluation/masks/README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Current algorithm
2
+
3
+ ## Choice of mask objects
4
+
5
+ For identification of the objects which are suitable for mask obtaining, panoptic segmentation model
6
+ from [detectron2](https://github.com/facebookresearch/detectron2) trained on COCO. Categories of the detected instances
7
+ belong either to "stuff" or "things" types. We consider that instances of objects should have category belong
8
+ to "things". Besides, we set upper bound on area which is taken by the object &mdash; we consider that too big
9
+ area indicates either of the instance being a background or a main object which should not be removed.
10
+
11
+ ## Choice of position for mask
12
+
13
+ We consider that input image has size 2^n x 2^m. We downsample it using
14
+ [COUNTLESS](https://github.com/william-silversmith/countless) algorithm so the width is equal to
15
+ 64 = 2^8 = 2^{downsample_levels}.
16
+
17
+ ### Augmentation
18
+
19
+ There are several parameters for augmentation:
20
+ - Scaling factor. We limit scaling to the case when a mask after scaling with pivot point in its center fits inside the
21
+ image completely.
22
+ -
23
+
24
+ ### Shift
25
+
26
+
27
+ ## Select
saicinpainting/evaluation/masks/__init__.py ADDED
File without changes
saicinpainting/evaluation/masks/countless/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ results
saicinpainting/evaluation/masks/countless/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [![Build Status](https://travis-ci.org/william-silversmith/countless.svg?branch=master)](https://travis-ci.org/william-silversmith/countless)
2
+
3
+ Python COUNTLESS Downsampling
4
+ =============================
5
+
6
+ To install:
7
+
8
+ `pip install -r requirements.txt`
9
+
10
+ To test:
11
+
12
+ `python test.py`
13
+
14
+ To benchmark countless2d:
15
+
16
+ `python python/countless2d.py python/images/gray_segmentation.png`
17
+
18
+ To benchmark countless3d:
19
+
20
+ `python python/countless3d.py`
21
+
22
+ Adjust N and the list of algorithms inside each script to modify the run parameters.
23
+
24
+
25
+ Python3 is slightly faster than Python2.
saicinpainting/evaluation/masks/countless/__init__.py ADDED
File without changes
saicinpainting/evaluation/masks/countless/countless2d.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+
3
+ """
4
+ COUNTLESS performance test in Python.
5
+
6
+ python countless2d.py ./images/NAMEOFIMAGE
7
+ """
8
+
9
+ import six
10
+ from six.moves import range
11
+ from collections import defaultdict
12
+ from functools import reduce
13
+ import operator
14
+ import io
15
+ import os
16
+ from PIL import Image
17
+ import math
18
+ import numpy as np
19
+ import random
20
+ import sys
21
+ import time
22
+ from tqdm import tqdm
23
+ from scipy import ndimage
24
+
25
+ def simplest_countless(data):
26
+ """
27
+ Vectorized implementation of downsampling a 2D
28
+ image by 2 on each side using the COUNTLESS algorithm.
29
+
30
+ data is a 2D numpy array with even dimensions.
31
+ """
32
+ sections = []
33
+
34
+ # This loop splits the 2D array apart into four arrays that are
35
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
36
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
37
+ factor = (2,2)
38
+ for offset in np.ndindex(factor):
39
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
40
+ sections.append(part)
41
+
42
+ a, b, c, d = sections
43
+
44
+ ab = a * (a == b) # PICK(A,B)
45
+ ac = a * (a == c) # PICK(A,C)
46
+ bc = b * (b == c) # PICK(B,C)
47
+
48
+ a = ab | ac | bc # Bitwise OR, safe b/c non-matches are zeroed
49
+
50
+ return a + (a == 0) * d # AB || AC || BC || D
51
+
52
+ def quick_countless(data):
53
+ """
54
+ Vectorized implementation of downsampling a 2D
55
+ image by 2 on each side using the COUNTLESS algorithm.
56
+
57
+ data is a 2D numpy array with even dimensions.
58
+ """
59
+ sections = []
60
+
61
+ # This loop splits the 2D array apart into four arrays that are
62
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
63
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
64
+ factor = (2,2)
65
+ for offset in np.ndindex(factor):
66
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
67
+ sections.append(part)
68
+
69
+ a, b, c, d = sections
70
+
71
+ ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
72
+ bc = b * (b == c) # PICK(B,C)
73
+
74
+ a = ab_ac | bc # (PICK(A,B) || PICK(A,C)) or PICK(B,C)
75
+ return a + (a == 0) * d # AB || AC || BC || D
76
+
77
+ def quickest_countless(data):
78
+ """
79
+ Vectorized implementation of downsampling a 2D
80
+ image by 2 on each side using the COUNTLESS algorithm.
81
+
82
+ data is a 2D numpy array with even dimensions.
83
+ """
84
+ sections = []
85
+
86
+ # This loop splits the 2D array apart into four arrays that are
87
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
88
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
89
+ factor = (2,2)
90
+ for offset in np.ndindex(factor):
91
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
92
+ sections.append(part)
93
+
94
+ a, b, c, d = sections
95
+
96
+ ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
97
+ ab_ac |= b * (b == c) # PICK(B,C)
98
+ return ab_ac + (ab_ac == 0) * d # AB || AC || BC || D
99
+
100
+ def quick_countless_xor(data):
101
+ """
102
+ Vectorized implementation of downsampling a 2D
103
+ image by 2 on each side using the COUNTLESS algorithm.
104
+
105
+ data is a 2D numpy array with even dimensions.
106
+ """
107
+ sections = []
108
+
109
+ # This loop splits the 2D array apart into four arrays that are
110
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
111
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
112
+ factor = (2,2)
113
+ for offset in np.ndindex(factor):
114
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
115
+ sections.append(part)
116
+
117
+ a, b, c, d = sections
118
+
119
+ ab = a ^ (a ^ b) # a or b
120
+ ab += (ab != a) * ((ab ^ (ab ^ c)) - b) # b or c
121
+ ab += (ab == c) * ((ab ^ (ab ^ d)) - c) # c or d
122
+ return ab
123
+
124
+ def stippled_countless(data):
125
+ """
126
+ Vectorized implementation of downsampling a 2D
127
+ image by 2 on each side using the COUNTLESS algorithm
128
+ that treats zero as "background" and inflates lone
129
+ pixels.
130
+
131
+ data is a 2D numpy array with even dimensions.
132
+ """
133
+ sections = []
134
+
135
+ # This loop splits the 2D array apart into four arrays that are
136
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
137
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
138
+ factor = (2,2)
139
+ for offset in np.ndindex(factor):
140
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
141
+ sections.append(part)
142
+
143
+ a, b, c, d = sections
144
+
145
+ ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
146
+ ab_ac |= b * (b == c) # PICK(B,C)
147
+
148
+ nonzero = a + (a == 0) * (b + (b == 0) * c)
149
+ return ab_ac + (ab_ac == 0) * (d + (d == 0) * nonzero) # AB || AC || BC || D
150
+
151
+ def zero_corrected_countless(data):
152
+ """
153
+ Vectorized implementation of downsampling a 2D
154
+ image by 2 on each side using the COUNTLESS algorithm.
155
+
156
+ data is a 2D numpy array with even dimensions.
157
+ """
158
+ # allows us to prevent losing 1/2 a bit of information
159
+ # at the top end by using a bigger type. Without this 255 is handled incorrectly.
160
+ data, upgraded = upgrade_type(data)
161
+
162
+ # offset from zero, raw countless doesn't handle 0 correctly
163
+ # we'll remove the extra 1 at the end.
164
+ data += 1
165
+
166
+ sections = []
167
+
168
+ # This loop splits the 2D array apart into four arrays that are
169
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
170
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
171
+ factor = (2,2)
172
+ for offset in np.ndindex(factor):
173
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
174
+ sections.append(part)
175
+
176
+ a, b, c, d = sections
177
+
178
+ ab = a * (a == b) # PICK(A,B)
179
+ ac = a * (a == c) # PICK(A,C)
180
+ bc = b * (b == c) # PICK(B,C)
181
+
182
+ a = ab | ac | bc # Bitwise OR, safe b/c non-matches are zeroed
183
+
184
+ result = a + (a == 0) * d - 1 # a or d - 1
185
+
186
+ if upgraded:
187
+ return downgrade_type(result)
188
+
189
+ # only need to reset data if we weren't upgraded
190
+ # b/c no copy was made in that case
191
+ data -= 1
192
+
193
+ return result
194
+
195
+ def countless_extreme(data):
196
+ nonzeros = np.count_nonzero(data)
197
+ # print("nonzeros", nonzeros)
198
+
199
+ N = reduce(operator.mul, data.shape)
200
+
201
+ if nonzeros == N:
202
+ print("quick")
203
+ return quick_countless(data)
204
+ elif np.count_nonzero(data + 1) == N:
205
+ print("quick")
206
+ # print("upper", nonzeros)
207
+ return quick_countless(data)
208
+ else:
209
+ return countless(data)
210
+
211
+
212
+ def countless(data):
213
+ """
214
+ Vectorized implementation of downsampling a 2D
215
+ image by 2 on each side using the COUNTLESS algorithm.
216
+
217
+ data is a 2D numpy array with even dimensions.
218
+ """
219
+ # allows us to prevent losing 1/2 a bit of information
220
+ # at the top end by using a bigger type. Without this 255 is handled incorrectly.
221
+ data, upgraded = upgrade_type(data)
222
+
223
+ # offset from zero, raw countless doesn't handle 0 correctly
224
+ # we'll remove the extra 1 at the end.
225
+ data += 1
226
+
227
+ sections = []
228
+
229
+ # This loop splits the 2D array apart into four arrays that are
230
+ # all the result of striding by 2 and offset by (0,0), (0,1), (1,0),
231
+ # and (1,1) representing the A, B, C, and D positions from Figure 1.
232
+ factor = (2,2)
233
+ for offset in np.ndindex(factor):
234
+ part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
235
+ sections.append(part)
236
+
237
+ a, b, c, d = sections
238
+
239
+ ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization
240
+ ab_ac |= b * (b == c) # PICK(B,C)
241
+ result = ab_ac + (ab_ac == 0) * d - 1 # (matches or d) - 1
242
+
243
+ if upgraded:
244
+ return downgrade_type(result)
245
+
246
+ # only need to reset data if we weren't upgraded
247
+ # b/c no copy was made in that case
248
+ data -= 1
249
+
250
+ return result
251
+
252
+ def upgrade_type(arr):
253
+ dtype = arr.dtype
254
+
255
+ if dtype == np.uint8:
256
+ return arr.astype(np.uint16), True
257
+ elif dtype == np.uint16:
258
+ return arr.astype(np.uint32), True
259
+ elif dtype == np.uint32:
260
+ return arr.astype(np.uint64), True
261
+
262
+ return arr, False
263
+
264
+ def downgrade_type(arr):
265
+ dtype = arr.dtype
266
+
267
+ if dtype == np.uint64:
268
+ return arr.astype(np.uint32)
269
+ elif dtype == np.uint32:
270
+ return arr.astype(np.uint16)
271
+ elif dtype == np.uint16:
272
+ return arr.astype(np.uint8)
273
+
274
+ return arr
275
+
276
+ def odd_to_even(image):
277
+ """
278
+ To facilitate 2x2 downsampling segmentation, change an odd sized image into an even sized one.
279
+ Works by mirroring the starting 1 pixel edge of the image on odd shaped sides.
280
+
281
+ e.g. turn a 3x3x5 image into a 4x4x5 (the x and y are what are getting downsampled)
282
+
283
+ For example: [ 3, 2, 4 ] => [ 3, 3, 2, 4 ] which is now easy to downsample.
284
+
285
+ """
286
+ shape = np.array(image.shape)
287
+
288
+ offset = (shape % 2)[:2] # x,y offset
289
+
290
+ # detect if we're dealing with an even
291
+ # image. if so it's fine, just return.
292
+ if not np.any(offset):
293
+ return image
294
+
295
+ oddshape = image.shape[:2] + offset
296
+ oddshape = np.append(oddshape, shape[2:])
297
+ oddshape = oddshape.astype(int)
298
+
299
+ newimg = np.empty(shape=oddshape, dtype=image.dtype)
300
+
301
+ ox,oy = offset
302
+ sx,sy = oddshape
303
+
304
+ newimg[0,0] = image[0,0] # corner
305
+ newimg[ox:sx,0] = image[:,0] # x axis line
306
+ newimg[0,oy:sy] = image[0,:] # y axis line
307
+
308
+ return newimg
309
+
310
+ def counting(array):
311
+ factor = (2, 2, 1)
312
+ shape = array.shape
313
+
314
+ while len(shape) < 4:
315
+ array = np.expand_dims(array, axis=-1)
316
+ shape = array.shape
317
+
318
+ output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(shape, factor))
319
+ output = np.zeros(output_shape, dtype=array.dtype)
320
+
321
+ for chan in range(0, shape[3]):
322
+ for z in range(0, shape[2]):
323
+ for x in range(0, shape[0], 2):
324
+ for y in range(0, shape[1], 2):
325
+ block = array[ x:x+2, y:y+2, z, chan ] # 2x2 block
326
+
327
+ hashtable = defaultdict(int)
328
+ for subx, suby in np.ndindex(block.shape[0], block.shape[1]):
329
+ hashtable[block[subx, suby]] += 1
330
+
331
+ best = (0, 0)
332
+ for segid, val in six.iteritems(hashtable):
333
+ if best[1] < val:
334
+ best = (segid, val)
335
+
336
+ output[ x // 2, y // 2, chan ] = best[0]
337
+
338
+ return output
339
+
340
+ def ndzoom(array):
341
+ if len(array.shape) == 3:
342
+ ratio = ( 1 / 2.0, 1 / 2.0, 1.0 )
343
+ else:
344
+ ratio = ( 1 / 2.0, 1 / 2.0)
345
+ return ndimage.interpolation.zoom(array, ratio, order=1)
346
+
347
+ def countless_if(array):
348
+ factor = (2, 2, 1)
349
+ shape = array.shape
350
+
351
+ if len(shape) < 3:
352
+ array = array[ :,:, np.newaxis ]
353
+ shape = array.shape
354
+
355
+ output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(shape, factor))
356
+ output = np.zeros(output_shape, dtype=array.dtype)
357
+
358
+ for chan in range(0, shape[2]):
359
+ for x in range(0, shape[0], 2):
360
+ for y in range(0, shape[1], 2):
361
+ block = array[ x:x+2, y:y+2, chan ] # 2x2 block
362
+
363
+ if block[0,0] == block[1,0]:
364
+ pick = block[0,0]
365
+ elif block[0,0] == block[0,1]:
366
+ pick = block[0,0]
367
+ elif block[1,0] == block[0,1]:
368
+ pick = block[1,0]
369
+ else:
370
+ pick = block[1,1]
371
+
372
+ output[ x // 2, y // 2, chan ] = pick
373
+
374
+ return np.squeeze(output)
375
+
376
+ def downsample_with_averaging(array):
377
+ """
378
+ Downsample x by factor using averaging.
379
+
380
+ @return: The downsampled array, of the same type as x.
381
+ """
382
+
383
+ if len(array.shape) == 3:
384
+ factor = (2,2,1)
385
+ else:
386
+ factor = (2,2)
387
+
388
+ if np.array_equal(factor[:3], np.array([1,1,1])):
389
+ return array
390
+
391
+ output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(array.shape, factor))
392
+ temp = np.zeros(output_shape, float)
393
+ counts = np.zeros(output_shape, np.int)
394
+ for offset in np.ndindex(factor):
395
+ part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
396
+ indexing_expr = tuple(np.s_[:s] for s in part.shape)
397
+ temp[indexing_expr] += part
398
+ counts[indexing_expr] += 1
399
+ return np.cast[array.dtype](temp / counts)
400
+
401
+ def downsample_with_max_pooling(array):
402
+
403
+ factor = (2,2)
404
+
405
+ if np.all(np.array(factor, int) == 1):
406
+ return array
407
+
408
+ sections = []
409
+
410
+ for offset in np.ndindex(factor):
411
+ part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))]
412
+ sections.append(part)
413
+
414
+ output = sections[0].copy()
415
+
416
+ for section in sections[1:]:
417
+ np.maximum(output, section, output)
418
+
419
+ return output
420
+
421
+ def striding(array):
422
+ """Downsample x by factor using striding.
423
+
424
+ @return: The downsampled array, of the same type as x.
425
+ """
426
+ factor = (2,2)
427
+ if np.all(np.array(factor, int) == 1):
428
+ return array
429
+ return array[tuple(np.s_[::f] for f in factor)]
430
+
431
+ def benchmark():
432
+ filename = sys.argv[1]
433
+ img = Image.open(filename)
434
+ data = np.array(img.getdata(), dtype=np.uint8)
435
+
436
+ if len(data.shape) == 1:
437
+ n_channels = 1
438
+ reshape = (img.height, img.width)
439
+ else:
440
+ n_channels = min(data.shape[1], 3)
441
+ data = data[:, :n_channels]
442
+ reshape = (img.height, img.width, n_channels)
443
+
444
+ data = data.reshape(reshape).astype(np.uint8)
445
+
446
+ methods = [
447
+ simplest_countless,
448
+ quick_countless,
449
+ quick_countless_xor,
450
+ quickest_countless,
451
+ stippled_countless,
452
+ zero_corrected_countless,
453
+ countless,
454
+ downsample_with_averaging,
455
+ downsample_with_max_pooling,
456
+ ndzoom,
457
+ striding,
458
+ # countless_if,
459
+ # counting,
460
+ ]
461
+
462
+ formats = {
463
+ 1: 'L',
464
+ 3: 'RGB',
465
+ 4: 'RGBA'
466
+ }
467
+
468
+ if not os.path.exists('./results'):
469
+ os.mkdir('./results')
470
+
471
+ N = 500
472
+ img_size = float(img.width * img.height) / 1024.0 / 1024.0
473
+ print("N = %d, %dx%d (%.2f MPx) %d chan, %s" % (N, img.width, img.height, img_size, n_channels, filename))
474
+ print("Algorithm\tMPx/sec\tMB/sec\tSec")
475
+ for fn in methods:
476
+ print(fn.__name__, end='')
477
+ sys.stdout.flush()
478
+
479
+ start = time.time()
480
+ # tqdm is here to show you what's going on the first time you run it.
481
+ # Feel free to remove it to get slightly more accurate timing results.
482
+ for _ in tqdm(range(N), desc=fn.__name__, disable=True):
483
+ result = fn(data)
484
+ end = time.time()
485
+ print("\r", end='')
486
+
487
+ total_time = (end - start)
488
+ mpx = N * img_size / total_time
489
+ mbytes = N * img_size * n_channels / total_time
490
+ # Output in tab separated format to enable copy-paste into excel/numbers
491
+ print("%s\t%.3f\t%.3f\t%.2f" % (fn.__name__, mpx, mbytes, total_time))
492
+ outimg = Image.fromarray(np.squeeze(result), formats[n_channels])
493
+ outimg.save('./results/{}.png'.format(fn.__name__, "PNG"))
494
+
495
+ if __name__ == '__main__':
496
+ benchmark()
497
+
498
+
499
+ # Example results:
500
+ # N = 5, 1024x1024 (1.00 MPx) 1 chan, images/gray_segmentation.png
501
+ # Function MPx/sec MB/sec Sec
502
+ # simplest_countless 752.855 752.855 0.01
503
+ # quick_countless 920.328 920.328 0.01
504
+ # zero_corrected_countless 534.143 534.143 0.01
505
+ # countless 644.247 644.247 0.01
506
+ # downsample_with_averaging 372.575 372.575 0.01
507
+ # downsample_with_max_pooling 974.060 974.060 0.01
508
+ # ndzoom 137.517 137.517 0.04
509
+ # striding 38550.588 38550.588 0.00
510
+ # countless_if 4.377 4.377 1.14
511
+ # counting 0.117 0.117 42.85
512
+
513
+ # Run without non-numpy implementations:
514
+ # N = 2000, 1024x1024 (1.00 MPx) 1 chan, images/gray_segmentation.png
515
+ # Algorithm MPx/sec MB/sec Sec
516
+ # simplest_countless 800.522 800.522 2.50
517
+ # quick_countless 945.420 945.420 2.12
518
+ # quickest_countless 947.256 947.256 2.11
519
+ # stippled_countless 544.049 544.049 3.68
520
+ # zero_corrected_countless 575.310 575.310 3.48
521
+ # countless 646.684 646.684 3.09
522
+ # downsample_with_averaging 385.132 385.132 5.19
523
+ # downsample_with_max_poolin 988.361 988.361 2.02
524
+ # ndzoom 163.104 163.104 12.26
525
+ # striding 81589.340 81589.340 0.02
526
+
527
+
528
+
529
+