zhigangjiang commited on
Commit
88b0dcb
1 Parent(s): 46e6683

no message

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -0
  2. LICENSE +21 -0
  3. Post-Porcessing.md +35 -0
  4. app.py +139 -0
  5. config/__init__.py +4 -0
  6. config/defaults.py +289 -0
  7. convert_ckpt.py +61 -0
  8. dataset/__init__.py +0 -0
  9. dataset/build.py +115 -0
  10. dataset/communal/__init__.py +4 -0
  11. dataset/communal/base_dataset.py +127 -0
  12. dataset/communal/data_augmentation.py +279 -0
  13. dataset/communal/read.py +214 -0
  14. dataset/mp3d_dataset.py +110 -0
  15. dataset/pano_s2d3d_dataset.py +107 -0
  16. dataset/pano_s2d3d_mix_dataset.py +91 -0
  17. dataset/zind_dataset.py +138 -0
  18. evaluation/__init__.py +4 -0
  19. evaluation/accuracy.py +249 -0
  20. evaluation/analyse_layout_type.py +83 -0
  21. evaluation/eval_visible_iou.py +56 -0
  22. evaluation/f1_score.py +78 -0
  23. evaluation/iou.py +148 -0
  24. inference.py +261 -0
  25. loss/__init__.py +10 -0
  26. loss/boundary_loss.py +51 -0
  27. loss/grad_loss.py +57 -0
  28. loss/led_loss.py +47 -0
  29. loss/object_loss.py +42 -0
  30. main.py +401 -0
  31. models/__init__.py +1 -0
  32. models/base_model.py +150 -0
  33. models/build.py +81 -0
  34. models/lgt_net.py +213 -0
  35. models/modules/__init__.py +8 -0
  36. models/modules/conv_transformer.py +128 -0
  37. models/modules/horizon_net_feature_extractor.py +267 -0
  38. models/modules/patch_feature_extractor.py +57 -0
  39. models/modules/swg_transformer.py +49 -0
  40. models/modules/swin_transformer.py +43 -0
  41. models/modules/transformer.py +44 -0
  42. models/modules/transformer_modules.py +250 -0
  43. models/other/__init__.py +4 -0
  44. models/other/criterion.py +72 -0
  45. models/other/init_env.py +37 -0
  46. models/other/optimizer.py +24 -0
  47. models/other/scheduler.py +51 -0
  48. postprocessing/__init__.py +4 -0
  49. postprocessing/dula/__init__.py +4 -0
  50. postprocessing/dula/layout.py +226 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ checkpoints
2
+ src/output
3
+ visualization/visualizer
4
+ flagged
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 ZhiGang Jiang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Post-Porcessing.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Post-Processing
2
+ ## Step
3
+
4
+ 1. Simplify polygon by [DP algorithm](https://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm)
5
+
6
+ ![img.png](src/fig/post_processing/img_0.png)
7
+
8
+ 2. Detect occlusion, calculating box fill with 1
9
+
10
+ ![img.png](src/fig/post_processing/img_1.png)
11
+
12
+ 3. Fill in reasonable sampling section
13
+
14
+ ![img.png](src/fig/post_processing/img_2.png)
15
+
16
+ 4. Output processed polygon
17
+
18
+ ![img.png](src/fig/post_processing/img_3.png)
19
+
20
+ ## performance
21
+ It works, and a performance comparison on the MatterportLayout dataset:
22
+
23
+ | Method | 2D IoU(%) | 3D IoU(%) | RMSE | $\mathbf{\delta_{1}}$ |
24
+ |--|--|--|--|--|
25
+ without post-proc | 83.52 | 81.11 | 0.204 | 0.951 |
26
+ original post-proc |83.12 | 80.71 | 0.230 | 0.936|\
27
+ optimized post-proc | 83.48 | 81.08| 0.214 | 0.940 |
28
+
29
+ original:
30
+
31
+ ![img.png](src/fig/post_processing/original.png)
32
+
33
+ optimized:
34
+
35
+ ![img.png](src/fig/post_processing/optimized.png)
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @author: Zhigang Jiang
3
+ @time: 2022/05/23
4
+ @description:
5
+ '''
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import os
10
+ import torch
11
+
12
+ from PIL import Image
13
+
14
+ from utils.logger import get_logger
15
+ from config.defaults import get_config
16
+ from inference import preprocess, run_one_inference
17
+ from models.build import build_model
18
+ from argparse import Namespace
19
+ import gdown
20
+
21
+
22
+ def down_ckpt(model_cfg, ckpt_dir):
23
+ model_ids = [
24
+ ['src/config/mp3d.yaml', '1o97oAmd-yEP5bQrM0eAWFPLq27FjUDbh'],
25
+ ['src/config/zind.yaml', '1PzBj-dfDfH_vevgSkRe5kczW0GVl_43I'],
26
+ ['src/config/pano.yaml', '1JoeqcPbm_XBPOi6O9GjjWi3_rtyPZS8m'],
27
+ ['src/config/s2d3d.yaml', '1PfJzcxzUsbwwMal7yTkBClIFgn8IdEzI'],
28
+ ['src/config/ablation_study/full.yaml', '1U16TxUkvZlRwJNaJnq9nAUap-BhCVIha']
29
+ ]
30
+
31
+ for model_id in model_ids:
32
+ if model_id[0] != model_cfg:
33
+ continue
34
+ path = os.path.join(ckpt_dir, 'best.pkl')
35
+ if not os.path.exists(path):
36
+ logger.info(f"Downloading {model_id}")
37
+ os.makedirs(ckpt_dir, exist_ok=True)
38
+ gdown.download(f"https://drive.google.com/uc?id={model_id[1]}", path, False)
39
+
40
+
41
+ def greet(img_path, pre_processing, weight_name, post_processing, visualization, mesh_format, mesh_resolution):
42
+ args.pre_processing = pre_processing
43
+ args.post_processing = post_processing
44
+ if weight_name == 'mp3d':
45
+ model = mp3d_model
46
+ elif weight_name == 'zind':
47
+ model = zind_model
48
+ else:
49
+ logger.error("unknown pre-trained weight name")
50
+ raise NotImplementedError
51
+
52
+ img_name = os.path.basename(img_path).split('.')[0]
53
+ img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3]
54
+
55
+ vp_cache_path = 'src/demo/default_vp.txt'
56
+ if args.pre_processing:
57
+ vp_cache_path = os.path.join('src/output', f'{img_name}_vp.txt')
58
+ logger.info("pre-processing ...")
59
+ img, vp = preprocess(img, vp_cache_path=vp_cache_path)
60
+
61
+ img = (img / 255.0).astype(np.float32)
62
+ run_one_inference(img, model, args, img_name,
63
+ logger=logger, show=False,
64
+ show_depth='depth-normal-gradient' in visualization,
65
+ show_floorplan='2d-floorplan' in visualization,
66
+ mesh_format=mesh_format, mesh_resolution=int(mesh_resolution))
67
+
68
+ return [os.path.join(args.output_dir, f"{img_name}_pred.png"),
69
+ os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"),
70
+ os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"),
71
+ vp_cache_path,
72
+ os.path.join(args.output_dir, f"{img_name}_pred.json")]
73
+
74
+
75
+ def get_model(args):
76
+ config = get_config(args)
77
+ down_ckpt(args.cfg, config.CKPT.DIR)
78
+ if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available():
79
+ logger.info(f'The {args.device} is not available, will use cpu ...')
80
+ config.defrost()
81
+ args.device = "cpu"
82
+ config.TRAIN.DEVICE = "cpu"
83
+ config.freeze()
84
+ model, _, _, _ = build_model(config, logger)
85
+ return model
86
+
87
+
88
+ if __name__ == '__main__':
89
+ logger = get_logger()
90
+ args = Namespace(device='cuda', output_dir='src/output', visualize_3d=False, output_3d=True)
91
+ os.makedirs(args.output_dir, exist_ok=True)
92
+
93
+ args.cfg = 'src/config/mp3d.yaml'
94
+ mp3d_model = get_model(args)
95
+
96
+ args.cfg = 'src/config/zind.yaml'
97
+ zind_model = get_model(args)
98
+
99
+ description = "This demo of the project " \
100
+ "<a href='https://github.com/zhigangjiang/LGT-Net' target='_blank'>LGT-Net</a>. " \
101
+ "It uses the Geometry-Aware Transformer Network to predict the 3d room layout of an rgb panorama."
102
+
103
+ demo = gr.Interface(fn=greet,
104
+ inputs=[gr.Image(type='filepath', label='input rgb panorama', value='src/demo/pano_demo1.png'),
105
+ gr.Checkbox(label='pre-processing', value=True),
106
+ gr.Radio(['mp3d', 'zind'],
107
+ label='pre-trained weight',
108
+ value='mp3d'),
109
+ gr.Radio(['manhattan', 'atalanta', 'original'],
110
+ label='post-processing method',
111
+ value='manhattan'),
112
+ gr.CheckboxGroup(['depth-normal-gradient', '2d-floorplan'],
113
+ label='2d-visualization',
114
+ value=['depth-normal-gradient', '2d-floorplan']),
115
+ gr.Radio(['.gltf', '.obj', '.glb'],
116
+ label='output format of 3d mesh',
117
+ value='.gltf'),
118
+ gr.Radio(['128', '256', '512', '1024'],
119
+ label='output resolution of 3d mesh',
120
+ value='256'),
121
+ ],
122
+ outputs=[gr.Image(label='predicted result 2d-visualization', type='filepath'),
123
+ gr.Model3D(label='3d mesh reconstruction', clear_color=[1.0, 1.0, 1.0, 1.0]),
124
+ gr.File(label='3d mesh file'),
125
+ gr.File(label='vanishing point information'),
126
+ gr.File(label='layout json')],
127
+ examples=[
128
+ ['src/demo/pano_demo1.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
129
+ ['src/demo/mp3d_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
130
+ ['src/demo/mp3d_demo2.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
131
+ ['src/demo/mp3d_demo3.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
132
+ ['src/demo/zind_demo1.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
133
+ ['src/demo/zind_demo2.png', False, 'zind', 'atalanta', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
134
+ ['src/demo/zind_demo3.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
135
+ ['src/demo/other_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
136
+ ['src/demo/other_demo2.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
137
+ ], title='LGT-Net', allow_flagging="never", cache_examples=False, description=description)
138
+
139
+ demo.launch(debug=True, enable_queue=False)
config/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ @Date: 2021/07/17
3
+ @description:
4
+ """
config/defaults.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/07/17
3
+ @description:
4
+ """
5
+ import os
6
+ import logging
7
+ from yacs.config import CfgNode as CN
8
+
9
+ _C = CN()
10
+ _C.DEBUG = False
11
+ _C.MODE = 'train'
12
+ _C.VAL_NAME = 'val'
13
+ _C.TAG = 'default'
14
+ _C.COMMENT = 'add some comments to help you understand'
15
+ _C.SHOW_BAR = True
16
+ _C.SAVE_EVAL = False
17
+ _C.MODEL = CN()
18
+ _C.MODEL.NAME = 'model_name'
19
+ _C.MODEL.SAVE_BEST = True
20
+ _C.MODEL.SAVE_LAST = True
21
+ _C.MODEL.ARGS = []
22
+ _C.MODEL.FINE_TUNE = []
23
+
24
+ # -----------------------------------------------------------------------------
25
+ # Training settings
26
+ # -----------------------------------------------------------------------------
27
+ _C.TRAIN = CN()
28
+ _C.TRAIN.SCRATCH = False
29
+ _C.TRAIN.START_EPOCH = 0
30
+ _C.TRAIN.EPOCHS = 300
31
+ _C.TRAIN.DETERMINISTIC = False
32
+ _C.TRAIN.SAVE_FREQ = 5
33
+
34
+ _C.TRAIN.BASE_LR = 5e-4
35
+
36
+ _C.TRAIN.WARMUP_EPOCHS = 20
37
+ _C.TRAIN.WEIGHT_DECAY = 0
38
+ _C.TRAIN.WARMUP_LR = 5e-7
39
+ _C.TRAIN.MIN_LR = 5e-6
40
+ # Clip gradient norm
41
+ _C.TRAIN.CLIP_GRAD = 5.0
42
+ # Auto resume from latest checkpoint
43
+ _C.TRAIN.RESUME_LAST = True
44
+ # Gradient accumulation steps
45
+ # could be overwritten by command line argument
46
+ _C.TRAIN.ACCUMULATION_STEPS = 0
47
+ # Whether to use gradient checkpointing to save memory
48
+ # could be overwritten by command line argument
49
+ _C.TRAIN.USE_CHECKPOINT = False
50
+ # 'cpu' or 'cuda:0, 1, 2, 3' or 'cuda'
51
+ _C.TRAIN.DEVICE = 'cuda'
52
+
53
+ # LR scheduler
54
+ _C.TRAIN.LR_SCHEDULER = CN()
55
+ _C.TRAIN.LR_SCHEDULER.NAME = ''
56
+ _C.TRAIN.LR_SCHEDULER.ARGS = []
57
+
58
+
59
+ # Optimizer
60
+ _C.TRAIN.OPTIMIZER = CN()
61
+ _C.TRAIN.OPTIMIZER.NAME = 'adam'
62
+ # Optimizer Epsilon
63
+ _C.TRAIN.OPTIMIZER.EPS = 1e-8
64
+ # Optimizer Betas
65
+ _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
66
+ # SGD momentum
67
+ _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
68
+
69
+ # Criterion
70
+ _C.TRAIN.CRITERION = CN()
71
+ # Boundary loss (Horizon-Net)
72
+ _C.TRAIN.CRITERION.BOUNDARY = CN()
73
+ _C.TRAIN.CRITERION.BOUNDARY.NAME = 'boundary'
74
+ _C.TRAIN.CRITERION.BOUNDARY.LOSS = 'BoundaryLoss'
75
+ _C.TRAIN.CRITERION.BOUNDARY.WEIGHT = 0.0
76
+ _C.TRAIN.CRITERION.BOUNDARY.WEIGHTS = []
77
+ _C.TRAIN.CRITERION.BOUNDARY.NEED_ALL = True
78
+ # Up and Down depth loss (LED2-Net)
79
+ _C.TRAIN.CRITERION.LEDDepth = CN()
80
+ _C.TRAIN.CRITERION.LEDDepth.NAME = 'led_depth'
81
+ _C.TRAIN.CRITERION.LEDDepth.LOSS = 'LEDLoss'
82
+ _C.TRAIN.CRITERION.LEDDepth.WEIGHT = 0.0
83
+ _C.TRAIN.CRITERION.LEDDepth.WEIGHTS = []
84
+ _C.TRAIN.CRITERION.LEDDepth.NEED_ALL = True
85
+ # Depth loss
86
+ _C.TRAIN.CRITERION.DEPTH = CN()
87
+ _C.TRAIN.CRITERION.DEPTH.NAME = 'depth'
88
+ _C.TRAIN.CRITERION.DEPTH.LOSS = 'L1Loss'
89
+ _C.TRAIN.CRITERION.DEPTH.WEIGHT = 0.0
90
+ _C.TRAIN.CRITERION.DEPTH.WEIGHTS = []
91
+ _C.TRAIN.CRITERION.DEPTH.NEED_ALL = False
92
+ # Ratio(Room Height) loss
93
+ _C.TRAIN.CRITERION.RATIO = CN()
94
+ _C.TRAIN.CRITERION.RATIO.NAME = 'ratio'
95
+ _C.TRAIN.CRITERION.RATIO.LOSS = 'L1Loss'
96
+ _C.TRAIN.CRITERION.RATIO.WEIGHT = 0.0
97
+ _C.TRAIN.CRITERION.RATIO.WEIGHTS = []
98
+ _C.TRAIN.CRITERION.RATIO.NEED_ALL = False
99
+ # Grad(Normal) loss
100
+ _C.TRAIN.CRITERION.GRAD = CN()
101
+ _C.TRAIN.CRITERION.GRAD.NAME = 'grad'
102
+ _C.TRAIN.CRITERION.GRAD.LOSS = 'GradLoss'
103
+ _C.TRAIN.CRITERION.GRAD.WEIGHT = 0.0
104
+ _C.TRAIN.CRITERION.GRAD.WEIGHTS = [1.0, 1.0]
105
+ _C.TRAIN.CRITERION.GRAD.NEED_ALL = True
106
+ # Object loss
107
+ _C.TRAIN.CRITERION.OBJECT = CN()
108
+ _C.TRAIN.CRITERION.OBJECT.NAME = 'object'
109
+ _C.TRAIN.CRITERION.OBJECT.LOSS = 'ObjectLoss'
110
+ _C.TRAIN.CRITERION.OBJECT.WEIGHT = 0.0
111
+ _C.TRAIN.CRITERION.OBJECT.WEIGHTS = []
112
+ _C.TRAIN.CRITERION.OBJECT.NEED_ALL = True
113
+ # Heatmap loss
114
+ _C.TRAIN.CRITERION.CHM = CN()
115
+ _C.TRAIN.CRITERION.CHM.NAME = 'corner_heat_map'
116
+ _C.TRAIN.CRITERION.CHM.LOSS = 'HeatmapLoss'
117
+ _C.TRAIN.CRITERION.CHM.WEIGHT = 0.0
118
+ _C.TRAIN.CRITERION.CHM.WEIGHTS = []
119
+ _C.TRAIN.CRITERION.CHM.NEED_ALL = False
120
+
121
+ _C.TRAIN.VIS_MERGE = True
122
+ _C.TRAIN.VIS_WEIGHT = 1024
123
+ # -----------------------------------------------------------------------------
124
+ # Output settings
125
+ # -----------------------------------------------------------------------------
126
+ _C.CKPT = CN()
127
+ _C.CKPT.PYTORCH = './'
128
+ _C.CKPT.ROOT = "./checkpoints"
129
+ _C.CKPT.DIR = os.path.join(_C.CKPT.ROOT, _C.MODEL.NAME, _C.TAG)
130
+ _C.CKPT.RESULT_DIR = os.path.join(_C.CKPT.DIR, 'results', _C.MODE)
131
+
132
+ _C.LOGGER = CN()
133
+ _C.LOGGER.DIR = os.path.join(_C.CKPT.DIR, "logs")
134
+ _C.LOGGER.LEVEL = logging.DEBUG
135
+
136
+ # -----------------------------------------------------------------------------
137
+ # Misc
138
+ # -----------------------------------------------------------------------------
139
+ # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2'), Please confirm your device support FP16(Half).
140
+ # overwritten by command line argument
141
+ _C.AMP_OPT_LEVEL = 'O1'
142
+ # Path to output folder, overwritten by command line argument
143
+ _C.OUTPUT = ''
144
+ # Tag of experiment, overwritten by command line argument
145
+ _C.TAG = 'default'
146
+ # Frequency to save checkpoint
147
+ _C.SAVE_FREQ = 1
148
+ # Frequency to logging info
149
+ _C.PRINT_FREQ = 10
150
+ # Fixed random seed
151
+ _C.SEED = 0
152
+ # Perform evaluation only, overwritten by command line argument
153
+ _C.EVAL_MODE = False
154
+ # Test throughput only, overwritten by command line argument
155
+ _C.THROUGHPUT_MODE = False
156
+
157
+ # -----------------------------------------------------------------------------
158
+ # FIX
159
+ # -----------------------------------------------------------------------------
160
+ _C.LOCAL_RANK = 0
161
+ _C.WORLD_SIZE = 0
162
+
163
+ # -----------------------------------------------------------------------------
164
+ # Data settings
165
+ # -----------------------------------------------------------------------------
166
+ _C.DATA = CN()
167
+ # Sub dataset of pano_s2d3d
168
+ _C.DATA.SUBSET = None
169
+ # Dataset name
170
+ _C.DATA.DATASET = 'mp3d'
171
+ # Path to dataset, could be overwritten by command line argument
172
+ _C.DATA.DIR = ''
173
+ # Max wall number
174
+ _C.DATA.WALL_NUM = 0 # all
175
+ # Panorama image size
176
+ _C.DATA.SHAPE = [512, 1024]
177
+ # Really camera height
178
+ _C.DATA.CAMERA_HEIGHT = 1.6
179
+ # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
180
+ _C.DATA.PIN_MEMORY = True
181
+ # Debug use, fast test performance of model
182
+ _C.DATA.FOR_TEST_INDEX = None
183
+
184
+ # Batch size for a single GPU, could be overwritten by command line argument
185
+ _C.DATA.BATCH_SIZE = 8
186
+ # Number of data loading threads
187
+ _C.DATA.NUM_WORKERS = 8
188
+
189
+ # Training augment
190
+ _C.DATA.AUG = CN()
191
+ # Flip the panorama horizontally
192
+ _C.DATA.AUG.FLIP = True
193
+ # Pano Stretch Data Augmentation by HorizonNet
194
+ _C.DATA.AUG.STRETCH = True
195
+ # Rotate the panorama horizontally
196
+ _C.DATA.AUG.ROTATE = True
197
+ # Gamma adjusting
198
+ _C.DATA.AUG.GAMMA = True
199
+
200
+ _C.DATA.KEYS = []
201
+
202
+
203
+ _C.EVAL = CN()
204
+ _C.EVAL.POST_PROCESSING = None
205
+ _C.EVAL.NEED_CPE = False
206
+ _C.EVAL.NEED_F1 = False
207
+ _C.EVAL.NEED_RMSE = False
208
+ _C.EVAL.FORCE_CUBE = False
209
+
210
+
211
+ def merge_from_file(cfg_path):
212
+ config = _C.clone()
213
+ config.merge_from_file(cfg_path)
214
+ return config
215
+
216
+
217
+ def get_config(args=None):
218
+ config = _C.clone()
219
+ if args:
220
+ if 'cfg' in args and args.cfg:
221
+ config.merge_from_file(args.cfg)
222
+
223
+ if 'mode' in args and args.mode:
224
+ config.MODE = args.mode
225
+
226
+ if 'debug' in args and args.debug:
227
+ config.DEBUG = args.debug
228
+
229
+ if 'hidden_bar' in args and args.hidden_bar:
230
+ config.SHOW_BAR = False
231
+
232
+ if 'bs' in args and args.bs:
233
+ config.DATA.BATCH_SIZE = args.bs
234
+
235
+ if 'save_eval' in args and args.save_eval:
236
+ config.SAVE_EVAL = True
237
+
238
+ if 'val_name' in args and args.val_name:
239
+ config.VAL_NAME = args.val_name
240
+
241
+ if 'post_processing' in args and args.post_processing:
242
+ config.EVAL.POST_PROCESSING = args.post_processing
243
+
244
+ if 'need_cpe' in args and args.need_cpe:
245
+ config.EVAL.NEED_CPE = args.need_cpe
246
+
247
+ if 'need_f1' in args and args.need_f1:
248
+ config.EVAL.NEED_F1 = args.need_f1
249
+
250
+ if 'need_rmse' in args and args.need_rmse:
251
+ config.EVAL.NEED_RMSE = args.need_rmse
252
+
253
+ if 'force_cube' in args and args.force_cube:
254
+ config.EVAL.FORCE_CUBE = args.force_cube
255
+
256
+ if 'wall_num' in args and args.wall_num:
257
+ config.DATA.WALL_NUM = args.wall_num
258
+
259
+ args = config.MODEL.ARGS[0]
260
+ config.CKPT.DIR = os.path.join(config.CKPT.ROOT, f"{args['decoder_name']}_{args['output_name']}_Net",
261
+ config.TAG, 'debug' if config.DEBUG else '')
262
+ config.CKPT.RESULT_DIR = os.path.join(config.CKPT.DIR, 'results', config.MODE)
263
+ config.LOGGER.DIR = os.path.join(config.CKPT.DIR, "logs")
264
+
265
+ core_number = os.popen("grep 'physical id' /proc/cpuinfo | sort | uniq | wc -l").read()
266
+
267
+ try:
268
+ config.DATA.NUM_WORKERS = int(core_number) * 2
269
+ print(f"System core number: {config.DATA.NUM_WORKERS}")
270
+ except ValueError:
271
+ print(f"Can't get system core number, will use config: { config.DATA.NUM_WORKERS}")
272
+ config.freeze()
273
+ return config
274
+
275
+
276
+ def get_rank_config(cfg, local_rank, world_size):
277
+ local_rank = 0 if local_rank is None else local_rank
278
+ config = cfg.clone()
279
+ config.defrost()
280
+ if world_size > 1:
281
+ ids = config.TRAIN.DEVICE.split(':')[-1].split(',') if ':' in config.TRAIN.DEVICE else range(world_size)
282
+ config.TRAIN.DEVICE = f'cuda:{ids[local_rank]}'
283
+
284
+ config.LOCAL_RANK = local_rank
285
+ config.WORLD_SIZE = world_size
286
+ config.SEED = config.SEED + local_rank
287
+
288
+ config.freeze()
289
+ return config
convert_ckpt.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @date: 2021/11/22
3
+ @description: Conversion training ckpt into inference ckpt
4
+ """
5
+ import argparse
6
+ import os
7
+
8
+ import torch
9
+
10
+ from config.defaults import merge_from_file
11
+
12
+
13
+ def parse_option():
14
+ parser = argparse.ArgumentParser(description='Conversion training ckpt into inference ckpt')
15
+ parser.add_argument('--cfg',
16
+ type=str,
17
+ required=True,
18
+ metavar='FILE',
19
+ help='path of config file')
20
+
21
+ parser.add_argument('--output_path',
22
+ type=str,
23
+ help='path of output ckpt')
24
+
25
+ args = parser.parse_args()
26
+
27
+ print("arguments:")
28
+ for arg in vars(args):
29
+ print(arg, ":", getattr(args, arg))
30
+ print("-" * 50)
31
+ return args
32
+
33
+
34
+ def convert_ckpt():
35
+ args = parse_option()
36
+ config = merge_from_file(args.cfg)
37
+ ck_dir = os.path.join("checkpoints", f"{config.MODEL.ARGS[0]['decoder_name']}_{config.MODEL.ARGS[0]['output_name']}_Net",
38
+ config.TAG)
39
+ print(f"Processing {ck_dir}")
40
+ model_paths = [name for name in os.listdir(ck_dir) if '_best_' in name]
41
+ if len(model_paths) == 0:
42
+ print("Not find best ckpt")
43
+ return
44
+ model_path = os.path.join(ck_dir, model_paths[0])
45
+ print(f"Loading {model_path}")
46
+ checkpoint = torch.load(model_path, map_location=torch.device('cuda:0'))
47
+ net = checkpoint['net']
48
+ output_path = None
49
+ if args.output_path is None:
50
+ output_path = os.path.join(ck_dir, 'best.pkl')
51
+ else:
52
+ output_path = args.output_path
53
+ if output_path is None:
54
+ print("Output path is invalid")
55
+ print(f"Save on: {output_path}")
56
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
57
+ torch.save(net, output_path)
58
+
59
+
60
+ if __name__ == '__main__':
61
+ convert_ckpt()
dataset/__init__.py ADDED
File without changes
dataset/build.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/07/18
3
+ @description:
4
+ """
5
+ import numpy as np
6
+ import torch.utils.data
7
+ from dataset.mp3d_dataset import MP3DDataset
8
+ from dataset.pano_s2d3d_dataset import PanoS2D3DDataset
9
+ from dataset.pano_s2d3d_mix_dataset import PanoS2D3DMixDataset
10
+ from dataset.zind_dataset import ZindDataset
11
+
12
+
13
+ def build_loader(config, logger):
14
+ name = config.DATA.DATASET
15
+ ddp = config.WORLD_SIZE > 1
16
+ train_dataset = None
17
+ train_data_loader = None
18
+ if config.MODE == 'train':
19
+ train_dataset = build_dataset(mode='train', config=config, logger=logger)
20
+
21
+ val_dataset = build_dataset(mode=config.VAL_NAME if config.MODE != 'test' else 'test', config=config, logger=logger)
22
+
23
+ train_sampler = None
24
+ val_sampler = None
25
+ if ddp:
26
+ if train_dataset:
27
+ train_sampler = torch.utils.data.DistributedSampler(train_dataset, shuffle=True)
28
+ val_sampler = torch.utils.data.DistributedSampler(val_dataset, shuffle=False)
29
+
30
+ batch_size = config.DATA.BATCH_SIZE
31
+ num_workers = 0 if config.DEBUG else config.DATA.NUM_WORKERS
32
+ pin_memory = config.DATA.PIN_MEMORY
33
+ if train_dataset:
34
+ logger.info(f'Train data loader batch size: {batch_size}')
35
+ train_data_loader = torch.utils.data.DataLoader(
36
+ train_dataset, sampler=train_sampler,
37
+ batch_size=batch_size,
38
+ shuffle=True,
39
+ num_workers=num_workers,
40
+ pin_memory=pin_memory,
41
+ drop_last=True,
42
+ )
43
+ batch_size = batch_size - (len(val_dataset) % np.arange(batch_size, 0, -1)).tolist().index(0)
44
+ logger.info(f'Val data loader batch size: {batch_size}')
45
+ val_data_loader = torch.utils.data.DataLoader(
46
+ val_dataset, sampler=val_sampler,
47
+ batch_size=batch_size,
48
+ shuffle=False,
49
+ num_workers=num_workers,
50
+ pin_memory=pin_memory,
51
+ drop_last=False
52
+ )
53
+ logger.info(f'Build data loader: num_workers:{num_workers} pin_memory:{pin_memory}')
54
+ return train_data_loader, val_data_loader
55
+
56
+
57
+ def build_dataset(mode, config, logger):
58
+ name = config.DATA.DATASET
59
+ if name == 'mp3d':
60
+ dataset = MP3DDataset(
61
+ root_dir=config.DATA.DIR,
62
+ mode=mode,
63
+ shape=config.DATA.SHAPE,
64
+ max_wall_num=config.DATA.WALL_NUM,
65
+ aug=config.DATA.AUG if mode == 'train' else None,
66
+ camera_height=config.DATA.CAMERA_HEIGHT,
67
+ logger=logger,
68
+ for_test_index=config.DATA.FOR_TEST_INDEX,
69
+ keys=config.DATA.KEYS
70
+ )
71
+ elif name == 'pano_s2d3d':
72
+ dataset = PanoS2D3DDataset(
73
+ root_dir=config.DATA.DIR,
74
+ mode=mode,
75
+ shape=config.DATA.SHAPE,
76
+ max_wall_num=config.DATA.WALL_NUM,
77
+ aug=config.DATA.AUG if mode == 'train' else None,
78
+ camera_height=config.DATA.CAMERA_HEIGHT,
79
+ logger=logger,
80
+ for_test_index=config.DATA.FOR_TEST_INDEX,
81
+ subset=config.DATA.SUBSET,
82
+ keys=config.DATA.KEYS
83
+ )
84
+ elif name == 'pano_s2d3d_mix':
85
+ dataset = PanoS2D3DMixDataset(
86
+ root_dir=config.DATA.DIR,
87
+ mode=mode,
88
+ shape=config.DATA.SHAPE,
89
+ max_wall_num=config.DATA.WALL_NUM,
90
+ aug=config.DATA.AUG if mode == 'train' else None,
91
+ camera_height=config.DATA.CAMERA_HEIGHT,
92
+ logger=logger,
93
+ for_test_index=config.DATA.FOR_TEST_INDEX,
94
+ subset=config.DATA.SUBSET,
95
+ keys=config.DATA.KEYS
96
+ )
97
+ elif name == 'zind':
98
+ dataset = ZindDataset(
99
+ root_dir=config.DATA.DIR,
100
+ mode=mode,
101
+ shape=config.DATA.SHAPE,
102
+ max_wall_num=config.DATA.WALL_NUM,
103
+ aug=config.DATA.AUG if mode == 'train' else None,
104
+ camera_height=config.DATA.CAMERA_HEIGHT,
105
+ logger=logger,
106
+ for_test_index=config.DATA.FOR_TEST_INDEX,
107
+ is_simple=True,
108
+ is_ceiling_flat=False,
109
+ keys=config.DATA.KEYS,
110
+ vp_align=config.EVAL.POST_PROCESSING is not None and 'manhattan' in config.EVAL.POST_PROCESSING
111
+ )
112
+ else:
113
+ raise NotImplementedError(f"Unknown dataset: {name}")
114
+
115
+ return dataset
dataset/communal/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ @Date: 2021/09/22
3
+ @description:
4
+ """
dataset/communal/base_dataset.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/07/26
3
+ @description:
4
+ """
5
+ import numpy as np
6
+ import torch
7
+
8
+ from utils.boundary import corners2boundary, visibility_corners, get_heat_map
9
+ from utils.conversion import xyz2depth, uv2xyz, uv2pixel
10
+ from dataset.communal.data_augmentation import PanoDataAugmentation
11
+
12
+
13
+ class BaseDataset(torch.utils.data.Dataset):
14
+ def __init__(self, mode, shape=None, max_wall_num=999, aug=None, camera_height=1.6, patch_num=256, keys=None):
15
+ if keys is None or len(keys) == 0:
16
+ keys = ['image', 'depth', 'ratio', 'id', 'corners']
17
+ if shape is None:
18
+ shape = [512, 1024]
19
+
20
+ assert mode == 'train' or mode == 'val' or mode == 'test' or mode is None, 'unknown mode!'
21
+ self.mode = mode
22
+ self.keys = keys
23
+ self.shape = shape
24
+ self.pano_aug = None if aug is None or mode == 'val' else PanoDataAugmentation(aug)
25
+ self.camera_height = camera_height
26
+ self.max_wall_num = max_wall_num
27
+ self.patch_num = patch_num
28
+ self.data = None
29
+
30
+ def __len__(self):
31
+ return len(self.data)
32
+
33
+ @staticmethod
34
+ def get_depth(corners, plan_y=1, length=256, visible=True):
35
+ visible_floor_boundary = corners2boundary(corners, length=length, visible=visible)
36
+ # The horizon-depth relative to plan_y
37
+ visible_depth = xyz2depth(uv2xyz(visible_floor_boundary, plan_y), plan_y)
38
+ return visible_depth
39
+
40
+ def process_data(self, label, image, patch_num):
41
+ """
42
+ :param label:
43
+ :param image:
44
+ :param patch_num:
45
+ :return:
46
+ """
47
+ corners = label['corners']
48
+ if self.pano_aug is not None:
49
+ corners, image = self.pano_aug.execute_aug(corners, image if 'image' in self.keys else None)
50
+ eps = 1e-3
51
+ corners[:, 1] = np.clip(corners[:, 1], 0.5+eps, 1-eps)
52
+
53
+ output = {}
54
+ if 'image' in self.keys:
55
+ image = image.transpose(2, 0, 1)
56
+ output['image'] = image
57
+
58
+ visible_corners = None
59
+ if 'corner_class' in self.keys or 'depth' in self.keys:
60
+ visible_corners = visibility_corners(corners)
61
+
62
+ if 'depth' in self.keys:
63
+ depth = self.get_depth(visible_corners, length=patch_num, visible=False)
64
+ assert len(depth) == patch_num, f"{label['id']}, {len(depth)}, {self.pano_aug.parameters}, {corners}"
65
+ output['depth'] = depth
66
+
67
+ if 'ratio' in self.keys:
68
+ # Why use ratio? Because when floor_height =y_plan=1, we only need to predict ceil_height(ratio).
69
+ output['ratio'] = label['ratio']
70
+
71
+ if 'id' in self.keys:
72
+ output['id'] = label['id']
73
+
74
+ if 'corners' in self.keys:
75
+ # all corners for evaluating Full_IoU
76
+ assert len(label['corners']) <= 32, "len(label['corners']):"+len(label['corners'])
77
+ output['corners'] = np.zeros((32, 2), dtype=np.float32)
78
+ output['corners'][:len(label['corners'])] = label['corners']
79
+
80
+ if 'corner_heat_map' in self.keys:
81
+ output['corner_heat_map'] = get_heat_map(visible_corners[..., 0])
82
+
83
+ if 'object' in self.keys and 'objects' in label:
84
+ output[f'object_heat_map'] = np.zeros((3, patch_num), dtype=np.float32)
85
+ output['object_size'] = np.zeros((3, patch_num), dtype=np.float32) # width, height, bottom_height
86
+ for i, type in enumerate(label['objects']):
87
+ if len(label['objects'][type]) == 0:
88
+ continue
89
+
90
+ u_s = []
91
+ for obj in label['objects'][type]:
92
+ center_u = obj['center_u']
93
+ u_s.append(center_u)
94
+ center_pixel_u = uv2pixel(np.array([center_u]), w=patch_num, axis=0)[0]
95
+ output['object_size'][0, center_pixel_u] = obj['width_u']
96
+ output['object_size'][1, center_pixel_u] = obj['height_v']
97
+ output['object_size'][2, center_pixel_u] = obj['boundary_v']
98
+ output[f'object_heat_map'][i] = get_heat_map(np.array(u_s))
99
+
100
+ return output
101
+
102
+
103
+ if __name__ == '__main__':
104
+ from dataset.communal.read import read_image, read_label
105
+ from visualization.boundary import draw_boundaries
106
+ from utils.boundary import depth2boundaries
107
+ from tqdm import trange
108
+
109
+ # np.random.seed(0)
110
+ dataset = BaseDataset()
111
+ dataset.pano_aug = PanoDataAugmentation(aug={
112
+ 'STRETCH': True,
113
+ 'ROTATE': True,
114
+ 'FLIP': True,
115
+ })
116
+ # pano_img = read_image("../src/demo.png")
117
+ # label = read_label("../src/demo.json")
118
+ pano_img_path = "../../src/dataset/mp3d/image/yqstnuAEVhm_6589ad7a5a0444b59adbf501c0f0fe53.png"
119
+ label_path = "../../src/dataset/mp3d/label/yqstnuAEVhm_6589ad7a5a0444b59adbf501c0f0fe53.json"
120
+ pano_img = read_image(pano_img_path)
121
+ label = read_label(label_path)
122
+
123
+ # batch test
124
+ for i in trange(1):
125
+ output = dataset.process_data(label, pano_img, 256)
126
+ boundary_list = depth2boundaries(output['ratio'], output['depth'], step=None)
127
+ draw_boundaries(output['image'].transpose(1, 2, 0), boundary_list=boundary_list, show=True)
dataset/communal/data_augmentation.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/07/27
3
+ @description:
4
+ """
5
+ import numpy as np
6
+ import cv2
7
+ import functools
8
+
9
+ from utils.conversion import pixel2lonlat, lonlat2pixel, uv2lonlat, lonlat2uv, pixel2uv
10
+
11
+
12
+ @functools.lru_cache()
13
+ def prepare_stretch(w, h):
14
+ lon = pixel2lonlat(np.array(range(w)), w=w, axis=0)
15
+ lat = pixel2lonlat(np.array(range(h)), h=h, axis=1)
16
+ sin_lon = np.sin(lon)
17
+ cos_lon = np.cos(lon)
18
+ tan_lat = np.tan(lat)
19
+ return sin_lon, cos_lon, tan_lat
20
+
21
+
22
+ def pano_stretch_image(pano_img, kx, ky, kz):
23
+ """
24
+ Note that this is the inverse mapping, which refers to Equation 3 in HorizonNet paper (the coordinate system in
25
+ the paper is different from here, xz needs to be swapped)
26
+ :param pano_img: a panorama image, shape must be [h,w,c]
27
+ :param kx: stretching along left-right direction
28
+ :param ky: stretching along up-down direction
29
+ :param kz: stretching along front-back direction
30
+ :return:
31
+ """
32
+ w = pano_img.shape[1]
33
+ h = pano_img.shape[0]
34
+
35
+ sin_lon, cos_lon, tan_lat = prepare_stretch(w, h)
36
+
37
+ n_lon = np.arctan2(sin_lon * kz / kx, cos_lon)
38
+ n_lat = np.arctan(tan_lat[..., None] * np.sin(n_lon) / sin_lon * kx / ky)
39
+ n_pu = lonlat2pixel(n_lon, w=w, axis=0, need_round=False)
40
+ n_pv = lonlat2pixel(n_lat, h=h, axis=1, need_round=False)
41
+
42
+ pixel_map = np.empty((h, w, 2), dtype=np.float32)
43
+ pixel_map[..., 0] = n_pu
44
+ pixel_map[..., 1] = n_pv
45
+ map1 = pixel_map[..., 0]
46
+ map2 = pixel_map[..., 1]
47
+ # using wrap mode because it is continues at left or right of panorama
48
+ new_img = cv2.remap(pano_img, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP)
49
+ return new_img
50
+
51
+
52
+ def pano_stretch_conner(corners, kx, ky, kz):
53
+ """
54
+ :param corners:
55
+ :param kx: stretching along left-right direction
56
+ :param ky: stretching along up-down direction
57
+ :param kz: stretching along front-back direction
58
+ :return:
59
+ """
60
+
61
+ lonlat = uv2lonlat(corners)
62
+ sin_lon = np.sin(lonlat[..., 0:1])
63
+ cos_lon = np.cos(lonlat[..., 0:1])
64
+ tan_lat = np.tan(lonlat[..., 1:2])
65
+
66
+ n_lon = np.arctan2(sin_lon * kx / kz, cos_lon)
67
+
68
+ a = np.bitwise_or(corners[..., 0] == 0.5, corners[..., 0] == 1)
69
+ b = np.bitwise_not(a)
70
+ w = np.zeros_like(n_lon)
71
+ w[b] = np.sin(n_lon[b]) / sin_lon[b]
72
+ w[a] = kx / kz
73
+
74
+ n_lat = np.arctan(tan_lat * w / kx * ky)
75
+
76
+ lst = [n_lon, n_lat]
77
+ lonlat = np.concatenate(lst, axis=-1)
78
+ new_corners = lonlat2uv(lonlat)
79
+ return new_corners
80
+
81
+
82
+ def pano_stretch(pano_img, corners, kx, ky, kz):
83
+ """
84
+ :param pano_img: a panorama image, shape must be [h,w,c]
85
+ :param corners:
86
+ :param kx: stretching along left-right direction
87
+ :param ky: stretching along up-down direction
88
+ :param kz: stretching along front-back direction
89
+ :return:
90
+ """
91
+ new_img = pano_stretch_image(pano_img, kx, ky, kz)
92
+ new_corners = pano_stretch_conner(corners, kx, ky, kz)
93
+ return new_img, new_corners
94
+
95
+
96
+ class PanoDataAugmentation:
97
+ def __init__(self, aug):
98
+ self.aug = aug
99
+ self.parameters = {}
100
+
101
+ def need_aug(self, name):
102
+ return name in self.aug and self.aug[name]
103
+
104
+ def execute_space_aug(self, corners, image):
105
+ if image is None:
106
+ return image
107
+
108
+ if self.aug is None:
109
+ return corners, image
110
+ w = image.shape[1]
111
+ h = image.shape[0]
112
+
113
+ if self.need_aug('STRETCH'):
114
+ kx = np.random.uniform(1, 2)
115
+ kx = 1 / kx if np.random.randint(2) == 0 else kx
116
+ # we found that the ky transform may cause IoU to drop (HorizonNet also only x and z transform)
117
+ # ky = np.random.uniform(1, 2)
118
+ # ky = 1 / ky if np.random.randint(2) == 0 else ky
119
+ ky = 1
120
+ kz = np.random.uniform(1, 2)
121
+ kz = 1 / kz if np.random.randint(2) == 0 else kz
122
+ image, corners = pano_stretch(image, corners, kx, ky, kz)
123
+ self.parameters['STRETCH'] = {'kx': kx, 'ky': ky, 'kz': kz}
124
+ else:
125
+ self.parameters['STRETCH'] = None
126
+
127
+ if self.need_aug('ROTATE'):
128
+ d_pu = np.random.randint(w)
129
+ image = np.roll(image, d_pu, axis=1)
130
+ corners[..., 0] = (corners[..., 0] + pixel2uv(np.array([d_pu]), w, h)) % pixel2uv(np.array([w]), w, h)
131
+ self.parameters['ROTATE'] = d_pu
132
+ else:
133
+ self.parameters['ROTATE'] = None
134
+
135
+ if self.need_aug('FLIP') and np.random.randint(2) == 0:
136
+ image = np.flip(image, axis=1).copy()
137
+ corners[..., 0] = pixel2uv(np.array([w]), w, h) - corners[..., 0]
138
+ corners = corners[::-1]
139
+ self.parameters['FLIP'] = True
140
+ else:
141
+ self.parameters['FLIP'] = None
142
+
143
+ return corners, image
144
+
145
+ def execute_visual_aug(self, image):
146
+ if self.need_aug('GAMMA'):
147
+ p = np.random.uniform(1, 2)
148
+ if np.random.randint(2) == 0:
149
+ p = 1 / p
150
+ image = image ** p
151
+ self.parameters['GAMMA'] = p
152
+ else:
153
+ self.parameters['GAMMA'] = None
154
+
155
+ # The following visual augmentation methods are only implemented but not tested
156
+ if self.need_aug('HUE') or self.need_aug('SATURATION'):
157
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
158
+
159
+ if self.need_aug('HUE') and np.random.randint(2) == 0:
160
+ p = np.random.uniform(-0.1, 0.1)
161
+ image[..., 0] = np.mod(image[..., 0] + p * 180, 180)
162
+ self.parameters['HUE'] = p
163
+ else:
164
+ self.parameters['HUE'] = None
165
+
166
+ if self.need_aug('SATURATION') and np.random.randint(2) == 0:
167
+ p = np.random.uniform(0.5, 1.5)
168
+ image[..., 1] = np.clip(image[..., 1] * p, 0, 1)
169
+ self.parameters['SATURATION'] = p
170
+ else:
171
+ self.parameters['SATURATION'] = None
172
+
173
+ image = cv2.cvtColor(image, cv2.COLOR_HSV2RGB)
174
+
175
+ if self.need_aug('CONTRAST') and np.random.randint(2) == 0:
176
+ p = np.random.uniform(0.9, 1.1)
177
+ mean = image.mean(axis=0).mean(axis=0)
178
+ image = (image - mean) * p + mean
179
+ image = np.clip(image, 0, 1)
180
+ self.parameters['CONTRAST'] = p
181
+ else:
182
+ self.parameters['CONTRAST'] = None
183
+
184
+ return image
185
+
186
+ def execute_aug(self, corners, image):
187
+ corners, image = self.execute_space_aug(corners, image)
188
+ if image is not None:
189
+ image = self.execute_visual_aug(image)
190
+ return corners, image
191
+
192
+
193
+ if __name__ == '__main__1':
194
+ from tqdm import trange
195
+ from visualization.floorplan import draw_floorplan
196
+ from dataset.communal.read import read_image, read_label
197
+ from utils.time_watch import TimeWatch
198
+ from utils.conversion import uv2xyz
199
+ from utils.boundary import corners2boundary
200
+
201
+ np.random.seed(123)
202
+ pano_img_path = "../../src/dataset/mp3d/image/TbHJrupSAjP_f320ae084f3a447da3e8ab11dd5f9320.png"
203
+ label_path = "../../src/dataset/mp3d/label/TbHJrupSAjP_f320ae084f3a447da3e8ab11dd5f9320.json"
204
+ pano_img = read_image(pano_img_path)
205
+ label = read_label(label_path)
206
+
207
+ corners = label['corners']
208
+ ratio = label['ratio']
209
+
210
+ pano_aug = PanoDataAugmentation(aug={
211
+ 'STRETCH': True,
212
+ 'ROTATE': True,
213
+ 'FLIP': True,
214
+ 'GAMMA': True,
215
+ # 'HUE': True,
216
+ # 'SATURATION': True,
217
+ # 'CONTRAST': True
218
+ })
219
+
220
+ # draw_floorplan(corners, show=True, marker_color=0.5, center_color=0.8, plan_y=1.6, show_radius=8)
221
+ # draw_boundaries(pano_img, corners_list=[corners], show=True, length=1024, ratio=ratio)
222
+
223
+ w = TimeWatch("test")
224
+ for i in trange(50000):
225
+ new_corners, new_pano_img = pano_aug.execute_aug(corners.copy(), pano_img.copy())
226
+ # draw_floorplan(uv2xyz(new_corners, plan_y=1.6)[..., ::2], show=True, marker_color=0.5, center_color=0.8,
227
+ # show_radius=8)
228
+ # draw_boundaries(new_pano_img, corners_list=[new_corners], show=True, length=1024, ratio=ratio)
229
+
230
+
231
+ if __name__ == '__main__':
232
+ from utils.boundary import corners2boundary
233
+ from visualization.floorplan import draw_floorplan
234
+ from utils.boundary import visibility_corners
235
+
236
+ corners = np.array([[0.7664539, 0.7416811],
237
+ [0.06641078, 0.6521386],
238
+ [0.30997428, 0.57855356],
239
+ [0.383300784, 0.58726823],
240
+ [0.383300775, 0.8005296],
241
+ [0.5062902, 0.74822706]])
242
+ corners = visibility_corners(corners)
243
+ print(corners)
244
+ # draw_floorplan(uv2xyz(corners, plan_y=1.6)[..., ::2], show=True, marker_color=0.5, center_color=0.8,
245
+ # show_radius=8)
246
+ visible_floor_boundary = corners2boundary(corners, length=256, visible=True)
247
+ # visible_depth = xyz2depth(uv2xyz(visible_floor_boundary, 1), 1)
248
+ print(len(visible_floor_boundary))
249
+
250
+
251
+ if __name__ == '__main__0':
252
+ from visualization.floorplan import draw_floorplan
253
+
254
+ from dataset.communal.read import read_image, read_label
255
+ from utils.time_watch import TimeWatch
256
+ from utils.conversion import uv2xyz
257
+
258
+ # np.random.seed(1234)
259
+ pano_img_path = "../../src/dataset/mp3d/image/VVfe2KiqLaN_35b41dcbfcf84f96878f6ca28c70e5af.png"
260
+ label_path = "../../src/dataset/mp3d/label/VVfe2KiqLaN_35b41dcbfcf84f96878f6ca28c70e5af.json"
261
+ pano_img = read_image(pano_img_path)
262
+ label = read_label(label_path)
263
+
264
+ corners = label['corners']
265
+ ratio = label['ratio']
266
+
267
+ # draw_floorplan(corners, show=True, marker_color=0.5, center_color=0.8, plan_y=1.6, show_radius=8)
268
+
269
+ w = TimeWatch()
270
+ for i in range(5):
271
+ kx = np.random.uniform(1, 2)
272
+ kx = 1 / kx if np.random.randint(2) == 0 else kx
273
+ ky = np.random.uniform(1, 2)
274
+ ky = 1 / ky if np.random.randint(2) == 0 else ky
275
+ kz = np.random.uniform(1, 2)
276
+ kz = 1 / kz if np.random.randint(2) == 0 else kz
277
+ new_corners = pano_stretch_conner(corners.copy(), kx, ky, kz)
278
+ draw_floorplan(uv2xyz(new_corners, plan_y=1.6)[..., ::2], show=True, marker_color=0.5, center_color=0.8,
279
+ show_radius=8)
dataset/communal/read.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/07/28
3
+ @description:
4
+ """
5
+ import os
6
+ import numpy as np
7
+ import cv2
8
+ import json
9
+ from PIL import Image
10
+ from utils.conversion import xyz2uv, pixel2uv
11
+ from utils.height import calc_ceil_ratio
12
+
13
+
14
+ def read_image(image_path, shape=None):
15
+ if shape is None:
16
+ shape = [512, 1024]
17
+ img = np.array(Image.open(image_path)).astype(np.float32) / 255
18
+ if img.shape[0] != shape[0] or img.shape[1] != shape[1]:
19
+ img = cv2.resize(img, dsize=tuple(shape[::-1]), interpolation=cv2.INTER_AREA)
20
+
21
+ return np.array(img)
22
+
23
+
24
+ def read_label(label_path, data_type='MP3D'):
25
+
26
+ if data_type == 'MP3D':
27
+ with open(label_path, 'r') as f:
28
+ label = json.load(f)
29
+ point_idx = [one['pointsIdx'][0] for one in label['layoutWalls']['walls']]
30
+ camera_height = label['cameraHeight']
31
+ room_height = label['layoutHeight']
32
+ camera_ceiling_height = room_height - camera_height
33
+ ratio = camera_ceiling_height / camera_height
34
+
35
+ xyz = [one['xyz'] for one in label['layoutPoints']['points']]
36
+ assert len(xyz) == len(point_idx), "len(xyz) != len(point_idx)"
37
+ xyz = [xyz[i] for i in point_idx]
38
+ xyz = np.asarray(xyz, dtype=np.float32)
39
+ xyz[:, 2] *= -1
40
+ xyz[:, 1] = camera_height
41
+ corners = xyz2uv(xyz)
42
+ elif data_type == 'Pano_S2D3D':
43
+ with open(label_path, 'r') as f:
44
+ lines = [line for line in f.readlines() if
45
+ len([c for c in line.split(' ') if c[0].isnumeric()]) > 1]
46
+
47
+ corners_list = np.array([line.strip().split() for line in lines], np.float32)
48
+ uv_list = pixel2uv(corners_list)
49
+ ceil_uv = uv_list[::2]
50
+ floor_uv = uv_list[1::2]
51
+ ratio = calc_ceil_ratio([ceil_uv, floor_uv], mode='mean')
52
+ corners = floor_uv
53
+ else:
54
+ return None
55
+
56
+ output = {
57
+ 'ratio': np.array([ratio], dtype=np.float32),
58
+ 'corners': corners,
59
+ 'id': os.path.basename(label_path).split('.')[0]
60
+ }
61
+ return output
62
+
63
+
64
+ def move_not_simple_image(data_dir, simple_panos):
65
+ import shutil
66
+ for house_index in os.listdir(data_dir):
67
+ house_path = os.path.join(data_dir, house_index)
68
+ if not os.path.isdir(house_path) or house_index == 'visualization':
69
+ continue
70
+
71
+ floor_plan_path = os.path.join(house_path, 'floor_plans')
72
+ if os.path.exists(floor_plan_path):
73
+ print(f'move:{floor_plan_path}')
74
+ dst_floor_plan_path = floor_plan_path.replace('zind', 'zind2')
75
+ os.makedirs(dst_floor_plan_path, exist_ok=True)
76
+ shutil.move(floor_plan_path, dst_floor_plan_path)
77
+
78
+ panos_path = os.path.join(house_path, 'panos')
79
+ for pano in os.listdir(panos_path):
80
+ pano_path = os.path.join(panos_path, pano)
81
+ pano_index = '_'.join(pano.split('.')[0].split('_')[-2:])
82
+ if f'{house_index}_{pano_index}' not in simple_panos and os.path.exists(pano_path):
83
+ print(f'move:{pano_path}')
84
+ dst_pano_path = pano_path.replace('zind', 'zind2')
85
+ os.makedirs(os.path.dirname(dst_pano_path), exist_ok=True)
86
+ shutil.move(pano_path, dst_pano_path)
87
+
88
+
89
+ def read_zind(partition_path, simplicity_path, data_dir, mode, is_simple=True,
90
+ layout_type='layout_raw', is_ceiling_flat=False, plan_y=1):
91
+ with open(simplicity_path, 'r') as f:
92
+ simple_tag = json.load(f)
93
+ simple_panos = {}
94
+ for k in simple_tag.keys():
95
+ if not simple_tag[k]:
96
+ continue
97
+ split = k.split('_')
98
+ house_index = split[0]
99
+ pano_index = '_'.join(split[-2:])
100
+ simple_panos[f'{house_index}_{pano_index}'] = True
101
+
102
+ # move_not_simple_image(data_dir, simple_panos)
103
+
104
+ pano_list = []
105
+ with open(partition_path, 'r') as f1:
106
+ house_list = json.load(f1)[mode]
107
+
108
+ for house_index in house_list:
109
+ with open(os.path.join(data_dir, house_index, f"zind_data.json"), 'r') as f2:
110
+ data = json.load(f2)
111
+
112
+ panos = []
113
+ merger = data['merger']
114
+ for floor in merger.values():
115
+ for complete_room in floor.values():
116
+ for partial_room in complete_room.values():
117
+ for pano_index in partial_room:
118
+ pano = partial_room[pano_index]
119
+ pano['index'] = pano_index
120
+ panos.append(pano)
121
+
122
+ for pano in panos:
123
+ if layout_type not in pano:
124
+ continue
125
+ pano_index = pano['index']
126
+
127
+ if is_simple and f'{house_index}_{pano_index}' not in simple_panos.keys():
128
+ continue
129
+
130
+ if is_ceiling_flat and not pano['is_ceiling_flat']:
131
+ continue
132
+
133
+ layout = pano[layout_type]
134
+ # corners
135
+ corner_xz = np.array(layout['vertices'])
136
+ corner_xz[..., 0] = -corner_xz[..., 0]
137
+ corner_xyz = np.insert(corner_xz, 1, pano['camera_height'], axis=1)
138
+ corners = xyz2uv(corner_xyz).astype(np.float32)
139
+
140
+ # ratio
141
+ ratio = np.array([(pano['ceiling_height'] - pano['camera_height']) / pano['camera_height']], dtype=np.float32)
142
+
143
+ # Ours future work: detection window, door, opening
144
+ objects = {
145
+ 'windows': [],
146
+ 'doors': [],
147
+ 'openings': [],
148
+ }
149
+ for label_index, wdo_type in enumerate(["windows", "doors", "openings"]):
150
+ if wdo_type not in layout:
151
+ continue
152
+
153
+ wdo_vertices = np.array(layout[wdo_type])
154
+ if len(wdo_vertices) == 0:
155
+ continue
156
+
157
+ assert len(wdo_vertices) % 3 == 0
158
+
159
+ for i in range(0, len(wdo_vertices), 3):
160
+ # In the Zind dataset, the camera height is 1, and the default camera height in our code is also 1,
161
+ # so the xyz coordinate here can be used directly
162
+ # Since we're taking the opposite z-axis, we're changing the order of left and right
163
+
164
+ left_bottom_xyz = np.array(
165
+ [-wdo_vertices[i + 1][0], -wdo_vertices[i + 2][0], wdo_vertices[i + 1][1]])
166
+ right_bottom_xyz = np.array(
167
+ [-wdo_vertices[i][0], -wdo_vertices[i + 2][0], wdo_vertices[i][1]])
168
+ center_bottom_xyz = (left_bottom_xyz + right_bottom_xyz) / 2
169
+
170
+ center_top_xyz = center_bottom_xyz.copy()
171
+ center_top_xyz[1] = -wdo_vertices[i + 2][1]
172
+
173
+ center_boundary_xyz = center_bottom_xyz.copy()
174
+ center_boundary_xyz[1] = plan_y
175
+
176
+ uv = xyz2uv(np.array([left_bottom_xyz, right_bottom_xyz,
177
+ center_bottom_xyz, center_top_xyz,
178
+ center_boundary_xyz]))
179
+
180
+ left_bottom_uv = uv[0]
181
+ right_bottom_uv = uv[1]
182
+ width_u = abs(right_bottom_uv[0] - left_bottom_uv[0])
183
+ width_u = 1 - width_u if width_u > 0.5 else width_u
184
+ assert width_u > 0, width_u
185
+
186
+ center_bottom_uv = uv[2]
187
+ center_top_uv = uv[3]
188
+ height_v = center_bottom_uv[1] - center_top_uv[1]
189
+
190
+ if height_v < 0:
191
+ continue
192
+
193
+ center_boundary_uv = uv[4]
194
+ boundary_v = center_boundary_uv[1] - center_bottom_uv[1] if wdo_type == 'windows' else 0
195
+ boundary_v = 0 if boundary_v < 0 else boundary_v
196
+
197
+ center_u = center_bottom_uv[0]
198
+
199
+ objects[wdo_type].append({
200
+ 'width_u': width_u,
201
+ 'height_v': height_v,
202
+ 'boundary_v': boundary_v,
203
+ 'center_u': center_u
204
+ })
205
+
206
+ pano_list.append({
207
+ 'img_path': os.path.join(data_dir, house_index, pano['image_path']),
208
+ 'corners': corners,
209
+ 'objects': objects,
210
+ 'ratio': ratio,
211
+ 'id': f'{house_index}_{pano_index}',
212
+ 'is_inside': pano['is_inside']
213
+ })
214
+ return pano_list
dataset/mp3d_dataset.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @date: 2021/6/25
3
+ @description:
4
+ """
5
+ import os
6
+ import json
7
+
8
+ from dataset.communal.read import read_image, read_label
9
+ from dataset.communal.base_dataset import BaseDataset
10
+ from utils.logger import get_logger
11
+
12
+
13
+ class MP3DDataset(BaseDataset):
14
+ def __init__(self, root_dir, mode, shape=None, max_wall_num=0, aug=None, camera_height=1.6, logger=None,
15
+ split_list=None, patch_num=256, keys=None, for_test_index=None):
16
+ super().__init__(mode, shape, max_wall_num, aug, camera_height, patch_num, keys)
17
+
18
+ if logger is None:
19
+ logger = get_logger()
20
+ self.root_dir = root_dir
21
+
22
+ split_dir = os.path.join(root_dir, 'split')
23
+ label_dir = os.path.join(root_dir, 'label')
24
+ img_dir = os.path.join(root_dir, 'image')
25
+
26
+ if split_list is None:
27
+ with open(os.path.join(split_dir, f"{mode}.txt"), 'r') as f:
28
+ split_list = [x.rstrip().split() for x in f]
29
+
30
+ split_list.sort()
31
+ if for_test_index is not None:
32
+ split_list = split_list[:for_test_index]
33
+
34
+ self.data = []
35
+ invalid_num = 0
36
+ for name in split_list:
37
+ name = "_".join(name)
38
+ img_path = os.path.join(img_dir, f"{name}.png")
39
+ label_path = os.path.join(label_dir, f"{name}.json")
40
+
41
+ if not os.path.exists(img_path):
42
+ logger.warning(f"{img_path} not exists")
43
+ invalid_num += 1
44
+ continue
45
+ if not os.path.exists(label_path):
46
+ logger.warning(f"{label_path} not exists")
47
+ invalid_num += 1
48
+ continue
49
+
50
+ with open(label_path, 'r') as f:
51
+ label = json.load(f)
52
+
53
+ if self.max_wall_num >= 10:
54
+ if label['layoutWalls']['num'] < self.max_wall_num:
55
+ invalid_num += 1
56
+ continue
57
+ elif self.max_wall_num != 0 and label['layoutWalls']['num'] != self.max_wall_num:
58
+ invalid_num += 1
59
+ continue
60
+
61
+ # print(label['layoutWalls']['num'])
62
+ self.data.append([img_path, label_path])
63
+
64
+ logger.info(
65
+ f"Build dataset mode: {self.mode} max_wall_num: {self.max_wall_num} valid: {len(self.data)} invalid: {invalid_num}")
66
+
67
+ def __getitem__(self, idx):
68
+ rgb_path, label_path = self.data[idx]
69
+ label = read_label(label_path, data_type='MP3D')
70
+ image = read_image(rgb_path, self.shape)
71
+ output = self.process_data(label, image, self.patch_num)
72
+ return output
73
+
74
+
75
+ if __name__ == "__main__":
76
+ import numpy as np
77
+ from PIL import Image
78
+
79
+ from tqdm import tqdm
80
+ from visualization.boundary import draw_boundaries
81
+ from visualization.floorplan import draw_floorplan
82
+ from utils.boundary import depth2boundaries
83
+ from utils.conversion import uv2xyz
84
+
85
+ modes = ['test', 'val']
86
+ for i in range(1):
87
+ for mode in modes:
88
+ print(mode)
89
+ mp3d_dataset = MP3DDataset(root_dir='../src/dataset/mp3d', mode=mode, aug={
90
+ 'STRETCH': True,
91
+ 'ROTATE': True,
92
+ 'FLIP': True,
93
+ 'GAMMA': True
94
+ })
95
+ save_dir = f'../src/dataset/mp3d/visualization/{mode}'
96
+ if not os.path.isdir(save_dir):
97
+ os.makedirs(save_dir)
98
+
99
+ bar = tqdm(mp3d_dataset, ncols=100)
100
+ for data in bar:
101
+ bar.set_description(f"Processing {data['id']}")
102
+ boundary_list = depth2boundaries(data['ratio'], data['depth'], step=None)
103
+ pano_img = draw_boundaries(data['image'].transpose(1, 2, 0), boundary_list=boundary_list, show=True)
104
+ Image.fromarray((pano_img * 255).astype(np.uint8)).save(
105
+ os.path.join(save_dir, f"{data['id']}_boundary.png"))
106
+
107
+ floorplan = draw_floorplan(uv2xyz(boundary_list[0])[..., ::2], show=True,
108
+ marker_color=None, center_color=0.8, show_radius=None)
109
+ Image.fromarray((floorplan.squeeze() * 255).astype(np.uint8)).save(
110
+ os.path.join(save_dir, f"{data['id']}_floorplan.png"))
dataset/pano_s2d3d_dataset.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @date: 2021/6/16
3
+ @description:
4
+ """
5
+ import math
6
+ import os
7
+ import numpy as np
8
+
9
+ from dataset.communal.read import read_image, read_label
10
+ from dataset.communal.base_dataset import BaseDataset
11
+ from utils.logger import get_logger
12
+
13
+
14
+ class PanoS2D3DDataset(BaseDataset):
15
+ def __init__(self, root_dir, mode, shape=None, max_wall_num=0, aug=None, camera_height=1.6, logger=None,
16
+ split_list=None, patch_num=256, keys=None, for_test_index=None, subset=None):
17
+ super().__init__(mode, shape, max_wall_num, aug, camera_height, patch_num, keys)
18
+
19
+ if logger is None:
20
+ logger = get_logger()
21
+ self.root_dir = root_dir
22
+
23
+ if mode is None:
24
+ return
25
+ label_dir = os.path.join(root_dir, 'valid' if mode == 'val' else mode, 'label_cor')
26
+ img_dir = os.path.join(root_dir, 'valid' if mode == 'val' else mode, 'img')
27
+
28
+ if split_list is None:
29
+ split_list = [name.split('.')[0] for name in os.listdir(label_dir) if
30
+ not name.startswith('.') and name.endswith('txt')]
31
+
32
+ split_list.sort()
33
+
34
+ assert subset == 'pano' or subset == 's2d3d' or subset is None, 'error subset'
35
+ if subset == 'pano':
36
+ split_list = [name for name in split_list if 'pano_' in name]
37
+ logger.info(f"Use PanoContext Dataset")
38
+ elif subset == 's2d3d':
39
+ split_list = [name for name in split_list if 'camera_' in name]
40
+ logger.info(f"Use Stanford2D3D Dataset")
41
+
42
+ if for_test_index is not None:
43
+ split_list = split_list[:for_test_index]
44
+
45
+ self.data = []
46
+ invalid_num = 0
47
+ for name in split_list:
48
+ img_path = os.path.join(img_dir, f"{name}.png")
49
+ label_path = os.path.join(label_dir, f"{name}.txt")
50
+
51
+ if not os.path.exists(img_path):
52
+ logger.warning(f"{img_path} not exists")
53
+ invalid_num += 1
54
+ continue
55
+ if not os.path.exists(label_path):
56
+ logger.warning(f"{label_path} not exists")
57
+ invalid_num += 1
58
+ continue
59
+
60
+ with open(label_path, 'r') as f:
61
+ lines = [line for line in f.readlines() if
62
+ len([c for c in line.split(' ') if c[0].isnumeric()]) > 1]
63
+ if len(lines) % 2 != 0:
64
+ invalid_num += 1
65
+ continue
66
+ self.data.append([img_path, label_path])
67
+
68
+ logger.info(
69
+ f"Build dataset mode: {self.mode} valid: {len(self.data)} invalid: {invalid_num}")
70
+
71
+ def __getitem__(self, idx):
72
+ rgb_path, label_path = self.data[idx]
73
+ label = read_label(label_path, data_type='Pano_S2D3D')
74
+ image = read_image(rgb_path, self.shape)
75
+ output = self.process_data(label, image, self.patch_num)
76
+ return output
77
+
78
+
79
+ if __name__ == '__main__':
80
+
81
+ modes = ['test', 'val', 'train']
82
+ for i in range(1):
83
+ for mode in modes:
84
+ print(mode)
85
+ mp3d_dataset = PanoS2D3DDataset(root_dir='../src/dataset/pano_s2d3d', mode=mode, aug={
86
+ # 'STRETCH': True,
87
+ # 'ROTATE': True,
88
+ # 'FLIP': True,
89
+ # 'GAMMA': True
90
+ })
91
+ continue
92
+ save_dir = f'../src/dataset/pano_s2d3d/visualization/{mode}'
93
+ if not os.path.isdir(save_dir):
94
+ os.makedirs(save_dir)
95
+
96
+ bar = tqdm(mp3d_dataset, ncols=100)
97
+ for data in bar:
98
+ bar.set_description(f"Processing {data['id']}")
99
+ boundary_list = depth2boundaries(data['ratio'], data['depth'], step=None)
100
+ pano_img = draw_boundaries(data['image'].transpose(1, 2, 0), boundary_list=boundary_list, show=False)
101
+ Image.fromarray((pano_img * 255).astype(np.uint8)).save(
102
+ os.path.join(save_dir, f"{data['id']}_boundary.png"))
103
+
104
+ floorplan = draw_floorplan(uv2xyz(boundary_list[0])[..., ::2], show=False,
105
+ marker_color=None, center_color=0.8, show_radius=None)
106
+ Image.fromarray((floorplan.squeeze() * 255).astype(np.uint8)).save(
107
+ os.path.join(save_dir, f"{data['id']}_floorplan.png"))
dataset/pano_s2d3d_mix_dataset.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @date: 2021/6/16
3
+ @description:
4
+ """
5
+
6
+ import os
7
+
8
+ from dataset.pano_s2d3d_dataset import PanoS2D3DDataset
9
+ from utils.logger import get_logger
10
+
11
+
12
+ class PanoS2D3DMixDataset(PanoS2D3DDataset):
13
+ def __init__(self, root_dir, mode, shape=None, max_wall_num=0, aug=None, camera_height=1.6, logger=None,
14
+ split_list=None, patch_num=256, keys=None, for_test_index=None, subset=None):
15
+ assert subset == 's2d3d' or subset == 'pano', 'error subset'
16
+ super().__init__(root_dir, None, shape, max_wall_num, aug, camera_height, logger,
17
+ split_list, patch_num, keys, None, subset)
18
+ if logger is None:
19
+ logger = get_logger()
20
+ self.mode = mode
21
+ if mode == 'train':
22
+ if subset == 'pano':
23
+ s2d3d_train_data = PanoS2D3DDataset(root_dir, 'train', shape, max_wall_num, aug, camera_height, logger,
24
+ split_list, patch_num, keys, None, 's2d3d').data
25
+ s2d3d_val_data = PanoS2D3DDataset(root_dir, 'val', shape, max_wall_num, aug, camera_height, logger,
26
+ split_list, patch_num, keys, None, 's2d3d').data
27
+ s2d3d_test_data = PanoS2D3DDataset(root_dir, 'test', shape, max_wall_num, aug, camera_height, logger,
28
+ split_list, patch_num, keys, None, 's2d3d').data
29
+ s2d3d_all_data = s2d3d_train_data + s2d3d_val_data + s2d3d_test_data
30
+
31
+ pano_train_data = PanoS2D3DDataset(root_dir, 'train', shape, max_wall_num, aug, camera_height, logger,
32
+ split_list, patch_num, keys, None, 'pano').data
33
+ self.data = s2d3d_all_data + pano_train_data
34
+ elif subset == 's2d3d':
35
+ pano_train_data = PanoS2D3DDataset(root_dir, 'train', shape, max_wall_num, aug, camera_height, logger,
36
+ split_list, patch_num, keys, None, 'pano').data
37
+ pano_val_data = PanoS2D3DDataset(root_dir, 'val', shape, max_wall_num, aug, camera_height, logger,
38
+ split_list, patch_num, keys, None, 'pano').data
39
+ pano_test_data = PanoS2D3DDataset(root_dir, 'test', shape, max_wall_num, aug, camera_height, logger,
40
+ split_list, patch_num, keys, None, 'pano').data
41
+ pano_all_data = pano_train_data + pano_val_data + pano_test_data
42
+
43
+ s2d3d_train_data = PanoS2D3DDataset(root_dir, 'train', shape, max_wall_num, aug, camera_height, logger,
44
+ split_list, patch_num, keys, None, 's2d3d').data
45
+ self.data = pano_all_data + s2d3d_train_data
46
+ else:
47
+ self.data = PanoS2D3DDataset(root_dir, mode, shape, max_wall_num, aug, camera_height, logger,
48
+ split_list, patch_num, keys, None, subset).data
49
+
50
+ if for_test_index is not None:
51
+ self.data = self.data[:for_test_index]
52
+ logger.info(f"Build dataset mode: {self.mode} valid: {len(self.data)}")
53
+
54
+
55
+ if __name__ == '__main__':
56
+ import numpy as np
57
+ from PIL import Image
58
+
59
+ from tqdm import tqdm
60
+ from visualization.boundary import draw_boundaries
61
+ from visualization.floorplan import draw_floorplan
62
+ from utils.boundary import depth2boundaries
63
+ from utils.conversion import uv2xyz
64
+
65
+ modes = ['test', 'val', 'train']
66
+ for i in range(1):
67
+ for mode in modes:
68
+ print(mode)
69
+ mp3d_dataset = PanoS2D3DMixDataset(root_dir='../src/dataset/pano_s2d3d', mode=mode, aug={
70
+ # 'STRETCH': True,
71
+ # 'ROTATE': True,
72
+ # 'FLIP': True,
73
+ # 'GAMMA': True
74
+ }, subset='pano')
75
+ continue
76
+ save_dir = f'../src/dataset/pano_s2d3d/visualization1/{mode}'
77
+ if not os.path.isdir(save_dir):
78
+ os.makedirs(save_dir)
79
+
80
+ bar = tqdm(mp3d_dataset, ncols=100)
81
+ for data in bar:
82
+ bar.set_description(f"Processing {data['id']}")
83
+ boundary_list = depth2boundaries(data['ratio'], data['depth'], step=None)
84
+ pano_img = draw_boundaries(data['image'].transpose(1, 2, 0), boundary_list=boundary_list, show=False)
85
+ Image.fromarray((pano_img * 255).astype(np.uint8)).save(
86
+ os.path.join(save_dir, f"{data['id']}_boundary.png"))
87
+
88
+ floorplan = draw_floorplan(uv2xyz(boundary_list[0])[..., ::2], show=False,
89
+ marker_color=None, center_color=0.8, show_radius=None)
90
+ Image.fromarray((floorplan.squeeze() * 255).astype(np.uint8)).save(
91
+ os.path.join(save_dir, f"{data['id']}_floorplan.png"))
dataset/zind_dataset.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/09/22
3
+ @description:
4
+ """
5
+ import os
6
+ import json
7
+ import math
8
+ import numpy as np
9
+
10
+ from dataset.communal.read import read_image, read_label, read_zind
11
+ from dataset.communal.base_dataset import BaseDataset
12
+ from utils.logger import get_logger
13
+ from preprocessing.filter import filter_center, filter_boundary, filter_self_intersection
14
+ from utils.boundary import calc_rotation
15
+
16
+
17
+ class ZindDataset(BaseDataset):
18
+ def __init__(self, root_dir, mode, shape=None, max_wall_num=0, aug=None, camera_height=1.6, logger=None,
19
+ split_list=None, patch_num=256, keys=None, for_test_index=None,
20
+ is_simple=True, is_ceiling_flat=False, vp_align=False):
21
+ # if keys is None:
22
+ # keys = ['image', 'depth', 'ratio', 'id', 'corners', 'corner_heat_map', 'object']
23
+ super().__init__(mode, shape, max_wall_num, aug, camera_height, patch_num, keys)
24
+ if logger is None:
25
+ logger = get_logger()
26
+ self.root_dir = root_dir
27
+ self.vp_align = vp_align
28
+
29
+ data_dir = os.path.join(root_dir)
30
+ img_dir = os.path.join(root_dir, 'image')
31
+
32
+ pano_list = read_zind(partition_path=os.path.join(data_dir, f"zind_partition.json"),
33
+ simplicity_path=os.path.join(data_dir, f"room_shape_simplicity_labels.json"),
34
+ data_dir=data_dir, mode=mode, is_simple=is_simple, is_ceiling_flat=is_ceiling_flat)
35
+
36
+ if for_test_index is not None:
37
+ pano_list = pano_list[:for_test_index]
38
+ if split_list:
39
+ pano_list = [pano for pano in pano_list if pano['id'] in split_list]
40
+ self.data = []
41
+ invalid_num = 0
42
+ for pano in pano_list:
43
+ if not os.path.exists(pano['img_path']):
44
+ logger.warning(f"{pano['img_path']} not exists")
45
+ invalid_num += 1
46
+ continue
47
+
48
+ if not filter_center(pano['corners']):
49
+ # logger.warning(f"{pano['id']} camera center not in layout")
50
+ # invalid_num += 1
51
+ continue
52
+
53
+ if self.max_wall_num >= 10:
54
+ if len(pano['corners']) < self.max_wall_num:
55
+ invalid_num += 1
56
+ continue
57
+ elif self.max_wall_num != 0 and len(pano['corners']) != self.max_wall_num:
58
+ invalid_num += 1
59
+ continue
60
+
61
+ if not filter_boundary(pano['corners']):
62
+ logger.warning(f"{pano['id']} boundary cross")
63
+ invalid_num += 1
64
+ continue
65
+
66
+ if not filter_self_intersection(pano['corners']):
67
+ logger.warning(f"{pano['id']} self_intersection")
68
+ invalid_num += 1
69
+ continue
70
+
71
+ self.data.append(pano)
72
+
73
+ logger.info(
74
+ f"Build dataset mode: {self.mode} max_wall_num: {self.max_wall_num} valid: {len(self.data)} invalid: {invalid_num}")
75
+
76
+ def __getitem__(self, idx):
77
+ pano = self.data[idx]
78
+ rgb_path = pano['img_path']
79
+ label = pano
80
+ image = read_image(rgb_path, self.shape)
81
+
82
+ if self.vp_align:
83
+ # Equivalent to vanishing point alignment step
84
+ rotation = calc_rotation(corners=label['corners'])
85
+ shift = math.modf(rotation / (2 * np.pi) + 1)[0]
86
+ image = np.roll(image, round(shift * self.shape[1]), axis=1)
87
+ label['corners'][:, 0] = np.modf(label['corners'][:, 0] + shift)[0]
88
+
89
+ output = self.process_data(label, image, self.patch_num)
90
+ return output
91
+
92
+
93
+ if __name__ == "__main__":
94
+ import numpy as np
95
+ from PIL import Image
96
+
97
+ from tqdm import tqdm
98
+ from visualization.boundary import draw_boundaries, draw_object
99
+ from visualization.floorplan import draw_floorplan
100
+ from utils.boundary import depth2boundaries, calc_rotation
101
+ from utils.conversion import uv2xyz
102
+ from models.other.init_env import init_env
103
+
104
+ init_env(123)
105
+
106
+ modes = ['val']
107
+ for i in range(1):
108
+ for mode in modes:
109
+ print(mode)
110
+ mp3d_dataset = ZindDataset(root_dir='../src/dataset/zind', mode=mode, aug={
111
+ 'STRETCH': False,
112
+ 'ROTATE': False,
113
+ 'FLIP': False,
114
+ 'GAMMA': False
115
+ })
116
+ # continue
117
+ # save_dir = f'../src/dataset/zind/visualization/{mode}'
118
+ # if not os.path.isdir(save_dir):
119
+ # os.makedirs(save_dir)
120
+
121
+ bar = tqdm(mp3d_dataset, ncols=100)
122
+ for data in bar:
123
+ # if data['id'] != '1079_pano_18':
124
+ # continue
125
+ bar.set_description(f"Processing {data['id']}")
126
+ boundary_list = depth2boundaries(data['ratio'], data['depth'], step=None)
127
+
128
+ pano_img = draw_boundaries(data['image'].transpose(1, 2, 0), boundary_list=boundary_list, show=True)
129
+ # Image.fromarray((pano_img * 255).astype(np.uint8)).save(
130
+ # os.path.join(save_dir, f"{data['id']}_boundary.png"))
131
+ # draw_object(pano_img, heat_maps=data['object_heat_map'], depth=data['depth'],
132
+ # size=data['object_size'], show=True)
133
+ # pass
134
+ #
135
+ floorplan = draw_floorplan(uv2xyz(boundary_list[0])[..., ::2], show=True,
136
+ marker_color=None, center_color=0.2)
137
+ # Image.fromarray((floorplan.squeeze() * 255).astype(np.uint8)).save(
138
+ # os.path.join(save_dir, f"{data['id']}_floorplan.png"))
evaluation/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ @date: 2021/6/29
3
+ @description:
4
+ """
evaluation/accuracy.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @date: 2021/8/4
3
+ @description:
4
+ """
5
+ import numpy as np
6
+ import cv2
7
+ import scipy
8
+
9
+ from evaluation.f1_score import f1_score_2d
10
+ from loss import GradLoss
11
+ from utils.boundary import corners2boundaries, layout2depth
12
+ from utils.conversion import depth2xyz, uv2xyz, get_u, depth2uv, xyz2uv, uv2pixel
13
+ from utils.height import calc_ceil_ratio
14
+ from evaluation.iou import calc_IoU, calc_Iou_height
15
+ from visualization.boundary import draw_boundaries
16
+ from visualization.floorplan import draw_iou_floorplan
17
+ from visualization.grad import show_grad
18
+
19
+
20
+ def calc_accuracy(dt, gt, visualization=False, h=512):
21
+ visb_iou_2ds = []
22
+ visb_iou_3ds = []
23
+ full_iou_2ds = []
24
+ full_iou_3ds = []
25
+ iou_heights = []
26
+
27
+ visb_iou_floodplans = []
28
+ full_iou_floodplans = []
29
+ pano_bds = []
30
+
31
+ if 'depth' not in dt.keys():
32
+ dt['depth'] = gt['depth']
33
+
34
+ for i in range(len(gt['depth'])):
35
+ # print(i)
36
+ dt_xyz = dt['processed_xyz'][i] if 'processed_xyz' in dt else depth2xyz(np.abs(dt['depth'][i]))
37
+ visb_gt_xyz = depth2xyz(np.abs(gt['depth'][i]))
38
+ corners = gt['corners'][i]
39
+ full_gt_corners = corners[corners[..., 0] + corners[..., 1] != 0] # Take effective corners
40
+ full_gt_xyz = uv2xyz(full_gt_corners)
41
+
42
+ dt_xz = dt_xyz[..., ::2]
43
+ visb_gt_xz = visb_gt_xyz[..., ::2]
44
+ full_gt_xz = full_gt_xyz[..., ::2]
45
+
46
+ gt_ratio = gt['ratio'][i][0]
47
+
48
+ if 'ratio' not in dt.keys():
49
+ if 'boundary' in dt.keys():
50
+ w = len(dt['boundary'][i])
51
+ boundary = np.clip(dt['boundary'][i], 0.0001, 0.4999)
52
+ depth = np.clip(dt['depth'][i], 0.001, 9999)
53
+ dt_ceil_boundary = np.concatenate([get_u(w, is_np=True)[..., None], boundary], axis=-1)
54
+ dt_floor_boundary = depth2uv(depth)
55
+ dt_ratio = calc_ceil_ratio(boundaries=[dt_ceil_boundary, dt_floor_boundary])
56
+ else:
57
+ dt_ratio = gt_ratio
58
+ else:
59
+ dt_ratio = dt['ratio'][i][0]
60
+
61
+ visb_iou_2d, visb_iou_3d = calc_IoU(dt_xz, visb_gt_xz, dt_height=1 + dt_ratio, gt_height=1 + gt_ratio)
62
+ full_iou_2d, full_iou_3d = calc_IoU(dt_xz, full_gt_xz, dt_height=1 + dt_ratio, gt_height=1 + gt_ratio)
63
+ iou_height = calc_Iou_height(dt_height=1 + dt_ratio, gt_height=1 + gt_ratio)
64
+
65
+ visb_iou_2ds.append(visb_iou_2d)
66
+ visb_iou_3ds.append(visb_iou_3d)
67
+ full_iou_2ds.append(full_iou_2d)
68
+ full_iou_3ds.append(full_iou_3d)
69
+ iou_heights.append(iou_height)
70
+
71
+ if visualization:
72
+ pano_img = cv2.resize(gt['image'][i].transpose(1, 2, 0), (h*2, h))
73
+ # visb_iou_floodplans.append(draw_iou_floorplan(dt_xz, visb_gt_xz, iou_2d=visb_iou_2d, iou_3d=visb_iou_3d, side_l=h))
74
+ # full_iou_floodplans.append(draw_iou_floorplan(dt_xz, full_gt_xz, iou_2d=full_iou_2d, iou_3d=full_iou_3d, side_l=h))
75
+ visb_iou_floodplans.append(draw_iou_floorplan(dt_xz, visb_gt_xz, side_l=h))
76
+ full_iou_floodplans.append(draw_iou_floorplan(dt_xz, full_gt_xz, side_l=h))
77
+ gt_boundaries = corners2boundaries(gt_ratio, corners_xyz=full_gt_xyz, step=None, length=1024, visible=False)
78
+ dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt_xyz, step=None, visible=False,
79
+ length=1024)#visb_gt_xyz.shape[0] if dt_xyz.shape[0] != visb_gt_xyz.shape[0] else None)
80
+
81
+ pano_bd = draw_boundaries(pano_img, boundary_list=gt_boundaries, boundary_color=[0, 0, 1])
82
+ pano_bd = draw_boundaries(pano_bd, boundary_list=dt_boundaries, boundary_color=[0, 1, 0])
83
+ pano_bds.append(pano_bd)
84
+
85
+ visb_iou_2d = np.array(visb_iou_2ds).mean()
86
+ visb_iou_3d = np.array(visb_iou_3ds).mean()
87
+ full_iou_2d = np.array(full_iou_2ds).mean()
88
+ full_iou_3d = np.array(full_iou_3ds).mean()
89
+ iou_height = np.array(iou_heights).mean()
90
+
91
+ if visualization:
92
+ visb_iou_floodplans = np.array(visb_iou_floodplans).transpose(0, 3, 1, 2) # NCHW
93
+ full_iou_floodplans = np.array(full_iou_floodplans).transpose(0, 3, 1, 2) # NCHW
94
+ pano_bds = np.array(pano_bds).transpose(0, 3, 1, 2)
95
+ return [visb_iou_2d, visb_iou_3d, visb_iou_floodplans],\
96
+ [full_iou_2d, full_iou_3d, full_iou_floodplans], iou_height, pano_bds, full_iou_2ds
97
+
98
+
99
+ def calc_ce(dt, gt):
100
+ w = 1024
101
+ h = 512
102
+ ce_s = []
103
+ for i in range(len(gt['corners'])):
104
+ floor_gt_corners = gt['corners'][i]
105
+ # Take effective corners
106
+ floor_gt_corners = floor_gt_corners[floor_gt_corners[..., 0] + floor_gt_corners[..., 1] != 0]
107
+ floor_gt_corners = np.roll(floor_gt_corners, -np.argmin(floor_gt_corners[..., 0]), 0)
108
+ gt_ratio = gt['ratio'][i][0]
109
+ ceil_gt_corners = corners2boundaries(gt_ratio, corners_uv=floor_gt_corners, step=None)[1]
110
+ gt_corners = np.concatenate((floor_gt_corners, ceil_gt_corners))
111
+ gt_corners = uv2pixel(gt_corners, w, h)
112
+
113
+ floor_dt_corners = xyz2uv(dt['processed_xyz'][i])
114
+ floor_dt_corners = np.roll(floor_dt_corners, -np.argmin(floor_dt_corners[..., 0]), 0)
115
+ dt_ratio = dt['ratio'][i][0]
116
+ ceil_dt_corners = corners2boundaries(dt_ratio, corners_uv=floor_dt_corners, step=None)[1]
117
+ dt_corners = np.concatenate((floor_dt_corners, ceil_dt_corners))
118
+ dt_corners = uv2pixel(dt_corners, w, h)
119
+
120
+ mse = np.sqrt(((gt_corners - dt_corners) ** 2).sum(1)).mean()
121
+ ce = 100 * mse / np.sqrt(w ** 2 + h ** 2)
122
+ ce_s.append(ce)
123
+
124
+ return np.array(ce_s).mean()
125
+
126
+
127
+ def calc_pe(dt, gt):
128
+ w = 1024
129
+ h = 512
130
+ pe_s = []
131
+ for i in range(len(gt['corners'])):
132
+ floor_gt_corners = gt['corners'][i]
133
+ # Take effective corners
134
+ floor_gt_corners = floor_gt_corners[floor_gt_corners[..., 0] + floor_gt_corners[..., 1] != 0]
135
+ floor_gt_corners = np.roll(floor_gt_corners, -np.argmin(floor_gt_corners[..., 0]), 0)
136
+ gt_ratio = gt['ratio'][i][0]
137
+ gt_floor_boundary, gt_ceil_boundary = corners2boundaries(gt_ratio, corners_uv=floor_gt_corners, length=w)
138
+ gt_floor_boundary = uv2pixel(gt_floor_boundary, w, h)
139
+ gt_ceil_boundary = uv2pixel(gt_ceil_boundary, w, h)
140
+
141
+ floor_dt_corners = xyz2uv(dt['processed_xyz'][i])
142
+ floor_dt_corners = np.roll(floor_dt_corners, -np.argmin(floor_dt_corners[..., 0]), 0)
143
+ dt_ratio = dt['ratio'][i][0]
144
+ dt_floor_boundary, dt_ceil_boundary = corners2boundaries(dt_ratio, corners_uv=floor_dt_corners, length=w)
145
+ dt_floor_boundary = uv2pixel(dt_floor_boundary, w, h)
146
+ dt_ceil_boundary = uv2pixel(dt_ceil_boundary, w, h)
147
+
148
+ gt_surface = np.zeros((h, w), dtype=np.int32)
149
+ gt_surface[gt_ceil_boundary[..., 1], np.arange(w)] = 1
150
+ gt_surface[gt_floor_boundary[..., 1], np.arange(w)] = 1
151
+ gt_surface = np.cumsum(gt_surface, axis=0)
152
+
153
+ dt_surface = np.zeros((h, w), dtype=np.int32)
154
+ dt_surface[dt_ceil_boundary[..., 1], np.arange(w)] = 1
155
+ dt_surface[dt_floor_boundary[..., 1], np.arange(w)] = 1
156
+ dt_surface = np.cumsum(dt_surface, axis=0)
157
+
158
+ pe = 100 * (dt_surface != gt_surface).sum() / (h * w)
159
+ pe_s.append(pe)
160
+ return np.array(pe_s).mean()
161
+
162
+
163
+ def calc_rmse_delta_1(dt, gt):
164
+ rmse_s = []
165
+ delta_1_s = []
166
+ for i in range(len(gt['depth'])):
167
+ gt_boundaries = corners2boundaries(gt['ratio'][i], corners_xyz=depth2xyz(gt['depth'][i]), step=None,
168
+ visible=False)
169
+ dt_xyz = dt['processed_xyz'][i] if 'processed_xyz' in dt else depth2xyz(np.abs(dt['depth'][i]))
170
+
171
+ dt_boundaries = corners2boundaries(dt['ratio'][i], corners_xyz=dt_xyz, step=None,
172
+ length=256 if 'processed_xyz' in dt else None,
173
+ visible=True if 'processed_xyz' in dt else False)
174
+ gt_layout_depth = layout2depth(gt_boundaries, show=False)
175
+ dt_layout_depth = layout2depth(dt_boundaries, show=False)
176
+
177
+ rmse = ((gt_layout_depth - dt_layout_depth) ** 2).mean() ** 0.5
178
+ threshold = np.maximum(gt_layout_depth / dt_layout_depth, dt_layout_depth / gt_layout_depth)
179
+ delta_1 = (threshold < 1.25).mean()
180
+ rmse_s.append(rmse)
181
+ delta_1_s.append(delta_1)
182
+ return np.array(rmse_s).mean(), np.array(delta_1_s).mean()
183
+
184
+
185
+ def calc_f1_score(dt, gt, threshold=10):
186
+ w = 1024
187
+ h = 512
188
+ f1_s = []
189
+ precision_s = []
190
+ recall_s = []
191
+ for i in range(len(gt['corners'])):
192
+ floor_gt_corners = gt['corners'][i]
193
+ # Take effective corners
194
+ floor_gt_corners = floor_gt_corners[floor_gt_corners[..., 0] + floor_gt_corners[..., 1] != 0]
195
+ floor_gt_corners = np.roll(floor_gt_corners, -np.argmin(floor_gt_corners[..., 0]), 0)
196
+ gt_ratio = gt['ratio'][i][0]
197
+ ceil_gt_corners = corners2boundaries(gt_ratio, corners_uv=floor_gt_corners, step=None)[1]
198
+ gt_corners = np.concatenate((floor_gt_corners, ceil_gt_corners))
199
+ gt_corners = uv2pixel(gt_corners, w, h)
200
+
201
+ floor_dt_corners = xyz2uv(dt['processed_xyz'][i])
202
+ floor_dt_corners = np.roll(floor_dt_corners, -np.argmin(floor_dt_corners[..., 0]), 0)
203
+ dt_ratio = dt['ratio'][i][0]
204
+ ceil_dt_corners = corners2boundaries(dt_ratio, corners_uv=floor_dt_corners, step=None)[1]
205
+ dt_corners = np.concatenate((floor_dt_corners, ceil_dt_corners))
206
+ dt_corners = uv2pixel(dt_corners, w, h)
207
+
208
+ Fs, Ps, Rs = f1_score_2d(gt_corners, dt_corners, [threshold])
209
+ f1_s.append(Fs[0])
210
+ precision_s.append(Ps[0])
211
+ recall_s.append(Rs[0])
212
+
213
+ return np.array(f1_s).mean(), np.array(precision_s).mean(), np.array(recall_s).mean()
214
+
215
+
216
+ def show_heat_map(dt, gt, vis_w=1024):
217
+ dt_heat_map = dt['corner_heat_map'].detach().cpu().numpy()
218
+ gt_heat_map = gt['corner_heat_map'].detach().cpu().numpy()
219
+ dt_heat_map_imgs = []
220
+ gt_heat_map_imgs = []
221
+ for i in range(len(gt['depth'])):
222
+ dt_heat_map_img = dt_heat_map[..., np.newaxis].repeat(3, axis=-1).repeat(20, axis=0)
223
+ gt_heat_map_img = gt_heat_map[..., np.newaxis].repeat(3, axis=-1).repeat(20, axis=0)
224
+ dt_heat_map_imgs.append(cv2.resize(dt_heat_map_img, (vis_w, dt_heat_map_img.shape[0])).transpose(2, 0, 1))
225
+ gt_heat_map_imgs.append(cv2.resize(gt_heat_map_img, (vis_w, dt_heat_map_img.shape[0])).transpose(2, 0, 1))
226
+ return dt_heat_map_imgs, gt_heat_map_imgs
227
+
228
+
229
+ def show_depth_normal_grad(dt, gt, device, vis_w=1024):
230
+ grad_conv = GradLoss().to(device).grad_conv
231
+ gt_grad_imgs = []
232
+ dt_grad_imgs = []
233
+
234
+ if 'depth' not in dt.keys():
235
+ dt['depth'] = gt['depth']
236
+
237
+ if vis_w == 1024:
238
+ h = 5
239
+ else:
240
+ h = int(vis_w / (12 * 10))
241
+
242
+ for i in range(len(gt['depth'])):
243
+ gt_grad_img = show_grad(gt['depth'][i], grad_conv, h)
244
+ dt_grad_img = show_grad(dt['depth'][i], grad_conv, h)
245
+ vis_h = dt_grad_img.shape[0] * (vis_w // dt_grad_img.shape[1])
246
+ gt_grad_imgs.append(cv2.resize(gt_grad_img, (vis_w, vis_h), interpolation=cv2.INTER_NEAREST).transpose(2, 0, 1))
247
+ dt_grad_imgs.append(cv2.resize(dt_grad_img, (vis_w, vis_h), interpolation=cv2.INTER_NEAREST).transpose(2, 0, 1))
248
+
249
+ return gt_grad_imgs, dt_grad_imgs
evaluation/analyse_layout_type.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2022/01/31
3
+ @description:
4
+ ZInd:
5
+ {'test': {'mw': 2789, 'aw': 381}, 'train': {'mw': 21228, 'aw': 3654}, 'val': {'mw': 2647, 'aw': 433}}
6
+
7
+ """
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import json
11
+
12
+ from tqdm import tqdm
13
+ from evaluation.iou import calc_IoU_2D
14
+ from visualization.floorplan import draw_floorplan
15
+ from visualization.boundary import draw_boundaries
16
+ from utils.conversion import depth2xyz, uv2xyz
17
+
18
+
19
+ def analyse_layout_type(dataset, show=False):
20
+ bar = tqdm(dataset, total=len(dataset), ncols=100)
21
+ manhattan = 0
22
+ atlanta = 0
23
+ corner_type = {}
24
+ for data in bar:
25
+ bar.set_description(f"Processing {data['id']}")
26
+ corners = data['corners']
27
+ corners = corners[corners[..., 0] + corners[..., 1] != 0] # Take effective corners
28
+ corners_count = str(len(corners)) if len(corners) < 10 else "10"
29
+ if corners_count not in corner_type:
30
+ corner_type[corners_count] = 0
31
+ corner_type[corners_count] += 1
32
+
33
+ all_xz = uv2xyz(corners)[..., ::2]
34
+
35
+ c = len(all_xz)
36
+ flag = False
37
+ for i in range(c - 1):
38
+ l1 = all_xz[i + 1] - all_xz[i]
39
+ l2 = all_xz[(i + 2) % c] - all_xz[i + 1]
40
+ a = (np.linalg.norm(l1)*np.linalg.norm(l2))
41
+ if a == 0:
42
+ continue
43
+ dot = np.dot(l1, l2)/a
44
+ if 0.9 > abs(dot) > 0.1:
45
+ # cos-1(0.1)=84.26 > angle > cos-1(0.9)=25.84 or
46
+ # cos-1(-0.9)=154.16 > angle > cos-1(-0.1)=95.74
47
+ flag = True
48
+ break
49
+ if flag:
50
+ atlanta += 1
51
+ else:
52
+ manhattan += 1
53
+
54
+ if flag and show:
55
+ draw_floorplan(all_xz, show=True)
56
+ draw_boundaries(data['image'].transpose(1, 2, 0), [corners], ratio=data['ratio'], show=True)
57
+
58
+ corner_type = dict(sorted(corner_type.items(), key=lambda item: int(item[0])))
59
+ return {'manhattan': manhattan, "atlanta": atlanta, "corner_type": corner_type}
60
+
61
+
62
+ def execute_analyse_layout_type(root_dir, dataset, modes=None):
63
+ if modes is None:
64
+ modes = ["train", "val", "test"]
65
+
66
+ iou2d_d = {}
67
+ for mode in modes:
68
+ print("mode: {}".format(mode))
69
+ types = analyse_layout_type(dataset(root_dir, mode), show=False)
70
+ iou2d_d[mode] = types
71
+ print(json.dumps(types, indent=4))
72
+ return iou2d_d
73
+
74
+
75
+ if __name__ == '__main__':
76
+ from dataset.zind_dataset import ZindDataset
77
+ from dataset.mp3d_dataset import MP3DDataset
78
+
79
+ iou2d_d = execute_analyse_layout_type(root_dir='../src/dataset/mp3d',
80
+ dataset=MP3DDataset)
81
+ # iou2d_d = execute_analyse_layout_type(root_dir='../src/dataset/zind',
82
+ # dataset=ZindDataset)
83
+ print(json.dumps(iou2d_d, indent=4))
evaluation/eval_visible_iou.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/08/02
3
+ @description:
4
+ The 2DIoU for calculating the visible and full boundaries, such as the MP3D dataset,
5
+ has the following data: {'train': 0.9775843958583535, 'test': 0.9828616219607289, 'val': 0.9883810438132491},
6
+ indicating that our best performance is limited to below 98.29% 2DIoU using our approach.
7
+ """
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+
11
+ from tqdm import tqdm
12
+ from evaluation.iou import calc_IoU_2D
13
+ from visualization.floorplan import draw_iou_floorplan
14
+ from utils.conversion import depth2xyz, uv2xyz
15
+
16
+
17
+ def eval_dataset_visible_IoU(dataset, show=False):
18
+ bar = tqdm(dataset, total=len(dataset), ncols=100)
19
+ iou2ds = []
20
+ for data in bar:
21
+ bar.set_description(f"Processing {data['id']}")
22
+ corners = data['corners']
23
+ corners = corners[corners[..., 0] + corners[..., 1] != 0] # Take effective corners
24
+ all_xz = uv2xyz(corners)[..., ::2]
25
+ visible_xz = depth2xyz(data['depth'])[..., ::2]
26
+ iou2d = calc_IoU_2D(all_xz, visible_xz)
27
+ iou2ds.append(iou2d)
28
+ if show:
29
+ layout_floorplan = draw_iou_floorplan(all_xz, visible_xz, iou2d=iou2d)
30
+ plt.imshow(layout_floorplan)
31
+ plt.show()
32
+
33
+ mean_iou2d = np.array(iou2ds).mean()
34
+ return mean_iou2d
35
+
36
+
37
+ def execute_eval_dataset_visible_IoU(root_dir, dataset, modes=None):
38
+ if modes is None:
39
+ modes = ["train", "test", "valid"]
40
+
41
+ iou2d_d = {}
42
+ for mode in modes:
43
+ print("mode: {}".format(mode))
44
+ iou2d = eval_dataset_visible_IoU(dataset(root_dir, mode, patch_num=1024,
45
+ keys=['depth', 'visible_corners', 'corners', 'id']), show=False)
46
+ iou2d_d[mode] = iou2d
47
+ return iou2d_d
48
+
49
+
50
+ if __name__ == '__main__':
51
+ from dataset.mp3d_dataset import MP3DDataset
52
+
53
+ iou2d_d = execute_eval_dataset_visible_IoU(root_dir='../src/dataset/mp3d',
54
+ dataset=MP3DDataset,
55
+ modes=['train', 'test', 'val'])
56
+ print(iou2d_d)
evaluation/f1_score.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: Zhigang Jiang
3
+ @time: 2022/01/28
4
+ @description:
5
+ Holistic 3D Vision Challenge on General Room Layout Estimation Track Evaluation Package
6
+ Reference: https://github.com/bertjiazheng/indoor-layout-evaluation
7
+ """
8
+
9
+ from scipy.optimize import linear_sum_assignment
10
+ import numpy as np
11
+ import scipy
12
+
13
+ HEIGHT, WIDTH = 512, 1024
14
+ MAX_DISTANCE = np.sqrt(HEIGHT**2 + WIDTH**2)
15
+
16
+
17
+ def f1_score_2d(gt_corners, dt_corners, thresholds):
18
+ distances = scipy.spatial.distance.cdist(gt_corners, dt_corners)
19
+ return eval_junctions(distances, thresholds=thresholds)
20
+
21
+
22
+ def eval_junctions(distances, thresholds=5):
23
+ thresholds = thresholds if isinstance(thresholds, tuple) or isinstance(
24
+ thresholds, list) else list([thresholds])
25
+
26
+ num_gts, num_preds = distances.shape
27
+
28
+ # filter the matches between ceiling-wall and floor-wall junctions
29
+ mask = np.zeros_like(distances, dtype=np.bool)
30
+ mask[:num_gts//2, :num_preds//2] = True
31
+ mask[num_gts//2:, num_preds//2:] = True
32
+ distances[~mask] = np.inf
33
+
34
+ # F-measure under different thresholds
35
+ Fs = []
36
+ Ps = []
37
+ Rs = []
38
+ for threshold in thresholds:
39
+ distances_temp = distances.copy()
40
+
41
+ # filter the mis-matched pairs
42
+ distances_temp[distances_temp > threshold] = np.inf
43
+
44
+ # remain the rows and columns that contain non-inf elements
45
+ distances_temp = distances_temp[:, np.any(np.isfinite(distances_temp), axis=0)]
46
+
47
+ if np.prod(distances_temp.shape) == 0:
48
+ Fs.append(0)
49
+ Ps.append(0)
50
+ Rs.append(0)
51
+ continue
52
+
53
+ distances_temp = distances_temp[np.any(np.isfinite(distances_temp), axis=1), :]
54
+
55
+ # solve the bipartite graph matching problem
56
+ row_ind, col_ind = linear_sum_assignment_with_inf(distances_temp)
57
+ true_positive = np.sum(np.isfinite(distances_temp[row_ind, col_ind]))
58
+
59
+ # compute precision and recall
60
+ precision = true_positive / num_preds
61
+ recall = true_positive / num_gts
62
+
63
+ # compute F measure
64
+ Fs.append(2 * precision * recall / (precision + recall))
65
+ Ps.append(precision)
66
+ Rs.append(recall)
67
+
68
+ return Fs, Ps, Rs
69
+
70
+
71
+ def linear_sum_assignment_with_inf(cost_matrix):
72
+ """
73
+ Deal with linear_sum_assignment with inf according to
74
+ https://github.com/scipy/scipy/issues/6900#issuecomment-451735634
75
+ """
76
+ cost_matrix = np.copy(cost_matrix)
77
+ cost_matrix[np.isinf(cost_matrix)] = MAX_DISTANCE
78
+ return linear_sum_assignment(cost_matrix)
evaluation/iou.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @date: 2021/6/29
3
+ @description:
4
+ The method with "_floorplan" suffix is only for comparison, which is used for calculation in LED2-net.
5
+ However, the floorplan is affected by show_radius. Setting too large will result in the decrease of accuracy,
6
+ and setting too small will result in the failure of calculation beyond the range.
7
+ """
8
+ import numpy as np
9
+ from shapely.geometry import Polygon
10
+
11
+
12
+ def calc_inter_area(dt_xz, gt_xz):
13
+ """
14
+ :param dt_xz: Prediction boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
15
+ :param gt_xz: Ground truth boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
16
+ :return:
17
+ """
18
+ dt_polygon = Polygon(dt_xz)
19
+ gt_polygon = Polygon(gt_xz)
20
+
21
+ dt_area = dt_polygon.area
22
+ gt_area = gt_polygon.area
23
+ inter_area = dt_polygon.intersection(gt_polygon).area
24
+ return dt_area, gt_area, inter_area
25
+
26
+
27
+ def calc_IoU_2D(dt_xz, gt_xz):
28
+ """
29
+ :param dt_xz: Prediction boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
30
+ :param gt_xz: Ground truth boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
31
+ :return:
32
+ """
33
+ dt_area, gt_area, inter_area = calc_inter_area(dt_xz, gt_xz)
34
+ iou_2d = inter_area / (gt_area + dt_area - inter_area)
35
+ return iou_2d
36
+
37
+
38
+ def calc_IoU_3D(dt_xz, gt_xz, dt_height, gt_height):
39
+ """
40
+ :param dt_xz: Prediction boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
41
+ :param gt_xz: Ground truth boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
42
+ :param dt_height:
43
+ :param gt_height:
44
+ :return:
45
+ """
46
+ dt_area, gt_area, inter_area = calc_inter_area(dt_xz, gt_xz)
47
+ dt_volume = dt_area * dt_height
48
+ gt_volume = gt_area * gt_height
49
+ inter_volume = inter_area * min(dt_height, gt_height)
50
+ iou_3d = inter_volume / (dt_volume + gt_volume - inter_volume)
51
+ return iou_3d
52
+
53
+
54
+ def calc_IoU(dt_xz, gt_xz, dt_height, gt_height):
55
+ """
56
+ :param dt_xz: Prediction boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
57
+ :param gt_xz: Ground truth boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
58
+ :param dt_height:
59
+ :param gt_height:
60
+ :return:
61
+ """
62
+ dt_area, gt_area, inter_area = calc_inter_area(dt_xz, gt_xz)
63
+ iou_2d = inter_area / (gt_area + dt_area - inter_area)
64
+
65
+ dt_volume = dt_area * dt_height
66
+ gt_volume = gt_area * gt_height
67
+ inter_volume = inter_area * min(dt_height, gt_height)
68
+ iou_3d = inter_volume / (dt_volume + gt_volume - inter_volume)
69
+
70
+ return iou_2d, iou_3d
71
+
72
+
73
+ def calc_Iou_height(dt_height, gt_height):
74
+ return min(dt_height, gt_height) / max(dt_height, gt_height)
75
+
76
+
77
+ # the following is for testing only
78
+ def calc_inter_area_floorplan(dt_floorplan, gt_floorplan):
79
+ intersect = np.sum(np.logical_and(dt_floorplan, gt_floorplan))
80
+ dt_area = np.sum(dt_floorplan)
81
+ gt_area = np.sum(gt_floorplan)
82
+ return dt_area, gt_area, intersect
83
+
84
+
85
+ def calc_IoU_2D_floorplan(dt_floorplan, gt_floorplan):
86
+ dt_area, gt_area, inter_area = calc_inter_area_floorplan(dt_floorplan, gt_floorplan)
87
+ iou_2d = inter_area / (gt_area + dt_area - inter_area)
88
+ return iou_2d
89
+
90
+
91
+ def calc_IoU_3D_floorplan(dt_floorplan, gt_floorplan, dt_height, gt_height):
92
+ dt_area, gt_area, inter_area = calc_inter_area_floorplan(dt_floorplan, gt_floorplan)
93
+ dt_volume = dt_area * dt_height
94
+ gt_volume = gt_area * gt_height
95
+ inter_volume = inter_area * min(dt_height, gt_height)
96
+ iou_3d = inter_volume / (dt_volume + gt_volume - inter_volume)
97
+ return iou_3d
98
+
99
+
100
+ def calc_IoU_floorplan(dt_floorplan, gt_floorplan, dt_height, gt_height):
101
+ dt_area, gt_area, inter_area = calc_inter_area_floorplan(dt_floorplan, gt_floorplan)
102
+ iou_2d = inter_area / (gt_area + dt_area - inter_area)
103
+
104
+ dt_volume = dt_area * dt_height
105
+ gt_volume = gt_area * gt_height
106
+ inter_volume = inter_area * min(dt_height, gt_height)
107
+ iou_3d = inter_volume / (dt_volume + gt_volume - inter_volume)
108
+ return iou_2d, iou_3d
109
+
110
+
111
+ if __name__ == '__main__':
112
+ from visualization.floorplan import draw_floorplan, draw_iou_floorplan
113
+ from visualization.boundary import draw_boundaries, corners2boundaries
114
+ from utils.conversion import uv2xyz
115
+ from utils.height import height2ratio
116
+
117
+ # dummy data
118
+ dt_floor_corners = np.array([[0.2, 0.7],
119
+ [0.4, 0.7],
120
+ [0.6, 0.7],
121
+ [0.8, 0.7]])
122
+ dt_height = 2.8
123
+
124
+ gt_floor_corners = np.array([[0.3, 0.7],
125
+ [0.5, 0.7],
126
+ [0.7, 0.7],
127
+ [0.9, 0.7]])
128
+ gt_height = 3.2
129
+
130
+ dt_xz = uv2xyz(dt_floor_corners)[..., ::2]
131
+ gt_xz = uv2xyz(gt_floor_corners)[..., ::2]
132
+
133
+ dt_floorplan = draw_floorplan(dt_xz, show=False, show_radius=1)
134
+ gt_floorplan = draw_floorplan(gt_xz, show=False, show_radius=1)
135
+ # dt_floorplan = draw_floorplan(dt_xz, show=False, show_radius=2)
136
+ # gt_floorplan = draw_floorplan(gt_xz, show=False, show_radius=2)
137
+
138
+ iou_2d, iou_3d = calc_IoU_floorplan(dt_floorplan, gt_floorplan, dt_height, gt_height)
139
+ print('use floor plan image:', iou_2d, iou_3d)
140
+
141
+ iou_2d, iou_3d = calc_IoU(dt_xz, gt_xz, dt_height, gt_height)
142
+ print('use floor plan polygon:', iou_2d, iou_3d)
143
+
144
+ draw_iou_floorplan(dt_xz, gt_xz, show=True, iou_2d=iou_2d, iou_3d=iou_3d)
145
+ pano_bd = draw_boundaries(np.zeros([512, 1024, 3]), corners_list=[dt_floor_corners],
146
+ boundary_color=[0, 0, 1], ratio=height2ratio(dt_height), draw_corners=False)
147
+ pano_bd = draw_boundaries(pano_bd, corners_list=[gt_floor_corners],
148
+ boundary_color=[0, 1, 0], ratio=height2ratio(gt_height), show=True, draw_corners=False)
inference.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/09/19
3
+ @description:
4
+ """
5
+ import json
6
+ import os
7
+ import argparse
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ import matplotlib.pyplot as plt
12
+ import glob
13
+
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+ from config.defaults import merge_from_file, get_config
17
+ from dataset.mp3d_dataset import MP3DDataset
18
+ from dataset.zind_dataset import ZindDataset
19
+ from models.build import build_model
20
+ from loss import GradLoss
21
+ from postprocessing.post_process import post_process
22
+ from preprocessing.pano_lsd_align import panoEdgeDetection, rotatePanorama
23
+ from utils.boundary import corners2boundaries, layout2depth
24
+ from utils.conversion import depth2xyz
25
+ from utils.logger import get_logger
26
+ from utils.misc import tensor2np_d, tensor2np
27
+ from evaluation.accuracy import show_grad
28
+ from models.lgt_net import LGT_Net
29
+ from utils.writer import xyz2json
30
+ from visualization.boundary import draw_boundaries
31
+ from visualization.floorplan import draw_floorplan, draw_iou_floorplan
32
+ from visualization.obj3d import create_3d_obj
33
+
34
+
35
+ def parse_option():
36
+ parser = argparse.ArgumentParser(description='Panorama Layout Transformer training and evaluation script')
37
+ parser.add_argument('--img_glob',
38
+ type=str,
39
+ required=True,
40
+ help='image glob path')
41
+
42
+ parser.add_argument('--cfg',
43
+ type=str,
44
+ required=True,
45
+ metavar='FILE',
46
+ help='path of config file')
47
+
48
+ parser.add_argument('--post_processing',
49
+ type=str,
50
+ default='manhattan',
51
+ choices=['manhattan', 'atalanta', 'original'],
52
+ help='post-processing type')
53
+
54
+ parser.add_argument('--output_dir',
55
+ type=str,
56
+ default='src/output',
57
+ help='path of output')
58
+
59
+ parser.add_argument('--visualize_3d', action='store_true',
60
+ help='visualize_3d')
61
+
62
+ parser.add_argument('--output_3d', action='store_true',
63
+ help='output_3d')
64
+
65
+ parser.add_argument('--device',
66
+ type=str,
67
+ default='cuda',
68
+ help='device')
69
+
70
+ args = parser.parse_args()
71
+ args.mode = 'test'
72
+
73
+ print("arguments:")
74
+ for arg in vars(args):
75
+ print(arg, ":", getattr(args, arg))
76
+ print("-" * 50)
77
+ return args
78
+
79
+
80
+ def visualize_2d(img, dt, show_depth=True, show_floorplan=True, show=False, save_path=None):
81
+ dt_np = tensor2np_d(dt)
82
+ dt_depth = dt_np['depth'][0]
83
+ dt_xyz = depth2xyz(np.abs(dt_depth))
84
+ dt_ratio = dt_np['ratio'][0][0]
85
+ dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt_xyz, step=None, visible=False, length=img.shape[1])
86
+ vis_img = draw_boundaries(img, boundary_list=dt_boundaries, boundary_color=[0, 1, 0])
87
+
88
+ if 'processed_xyz' in dt:
89
+ dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt['processed_xyz'][0], step=None, visible=False,
90
+ length=img.shape[1])
91
+ vis_img = draw_boundaries(vis_img, boundary_list=dt_boundaries, boundary_color=[1, 0, 0])
92
+
93
+ if show_depth:
94
+ dt_grad_img = show_depth_normal_grad(dt)
95
+ grad_h = dt_grad_img.shape[0]
96
+ vis_merge = [
97
+ vis_img[0:-grad_h, :, :],
98
+ dt_grad_img,
99
+ ]
100
+ vis_img = np.concatenate(vis_merge, axis=0)
101
+ # vis_img = dt_grad_img.transpose(1, 2, 0)[100:]
102
+
103
+ if show_floorplan:
104
+ if 'processed_xyz' in dt:
105
+ floorplan = draw_iou_floorplan(dt['processed_xyz'][0][..., ::2], dt_xyz[..., ::2],
106
+ dt_board_color=[1, 0, 0, 1], gt_board_color=[0, 1, 0, 1])
107
+ else:
108
+ floorplan = show_alpha_floorplan(dt_xyz, border_color=[0, 1, 0, 1])
109
+
110
+ vis_img = np.concatenate([vis_img, floorplan[:, 60:-60, :]], axis=1)
111
+ if show:
112
+ plt.imshow(vis_img)
113
+ plt.show()
114
+ if save_path:
115
+ result = Image.fromarray((vis_img * 255).astype(np.uint8))
116
+ result.save(save_path)
117
+ return vis_img
118
+
119
+
120
+ def preprocess(img_ori, q_error=0.7, refine_iter=3, vp_cache_path=None):
121
+ # Align images with VP
122
+ if os.path.exists(vp_cache_path):
123
+ with open(vp_cache_path) as f:
124
+ vp = [[float(v) for v in line.rstrip().split(' ')] for line in f.readlines()]
125
+ vp = np.array(vp)
126
+ else:
127
+ # VP detection and line segment extraction
128
+ _, vp, _, _, _, _, _ = panoEdgeDetection(img_ori,
129
+ qError=q_error,
130
+ refineIter=refine_iter)
131
+ i_img = rotatePanorama(img_ori, vp[2::-1])
132
+
133
+ if vp_cache_path is not None:
134
+ with open(vp_cache_path, 'w') as f:
135
+ for i in range(3):
136
+ f.write('%.6f %.6f %.6f\n' % (vp[i, 0], vp[i, 1], vp[i, 2]))
137
+
138
+ return i_img, vp
139
+
140
+
141
+ def show_depth_normal_grad(dt):
142
+ grad_conv = GradLoss().to(dt['depth'].device).grad_conv
143
+ dt_grad_img = show_grad(dt['depth'][0], grad_conv, 50)
144
+ dt_grad_img = cv2.resize(dt_grad_img, (1024, 60), interpolation=cv2.INTER_NEAREST)
145
+ return dt_grad_img
146
+
147
+
148
+ def show_alpha_floorplan(dt_xyz, side_l=512, border_color=None):
149
+ if border_color is None:
150
+ border_color = [1, 0, 0, 1]
151
+ fill_color = [0.2, 0.2, 0.2, 0.2]
152
+ dt_floorplan = draw_floorplan(xz=dt_xyz[..., ::2], fill_color=fill_color,
153
+ border_color=border_color, side_l=side_l, show=False, center_color=[1, 0, 0, 1])
154
+ dt_floorplan = Image.fromarray((dt_floorplan * 255).astype(np.uint8), mode='RGBA')
155
+ back = np.zeros([side_l, side_l, len(fill_color)], dtype=np.float)
156
+ back[..., :] = [0.8, 0.8, 0.8, 1]
157
+ back = Image.fromarray((back * 255).astype(np.uint8), mode='RGBA')
158
+ iou_floorplan = Image.alpha_composite(back, dt_floorplan).convert("RGB")
159
+ dt_floorplan = np.array(iou_floorplan) / 255.0
160
+ return dt_floorplan
161
+
162
+
163
+ def save_pred_json(xyz, ration, save_path):
164
+ # xyz[..., -1] = -xyz[..., -1]
165
+ json_data = xyz2json(xyz, ration)
166
+ with open(save_path, 'w') as f:
167
+ f.write(json.dumps(json_data, indent=4) + '\n')
168
+ return json_data
169
+
170
+
171
+ def inference():
172
+ if len(img_paths) == 0:
173
+ logger.error('No images found')
174
+ return
175
+
176
+ bar = tqdm(img_paths, ncols=100)
177
+ for img_path in bar:
178
+ if not os.path.isfile(img_path):
179
+ logger.error(f'The {img_path} not is file')
180
+ continue
181
+ name = os.path.basename(img_path).split('.')[0]
182
+ bar.set_description(name)
183
+ img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3]
184
+ if args.post_processing is not None and 'manhattan' in args.post_processing:
185
+ bar.set_description("Preprocessing")
186
+ img, vp = preprocess(img, vp_cache_path=os.path.join(args.output_dir, f"{name}_vp.txt"))
187
+
188
+ img = (img / 255.0).astype(np.float32)
189
+ run_one_inference(img, model, args, name)
190
+
191
+
192
+ def inference_dataset(dataset):
193
+ bar = tqdm(dataset, ncols=100)
194
+ for data in bar:
195
+ bar.set_description(data['id'])
196
+ run_one_inference(data['image'].transpose(1, 2, 0), model, args, name=data['id'], logger=logger)
197
+
198
+
199
+ @torch.no_grad()
200
+ def run_one_inference(img, model, args, name, logger, show=True, show_depth=True,
201
+ show_floorplan=True, mesh_format='.gltf', mesh_resolution=512):
202
+ model.eval()
203
+ logger.info("model inference...")
204
+ dt = model(torch.from_numpy(img.transpose(2, 0, 1)[None]).to(args.device))
205
+ if args.post_processing != 'original':
206
+ logger.info(f"post-processing, type:{args.post_processing}...")
207
+ dt['processed_xyz'] = post_process(tensor2np(dt['depth']), type_name=args.post_processing)
208
+
209
+ visualize_2d(img, dt,
210
+ show_depth=show_depth,
211
+ show_floorplan=show_floorplan,
212
+ show=show,
213
+ save_path=os.path.join(args.output_dir, f"{name}_pred.png"))
214
+ output_xyz = dt['processed_xyz'][0] if 'processed_xyz' in dt else depth2xyz(tensor2np(dt['depth'][0]))
215
+
216
+ logger.info(f"saving predicted layout json...")
217
+ json_data = save_pred_json(output_xyz, tensor2np(dt['ratio'][0])[0],
218
+ save_path=os.path.join(args.output_dir, f"{name}_pred.json"))
219
+ # if args.visualize_3d:
220
+ # from visualization.visualizer.visualizer import visualize_3d
221
+ # visualize_3d(json_data, (img * 255).astype(np.uint8))
222
+
223
+ if args.visualize_3d or args.output_3d:
224
+ dt_boundaries = corners2boundaries(tensor2np(dt['ratio'][0])[0], corners_xyz=output_xyz, step=None,
225
+ length=mesh_resolution if 'processed_xyz' in dt else None,
226
+ visible=True if 'processed_xyz' in dt else False)
227
+ dt_layout_depth = layout2depth(dt_boundaries, show=False)
228
+
229
+ logger.info(f"creating 3d mesh ...")
230
+ create_3d_obj(cv2.resize(img, dt_layout_depth.shape[::-1]), dt_layout_depth,
231
+ save_path=os.path.join(args.output_dir, f"{name}_3d{mesh_format}") if args.output_3d else None,
232
+ mesh=True, show=args.visualize_3d)
233
+
234
+
235
+ if __name__ == '__main__':
236
+ logger = get_logger()
237
+ args = parse_option()
238
+ config = get_config(args)
239
+
240
+ if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available():
241
+ logger.info(f'The {args.device} is not available, will use cpu ...')
242
+ config.defrost()
243
+ args.device = "cpu"
244
+ config.TRAIN.DEVICE = "cpu"
245
+ config.freeze()
246
+
247
+ model, _, _, _ = build_model(config, logger)
248
+ os.makedirs(args.output_dir, exist_ok=True)
249
+ img_paths = sorted(glob.glob(args.img_glob))
250
+
251
+ inference()
252
+
253
+ # dataset = MP3DDataset(root_dir='./src/dataset/mp3d', mode='test', split_list=[
254
+ # ['7y3sRwLe3Va', '155fac2d50764bf09feb6c8f33e8fb76'],
255
+ # ['e9zR4mvMWw7', 'c904c55a5d0e420bbd6e4e030b9fe5b4'],
256
+ # ])
257
+ # dataset = ZindDataset(root_dir='./src/dataset/zind', mode='test', split_list=[
258
+ # '1169_pano_21',
259
+ # '0583_pano_59',
260
+ # ], vp_align=True)
261
+ # inference_dataset(dataset)
loss/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @date: 2021/7/19
3
+ @description:
4
+ """
5
+
6
+ from torch.nn import L1Loss
7
+ from loss.led_loss import LEDLoss
8
+ from loss.grad_loss import GradLoss
9
+ from loss.boundary_loss import BoundaryLoss
10
+ from loss.object_loss import ObjectLoss, HeatmapLoss
loss/boundary_loss.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/08/12
3
+ @description: For HorizonNet, using latitudes to calculate loss.
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ from utils.conversion import depth2xyz, xyz2lonlat
8
+
9
+
10
+ class BoundaryLoss(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.loss = nn.L1Loss()
14
+
15
+ def forward(self, gt, dt):
16
+ gt_floor_xyz = depth2xyz(gt['depth'])
17
+ gt_ceil_xyz = gt_floor_xyz.clone()
18
+ gt_ceil_xyz[..., 1] = -gt['ratio']
19
+
20
+ gt_floor_boundary = xyz2lonlat(gt_floor_xyz)[..., -1:]
21
+ gt_ceil_boundary = xyz2lonlat(gt_ceil_xyz)[..., -1:]
22
+
23
+ gt_boundary = torch.cat([gt_floor_boundary, gt_ceil_boundary], dim=-1).permute(0, 2, 1)
24
+ dt_boundary = dt['boundary']
25
+
26
+ loss = self.loss(gt_boundary, dt_boundary)
27
+ return loss
28
+
29
+
30
+ if __name__ == '__main__':
31
+ import numpy as np
32
+ from dataset.mp3d_dataset import MP3DDataset
33
+
34
+ mp3d_dataset = MP3DDataset(root_dir='../src/dataset/mp3d', mode='train')
35
+ gt = mp3d_dataset.__getitem__(0)
36
+
37
+ gt['depth'] = torch.from_numpy(gt['depth'][np.newaxis]) # batch size is 1
38
+ gt['ratio'] = torch.from_numpy(gt['ratio'][np.newaxis]) # batch size is 1
39
+
40
+ dummy_dt = {
41
+ 'depth': gt['depth'].clone(),
42
+ 'boundary': torch.cat([
43
+ xyz2lonlat(depth2xyz(gt['depth']))[..., -1:],
44
+ xyz2lonlat(depth2xyz(gt['depth'], plan_y=-gt['ratio']))[..., -1:]
45
+ ], dim=-1).permute(0, 2, 1)
46
+ }
47
+ # dummy_dt['boundary'][:, :, :20] /= 1.2 # some different
48
+
49
+ boundary_loss = BoundaryLoss()
50
+ loss = boundary_loss(gt, dummy_dt)
51
+ print(loss)
loss/grad_loss.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/08/12
3
+ @description:
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+
10
+ from visualization.grad import get_all
11
+
12
+
13
+ class GradLoss(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.loss = nn.L1Loss()
17
+ self.cos = nn.CosineSimilarity(dim=-1, eps=0)
18
+
19
+ self.grad_conv = nn.Conv1d(1, 1, kernel_size=3, stride=1, padding=0, bias=False, padding_mode='circular')
20
+ self.grad_conv.weight = nn.Parameter(torch.tensor([[[1, 0, -1]]]).float())
21
+ self.grad_conv.weight.requires_grad = False
22
+
23
+ def forward(self, gt, dt):
24
+ gt_direction, _, gt_angle_grad = get_all(gt['depth'], self.grad_conv)
25
+ dt_direction, _, dt_angle_grad = get_all(dt['depth'], self.grad_conv)
26
+
27
+ normal_loss = (1 - self.cos(gt_direction, dt_direction)).mean()
28
+ grad_loss = self.loss(gt_angle_grad, dt_angle_grad)
29
+ return [normal_loss, grad_loss]
30
+
31
+
32
+ if __name__ == '__main__':
33
+ from dataset.mp3d_dataset import MP3DDataset
34
+ from utils.boundary import depth2boundaries
35
+ from utils.conversion import uv2xyz
36
+ from visualization.boundary import draw_boundaries
37
+ from visualization.floorplan import draw_floorplan
38
+
39
+ def show_boundary(image, depth, ratio):
40
+ boundary_list = depth2boundaries(ratio, depth, step=None)
41
+ draw_boundaries(image.transpose(1, 2, 0), boundary_list=boundary_list, show=True)
42
+ draw_floorplan(uv2xyz(boundary_list[0])[..., ::2], show=True, center_color=0.8)
43
+
44
+ mp3d_dataset = MP3DDataset(root_dir='../src/dataset/mp3d', mode='train', patch_num=256)
45
+ gt = mp3d_dataset.__getitem__(1)
46
+ gt['depth'] = torch.from_numpy(gt['depth'][np.newaxis]) # batch size is 1
47
+ dummy_dt = {
48
+ 'depth': gt['depth'].clone(),
49
+ }
50
+ # dummy_dt['depth'][..., 20] *= 3 # some different
51
+
52
+ # show_boundary(gt['image'], gt['depth'][0].numpy(), gt['ratio'])
53
+ # show_boundary(gt['image'], dummy_dt['depth'][0].numpy(), gt['ratio'])
54
+
55
+ grad_loss = GradLoss()
56
+ loss = grad_loss(gt, dummy_dt)
57
+ print(loss)
loss/led_loss.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/08/12
3
+ @description:
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class LEDLoss(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.loss = nn.L1Loss()
13
+
14
+ def forward(self, gt, dt):
15
+ camera_height = 1.6
16
+ gt_depth = gt['depth'] * camera_height
17
+
18
+ dt_ceil_depth = dt['ceil_depth'] * camera_height * gt['ratio']
19
+ dt_floor_depth = dt['depth'] * camera_height
20
+
21
+ ceil_loss = self.loss(gt_depth, dt_ceil_depth)
22
+ floor_loss = self.loss(gt_depth, dt_floor_depth)
23
+
24
+ loss = floor_loss + ceil_loss
25
+
26
+ return loss
27
+
28
+
29
+ if __name__ == '__main__':
30
+ import numpy as np
31
+ from dataset.mp3d_dataset import MP3DDataset
32
+
33
+ mp3d_dataset = MP3DDataset(root_dir='../src/dataset/mp3d', mode='train')
34
+ gt = mp3d_dataset.__getitem__(0)
35
+
36
+ gt['depth'] = torch.from_numpy(gt['depth'][np.newaxis]) # batch size is 1
37
+ gt['ratio'] = torch.from_numpy(gt['ratio'][np.newaxis]) # batch size is 1
38
+
39
+ dummy_dt = {
40
+ 'depth': gt['depth'].clone(),
41
+ 'ceil_depth': gt['depth'] / gt['ratio']
42
+ }
43
+ # dummy_dt['depth'][..., :20] *= 3 # some different
44
+
45
+ led_loss = LEDLoss()
46
+ loss = led_loss(gt, dummy_dt)
47
+ print(loss)
loss/object_loss.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/08/12
3
+ @description:
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ from loss.grad_loss import GradLoss
8
+
9
+
10
+ class ObjectLoss(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.heat_map_loss = HeatmapLoss(reduction='mean') # FocalLoss(reduction='mean')
14
+ self.l1_loss = nn.SmoothL1Loss()
15
+
16
+ def forward(self, gt, dt):
17
+ # TODO::
18
+ return 0
19
+
20
+
21
+ class HeatmapLoss(nn.Module):
22
+ def __init__(self, weight=None, alpha=2, beta=4, reduction='mean'):
23
+ super(HeatmapLoss, self).__init__()
24
+ self.alpha = alpha
25
+ self.beta = beta
26
+ self.reduction = reduction
27
+
28
+ def forward(self, targets, inputs):
29
+ center_id = (targets == 1.0).float()
30
+ other_id = (targets != 1.0).float()
31
+ center_loss = -center_id * (1.0 - inputs) ** self.alpha * torch.log(inputs + 1e-14)
32
+ other_loss = -other_id * (1 - targets) ** self.beta * inputs ** self.alpha * torch.log(1.0 - inputs + 1e-14)
33
+ loss = center_loss + other_loss
34
+
35
+ batch_size = loss.size(0)
36
+ if self.reduction == 'mean':
37
+ loss = torch.sum(loss) / batch_size
38
+
39
+ if self.reduction == 'sum':
40
+ loss = torch.sum(loss) / batch_size
41
+
42
+ return loss
main.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/07/17
3
+ @description:
4
+ """
5
+ import sys
6
+ import os
7
+ import shutil
8
+ import argparse
9
+ import numpy as np
10
+ import json
11
+ import torch
12
+ import torch.nn.parallel
13
+ import torch.optim
14
+ import torch.multiprocessing as mp
15
+ import torch.utils.data
16
+ import torch.utils.data.distributed
17
+ import torch.cuda
18
+
19
+ from PIL import Image
20
+ from tqdm import tqdm
21
+ from torch.utils.tensorboard import SummaryWriter
22
+ from config.defaults import get_config, get_rank_config
23
+ from models.other.criterion import calc_criterion
24
+ from models.build import build_model
25
+ from models.other.init_env import init_env
26
+ from utils.logger import build_logger
27
+ from utils.misc import tensor2np_d, tensor2np
28
+ from dataset.build import build_loader
29
+ from evaluation.accuracy import calc_accuracy, show_heat_map, calc_ce, calc_pe, calc_rmse_delta_1, \
30
+ show_depth_normal_grad, calc_f1_score
31
+ from postprocessing.post_process import post_process
32
+
33
+ try:
34
+ from apex import amp
35
+ except ImportError:
36
+ amp = None
37
+
38
+
39
+ def parse_option():
40
+ debug = True if sys.gettrace() else False
41
+ parser = argparse.ArgumentParser(description='Panorama Layout Transformer training and evaluation script')
42
+ parser.add_argument('--cfg',
43
+ type=str,
44
+ metavar='FILE',
45
+ help='path to config file')
46
+
47
+ parser.add_argument('--mode',
48
+ type=str,
49
+ default='train',
50
+ choices=['train', 'val', 'test'],
51
+ help='train/val/test mode')
52
+
53
+ parser.add_argument('--val_name',
54
+ type=str,
55
+ choices=['val', 'test'],
56
+ help='val name')
57
+
58
+ parser.add_argument('--bs', type=int,
59
+ help='batch size')
60
+
61
+ parser.add_argument('--save_eval', action='store_true',
62
+ help='save eval result')
63
+
64
+ parser.add_argument('--post_processing', type=str,
65
+ choices=['manhattan', 'atalanta', 'manhattan_old'],
66
+ help='type of postprocessing ')
67
+
68
+ parser.add_argument('--need_cpe', action='store_true',
69
+ help='need to evaluate corner error and pixel error')
70
+
71
+ parser.add_argument('--need_f1', action='store_true',
72
+ help='need to evaluate f1-score of corners')
73
+
74
+ parser.add_argument('--need_rmse', action='store_true',
75
+ help='need to evaluate root mean squared error and delta error')
76
+
77
+ parser.add_argument('--force_cube', action='store_true',
78
+ help='force cube shape when eval')
79
+
80
+ parser.add_argument('--wall_num', type=int,
81
+ help='wall number')
82
+
83
+ args = parser.parse_args()
84
+ args.debug = debug
85
+ print("arguments:")
86
+ for arg in vars(args):
87
+ print(arg, ":", getattr(args, arg))
88
+ print("-" * 50)
89
+ return args
90
+
91
+
92
+ def main():
93
+ args = parse_option()
94
+ config = get_config(args)
95
+
96
+ if config.TRAIN.SCRATCH and os.path.exists(config.CKPT.DIR) and config.MODE == 'train':
97
+ print(f"Train from scratch, delete checkpoint dir: {config.CKPT.DIR}")
98
+ f = [int(f.split('_')[-1].split('.')[0]) for f in os.listdir(config.CKPT.DIR) if 'pkl' in f]
99
+ if len(f) > 0:
100
+ last_epoch = np.array(f).max()
101
+ if last_epoch > 10:
102
+ c = input(f"delete it (last_epoch: {last_epoch})?(Y/N)\n")
103
+ if c != 'y' and c != 'Y':
104
+ exit(0)
105
+
106
+ shutil.rmtree(config.CKPT.DIR, ignore_errors=True)
107
+
108
+ os.makedirs(config.CKPT.DIR, exist_ok=True)
109
+ os.makedirs(config.CKPT.RESULT_DIR, exist_ok=True)
110
+ os.makedirs(config.LOGGER.DIR, exist_ok=True)
111
+
112
+ if ':' in config.TRAIN.DEVICE:
113
+ nprocs = len(config.TRAIN.DEVICE.split(':')[-1].split(','))
114
+ if 'cuda' in config.TRAIN.DEVICE:
115
+ if not torch.cuda.is_available():
116
+ print(f"Cuda is not available(config is: {config.TRAIN.DEVICE}), will use cpu ...")
117
+ config.defrost()
118
+ config.TRAIN.DEVICE = "cpu"
119
+ config.freeze()
120
+ nprocs = 1
121
+
122
+ if config.MODE == 'train':
123
+ with open(os.path.join(config.CKPT.DIR, "config.yaml"), "w") as f:
124
+ f.write(config.dump(allow_unicode=True))
125
+
126
+ if config.TRAIN.DEVICE == 'cpu' or nprocs < 2:
127
+ print(f"Use single process, device:{config.TRAIN.DEVICE}")
128
+ main_worker(0, config, 1)
129
+ else:
130
+ print(f"Use {nprocs} processes ...")
131
+ mp.spawn(main_worker, nprocs=nprocs, args=(config, nprocs), join=True)
132
+
133
+
134
+ def main_worker(local_rank, cfg, world_size):
135
+ config = get_rank_config(cfg, local_rank, world_size)
136
+ logger = build_logger(config)
137
+ writer = SummaryWriter(config.CKPT.DIR)
138
+ logger.info(f"Comment: {config.COMMENT}")
139
+ cur_pid = os.getpid()
140
+ logger.info(f"Current process id: {cur_pid}")
141
+ torch.hub._hub_dir = config.CKPT.PYTORCH
142
+ logger.info(f"Pytorch hub dir: {torch.hub._hub_dir}")
143
+ init_env(config.SEED, config.TRAIN.DETERMINISTIC, config.DATA.NUM_WORKERS)
144
+
145
+ model, optimizer, criterion, scheduler = build_model(config, logger)
146
+ train_data_loader, val_data_loader = build_loader(config, logger)
147
+
148
+ if 'cuda' in config.TRAIN.DEVICE:
149
+ torch.cuda.set_device(config.TRAIN.DEVICE)
150
+
151
+ if config.MODE == 'train':
152
+ train(model, train_data_loader, val_data_loader, optimizer, criterion, config, logger, writer, scheduler)
153
+ else:
154
+ iou_results, other_results = val_an_epoch(model, val_data_loader,
155
+ criterion, config, logger, writer=None,
156
+ epoch=config.TRAIN.START_EPOCH)
157
+ results = dict(iou_results, **other_results)
158
+ if config.SAVE_EVAL:
159
+ save_path = os.path.join(config.CKPT.RESULT_DIR, f"result.json")
160
+ with open(save_path, 'w+') as f:
161
+ json.dump(results, f, indent=4)
162
+
163
+
164
+ def save(model, optimizer, epoch, iou_d, logger, writer, config):
165
+ model.save(optimizer, epoch, accuracy=iou_d['full_3d'], logger=logger, acc_d=iou_d, config=config)
166
+ for k in model.acc_d:
167
+ writer.add_scalar(f"BestACC/{k}", model.acc_d[k]['acc'], epoch)
168
+
169
+
170
+ def train(model, train_data_loader, val_data_loader, optimizer, criterion, config, logger, writer, scheduler):
171
+ for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
172
+ logger.info("=" * 200)
173
+ train_an_epoch(model, train_data_loader, optimizer, criterion, config, logger, writer, epoch)
174
+ epoch_iou_d, _ = val_an_epoch(model, val_data_loader, criterion, config, logger, writer, epoch)
175
+
176
+ if config.LOCAL_RANK == 0:
177
+ ddp = config.WORLD_SIZE > 1
178
+ save(model.module if ddp else model, optimizer, epoch, epoch_iou_d, logger, writer, config)
179
+
180
+ if scheduler is not None:
181
+ if scheduler.min_lr is not None and optimizer.param_groups[0]['lr'] <= scheduler.min_lr:
182
+ continue
183
+ scheduler.step()
184
+ writer.close()
185
+
186
+
187
+ def train_an_epoch(model, train_data_loader, optimizer, criterion, config, logger, writer, epoch=0):
188
+ logger.info(f'Start Train Epoch {epoch}/{config.TRAIN.EPOCHS - 1}')
189
+ model.train()
190
+
191
+ if len(config.MODEL.FINE_TUNE) > 0:
192
+ model.feature_extractor.eval()
193
+
194
+ optimizer.zero_grad()
195
+
196
+ data_len = len(train_data_loader)
197
+ start_i = data_len * epoch * config.WORLD_SIZE
198
+ bar = enumerate(train_data_loader)
199
+ if config.LOCAL_RANK == 0 and config.SHOW_BAR:
200
+ bar = tqdm(bar, total=data_len, ncols=200)
201
+
202
+ device = config.TRAIN.DEVICE
203
+ epoch_loss_d = {}
204
+ for i, gt in bar:
205
+ imgs = gt['image'].to(device, non_blocking=True)
206
+ gt['depth'] = gt['depth'].to(device, non_blocking=True)
207
+ gt['ratio'] = gt['ratio'].to(device, non_blocking=True)
208
+ if 'corner_heat_map' in gt:
209
+ gt['corner_heat_map'] = gt['corner_heat_map'].to(device, non_blocking=True)
210
+ if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device:
211
+ imgs = imgs.type(torch.float16)
212
+ gt['depth'] = gt['depth'].type(torch.float16)
213
+ gt['ratio'] = gt['ratio'].type(torch.float16)
214
+ dt = model(imgs)
215
+ loss, batch_loss_d, epoch_loss_d = calc_criterion(criterion, gt, dt, epoch_loss_d)
216
+ if config.LOCAL_RANK == 0 and config.SHOW_BAR:
217
+ bar.set_postfix(batch_loss_d)
218
+
219
+ optimizer.zero_grad()
220
+ if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device:
221
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
222
+ scaled_loss.backward()
223
+ else:
224
+ loss.backward()
225
+ optimizer.step()
226
+
227
+ global_step = start_i + i * config.WORLD_SIZE + config.LOCAL_RANK
228
+ for key, val in batch_loss_d.items():
229
+ writer.add_scalar(f'TrainBatchLoss/{key}', val, global_step)
230
+
231
+ if config.LOCAL_RANK != 0:
232
+ return
233
+
234
+ epoch_loss_d = dict(zip(epoch_loss_d.keys(), [np.array(epoch_loss_d[k]).mean() for k in epoch_loss_d.keys()]))
235
+ s = 'TrainEpochLoss: '
236
+ for key, val in epoch_loss_d.items():
237
+ writer.add_scalar(f'TrainEpochLoss/{key}', val, epoch)
238
+ s += f" {key}={val}"
239
+ logger.info(s)
240
+ writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], epoch)
241
+ logger.info(f"LearningRate: {optimizer.param_groups[0]['lr']}")
242
+
243
+
244
+ @torch.no_grad()
245
+ def val_an_epoch(model, val_data_loader, criterion, config, logger, writer, epoch=0):
246
+ model.eval()
247
+ logger.info(f'Start Validate Epoch {epoch}/{config.TRAIN.EPOCHS - 1}')
248
+ data_len = len(val_data_loader)
249
+ start_i = data_len * epoch * config.WORLD_SIZE
250
+ bar = enumerate(val_data_loader)
251
+ if config.LOCAL_RANK == 0 and config.SHOW_BAR:
252
+ bar = tqdm(bar, total=data_len, ncols=200)
253
+ device = config.TRAIN.DEVICE
254
+ epoch_loss_d = {}
255
+ epoch_iou_d = {
256
+ 'visible_2d': [],
257
+ 'visible_3d': [],
258
+ 'full_2d': [],
259
+ 'full_3d': [],
260
+ 'height': []
261
+ }
262
+
263
+ epoch_other_d = {
264
+ 'ce': [],
265
+ 'pe': [],
266
+ 'f1': [],
267
+ 'precision': [],
268
+ 'recall': [],
269
+ 'rmse': [],
270
+ 'delta_1': []
271
+ }
272
+
273
+ show_index = np.random.randint(0, data_len)
274
+ for i, gt in bar:
275
+ imgs = gt['image'].to(device, non_blocking=True)
276
+ gt['depth'] = gt['depth'].to(device, non_blocking=True)
277
+ gt['ratio'] = gt['ratio'].to(device, non_blocking=True)
278
+ if 'corner_heat_map' in gt:
279
+ gt['corner_heat_map'] = gt['corner_heat_map'].to(device, non_blocking=True)
280
+ dt = model(imgs)
281
+
282
+ vis_w = config.TRAIN.VIS_WEIGHT
283
+ visualization = False # (config.LOCAL_RANK == 0 and i == show_index) or config.SAVE_EVAL
284
+
285
+ loss, batch_loss_d, epoch_loss_d = calc_criterion(criterion, gt, dt, epoch_loss_d)
286
+
287
+ if config.EVAL.POST_PROCESSING is not None:
288
+ depth = tensor2np(dt['depth'])
289
+ dt['processed_xyz'] = post_process(depth, type_name=config.EVAL.POST_PROCESSING,
290
+ need_cube=config.EVAL.FORCE_CUBE)
291
+
292
+ if config.EVAL.FORCE_CUBE and config.EVAL.NEED_CPE:
293
+ ce = calc_ce(tensor2np_d(dt), tensor2np_d(gt))
294
+ pe = calc_pe(tensor2np_d(dt), tensor2np_d(gt))
295
+
296
+ epoch_other_d['ce'].append(ce)
297
+ epoch_other_d['pe'].append(pe)
298
+
299
+ if config.EVAL.NEED_F1:
300
+ f1, precision, recall = calc_f1_score(tensor2np_d(dt), tensor2np_d(gt))
301
+ epoch_other_d['f1'].append(f1)
302
+ epoch_other_d['precision'].append(precision)
303
+ epoch_other_d['recall'].append(recall)
304
+
305
+ if config.EVAL.NEED_RMSE:
306
+ rmse, delta_1 = calc_rmse_delta_1(tensor2np_d(dt), tensor2np_d(gt))
307
+ epoch_other_d['rmse'].append(rmse)
308
+ epoch_other_d['delta_1'].append(delta_1)
309
+
310
+ visb_iou, full_iou, iou_height, pano_bds, full_iou_2ds = calc_accuracy(tensor2np_d(dt), tensor2np_d(gt),
311
+ visualization, h=vis_w // 2)
312
+ epoch_iou_d['visible_2d'].append(visb_iou[0])
313
+ epoch_iou_d['visible_3d'].append(visb_iou[1])
314
+ epoch_iou_d['full_2d'].append(full_iou[0])
315
+ epoch_iou_d['full_3d'].append(full_iou[1])
316
+ epoch_iou_d['height'].append(iou_height)
317
+
318
+ if config.LOCAL_RANK == 0 and config.SHOW_BAR:
319
+ bar.set_postfix(batch_loss_d)
320
+
321
+ global_step = start_i + i * config.WORLD_SIZE + config.LOCAL_RANK
322
+
323
+ if writer:
324
+ for key, val in batch_loss_d.items():
325
+ writer.add_scalar(f'ValBatchLoss/{key}', val, global_step)
326
+
327
+ if not visualization:
328
+ continue
329
+
330
+ gt_grad_imgs, dt_grad_imgs = show_depth_normal_grad(dt, gt, device, vis_w)
331
+
332
+ dt_heat_map_imgs = None
333
+ gt_heat_map_imgs = None
334
+ if 'corner_heat_map' in gt:
335
+ dt_heat_map_imgs, gt_heat_map_imgs = show_heat_map(dt, gt, vis_w)
336
+
337
+ if config.TRAIN.VIS_MERGE or config.SAVE_EVAL:
338
+ imgs = []
339
+ for j in range(len(pano_bds)):
340
+ # floorplan = np.concatenate([visb_iou[2][j], full_iou[2][j]], axis=-1)
341
+ floorplan = full_iou[2][j]
342
+ margin_w = int(floorplan.shape[-1] * (60/512))
343
+ floorplan = floorplan[:, :, margin_w:-margin_w]
344
+
345
+ grad_h = dt_grad_imgs[0].shape[1]
346
+ vis_merge = [
347
+ gt_grad_imgs[j],
348
+ pano_bds[j][:, grad_h:-grad_h],
349
+ dt_grad_imgs[j]
350
+ ]
351
+ if 'corner_heat_map' in gt:
352
+ vis_merge = [dt_heat_map_imgs[j], gt_heat_map_imgs[j]] + vis_merge
353
+ img = np.concatenate(vis_merge, axis=-2)
354
+
355
+ img = np.concatenate([img, ], axis=-1)
356
+ # img = gt_grad_imgs[j]
357
+ imgs.append(img)
358
+ if writer:
359
+ writer.add_images('VIS/Merge', np.array(imgs), global_step)
360
+
361
+ if config.SAVE_EVAL:
362
+ for k in range(len(imgs)):
363
+ img = imgs[k] * 255.0
364
+ save_path = os.path.join(config.CKPT.RESULT_DIR, f"{gt['id'][k]}_{full_iou_2ds[k]:.5f}.png")
365
+ Image.fromarray(img.transpose(1, 2, 0).astype(np.uint8)).save(save_path)
366
+
367
+ elif writer:
368
+ writer.add_images('IoU/Visible_Floorplan', visb_iou[2], global_step)
369
+ writer.add_images('IoU/Full_Floorplan', full_iou[2], global_step)
370
+ writer.add_images('IoU/Boundary', pano_bds, global_step)
371
+ writer.add_images('Grad/gt', gt_grad_imgs, global_step)
372
+ writer.add_images('Grad/dt', dt_grad_imgs, global_step)
373
+
374
+ if config.LOCAL_RANK != 0:
375
+ return
376
+
377
+ epoch_loss_d = dict(zip(epoch_loss_d.keys(), [np.array(epoch_loss_d[k]).mean() for k in epoch_loss_d.keys()]))
378
+ s = 'ValEpochLoss: '
379
+ for key, val in epoch_loss_d.items():
380
+ if writer:
381
+ writer.add_scalar(f'ValEpochLoss/{key}', val, epoch)
382
+ s += f" {key}={val}"
383
+ logger.info(s)
384
+
385
+ epoch_iou_d = dict(zip(epoch_iou_d.keys(), [np.array(epoch_iou_d[k]).mean() for k in epoch_iou_d.keys()]))
386
+ s = 'ValEpochIoU: '
387
+ for key, val in epoch_iou_d.items():
388
+ if writer:
389
+ writer.add_scalar(f'ValEpochIoU/{key}', val, epoch)
390
+ s += f" {key}={val}"
391
+ logger.info(s)
392
+ epoch_other_d = dict(zip(epoch_other_d.keys(),
393
+ [np.array(epoch_other_d[k]).mean() if len(epoch_other_d[k]) > 0 else 0 for k in
394
+ epoch_other_d.keys()]))
395
+
396
+ logger.info(f'other acc: {epoch_other_d}')
397
+ return epoch_iou_d, epoch_other_d
398
+
399
+
400
+ if __name__ == '__main__':
401
+ main()
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from models.lgt_net import LGT_Net
models/base_model.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/07/17
3
+ @description:
4
+ """
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ import datetime
9
+
10
+
11
+ class BaseModule(nn.Module):
12
+ def __init__(self, ckpt_dir=None):
13
+ super().__init__()
14
+
15
+ self.ckpt_dir = ckpt_dir
16
+
17
+ if ckpt_dir:
18
+ if not os.path.exists(ckpt_dir):
19
+ os.makedirs(ckpt_dir)
20
+ else:
21
+ self.model_lst = [x for x in sorted(os.listdir(self.ckpt_dir)) if x.endswith('.pkl')]
22
+
23
+ self.last_model_path = None
24
+ self.best_model_path = None
25
+ self.best_accuracy = -float('inf')
26
+ self.acc_d = {}
27
+
28
+ def show_parameter_number(self, logger):
29
+ total = sum(p.numel() for p in self.parameters())
30
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
31
+ logger.info('{} parameter total:{:,}, trainable:{:,}'.format(self._get_name(), total, trainable))
32
+
33
+ def load(self, device, logger, optimizer=None, best=False):
34
+ if len(self.model_lst) == 0:
35
+ logger.info('*'*50)
36
+ logger.info("Empty model folder! Using initial weights")
37
+ logger.info('*'*50)
38
+ return 0
39
+
40
+ last_model_lst = list(filter(lambda n: '_last_' in n, self.model_lst))
41
+ best_model_lst = list(filter(lambda n: '_best_' in n, self.model_lst))
42
+
43
+ if len(last_model_lst) == 0 and len(best_model_lst) == 0:
44
+ logger.info('*'*50)
45
+ ckpt_path = os.path.join(self.ckpt_dir, self.model_lst[0])
46
+ logger.info(f"Load: {ckpt_path}")
47
+ checkpoint = torch.load(ckpt_path, map_location=torch.device(device))
48
+ self.load_state_dict(checkpoint, strict=False)
49
+ logger.info('*'*50)
50
+ return 0
51
+
52
+ checkpoint = None
53
+ if len(last_model_lst) > 0:
54
+ self.last_model_path = os.path.join(self.ckpt_dir, last_model_lst[-1])
55
+ checkpoint = torch.load(self.last_model_path, map_location=torch.device(device))
56
+ self.best_accuracy = checkpoint['accuracy']
57
+ self.acc_d = checkpoint['acc_d']
58
+
59
+ if len(best_model_lst) > 0:
60
+ self.best_model_path = os.path.join(self.ckpt_dir, best_model_lst[-1])
61
+ best_checkpoint = torch.load(self.best_model_path, map_location=torch.device(device))
62
+ self.best_accuracy = best_checkpoint['accuracy']
63
+ self.acc_d = best_checkpoint['acc_d']
64
+ if best:
65
+ checkpoint = best_checkpoint
66
+
67
+ for k in self.acc_d:
68
+ if isinstance(self.acc_d[k], float):
69
+ self.acc_d[k] = {
70
+ 'acc': self.acc_d[k],
71
+ 'epoch': checkpoint['epoch']
72
+ }
73
+
74
+ if checkpoint is None:
75
+ logger.error("Invalid checkpoint")
76
+ return
77
+
78
+ self.load_state_dict(checkpoint['net'], strict=False)
79
+ if optimizer and not best: # best的时候使用新的优化器比如从adam->sgd
80
+ logger.info('Load optimizer')
81
+ optimizer.load_state_dict(checkpoint['optimizer'])
82
+ for state in optimizer.state.values():
83
+ for k, v in state.items():
84
+ if torch.is_tensor(v):
85
+ state[k] = v.to(device)
86
+
87
+ logger.info('*'*50)
88
+ if best:
89
+ logger.info(f"Lode best: {self.best_model_path}")
90
+ else:
91
+ logger.info(f"Lode last: {self.last_model_path}")
92
+
93
+ logger.info(f"Best accuracy: {self.best_accuracy}")
94
+ logger.info(f"Last epoch: {checkpoint['epoch'] + 1}")
95
+ logger.info('*'*50)
96
+ return checkpoint['epoch'] + 1
97
+
98
+ def update_acc(self, acc_d, epoch, logger):
99
+ logger.info("-" * 100)
100
+ for k in acc_d:
101
+ if k not in self.acc_d.keys() or acc_d[k] > self.acc_d[k]['acc']:
102
+ self.acc_d[k] = {
103
+ 'acc': acc_d[k],
104
+ 'epoch': epoch
105
+ }
106
+ logger.info(f"Update ACC: {k} {self.acc_d[k]['acc']:.4f}({self.acc_d[k]['epoch']}-{epoch})")
107
+ logger.info("-" * 100)
108
+
109
+ def save(self, optim, epoch, accuracy, logger, replace=True, acc_d=None, config=None):
110
+ """
111
+
112
+ :param config:
113
+ :param optim:
114
+ :param epoch:
115
+ :param accuracy:
116
+ :param logger:
117
+ :param replace:
118
+ :param acc_d: 其他评估数据,visible_2/3d, full_2/3d, rmse...
119
+ :return:
120
+ """
121
+ if acc_d:
122
+ self.update_acc(acc_d, epoch, logger)
123
+ name = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S_last_{:.4f}_{}'.format(accuracy, epoch))
124
+ name = f"model_{name}.pkl"
125
+ checkpoint = {
126
+ 'net': self.state_dict(),
127
+ 'optimizer': optim.state_dict(),
128
+ 'epoch': epoch,
129
+ 'accuracy': accuracy,
130
+ 'acc_d': acc_d
131
+ }
132
+ # FIXME:: delete always true
133
+ if (True or config.MODEL.SAVE_LAST) and epoch % config.TRAIN.SAVE_FREQ == 0:
134
+ if replace and self.last_model_path and os.path.exists(self.last_model_path):
135
+ os.remove(self.last_model_path)
136
+ self.last_model_path = os.path.join(self.ckpt_dir, name)
137
+ torch.save(checkpoint, self.last_model_path)
138
+ logger.info(f"Saved last model: {self.last_model_path}")
139
+
140
+ if accuracy > self.best_accuracy:
141
+ self.best_accuracy = accuracy
142
+ # FIXME:: delete always true
143
+ if True or config.MODEL.SAVE_BEST:
144
+ if replace and self.best_model_path and os.path.exists(self.best_model_path):
145
+ os.remove(self.best_model_path)
146
+ self.best_model_path = os.path.join(self.ckpt_dir, name.replace('last', 'best'))
147
+ torch.save(checkpoint, self.best_model_path)
148
+ logger.info("#" * 100)
149
+ logger.info(f"Saved best model: {self.best_model_path}")
150
+ logger.info("#" * 100)
models/build.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/07/18
3
+ @description:
4
+ """
5
+ import os
6
+ import models
7
+ import torch.distributed as dist
8
+ import torch
9
+
10
+ from torch.nn import init
11
+ from torch.optim import lr_scheduler
12
+ from utils.time_watch import TimeWatch
13
+ from models.other.optimizer import build_optimizer
14
+ from models.other.criterion import build_criterion
15
+
16
+
17
+ def build_model(config, logger):
18
+ name = config.MODEL.NAME
19
+ w = TimeWatch(f"Build model: {name}", logger)
20
+
21
+ ddp = config.WORLD_SIZE > 1
22
+ if ddp:
23
+ logger.info(f"use ddp")
24
+ dist.init_process_group("nccl", init_method='tcp://127.0.0.1:23456', rank=config.LOCAL_RANK,
25
+ world_size=config.WORLD_SIZE)
26
+
27
+ device = config.TRAIN.DEVICE
28
+ logger.info(f"Creating model: {name} to device:{device}, args:{config.MODEL.ARGS[0]}")
29
+
30
+ net = getattr(models, name)
31
+ ckpt_dir = os.path.abspath(os.path.join(config.CKPT.DIR, os.pardir)) if config.DEBUG else config.CKPT.DIR
32
+ if len(config.MODEL.ARGS) != 0:
33
+ model = net(ckpt_dir=ckpt_dir, **config.MODEL.ARGS[0])
34
+ else:
35
+ model = net(ckpt_dir=ckpt_dir)
36
+ logger.info(f'model dropout: {model.dropout_d}')
37
+ model = model.to(device)
38
+ optimizer = None
39
+ scheduler = None
40
+
41
+ if config.MODE == 'train':
42
+ optimizer = build_optimizer(config, model, logger)
43
+
44
+ config.defrost()
45
+ config.TRAIN.START_EPOCH = model.load(device, logger, optimizer, best=config.MODE != 'train' or not config.TRAIN.RESUME_LAST)
46
+ config.freeze()
47
+
48
+ if config.MODE == 'train' and len(config.MODEL.FINE_TUNE) > 0:
49
+ for param in model.parameters():
50
+ param.requires_grad = False
51
+ for layer in config.MODEL.FINE_TUNE:
52
+ logger.info(f'Fine-tune: {layer}')
53
+ getattr(model, layer).requires_grad_(requires_grad=True)
54
+ getattr(model, layer).reset_parameters()
55
+
56
+ model.show_parameter_number(logger)
57
+
58
+ if config.MODE == 'train':
59
+ if len(config.TRAIN.LR_SCHEDULER.NAME) > 0:
60
+ if 'last_epoch' not in config.TRAIN.LR_SCHEDULER.ARGS[0].keys():
61
+ config.TRAIN.LR_SCHEDULER.ARGS[0]['last_epoch'] = config.TRAIN.START_EPOCH - 1
62
+
63
+ scheduler = getattr(lr_scheduler, config.TRAIN.LR_SCHEDULER.NAME)(optimizer=optimizer,
64
+ **config.TRAIN.LR_SCHEDULER.ARGS[0])
65
+ logger.info(f"Use scheduler: name:{config.TRAIN.LR_SCHEDULER.NAME} args: {config.TRAIN.LR_SCHEDULER.ARGS[0]}")
66
+ logger.info(f"Current scheduler last lr: {scheduler.get_last_lr()}")
67
+ else:
68
+ scheduler = None
69
+
70
+ if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device:
71
+ import apex
72
+ logger.info(f"use amp:{config.AMP_OPT_LEVEL}")
73
+ model, optimizer = apex.amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL, verbosity=0)
74
+ if ddp:
75
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.TRAIN.DEVICE],
76
+ broadcast_buffers=True) # use rank:0 bn
77
+
78
+ criterion = build_criterion(config, logger)
79
+ if optimizer is not None:
80
+ logger.info(f"Finally lr: {optimizer.param_groups[0]['lr']}")
81
+ return model, optimizer, criterion, scheduler
models/lgt_net.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn
2
+ import torch
3
+ import torch.nn as nn
4
+ import models.modules as modules
5
+ import numpy as np
6
+
7
+ from models.base_model import BaseModule
8
+ from models.modules.horizon_net_feature_extractor import HorizonNetFeatureExtractor
9
+ from models.modules.patch_feature_extractor import PatchFeatureExtractor
10
+ from utils.conversion import uv2depth, get_u, lonlat2depth, get_lon, lonlat2uv
11
+ from utils.height import calc_ceil_ratio
12
+ from utils.misc import tensor2np
13
+
14
+
15
+ class LGT_Net(BaseModule):
16
+ def __init__(self, ckpt_dir=None, backbone='resnet50', dropout=0.0, output_name='LGT',
17
+ decoder_name='Transformer', win_size=8, depth=6,
18
+ ape=None, rpe=None, corner_heat_map=False, rpe_pos=1):
19
+ super().__init__(ckpt_dir)
20
+
21
+ self.patch_num = 256
22
+ self.patch_dim = 1024
23
+ self.decoder_name = decoder_name
24
+ self.output_name = output_name
25
+ self.corner_heat_map = corner_heat_map
26
+ self.dropout_d = dropout
27
+
28
+ if backbone == 'patch':
29
+ self.feature_extractor = PatchFeatureExtractor(patch_num=self.patch_num, input_shape=[3, 512, 1024])
30
+ else:
31
+ # feature extractor
32
+ self.feature_extractor = HorizonNetFeatureExtractor(backbone)
33
+
34
+ if 'Transformer' in self.decoder_name:
35
+ # transformer encoder
36
+ transformer_dim = self.patch_dim
37
+ transformer_layers = depth
38
+ transformer_heads = 8
39
+ transformer_head_dim = transformer_dim // transformer_heads
40
+ transformer_ff_dim = 2048
41
+ rpe = None if rpe == 'None' else rpe
42
+ self.transformer = getattr(modules, decoder_name)(dim=transformer_dim, depth=transformer_layers,
43
+ heads=transformer_heads, dim_head=transformer_head_dim,
44
+ mlp_dim=transformer_ff_dim, win_size=win_size,
45
+ dropout=self.dropout_d, patch_num=self.patch_num,
46
+ ape=ape, rpe=rpe, rpe_pos=rpe_pos)
47
+ elif self.decoder_name == 'LSTM':
48
+ self.bi_rnn = nn.LSTM(input_size=self.feature_extractor.c_last,
49
+ hidden_size=self.patch_dim // 2,
50
+ num_layers=2,
51
+ dropout=self.dropout_d,
52
+ batch_first=False,
53
+ bidirectional=True)
54
+ self.drop_out = nn.Dropout(self.dropout_d)
55
+ else:
56
+ raise NotImplementedError("Only support *Transformer and LSTM")
57
+
58
+ if self.output_name == 'LGT':
59
+ # omnidirectional-geometry aware output
60
+ self.linear_depth_output = nn.Linear(in_features=self.patch_dim, out_features=1)
61
+ self.linear_ratio = nn.Linear(in_features=self.patch_dim, out_features=1)
62
+ self.linear_ratio_output = nn.Linear(in_features=self.patch_num, out_features=1)
63
+ elif self.output_name == 'LED' or self.output_name == 'Horizon':
64
+ # horizon-depth or latitude output
65
+ self.linear = nn.Linear(in_features=self.patch_dim, out_features=2)
66
+ else:
67
+ raise NotImplementedError("Unknown output")
68
+
69
+ if self.corner_heat_map:
70
+ # corners heat map output
71
+ self.linear_corner_heat_map_output = nn.Linear(in_features=self.patch_dim, out_features=1)
72
+
73
+ self.name = f"{self.decoder_name}_{self.output_name}_Net"
74
+
75
+ def lgt_output(self, x):
76
+ """
77
+ :param x: [ b, 256(patch_num), 1024(d)]
78
+ :return: {
79
+ 'depth': [b, 256(patch_num & d)]
80
+ 'ratio': [b, 1(d)]
81
+ }
82
+ """
83
+ depth = self.linear_depth_output(x) # [b, 256(patch_num), 1(d)]
84
+ depth = depth.view(-1, self.patch_num) # [b, 256(patch_num & d)]
85
+
86
+ # ratio represent room height
87
+ ratio = self.linear_ratio(x) # [b, 256(patch_num), 1(d)]
88
+ ratio = ratio.view(-1, self.patch_num) # [b, 256(patch_num & d)]
89
+ ratio = self.linear_ratio_output(ratio) # [b, 1(d)]
90
+ output = {
91
+ 'depth': depth,
92
+ 'ratio': ratio
93
+ }
94
+ return output
95
+
96
+ def led_output(self, x):
97
+ """
98
+ :param x: [ b, 256(patch_num), 1024(d)]
99
+ :return: {
100
+ 'depth': [b, 256(patch_num)]
101
+ 'ceil_depth': [b, 256(patch_num)]
102
+ 'ratio': [b, 1(d)]
103
+ }
104
+ """
105
+ bon = self.linear(x) # [b, 256(patch_num), 2(d)]
106
+ bon = bon.permute(0, 2, 1) # [b, 2(d), 256(patch_num)]
107
+ bon = torch.sigmoid(bon)
108
+
109
+ ceil_v = bon[:, 0, :] * -0.5 + 0.5 # [b, 256(patch_num)]
110
+ floor_v = bon[:, 1, :] * 0.5 + 0.5 # [b, 256(patch_num)]
111
+ u = get_u(w=self.patch_num, is_np=False, b=ceil_v.shape[0]).to(ceil_v.device)
112
+ ceil_boundary = torch.stack((u, ceil_v), axis=-1) # [b, 256(patch_num), 2]
113
+ floor_boundary = torch.stack((u, floor_v), axis=-1) # [b, 256(patch_num), 2]
114
+ output = {
115
+ 'depth': uv2depth(floor_boundary), # [b, 256(patch_num)]
116
+ 'ceil_depth': uv2depth(ceil_boundary), # [b, 256(patch_num)]
117
+ }
118
+ # print(output['depth'].mean())
119
+ if not self.training:
120
+ # [b, 1(d)]
121
+ output['ratio'] = calc_ceil_ratio([tensor2np(ceil_boundary), tensor2np(floor_boundary)], mode='lsq').reshape(-1, 1)
122
+ return output
123
+
124
+ def horizon_output(self, x):
125
+ """
126
+ :param x: [ b, 256(patch_num), 1024(d)]
127
+ :return: {
128
+ 'floor_boundary': [b, 256(patch_num)]
129
+ 'ceil_boundary': [b, 256(patch_num)]
130
+ }
131
+ """
132
+ bon = self.linear(x) # [b, 256(patch_num), 2(d)]
133
+ bon = bon.permute(0, 2, 1) # [b, 2(d), 256(patch_num)]
134
+
135
+ output = {
136
+ 'boundary': bon
137
+ }
138
+ if not self.training:
139
+ lon = get_lon(w=self.patch_num, is_np=False, b=bon.shape[0]).to(bon.device)
140
+ floor_lat = torch.clip(bon[:, 0, :], 1e-4, np.pi / 2)
141
+ ceil_lat = torch.clip(bon[:, 1, :], -np.pi / 2, -1e-4)
142
+ floor_lonlat = torch.stack((lon, floor_lat), axis=-1) # [b, 256(patch_num), 2]
143
+ ceil_lonlat = torch.stack((lon, ceil_lat), axis=-1) # [b, 256(patch_num), 2]
144
+ output['depth'] = lonlat2depth(floor_lonlat)
145
+ output['ratio'] = calc_ceil_ratio([tensor2np(lonlat2uv(ceil_lonlat)),
146
+ tensor2np(lonlat2uv(floor_lonlat))], mode='mean').reshape(-1, 1)
147
+ return output
148
+
149
+ def forward(self, x):
150
+ """
151
+ :param x: [b, 3(d), 512(h), 1024(w)]
152
+ :return: {
153
+ 'depth': [b, 256(patch_num & d)]
154
+ 'ratio': [b, 1(d)]
155
+ }
156
+ """
157
+
158
+ # feature extractor
159
+ x = self.feature_extractor(x) # [b 1024(d) 256(w)]
160
+
161
+ if 'Transformer' in self.decoder_name:
162
+ # transformer decoder
163
+ x = x.permute(0, 2, 1) # [b 256(patch_num) 1024(d)]
164
+ x = self.transformer(x) # [b 256(patch_num) 1024(d)]
165
+ elif self.decoder_name == 'LSTM':
166
+ # lstm decoder
167
+ x = x.permute(2, 0, 1) # [256(patch_num), b, 1024(d)]
168
+ self.bi_rnn.flatten_parameters()
169
+ x, _ = self.bi_rnn(x) # [256(patch_num & seq_len), b, 1024(d)]
170
+ x = x.permute(1, 0, 2) # [b, 256(patch_num), 1024(d)]
171
+ x = self.drop_out(x)
172
+
173
+ output = None
174
+ if self.output_name == 'LGT':
175
+ # plt output
176
+ output = self.lgt_output(x)
177
+
178
+ elif self.output_name == 'LED':
179
+ # led output
180
+ output = self.led_output(x)
181
+
182
+ elif self.output_name == 'Horizon':
183
+ # led output
184
+ output = self.horizon_output(x)
185
+
186
+ if self.corner_heat_map:
187
+ corner_heat_map = self.linear_corner_heat_map_output(x) # [b, 256(patch_num), 1]
188
+ corner_heat_map = corner_heat_map.view(-1, self.patch_num)
189
+ corner_heat_map = torch.sigmoid(corner_heat_map)
190
+ output['corner_heat_map'] = corner_heat_map
191
+
192
+ return output
193
+
194
+
195
+ if __name__ == '__main__':
196
+ from PIL import Image
197
+ import numpy as np
198
+ from models.other.init_env import init_env
199
+
200
+ init_env(0, deterministic=True)
201
+
202
+ net = LGT_Net()
203
+
204
+ total = sum(p.numel() for p in net.parameters())
205
+ trainable = sum(p.numel() for p in net.parameters() if p.requires_grad)
206
+ print('parameter total:{:,}, trainable:{:,}'.format(total, trainable))
207
+
208
+ img = np.array(Image.open("../src/demo.png")).transpose((2, 0, 1))
209
+ input = torch.Tensor([img]) # 1 3 512 1024
210
+ output = net(input)
211
+
212
+ print(output['depth'].shape) # 1 256
213
+ print(output['ratio'].shape) # 1 1
models/modules/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/09/01
3
+ @description:
4
+ """
5
+
6
+ from models.modules.swin_transformer import Swin_Transformer
7
+ from models.modules.swg_transformer import SWG_Transformer
8
+ from models.modules.transformer import Transformer
models/modules/conv_transformer.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from torch import nn, einsum
5
+ from einops import rearrange
6
+
7
+
8
+ class PreNorm(nn.Module):
9
+ def __init__(self, dim, fn):
10
+ super().__init__()
11
+ self.norm = nn.LayerNorm(dim)
12
+ self.fn = fn
13
+
14
+ def forward(self, x, **kwargs):
15
+ return self.fn(self.norm(x), **kwargs)
16
+
17
+
18
+ class GELU(nn.Module):
19
+ def forward(self, input):
20
+ return F.gelu(input)
21
+
22
+
23
+ class Attend(nn.Module):
24
+
25
+ def __init__(self, dim=None):
26
+ super().__init__()
27
+ self.dim = dim
28
+
29
+ def forward(self, input):
30
+ return F.softmax(input, dim=self.dim, dtype=input.dtype)
31
+
32
+
33
+ class FeedForward(nn.Module):
34
+ def __init__(self, dim, hidden_dim, dropout=0.):
35
+ super().__init__()
36
+ self.net = nn.Sequential(
37
+ nn.Linear(dim, hidden_dim),
38
+ GELU(),
39
+ nn.Dropout(dropout),
40
+ nn.Linear(hidden_dim, dim),
41
+ nn.Dropout(dropout)
42
+ )
43
+
44
+ def forward(self, x):
45
+ return self.net(x)
46
+
47
+
48
+ class Attention(nn.Module):
49
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
50
+ super().__init__()
51
+ inner_dim = dim_head * heads
52
+ project_out = not (heads == 1 and dim_head == dim)
53
+
54
+ self.heads = heads
55
+ self.scale = dim_head ** -0.5
56
+
57
+ self.attend = Attend(dim=-1)
58
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
59
+
60
+ self.to_out = nn.Sequential(
61
+ nn.Linear(inner_dim, dim),
62
+ nn.Dropout(dropout)
63
+ ) if project_out else nn.Identity()
64
+
65
+ def forward(self, x):
66
+ b, n, _, h = *x.shape, self.heads
67
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
68
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
69
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
70
+ attn = self.attend(dots)
71
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
72
+ out = rearrange(out, 'b h n d -> b n (h d)')
73
+ return self.to_out(out)
74
+
75
+
76
+ class Conv(nn.Module):
77
+ def __init__(self, dim, dropout=0.):
78
+ super().__init__()
79
+ self.dim = dim
80
+ self.net = nn.Sequential(
81
+ nn.Conv1d(dim, dim, kernel_size=3, stride=1, padding=0),
82
+ nn.Dropout(dropout)
83
+ )
84
+
85
+ def forward(self, x):
86
+ x = x.transpose(1, 2)
87
+ x = torch.cat([x[..., -1:], x, x[..., :1]], dim=-1)
88
+ x = self.net(x)
89
+ return x.transpose(1, 2)
90
+
91
+
92
+ class ConvTransformer(nn.Module):
93
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
94
+ super().__init__()
95
+ self.layers = nn.ModuleList([])
96
+ for _ in range(depth):
97
+ self.layers.append(nn.ModuleList([
98
+ PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
99
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)),
100
+ PreNorm(dim, Conv(dim, dropout=dropout))
101
+ ]))
102
+
103
+ def forward(self, x):
104
+ for attn, ff, cov in self.layers:
105
+ x = attn(x) + x
106
+ x = ff(x) + x
107
+ x = cov(x) + x
108
+ return x
109
+
110
+
111
+ if __name__ == '__main__':
112
+ token_dim = 1024
113
+ toke_len = 256
114
+
115
+ transformer = ConvTransformer(dim=token_dim,
116
+ depth=6,
117
+ heads=16,
118
+ dim_head=64,
119
+ mlp_dim=2048,
120
+ dropout=0.1)
121
+
122
+ total = sum(p.numel() for p in transformer.parameters())
123
+ trainable = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
124
+ print('parameter total:{:,}, trainable:{:,}'.format(total, trainable))
125
+
126
+ input = torch.randn(1, toke_len, token_dim)
127
+ output = transformer(input)
128
+ print(output.shape)
models/modules/horizon_net_feature_extractor.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author:
3
+ @Date: 2021/07/17
4
+ @description: Use the feature extractor proposed by HorizonNet
5
+ """
6
+
7
+ import numpy as np
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torchvision.models as models
13
+ import functools
14
+ from models.base_model import BaseModule
15
+
16
+ ENCODER_RESNET = [
17
+ 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
18
+ 'resnext50_32x4d', 'resnext101_32x8d'
19
+ ]
20
+ ENCODER_DENSENET = [
21
+ 'densenet121', 'densenet169', 'densenet161', 'densenet201'
22
+ ]
23
+
24
+
25
+ def lr_pad(x, padding=1):
26
+ ''' Pad left/right-most to each other instead of zero padding '''
27
+ return torch.cat([x[..., -padding:], x, x[..., :padding]], dim=3)
28
+
29
+
30
+ class LR_PAD(nn.Module):
31
+ ''' Pad left/right-most to each other instead of zero padding '''
32
+
33
+ def __init__(self, padding=1):
34
+ super(LR_PAD, self).__init__()
35
+ self.padding = padding
36
+
37
+ def forward(self, x):
38
+ return lr_pad(x, self.padding)
39
+
40
+
41
+ def wrap_lr_pad(net):
42
+ for name, m in net.named_modules():
43
+ if not isinstance(m, nn.Conv2d):
44
+ continue
45
+ if m.padding[1] == 0:
46
+ continue
47
+ w_pad = int(m.padding[1])
48
+ m.padding = (m.padding[0], 0) # weight padding is 0, LR_PAD then use valid padding will keep dim of weight
49
+ names = name.split('.')
50
+
51
+ root = functools.reduce(lambda o, i: getattr(o, i), [net] + names[:-1])
52
+ setattr(
53
+ root, names[-1],
54
+ nn.Sequential(LR_PAD(w_pad), m)
55
+ )
56
+
57
+
58
+ '''
59
+ Encoder
60
+ '''
61
+
62
+
63
+ class Resnet(nn.Module):
64
+ def __init__(self, backbone='resnet50', pretrained=True):
65
+ super(Resnet, self).__init__()
66
+ assert backbone in ENCODER_RESNET
67
+ self.encoder = getattr(models, backbone)(pretrained=pretrained)
68
+ del self.encoder.fc, self.encoder.avgpool
69
+
70
+ def forward(self, x):
71
+ features = []
72
+ x = self.encoder.conv1(x)
73
+ x = self.encoder.bn1(x)
74
+ x = self.encoder.relu(x)
75
+ x = self.encoder.maxpool(x)
76
+
77
+ x = self.encoder.layer1(x)
78
+ features.append(x) # 1/4
79
+ x = self.encoder.layer2(x)
80
+ features.append(x) # 1/8
81
+ x = self.encoder.layer3(x)
82
+ features.append(x) # 1/16
83
+ x = self.encoder.layer4(x)
84
+ features.append(x) # 1/32
85
+ return features
86
+
87
+ def list_blocks(self):
88
+ lst = [m for m in self.encoder.children()]
89
+ block0 = lst[:4]
90
+ block1 = lst[4:5]
91
+ block2 = lst[5:6]
92
+ block3 = lst[6:7]
93
+ block4 = lst[7:8]
94
+ return block0, block1, block2, block3, block4
95
+
96
+
97
+ class Densenet(nn.Module):
98
+ def __init__(self, backbone='densenet169', pretrained=True):
99
+ super(Densenet, self).__init__()
100
+ assert backbone in ENCODER_DENSENET
101
+ self.encoder = getattr(models, backbone)(pretrained=pretrained)
102
+ self.final_relu = nn.ReLU(inplace=True)
103
+ del self.encoder.classifier
104
+
105
+ def forward(self, x):
106
+ lst = []
107
+ for m in self.encoder.features.children():
108
+ x = m(x)
109
+ lst.append(x)
110
+ features = [lst[4], lst[6], lst[8], self.final_relu(lst[11])]
111
+ return features
112
+
113
+ def list_blocks(self):
114
+ lst = [m for m in self.encoder.features.children()]
115
+ block0 = lst[:4]
116
+ block1 = lst[4:6]
117
+ block2 = lst[6:8]
118
+ block3 = lst[8:10]
119
+ block4 = lst[10:]
120
+ return block0, block1, block2, block3, block4
121
+
122
+
123
+ '''
124
+ Decoder
125
+ '''
126
+
127
+
128
+ class ConvCompressH(nn.Module):
129
+ ''' Reduce feature height by factor of two '''
130
+
131
+ def __init__(self, in_c, out_c, ks=3):
132
+ super(ConvCompressH, self).__init__()
133
+ assert ks % 2 == 1
134
+ self.layers = nn.Sequential(
135
+ nn.Conv2d(in_c, out_c, kernel_size=ks, stride=(2, 1), padding=ks // 2),
136
+ nn.BatchNorm2d(out_c),
137
+ nn.ReLU(inplace=True),
138
+ )
139
+
140
+ def forward(self, x):
141
+ return self.layers(x)
142
+
143
+
144
+ class GlobalHeightConv(nn.Module):
145
+ def __init__(self, in_c, out_c):
146
+ super(GlobalHeightConv, self).__init__()
147
+ self.layer = nn.Sequential(
148
+ ConvCompressH(in_c, in_c // 2),
149
+ ConvCompressH(in_c // 2, in_c // 2),
150
+ ConvCompressH(in_c // 2, in_c // 4),
151
+ ConvCompressH(in_c // 4, out_c),
152
+ )
153
+
154
+ def forward(self, x, out_w):
155
+ x = self.layer(x)
156
+
157
+ factor = out_w // x.shape[3]
158
+ x = torch.cat([x[..., -1:], x, x[..., :1]], 3) # 先补左右,相当于warp模式,然后进行插值
159
+ d_type = x.dtype
160
+ x = F.interpolate(x, size=(x.shape[2], out_w + 2 * factor), mode='bilinear', align_corners=False)
161
+ # if x.dtype != d_type:
162
+ # x = x.type(d_type)
163
+ x = x[..., factor:-factor]
164
+ return x
165
+
166
+
167
+ class GlobalHeightStage(nn.Module):
168
+ def __init__(self, c1, c2, c3, c4, out_scale=8):
169
+ ''' Process 4 blocks from encoder to single multiscale features '''
170
+ super(GlobalHeightStage, self).__init__()
171
+ self.cs = c1, c2, c3, c4
172
+ self.out_scale = out_scale
173
+ self.ghc_lst = nn.ModuleList([
174
+ GlobalHeightConv(c1, c1 // out_scale),
175
+ GlobalHeightConv(c2, c2 // out_scale),
176
+ GlobalHeightConv(c3, c3 // out_scale),
177
+ GlobalHeightConv(c4, c4 // out_scale),
178
+ ])
179
+
180
+ def forward(self, conv_list, out_w):
181
+ assert len(conv_list) == 4
182
+ bs = conv_list[0].shape[0]
183
+ feature = torch.cat([
184
+ f(x, out_w).reshape(bs, -1, out_w)
185
+ for f, x, out_c in zip(self.ghc_lst, conv_list, self.cs)
186
+ ], dim=1)
187
+ # conv_list:
188
+ # 0 [b, 256(d), 128(h), 256(w)] ->(4*{conv3*3 step2*1} : d/8 h/16)-> [b 32(d) 8(h) 256(w)]
189
+ # 1 [b, 512(d), 64(h), 128(w)] ->(4*{conv3*3 step2*1} : d/8 h/16)-> [b 64(d) 4(h) 128(w)]
190
+ # 2 [b, 1024(d), 32(h), 64(w)] ->(4*{conv3*3 step2*1} : d/8 h/16)-> [b 128(d) 2(h) 64(w)]
191
+ # 3 [b, 2048(d), 16(h), 32(w)] ->(4*{conv3*3 step2*1} : d/8 h/16)-> [b 256(d) 1(h) 32(w)]
192
+ # 0 ->(unsampledW256} : w=256)-> [b 32(d) 8(h) 256(w)] ->(reshapeH1} : h=1)-> [b 256(d) 1(h) 256(w)]
193
+ # 1 ->(unsampledW256} : w=256)-> [b 64(d) 4(h) 256(w)] ->(reshapeH1} : h=1)-> [b 256(d) 1(h) 256(w)]
194
+ # 2 ->(unsampledW256} : w=256)-> [b 128(d) 2(h) 256(w)] ->(reshapeH1} : h=1)-> [b 256(d) 1(h) 256(w)]
195
+ # 3 ->(unsampledW256} : w=256)-> [b 256(d) 1(h) 256(w)] ->(reshapeH1} : h=1)-> [b 256(d) 1(h) 256(w)]
196
+ # 0 --\
197
+ # 1 -- \
198
+ # ---- cat [b 1024(d) 1(h) 256(w)]
199
+ # 2 -- /
200
+ # 3 --/
201
+ return feature # [b 1024(d) 256(w)]
202
+
203
+
204
+ class HorizonNetFeatureExtractor(nn.Module):
205
+ x_mean = torch.FloatTensor(np.array([0.485, 0.456, 0.406])[None, :, None, None])
206
+ x_std = torch.FloatTensor(np.array([0.229, 0.224, 0.225])[None, :, None, None])
207
+
208
+ def __init__(self, backbone='resnet50'):
209
+ super(HorizonNetFeatureExtractor, self).__init__()
210
+ self.out_scale = 8
211
+ self.step_cols = 4
212
+
213
+ # Encoder
214
+ if backbone.startswith('res'):
215
+ self.feature_extractor = Resnet(backbone, pretrained=True)
216
+ elif backbone.startswith('dense'):
217
+ self.feature_extractor = Densenet(backbone, pretrained=True)
218
+ else:
219
+ raise NotImplementedError()
220
+
221
+ # Inference channels number from each block of the encoder
222
+ with torch.no_grad():
223
+ dummy = torch.zeros(1, 3, 512, 1024)
224
+ c1, c2, c3, c4 = [b.shape[1] for b in self.feature_extractor(dummy)]
225
+ self.c_last = (c1 * 8 + c2 * 4 + c3 * 2 + c4 * 1) // self.out_scale
226
+
227
+ # Convert features from 4 blocks of the encoder into B x C x 1 x W'
228
+ self.reduce_height_module = GlobalHeightStage(c1, c2, c3, c4, self.out_scale)
229
+ self.x_mean.requires_grad = False
230
+ self.x_std.requires_grad = False
231
+ wrap_lr_pad(self)
232
+
233
+ def _prepare_x(self, x):
234
+ x = x.clone()
235
+ if self.x_mean.device != x.device:
236
+ self.x_mean = self.x_mean.to(x.device)
237
+ self.x_std = self.x_std.to(x.device)
238
+ x[:, :3] = (x[:, :3] - self.x_mean) / self.x_std
239
+
240
+ return x
241
+
242
+ def forward(self, x):
243
+ # x [b 3 512 1024]
244
+ x = self._prepare_x(x) # [b 3 512 1024]
245
+ conv_list = self.feature_extractor(x)
246
+ # conv_list:
247
+ # 0 [b, 256(d), 128(h), 256(w)]
248
+ # 1 [b, 512(d), 64(h), 128(w)]
249
+ # 2 [b, 1024(d), 32(h), 64(w)]
250
+ # 3 [b, 2048(d), 16(h), 32(w)]
251
+ x = self.reduce_height_module(conv_list, x.shape[3] // self.step_cols) # [b 1024(d) 1(h) 256(w)]
252
+ # After reduce_Height_module, h becomes 1, the information is compressed to d,
253
+ # and w contains different resolutions
254
+ # 0 [b, 256(d), 128(h), 256(w)] -> [b, 256/8(d) * 128/16(h') = 256(d), 1(h) 256(w)]
255
+ # 1 [b, 512(d), 64(h), 128(w)] -> [b, 512/8(d) * 64/16(h') = 256(d), 1(h) 256(w)]
256
+ # 2 [b, 1024(d), 32(h), 64(w)] -> [b, 1024/8(d) * 32/16(h') = 256(d), 1(h) 256(w)]
257
+ # 3 [b, 2048(d), 16(h), 32(w)] -> [b, 2048/8(d) * 16/16(h') = 256(d), 1(h) 256(w)]
258
+ return x # [b 1024(d) 1(h) 256(w)]
259
+
260
+
261
+ if __name__ == '__main__':
262
+ from PIL import Image
263
+ extractor = HorizonNetFeatureExtractor()
264
+ img = np.array(Image.open("../../src/demo.png")).transpose((2, 0, 1))
265
+ input = torch.Tensor([img]) # 1 3 512 1024
266
+ feature = extractor(input)
267
+ print(feature.shape) # 1, 1024, 256 | 1024 = (out_c_0*h_0 +... + out_c_3*h_3) = 256 * 4
models/modules/patch_feature_extractor.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops.layers.torch import Rearrange
5
+
6
+
7
+ class PatchFeatureExtractor(nn.Module):
8
+ x_mean = torch.FloatTensor(np.array([0.485, 0.456, 0.406])[None, :, None, None])
9
+ x_std = torch.FloatTensor(np.array([0.229, 0.224, 0.225])[None, :, None, None])
10
+
11
+ def __init__(self, patch_num=256, input_shape=None):
12
+ super(PatchFeatureExtractor, self).__init__()
13
+
14
+ if input_shape is None:
15
+ input_shape = [3, 512, 1024]
16
+ self.patch_dim = 1024
17
+ self.patch_num = patch_num
18
+
19
+ img_channel = input_shape[0]
20
+ img_h = input_shape[1]
21
+ img_w = input_shape[2]
22
+
23
+ p_h, p_w = img_h, img_w // self.patch_num
24
+ p_dim = p_h * p_w * img_channel
25
+
26
+ self.patch_embedding = nn.Sequential(
27
+ Rearrange('b c h (p_n p_w) -> b p_n (h p_w c)', p_w=p_w),
28
+ nn.Linear(p_dim, self.patch_dim)
29
+ )
30
+
31
+ self.x_mean.requires_grad = False
32
+ self.x_std.requires_grad = False
33
+
34
+ def _prepare_x(self, x):
35
+ x = x.clone()
36
+ if self.x_mean.device != x.device:
37
+ self.x_mean = self.x_mean.to(x.device)
38
+ self.x_std = self.x_std.to(x.device)
39
+ x[:, :3] = (x[:, :3] - self.x_mean) / self.x_std
40
+
41
+ return x
42
+
43
+ def forward(self, x):
44
+ # x [b 3 512 1024]
45
+ x = self._prepare_x(x) # [b 3 512 1024]
46
+ x = self.patch_embedding(x) # [b 256(patch_num) 1024(d)]
47
+ x = x.permute(0, 2, 1) # [b 1024(d) 256(patch_num)]
48
+ return x
49
+
50
+
51
+ if __name__ == '__main__':
52
+ from PIL import Image
53
+ extractor = PatchFeatureExtractor()
54
+ img = np.array(Image.open("../../src/demo.png")).transpose((2, 0, 1))
55
+ input = torch.Tensor([img]) # 1 3 512 1024
56
+ feature = extractor(input)
57
+ print(feature.shape) # 1, 1024, 256
models/modules/swg_transformer.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.modules.transformer_modules import *
2
+
3
+
4
+ class SWG_Transformer(nn.Module):
5
+ def __init__(self, dim, depth, heads, win_size, dim_head, mlp_dim,
6
+ dropout=0., patch_num=None, ape=None, rpe=None, rpe_pos=1):
7
+ super().__init__()
8
+ self.absolute_pos_embed = None if patch_num is None or ape is None else AbsolutePosition(dim, dropout,
9
+ patch_num, ape)
10
+ self.pos_dropout = nn.Dropout(dropout)
11
+ self.layers = nn.ModuleList([])
12
+ for i in range(depth):
13
+ if i % 2 == 0:
14
+ attention = WinAttention(dim, win_size=win_size, shift=0 if (i % 3 == 0) else win_size // 2,
15
+ heads=heads, dim_head=dim_head, dropout=dropout, rpe=rpe, rpe_pos=rpe_pos)
16
+ else:
17
+ attention = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout,
18
+ patch_num=patch_num, rpe=rpe, rpe_pos=rpe_pos)
19
+
20
+ self.layers.append(nn.ModuleList([
21
+ PreNorm(dim, attention),
22
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)),
23
+ ]))
24
+
25
+ def forward(self, x):
26
+ if self.absolute_pos_embed is not None:
27
+ x = self.absolute_pos_embed(x)
28
+ x = self.pos_dropout(x)
29
+ for attn, ff in self.layers:
30
+ x = attn(x) + x
31
+ x = ff(x) + x
32
+ return x
33
+
34
+
35
+ if __name__ == '__main__':
36
+ token_dim = 1024
37
+ toke_len = 256
38
+
39
+ transformer = SWG_Transformer(dim=token_dim,
40
+ depth=6,
41
+ heads=16,
42
+ win_size=8,
43
+ dim_head=64,
44
+ mlp_dim=2048,
45
+ dropout=0.1)
46
+
47
+ input = torch.randn(1, toke_len, token_dim)
48
+ output = transformer(input)
49
+ print(output.shape)
models/modules/swin_transformer.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.modules.transformer_modules import *
2
+
3
+
4
+ class Swin_Transformer(nn.Module):
5
+ def __init__(self, dim, depth, heads, win_size, dim_head, mlp_dim,
6
+ dropout=0., patch_num=None, ape=None, rpe=None, rpe_pos=1):
7
+ super().__init__()
8
+ self.absolute_pos_embed = None if patch_num is None or ape is None else AbsolutePosition(dim, dropout,
9
+ patch_num, ape)
10
+ self.pos_dropout = nn.Dropout(dropout)
11
+ self.layers = nn.ModuleList([])
12
+ for i in range(depth):
13
+ self.layers.append(nn.ModuleList([
14
+ PreNorm(dim, WinAttention(dim, win_size=win_size, shift=0 if (i % 2 == 0) else win_size // 2,
15
+ heads=heads, dim_head=dim_head, dropout=dropout, rpe=rpe, rpe_pos=rpe_pos)),
16
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)),
17
+ ]))
18
+
19
+ def forward(self, x):
20
+ if self.absolute_pos_embed is not None:
21
+ x = self.absolute_pos_embed(x)
22
+ x = self.pos_dropout(x)
23
+ for attn, ff in self.layers:
24
+ x = attn(x) + x
25
+ x = ff(x) + x
26
+ return x
27
+
28
+
29
+ if __name__ == '__main__':
30
+ token_dim = 1024
31
+ toke_len = 256
32
+
33
+ transformer = Swin_Transformer(dim=token_dim,
34
+ depth=6,
35
+ heads=16,
36
+ win_size=8,
37
+ dim_head=64,
38
+ mlp_dim=2048,
39
+ dropout=0.1)
40
+
41
+ input = torch.randn(1, toke_len, token_dim)
42
+ output = transformer(input)
43
+ print(output.shape)
models/modules/transformer.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.modules.transformer_modules import *
2
+
3
+
4
+ class Transformer(nn.Module):
5
+ def __init__(self, dim, depth, heads, win_size, dim_head, mlp_dim,
6
+ dropout=0., patch_num=None, ape=None, rpe=None, rpe_pos=1):
7
+ super().__init__()
8
+
9
+ self.absolute_pos_embed = None if patch_num is None or ape is None else AbsolutePosition(dim, dropout,
10
+ patch_num, ape)
11
+ self.pos_dropout = nn.Dropout(dropout)
12
+ self.layers = nn.ModuleList([])
13
+ for _ in range(depth):
14
+ self.layers.append(nn.ModuleList([
15
+ PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout, patch_num=patch_num,
16
+ rpe=rpe, rpe_pos=rpe_pos)),
17
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
18
+ ]))
19
+
20
+ def forward(self, x):
21
+ if self.absolute_pos_embed is not None:
22
+ x = self.absolute_pos_embed(x)
23
+ x = self.pos_dropout(x)
24
+ for attn, ff in self.layers:
25
+ x = attn(x) + x
26
+ x = ff(x) + x
27
+ return x
28
+
29
+
30
+ if __name__ == '__main__':
31
+ token_dim = 1024
32
+ toke_len = 256
33
+
34
+ transformer = Transformer(dim=token_dim, depth=6, heads=16,
35
+ dim_head=64, mlp_dim=2048, dropout=0.1,
36
+ patch_num=256, ape='lr_parameter', rpe='lr_parameter_mirror')
37
+
38
+ total = sum(p.numel() for p in transformer.parameters())
39
+ trainable = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
40
+ print('parameter total:{:,}, trainable:{:,}'.format(total, trainable))
41
+
42
+ input = torch.randn(1, toke_len, token_dim)
43
+ output = transformer(input)
44
+ print(output.shape)
models/modules/transformer_modules.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/09/01
3
+ @description:
4
+ """
5
+ import warnings
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from torch import nn, einsum
11
+ from einops import rearrange
12
+
13
+
14
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
15
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
16
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
17
+ def norm_cdf(x):
18
+ # Computes standard normal cumulative distribution function
19
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
20
+
21
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
22
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
23
+ "The distribution of values may be incorrect.",
24
+ stacklevel=2)
25
+
26
+ with torch.no_grad():
27
+ # Values are generated by using a truncated uniform distribution and
28
+ # then using the inverse CDF for the normal distribution.
29
+ # Get upper and lower cdf values
30
+ l = norm_cdf((a - mean) / std)
31
+ u = norm_cdf((b - mean) / std)
32
+
33
+ # Uniformly fill tensor with values from [l, u], then translate to
34
+ # [2l-1, 2u-1].
35
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
36
+
37
+ # Use inverse cdf transform for normal distribution to get truncated
38
+ # standard normal
39
+ tensor.erfinv_()
40
+
41
+ # Transform to proper mean, std
42
+ tensor.mul_(std * math.sqrt(2.))
43
+ tensor.add_(mean)
44
+
45
+ # Clamp to ensure it's in the proper range
46
+ tensor.clamp_(min=a, max=b)
47
+ return tensor
48
+
49
+
50
+ class PreNorm(nn.Module):
51
+ def __init__(self, dim, fn):
52
+ super().__init__()
53
+ self.norm = nn.LayerNorm(dim)
54
+ self.fn = fn
55
+
56
+ def forward(self, x, **kwargs):
57
+ return self.fn(self.norm(x), **kwargs)
58
+
59
+
60
+ # compatibility pytorch < 1.4
61
+ class GELU(nn.Module):
62
+ def forward(self, input):
63
+ return F.gelu(input)
64
+
65
+
66
+ class Attend(nn.Module):
67
+
68
+ def __init__(self, dim=None):
69
+ super().__init__()
70
+ self.dim = dim
71
+
72
+ def forward(self, input):
73
+ return F.softmax(input, dim=self.dim, dtype=input.dtype)
74
+
75
+
76
+ class FeedForward(nn.Module):
77
+ def __init__(self, dim, hidden_dim, dropout=0.):
78
+ super().__init__()
79
+ self.net = nn.Sequential(
80
+ nn.Linear(dim, hidden_dim),
81
+ GELU(),
82
+ nn.Dropout(dropout),
83
+ nn.Linear(hidden_dim, dim),
84
+ nn.Dropout(dropout)
85
+ )
86
+
87
+ def forward(self, x):
88
+ return self.net(x)
89
+
90
+
91
+ class RelativePosition(nn.Module):
92
+ def __init__(self, heads, patch_num=None, rpe=None):
93
+ super().__init__()
94
+ self.rpe = rpe
95
+ self.heads = heads
96
+ self.patch_num = patch_num
97
+
98
+ if rpe == 'lr_parameter':
99
+ # -255 ~ 0 ~ 255 all count : patch * 2 - 1
100
+ count = patch_num * 2 - 1
101
+ self.rpe_table = nn.Parameter(torch.Tensor(count, heads))
102
+ nn.init.xavier_uniform_(self.rpe_table)
103
+ elif rpe == 'lr_parameter_mirror':
104
+ # 0 ~ 127 128 ~ 1 all count : patch_num // 2 + 1
105
+ count = patch_num // 2 + 1
106
+ self.rpe_table = nn.Parameter(torch.Tensor(count, heads))
107
+ nn.init.xavier_uniform_(self.rpe_table)
108
+ elif rpe == 'lr_parameter_half':
109
+ # -127 ~ 0 ~ 128 all count : patch
110
+ count = patch_num
111
+ self.rpe_table = nn.Parameter(torch.Tensor(count, heads))
112
+ nn.init.xavier_uniform_(self.rpe_table)
113
+ elif rpe == 'fix_angle':
114
+ # 0 ~ 127 128 ~ 1 all count : patch_num // 2 + 1
115
+ count = patch_num // 2 + 1
116
+ # we think that closer proximity should have stronger relationships
117
+ rpe_table = (torch.arange(count, 0, -1) / count)[..., None].repeat(1, heads)
118
+ self.register_buffer('rpe_table', rpe_table)
119
+
120
+ def get_relative_pos_embed(self):
121
+ range_vec = torch.arange(self.patch_num)
122
+ distance_mat = range_vec[None, :] - range_vec[:, None]
123
+ if self.rpe == 'lr_parameter':
124
+ # -255 ~ 0 ~ 255 -> 0 ~ 255 ~ 255 + 255
125
+ distance_mat += self.patch_num - 1 # remove negative
126
+ return self.rpe_table[distance_mat].permute(2, 0, 1)[None]
127
+ elif self.rpe == 'lr_parameter_mirror' or self.rpe == 'fix_angle':
128
+ distance_mat[distance_mat < 0] = -distance_mat[distance_mat < 0] # mirror
129
+ distance_mat[distance_mat > self.patch_num // 2] = self.patch_num - distance_mat[
130
+ distance_mat > self.patch_num // 2] # remove repeat
131
+ return self.rpe_table[distance_mat].permute(2, 0, 1)[None]
132
+ elif self.rpe == 'lr_parameter_half':
133
+ distance_mat[distance_mat > self.patch_num // 2] = distance_mat[
134
+ distance_mat > self.patch_num // 2] - self.patch_num # remove repeat > 128 exp: 129 -> -127
135
+ distance_mat[distance_mat < -self.patch_num // 2 + 1] = distance_mat[
136
+ distance_mat < -self.patch_num // 2 + 1] + self.patch_num # remove repeat < -127 exp: -128 -> 128
137
+ # -127 ~ 0 ~ 128 -> 0 ~ 0 ~ 127 + 127 + 128
138
+ distance_mat += self.patch_num//2 - 1 # remove negative
139
+ return self.rpe_table[distance_mat].permute(2, 0, 1)[None]
140
+
141
+ def forward(self, attn):
142
+ return attn + self.get_relative_pos_embed()
143
+
144
+
145
+ class Attention(nn.Module):
146
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0., patch_num=None, rpe=None, rpe_pos=1):
147
+ """
148
+ :param dim:
149
+ :param heads:
150
+ :param dim_head:
151
+ :param dropout:
152
+ :param patch_num:
153
+ :param rpe: relative position embedding
154
+ """
155
+ super().__init__()
156
+
157
+ self.relative_pos_embed = None if patch_num is None or rpe is None else RelativePosition(heads, patch_num, rpe)
158
+ inner_dim = dim_head * heads
159
+ project_out = not (heads == 1 and dim_head == dim)
160
+
161
+ self.heads = heads
162
+ self.scale = dim_head ** -0.5
163
+ self.rpe_pos = rpe_pos
164
+
165
+ self.attend = Attend(dim=-1)
166
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
167
+
168
+ self.to_out = nn.Sequential(
169
+ nn.Linear(inner_dim, dim),
170
+ nn.Dropout(dropout)
171
+ ) if project_out else nn.Identity()
172
+
173
+ def forward(self, x):
174
+ b, n, _, h = *x.shape, self.heads
175
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
176
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
177
+
178
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
179
+
180
+ if self.rpe_pos == 0:
181
+ if self.relative_pos_embed is not None:
182
+ dots = self.relative_pos_embed(dots)
183
+
184
+ attn = self.attend(dots)
185
+
186
+ if self.rpe_pos == 1:
187
+ if self.relative_pos_embed is not None:
188
+ attn = self.relative_pos_embed(attn)
189
+
190
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
191
+ out = rearrange(out, 'b h n d -> b n (h d)')
192
+ return self.to_out(out)
193
+
194
+
195
+ class AbsolutePosition(nn.Module):
196
+ def __init__(self, dim, dropout=0., patch_num=None, ape=None):
197
+ super().__init__()
198
+ self.ape = ape
199
+
200
+ if ape == 'lr_parameter':
201
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, patch_num, dim))
202
+ trunc_normal_(self.absolute_pos_embed, std=.02)
203
+
204
+ elif ape == 'fix_angle':
205
+ angle = torch.arange(0, patch_num, dtype=torch.float) / patch_num * (math.pi * 2)
206
+ self.absolute_pos_embed = torch.sin(angle)[..., None].repeat(1, dim)[None]
207
+
208
+ def forward(self, x):
209
+ return x + self.absolute_pos_embed
210
+
211
+
212
+ class WinAttention(nn.Module):
213
+ def __init__(self, dim, win_size=8, shift=0, heads=8, dim_head=64, dropout=0., rpe=None, rpe_pos=1):
214
+ super().__init__()
215
+
216
+ self.win_size = win_size
217
+ self.shift = shift
218
+ self.attend = Attention(dim, heads=heads, dim_head=dim_head,
219
+ dropout=dropout, patch_num=win_size, rpe=None if rpe is None else 'lr_parameter',
220
+ rpe_pos=rpe_pos)
221
+
222
+ def forward(self, x):
223
+ b = x.shape[0]
224
+ if self.shift != 0:
225
+ x = torch.roll(x, shifts=self.shift, dims=-2)
226
+ x = rearrange(x, 'b (m w) d -> (b m) w d', w=self.win_size) # split windows
227
+
228
+ out = self.attend(x)
229
+
230
+ out = rearrange(out, '(b m) w d -> b (m w) d ', b=b) # recover windows
231
+ if self.shift != 0:
232
+ out = torch.roll(out, shifts=-self.shift, dims=-2)
233
+
234
+ return out
235
+
236
+
237
+ class Conv(nn.Module):
238
+ def __init__(self, dim, dropout=0.):
239
+ super().__init__()
240
+ self.dim = dim
241
+ self.net = nn.Sequential(
242
+ nn.Conv1d(dim, dim, kernel_size=3, stride=1, padding=0),
243
+ nn.Dropout(dropout)
244
+ )
245
+
246
+ def forward(self, x):
247
+ x = x.transpose(1, 2)
248
+ x = torch.cat([x[..., -1:], x, x[..., :1]], dim=-1)
249
+ x = self.net(x)
250
+ return x.transpose(1, 2)
models/other/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ @Date: 2021/07/18
3
+ @description:
4
+ """
models/other/criterion.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @date: 2021/7/19
3
+ @description:
4
+ """
5
+ import torch
6
+ import loss
7
+
8
+ from utils.misc import tensor2np
9
+
10
+
11
+ def build_criterion(config, logger):
12
+ criterion = {}
13
+ device = config.TRAIN.DEVICE
14
+
15
+ for k in config.TRAIN.CRITERION.keys():
16
+ sc = config.TRAIN.CRITERION[k]
17
+ if sc.WEIGHT is None or float(sc.WEIGHT) == 0:
18
+ continue
19
+ criterion[sc.NAME] = {
20
+ 'loss': getattr(loss, sc.LOSS)(),
21
+ 'weight': float(sc.WEIGHT),
22
+ 'sub_weights': sc.WEIGHTS,
23
+ 'need_all': sc.NEED_ALL
24
+ }
25
+
26
+ criterion[sc.NAME]['loss'] = criterion[sc.NAME]['loss'].to(device)
27
+ if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device:
28
+ criterion[sc.NAME]['loss'] = criterion[sc.NAME]['loss'].type(torch.float16)
29
+
30
+ # logger.info(f"Build criterion:{sc.WEIGHT}_{sc.NAME}_{sc.LOSS}_{sc.WEIGHTS}")
31
+ return criterion
32
+
33
+
34
+ def calc_criterion(criterion, gt, dt, epoch_loss_d):
35
+ loss = None
36
+ postfix_d = {}
37
+ for k in criterion.keys():
38
+ if criterion[k]['need_all']:
39
+ single_loss = criterion[k]['loss'](gt, dt)
40
+ ws_loss = None
41
+ for i, sub_weight in enumerate(criterion[k]['sub_weights']):
42
+ if sub_weight == 0:
43
+ continue
44
+ if ws_loss is None:
45
+ ws_loss = single_loss[i] * sub_weight
46
+ else:
47
+ ws_loss = ws_loss + single_loss[i] * sub_weight
48
+ single_loss = ws_loss if ws_loss is not None else single_loss
49
+ else:
50
+ assert k in gt.keys(), "ground label is None:" + k
51
+ assert k in dt.keys(), "detection key is None:" + k
52
+ if k == 'ratio' and gt[k].shape[-1] != dt[k].shape[-1]:
53
+ gt[k] = gt[k].repeat(1, dt[k].shape[-1])
54
+ single_loss = criterion[k]['loss'](gt[k], dt[k])
55
+
56
+ postfix_d[k] = tensor2np(single_loss)
57
+ if k not in epoch_loss_d.keys():
58
+ epoch_loss_d[k] = []
59
+ epoch_loss_d[k].append(postfix_d[k])
60
+
61
+ single_loss = single_loss * criterion[k]['weight']
62
+ if loss is None:
63
+ loss = single_loss
64
+ else:
65
+ loss = loss + single_loss
66
+
67
+ k = 'loss'
68
+ postfix_d[k] = tensor2np(loss)
69
+ if k not in epoch_loss_d.keys():
70
+ epoch_loss_d[k] = []
71
+ epoch_loss_d[k].append(postfix_d[k])
72
+ return loss, postfix_d, epoch_loss_d
models/other/init_env.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/08/15
3
+ @description:
4
+ """
5
+ import random
6
+ import torch
7
+ import torch.backends.cudnn as cudnn
8
+ import numpy as np
9
+ import os
10
+ import cv2
11
+
12
+
13
+ def init_env(seed, deterministic=False, loader_work_num=0):
14
+ # Fix seed
15
+ # Python & NumPy
16
+ np.random.seed(seed)
17
+ random.seed(seed)
18
+ os.environ['PYTHONHASHSEED'] = str(seed)
19
+
20
+ # PyTorch
21
+ torch.manual_seed(seed) # 为CPU设置随机种子
22
+ if torch.cuda.is_available():
23
+ torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
24
+ torch.cuda.manual_seed_all(seed) # 为所有GPU设置随机种子
25
+
26
+ # cuDNN
27
+ if deterministic:
28
+ # 复现
29
+ torch.backends.cudnn.benchmark = False
30
+ torch.backends.cudnn.deterministic = True # 将这个 flag 置为 True 的话,每次返回的卷积算法将是确定的,即默认算法
31
+ else:
32
+ cudnn.benchmark = True # 如果网络的输入数据维度或类型上变化不大,设置true
33
+ torch.backends.cudnn.deterministic = False
34
+
35
+ # Using multiple threads in Opencv can cause deadlocks
36
+ if loader_work_num != 0:
37
+ cv2.setNumThreads(0)
models/other/optimizer.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/07/18
3
+ @description:
4
+ """
5
+ from torch import optim as optim
6
+
7
+
8
+ def build_optimizer(config, model, logger):
9
+ name = config.TRAIN.OPTIMIZER.NAME.lower()
10
+
11
+ optimizer = None
12
+ if name == 'sgd':
13
+ optimizer = optim.SGD(model.parameters(), momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
14
+ lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
15
+ elif name == 'adamw':
16
+ optimizer = optim.AdamW(model.parameters(), eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
17
+ lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
18
+ elif name == 'adam':
19
+ optimizer = optim.Adam(model.parameters(), eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
20
+ lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
21
+
22
+ logger.info(f"Build optimizer: {name}, lr:{config.TRAIN.BASE_LR}")
23
+
24
+ return optimizer
models/other/scheduler.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/09/14
3
+ @description:
4
+ """
5
+
6
+
7
+ class WarmupScheduler:
8
+ def __init__(self, optimizer, lr_pow, init_lr, warmup_lr, warmup_step, max_step, **kwargs):
9
+ self.lr_pow = lr_pow
10
+ self.init_lr = init_lr
11
+ self.running_lr = init_lr
12
+ self.warmup_lr = warmup_lr
13
+ self.warmup_step = warmup_step
14
+ self.max_step = max_step
15
+ self.optimizer = optimizer
16
+
17
+ def step_update(self, cur_step):
18
+ if cur_step < self.warmup_step:
19
+ frac = cur_step / self.warmup_step
20
+ step = self.warmup_lr - self.init_lr
21
+ self.running_lr = self.init_lr + step * frac
22
+ else:
23
+ frac = (float(cur_step) - self.warmup_step) / (self.max_step - self.warmup_step)
24
+ scale_running_lr = max((1. - frac), 0.) ** self.lr_pow
25
+ self.running_lr = self.warmup_lr * scale_running_lr
26
+
27
+ if self.optimizer is not None:
28
+ for param_group in self.optimizer.param_groups:
29
+ param_group['lr'] = self.running_lr
30
+
31
+
32
+ if __name__ == '__main__':
33
+ import matplotlib.pyplot as plt
34
+
35
+ scheduler = WarmupScheduler(optimizer=None,
36
+ lr_pow=4,
37
+ init_lr=0.0000003,
38
+ warmup_lr=0.00003,
39
+ warmup_step=10000,
40
+ max_step=100000)
41
+
42
+ x = []
43
+ y = []
44
+ for i in range(100000):
45
+ if i == 10000-1:
46
+ print()
47
+ scheduler.step_update(i)
48
+ x.append(i)
49
+ y.append(scheduler.running_lr)
50
+ plt.plot(x, y, linewidth=1)
51
+ plt.show()
postprocessing/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ @Date: 2021/10/06
3
+ @description:
4
+ """
postprocessing/dula/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ @Date: 2021/10/06
3
+ @description:
4
+ """
postprocessing/dula/layout.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Date: 2021/10/06
3
+ @description: Use the approach proposed by DuLa-Net
4
+ """
5
+ import cv2
6
+ import numpy as np
7
+ import math
8
+ import matplotlib.pyplot as plt
9
+
10
+ from visualization.floorplan import draw_floorplan
11
+
12
+
13
+ def merge_near(lst, diag):
14
+ group = [[0, ]]
15
+ for i in range(1, len(lst)):
16
+ if lst[i][1] == 0 and lst[i][0] - np.mean(group[-1]) < diag * 0.02:
17
+ group[-1].append(lst[i][0])
18
+ else:
19
+ group.append([lst[i][0], ])
20
+ if len(group) == 1:
21
+ group = [lst[0][0], lst[-1][0]]
22
+ else:
23
+ group = [int(np.mean(x)) for x in group]
24
+ return group
25
+
26
+
27
+ def fit_layout(floor_xz, need_cube=False, show=False, block_eps=0.2):
28
+ show_radius = np.linalg.norm(floor_xz, axis=-1).max()
29
+ side_l = 512
30
+ floorplan = draw_floorplan(xz=floor_xz, show_radius=show_radius, show=show, scale=1, side_l=side_l).astype(np.uint8)
31
+ center = np.array([side_l / 2, side_l / 2])
32
+ polys = cv2.findContours(floorplan, 1, 2)
33
+ if isinstance(polys, tuple):
34
+ if len(polys) == 3:
35
+ # opencv 3
36
+ polys = list(polys[1])
37
+ else:
38
+ polys = list(polys[0])
39
+ polys.sort(key=lambda x: cv2.contourArea(x), reverse=True)
40
+ poly = polys[0]
41
+ sub_x, sub_y, w, h = cv2.boundingRect(poly)
42
+ floorplan_sub = floorplan[sub_y:sub_y + h, sub_x:sub_x + w]
43
+ sub_center = center - np.array([sub_x, sub_y])
44
+ polys = cv2.findContours(floorplan_sub, 1, 2)
45
+ if isinstance(polys, tuple):
46
+ if len(polys) == 3:
47
+ polys = polys[1]
48
+ else:
49
+ polys = polys[0]
50
+ poly = polys[0]
51
+ epsilon = 0.005 * cv2.arcLength(poly, True)
52
+ poly = cv2.approxPolyDP(poly, epsilon, True)
53
+
54
+ x_lst = [[0, 0], ]
55
+ y_lst = [[0, 0], ]
56
+
57
+ ans = np.zeros((floorplan_sub.shape[0], floorplan_sub.shape[1]))
58
+
59
+ for i in range(len(poly)):
60
+ p1 = poly[i][0]
61
+ p2 = poly[(i + 1) % len(poly)][0]
62
+ # We added occlusion detection
63
+ cp1 = p1 - sub_center
64
+ cp2 = p2 - sub_center
65
+ p12 = p2 - p1
66
+ l1 = np.linalg.norm(cp1)
67
+ l2 = np.linalg.norm(cp2)
68
+ l3 = np.linalg.norm(p12)
69
+ # We added occlusion detection
70
+ is_block1 = abs(np.cross(cp1/l1, cp2/l2)) < block_eps
71
+ is_block2 = abs(np.cross(cp2/l2, p12/l3)) < block_eps*2
72
+ is_block = is_block1 and is_block2
73
+
74
+ if (p2[0] - p1[0]) == 0:
75
+ slope = 10
76
+ else:
77
+ slope = abs((p2[1] - p1[1]) / (p2[0] - p1[0]))
78
+
79
+ if is_block:
80
+ s = p1[1] if l1 < l2 else p2[1]
81
+ y_lst.append([s, 1])
82
+ s = p1[0] if l1 < l2 else p2[0]
83
+ x_lst.append([s, 1])
84
+
85
+ left = p1[0] if p1[0] < p2[0] else p2[0]
86
+ right = p1[0] if p1[0] > p2[0] else p2[0]
87
+ top = p1[1] if p1[1] < p2[1] else p2[1]
88
+ bottom = p1[1] if p1[1] > p2[1] else p2[1]
89
+ sample = floorplan_sub[top:bottom, left:right]
90
+ score = 0 if sample.size == 0 else sample.mean()
91
+ if score >= 0.3:
92
+ ans[top:bottom, left:right] = 1
93
+
94
+ else:
95
+ if slope <= 1:
96
+ s = int((p1[1] + p2[1]) / 2)
97
+ y_lst.append([s, 0])
98
+ elif slope > 1:
99
+ s = int((p1[0] + p2[0]) / 2)
100
+ x_lst.append([s, 0])
101
+
102
+ debug_show = False
103
+ if debug_show:
104
+ plt.figure(dpi=300)
105
+ plt.axis('off')
106
+ a = cv2.drawMarker(floorplan_sub.copy()*0.5, tuple([floorplan_sub.shape[1] // 2, floorplan_sub.shape[0] // 2]), [1], markerType=0, markerSize=10, thickness=2)
107
+ plt.imshow(cv2.drawContours(a, [poly], 0, 1, 1))
108
+ plt.savefig('src/1.png', bbox_inches='tight', transparent=True, pad_inches=0)
109
+ plt.show()
110
+
111
+ plt.figure(dpi=300)
112
+ plt.axis('off')
113
+ a = cv2.drawMarker(ans.copy()*0.5, tuple([floorplan_sub.shape[1] // 2, floorplan_sub.shape[0] // 2]), [1], markerType=0, markerSize=10, thickness=2)
114
+ plt.imshow(cv2.drawContours(a, [poly], 0, 1, 1))
115
+ # plt.show()
116
+ plt.savefig('src/2.png', bbox_inches='tight', transparent=True, pad_inches=0)
117
+ plt.show()
118
+
119
+ x_lst.append([floorplan_sub.shape[1], 0])
120
+ y_lst.append([floorplan_sub.shape[0], 0])
121
+ x_lst.sort(key=lambda x: x[0])
122
+ y_lst.sort(key=lambda x: x[0])
123
+
124
+ diag = math.sqrt(math.pow(floorplan_sub.shape[1], 2) + math.pow(floorplan_sub.shape[0], 2))
125
+ x_lst = merge_near(x_lst, diag)
126
+ y_lst = merge_near(y_lst, diag)
127
+ if need_cube and len(x_lst) > 2:
128
+ x_lst = [x_lst[0], x_lst[-1]]
129
+ if need_cube and len(y_lst) > 2:
130
+ y_lst = [y_lst[0], y_lst[-1]]
131
+
132
+ for i in range(len(x_lst) - 1):
133
+ for j in range(len(y_lst) - 1):
134
+ sample = floorplan_sub[y_lst[j]:y_lst[j + 1], x_lst[i]:x_lst[i + 1]]
135
+ score = 0 if sample.size == 0 else sample.mean()
136
+ if score >= 0.3:
137
+ ans[y_lst[j]:y_lst[j + 1], x_lst[i]:x_lst[i + 1]] = 1
138
+
139
+ if debug_show:
140
+ plt.figure(dpi=300)
141
+ plt.axis('off')
142
+ a = cv2.drawMarker(ans.copy() * 0.5, tuple([floorplan_sub.shape[1] // 2, floorplan_sub.shape[0] // 2]), [1],
143
+ markerType=0, markerSize=10, thickness=2)
144
+ plt.imshow(cv2.drawContours(a, [poly], 0, 1, 1))
145
+ # plt.show()
146
+ plt.savefig('src/3.png', bbox_inches='tight', transparent=True, pad_inches=0)
147
+ plt.show()
148
+
149
+ pred = np.uint8(ans)
150
+ pred_polys = cv2.findContours(pred, 1, 3)
151
+ if isinstance(pred_polys, tuple):
152
+ if len(pred_polys) == 3:
153
+ pred_polys = pred_polys[1]
154
+ else:
155
+ pred_polys = pred_polys[0]
156
+
157
+ pred_polys.sort(key=lambda x: cv2.contourArea(x), reverse=True)
158
+ pred_polys = pred_polys[0]
159
+
160
+ if debug_show:
161
+ plt.figure(dpi=300)
162
+ plt.axis('off')
163
+ a = cv2.drawMarker(ans.copy() * 0.5, tuple([floorplan_sub.shape[1] // 2, floorplan_sub.shape[0] // 2]), [1],
164
+ markerType=0, markerSize=10, thickness=2)
165
+ a = cv2.drawContours(a, [poly], 0, 0.8, 1)
166
+ a = cv2.drawContours(a, [pred_polys], 0, 1, 1)
167
+ plt.imshow(a)
168
+ # plt.show()
169
+ plt.savefig('src/4.png', bbox_inches='tight', transparent=True, pad_inches=0)
170
+ plt.show()
171
+
172
+ polygon = [(p[0][1], p[0][0]) for p in pred_polys[::-1]]
173
+
174
+ v = np.array([p[0] + sub_y for p in polygon])
175
+ u = np.array([p[1] + sub_x for p in polygon])
176
+ # side_l
177
+ # v<-----------|o
178
+ # | | |
179
+ # | ----|----z | side_l
180
+ # | | |
181
+ # | x \|/
182
+ # |------------u
183
+ side_l = floorplan.shape[0]
184
+ pred_xz = np.concatenate((u[:, np.newaxis] - side_l // 2, side_l // 2 - v[:, np.newaxis]), axis=1)
185
+
186
+ pred_xz = pred_xz * show_radius / (side_l // 2)
187
+ if show:
188
+ draw_floorplan(pred_xz, show_radius=show_radius, show=show)
189
+
190
+ show_process = False
191
+ if show_process:
192
+ img = np.zeros((floorplan_sub.shape[0], floorplan_sub.shape[1], 3))
193
+ for x in x_lst:
194
+ cv2.line(img, (x, 0), (x, floorplan_sub.shape[0]), (0, 255, 0), 1)
195
+ for y in y_lst:
196
+ cv2.line(img, (0, y), (floorplan_sub.shape[1], y), (255, 0, 0), 1)
197
+
198
+ fig = plt.figure()
199
+ plt.axis('off')
200
+ ax1 = fig.add_subplot(2, 2, 1)
201
+ ax1.imshow(floorplan)
202
+ ax3 = fig.add_subplot(2, 2, 2)
203
+ ax3.imshow(floorplan_sub)
204
+ ax4 = fig.add_subplot(2, 2, 3)
205
+ ax4.imshow(img)
206
+ ax5 = fig.add_subplot(2, 2, 4)
207
+ ax5.imshow(ans)
208
+ plt.show()
209
+
210
+ return pred_xz
211
+
212
+
213
+ if __name__ == '__main__':
214
+ from utils.conversion import uv2xyz
215
+
216
+ pano_img = np.zeros([512, 1024, 3])
217
+ corners = np.array([[0.1, 0.7],
218
+ [0.4, 0.7],
219
+ [0.3, 0.6],
220
+ [0.6, 0.6],
221
+ [0.8, 0.7]])
222
+ xz = uv2xyz(corners)[..., ::2]
223
+ draw_floorplan(xz, show=True, marker_color=None, center_color=0.8)
224
+
225
+ xz = fit_layout(xz)
226
+ draw_floorplan(xz, show=True, marker_color=None, center_color=0.8)