zhigangjiang
commited on
Commit
•
88b0dcb
1
Parent(s):
46e6683
no message
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +4 -0
- LICENSE +21 -0
- Post-Porcessing.md +35 -0
- app.py +139 -0
- config/__init__.py +4 -0
- config/defaults.py +289 -0
- convert_ckpt.py +61 -0
- dataset/__init__.py +0 -0
- dataset/build.py +115 -0
- dataset/communal/__init__.py +4 -0
- dataset/communal/base_dataset.py +127 -0
- dataset/communal/data_augmentation.py +279 -0
- dataset/communal/read.py +214 -0
- dataset/mp3d_dataset.py +110 -0
- dataset/pano_s2d3d_dataset.py +107 -0
- dataset/pano_s2d3d_mix_dataset.py +91 -0
- dataset/zind_dataset.py +138 -0
- evaluation/__init__.py +4 -0
- evaluation/accuracy.py +249 -0
- evaluation/analyse_layout_type.py +83 -0
- evaluation/eval_visible_iou.py +56 -0
- evaluation/f1_score.py +78 -0
- evaluation/iou.py +148 -0
- inference.py +261 -0
- loss/__init__.py +10 -0
- loss/boundary_loss.py +51 -0
- loss/grad_loss.py +57 -0
- loss/led_loss.py +47 -0
- loss/object_loss.py +42 -0
- main.py +401 -0
- models/__init__.py +1 -0
- models/base_model.py +150 -0
- models/build.py +81 -0
- models/lgt_net.py +213 -0
- models/modules/__init__.py +8 -0
- models/modules/conv_transformer.py +128 -0
- models/modules/horizon_net_feature_extractor.py +267 -0
- models/modules/patch_feature_extractor.py +57 -0
- models/modules/swg_transformer.py +49 -0
- models/modules/swin_transformer.py +43 -0
- models/modules/transformer.py +44 -0
- models/modules/transformer_modules.py +250 -0
- models/other/__init__.py +4 -0
- models/other/criterion.py +72 -0
- models/other/init_env.py +37 -0
- models/other/optimizer.py +24 -0
- models/other/scheduler.py +51 -0
- postprocessing/__init__.py +4 -0
- postprocessing/dula/__init__.py +4 -0
- postprocessing/dula/layout.py +226 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoints
|
2 |
+
src/output
|
3 |
+
visualization/visualizer
|
4 |
+
flagged
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 ZhiGang Jiang
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Post-Porcessing.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Post-Processing
|
2 |
+
## Step
|
3 |
+
|
4 |
+
1. Simplify polygon by [DP algorithm](https://en.wikipedia.org/wiki/Ramer%E2%80%93Douglas%E2%80%93Peucker_algorithm)
|
5 |
+
|
6 |
+
![img.png](src/fig/post_processing/img_0.png)
|
7 |
+
|
8 |
+
2. Detect occlusion, calculating box fill with 1
|
9 |
+
|
10 |
+
![img.png](src/fig/post_processing/img_1.png)
|
11 |
+
|
12 |
+
3. Fill in reasonable sampling section
|
13 |
+
|
14 |
+
![img.png](src/fig/post_processing/img_2.png)
|
15 |
+
|
16 |
+
4. Output processed polygon
|
17 |
+
|
18 |
+
![img.png](src/fig/post_processing/img_3.png)
|
19 |
+
|
20 |
+
## performance
|
21 |
+
It works, and a performance comparison on the MatterportLayout dataset:
|
22 |
+
|
23 |
+
| Method | 2D IoU(%) | 3D IoU(%) | RMSE | $\mathbf{\delta_{1}}$ |
|
24 |
+
|--|--|--|--|--|
|
25 |
+
without post-proc | 83.52 | 81.11 | 0.204 | 0.951 |
|
26 |
+
original post-proc |83.12 | 80.71 | 0.230 | 0.936|\
|
27 |
+
optimized post-proc | 83.48 | 81.08| 0.214 | 0.940 |
|
28 |
+
|
29 |
+
original:
|
30 |
+
|
31 |
+
![img.png](src/fig/post_processing/original.png)
|
32 |
+
|
33 |
+
optimized:
|
34 |
+
|
35 |
+
![img.png](src/fig/post_processing/optimized.png)
|
app.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@author: Zhigang Jiang
|
3 |
+
@time: 2022/05/23
|
4 |
+
@description:
|
5 |
+
'''
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
from utils.logger import get_logger
|
15 |
+
from config.defaults import get_config
|
16 |
+
from inference import preprocess, run_one_inference
|
17 |
+
from models.build import build_model
|
18 |
+
from argparse import Namespace
|
19 |
+
import gdown
|
20 |
+
|
21 |
+
|
22 |
+
def down_ckpt(model_cfg, ckpt_dir):
|
23 |
+
model_ids = [
|
24 |
+
['src/config/mp3d.yaml', '1o97oAmd-yEP5bQrM0eAWFPLq27FjUDbh'],
|
25 |
+
['src/config/zind.yaml', '1PzBj-dfDfH_vevgSkRe5kczW0GVl_43I'],
|
26 |
+
['src/config/pano.yaml', '1JoeqcPbm_XBPOi6O9GjjWi3_rtyPZS8m'],
|
27 |
+
['src/config/s2d3d.yaml', '1PfJzcxzUsbwwMal7yTkBClIFgn8IdEzI'],
|
28 |
+
['src/config/ablation_study/full.yaml', '1U16TxUkvZlRwJNaJnq9nAUap-BhCVIha']
|
29 |
+
]
|
30 |
+
|
31 |
+
for model_id in model_ids:
|
32 |
+
if model_id[0] != model_cfg:
|
33 |
+
continue
|
34 |
+
path = os.path.join(ckpt_dir, 'best.pkl')
|
35 |
+
if not os.path.exists(path):
|
36 |
+
logger.info(f"Downloading {model_id}")
|
37 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
38 |
+
gdown.download(f"https://drive.google.com/uc?id={model_id[1]}", path, False)
|
39 |
+
|
40 |
+
|
41 |
+
def greet(img_path, pre_processing, weight_name, post_processing, visualization, mesh_format, mesh_resolution):
|
42 |
+
args.pre_processing = pre_processing
|
43 |
+
args.post_processing = post_processing
|
44 |
+
if weight_name == 'mp3d':
|
45 |
+
model = mp3d_model
|
46 |
+
elif weight_name == 'zind':
|
47 |
+
model = zind_model
|
48 |
+
else:
|
49 |
+
logger.error("unknown pre-trained weight name")
|
50 |
+
raise NotImplementedError
|
51 |
+
|
52 |
+
img_name = os.path.basename(img_path).split('.')[0]
|
53 |
+
img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3]
|
54 |
+
|
55 |
+
vp_cache_path = 'src/demo/default_vp.txt'
|
56 |
+
if args.pre_processing:
|
57 |
+
vp_cache_path = os.path.join('src/output', f'{img_name}_vp.txt')
|
58 |
+
logger.info("pre-processing ...")
|
59 |
+
img, vp = preprocess(img, vp_cache_path=vp_cache_path)
|
60 |
+
|
61 |
+
img = (img / 255.0).astype(np.float32)
|
62 |
+
run_one_inference(img, model, args, img_name,
|
63 |
+
logger=logger, show=False,
|
64 |
+
show_depth='depth-normal-gradient' in visualization,
|
65 |
+
show_floorplan='2d-floorplan' in visualization,
|
66 |
+
mesh_format=mesh_format, mesh_resolution=int(mesh_resolution))
|
67 |
+
|
68 |
+
return [os.path.join(args.output_dir, f"{img_name}_pred.png"),
|
69 |
+
os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"),
|
70 |
+
os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"),
|
71 |
+
vp_cache_path,
|
72 |
+
os.path.join(args.output_dir, f"{img_name}_pred.json")]
|
73 |
+
|
74 |
+
|
75 |
+
def get_model(args):
|
76 |
+
config = get_config(args)
|
77 |
+
down_ckpt(args.cfg, config.CKPT.DIR)
|
78 |
+
if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available():
|
79 |
+
logger.info(f'The {args.device} is not available, will use cpu ...')
|
80 |
+
config.defrost()
|
81 |
+
args.device = "cpu"
|
82 |
+
config.TRAIN.DEVICE = "cpu"
|
83 |
+
config.freeze()
|
84 |
+
model, _, _, _ = build_model(config, logger)
|
85 |
+
return model
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == '__main__':
|
89 |
+
logger = get_logger()
|
90 |
+
args = Namespace(device='cuda', output_dir='src/output', visualize_3d=False, output_3d=True)
|
91 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
92 |
+
|
93 |
+
args.cfg = 'src/config/mp3d.yaml'
|
94 |
+
mp3d_model = get_model(args)
|
95 |
+
|
96 |
+
args.cfg = 'src/config/zind.yaml'
|
97 |
+
zind_model = get_model(args)
|
98 |
+
|
99 |
+
description = "This demo of the project " \
|
100 |
+
"<a href='https://github.com/zhigangjiang/LGT-Net' target='_blank'>LGT-Net</a>. " \
|
101 |
+
"It uses the Geometry-Aware Transformer Network to predict the 3d room layout of an rgb panorama."
|
102 |
+
|
103 |
+
demo = gr.Interface(fn=greet,
|
104 |
+
inputs=[gr.Image(type='filepath', label='input rgb panorama', value='src/demo/pano_demo1.png'),
|
105 |
+
gr.Checkbox(label='pre-processing', value=True),
|
106 |
+
gr.Radio(['mp3d', 'zind'],
|
107 |
+
label='pre-trained weight',
|
108 |
+
value='mp3d'),
|
109 |
+
gr.Radio(['manhattan', 'atalanta', 'original'],
|
110 |
+
label='post-processing method',
|
111 |
+
value='manhattan'),
|
112 |
+
gr.CheckboxGroup(['depth-normal-gradient', '2d-floorplan'],
|
113 |
+
label='2d-visualization',
|
114 |
+
value=['depth-normal-gradient', '2d-floorplan']),
|
115 |
+
gr.Radio(['.gltf', '.obj', '.glb'],
|
116 |
+
label='output format of 3d mesh',
|
117 |
+
value='.gltf'),
|
118 |
+
gr.Radio(['128', '256', '512', '1024'],
|
119 |
+
label='output resolution of 3d mesh',
|
120 |
+
value='256'),
|
121 |
+
],
|
122 |
+
outputs=[gr.Image(label='predicted result 2d-visualization', type='filepath'),
|
123 |
+
gr.Model3D(label='3d mesh reconstruction', clear_color=[1.0, 1.0, 1.0, 1.0]),
|
124 |
+
gr.File(label='3d mesh file'),
|
125 |
+
gr.File(label='vanishing point information'),
|
126 |
+
gr.File(label='layout json')],
|
127 |
+
examples=[
|
128 |
+
['src/demo/pano_demo1.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
|
129 |
+
['src/demo/mp3d_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
|
130 |
+
['src/demo/mp3d_demo2.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
|
131 |
+
['src/demo/mp3d_demo3.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
|
132 |
+
['src/demo/zind_demo1.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
|
133 |
+
['src/demo/zind_demo2.png', False, 'zind', 'atalanta', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
|
134 |
+
['src/demo/zind_demo3.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
|
135 |
+
['src/demo/other_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
|
136 |
+
['src/demo/other_demo2.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
|
137 |
+
], title='LGT-Net', allow_flagging="never", cache_examples=False, description=description)
|
138 |
+
|
139 |
+
demo.launch(debug=True, enable_queue=False)
|
config/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/07/17
|
3 |
+
@description:
|
4 |
+
"""
|
config/defaults.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/07/17
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import logging
|
7 |
+
from yacs.config import CfgNode as CN
|
8 |
+
|
9 |
+
_C = CN()
|
10 |
+
_C.DEBUG = False
|
11 |
+
_C.MODE = 'train'
|
12 |
+
_C.VAL_NAME = 'val'
|
13 |
+
_C.TAG = 'default'
|
14 |
+
_C.COMMENT = 'add some comments to help you understand'
|
15 |
+
_C.SHOW_BAR = True
|
16 |
+
_C.SAVE_EVAL = False
|
17 |
+
_C.MODEL = CN()
|
18 |
+
_C.MODEL.NAME = 'model_name'
|
19 |
+
_C.MODEL.SAVE_BEST = True
|
20 |
+
_C.MODEL.SAVE_LAST = True
|
21 |
+
_C.MODEL.ARGS = []
|
22 |
+
_C.MODEL.FINE_TUNE = []
|
23 |
+
|
24 |
+
# -----------------------------------------------------------------------------
|
25 |
+
# Training settings
|
26 |
+
# -----------------------------------------------------------------------------
|
27 |
+
_C.TRAIN = CN()
|
28 |
+
_C.TRAIN.SCRATCH = False
|
29 |
+
_C.TRAIN.START_EPOCH = 0
|
30 |
+
_C.TRAIN.EPOCHS = 300
|
31 |
+
_C.TRAIN.DETERMINISTIC = False
|
32 |
+
_C.TRAIN.SAVE_FREQ = 5
|
33 |
+
|
34 |
+
_C.TRAIN.BASE_LR = 5e-4
|
35 |
+
|
36 |
+
_C.TRAIN.WARMUP_EPOCHS = 20
|
37 |
+
_C.TRAIN.WEIGHT_DECAY = 0
|
38 |
+
_C.TRAIN.WARMUP_LR = 5e-7
|
39 |
+
_C.TRAIN.MIN_LR = 5e-6
|
40 |
+
# Clip gradient norm
|
41 |
+
_C.TRAIN.CLIP_GRAD = 5.0
|
42 |
+
# Auto resume from latest checkpoint
|
43 |
+
_C.TRAIN.RESUME_LAST = True
|
44 |
+
# Gradient accumulation steps
|
45 |
+
# could be overwritten by command line argument
|
46 |
+
_C.TRAIN.ACCUMULATION_STEPS = 0
|
47 |
+
# Whether to use gradient checkpointing to save memory
|
48 |
+
# could be overwritten by command line argument
|
49 |
+
_C.TRAIN.USE_CHECKPOINT = False
|
50 |
+
# 'cpu' or 'cuda:0, 1, 2, 3' or 'cuda'
|
51 |
+
_C.TRAIN.DEVICE = 'cuda'
|
52 |
+
|
53 |
+
# LR scheduler
|
54 |
+
_C.TRAIN.LR_SCHEDULER = CN()
|
55 |
+
_C.TRAIN.LR_SCHEDULER.NAME = ''
|
56 |
+
_C.TRAIN.LR_SCHEDULER.ARGS = []
|
57 |
+
|
58 |
+
|
59 |
+
# Optimizer
|
60 |
+
_C.TRAIN.OPTIMIZER = CN()
|
61 |
+
_C.TRAIN.OPTIMIZER.NAME = 'adam'
|
62 |
+
# Optimizer Epsilon
|
63 |
+
_C.TRAIN.OPTIMIZER.EPS = 1e-8
|
64 |
+
# Optimizer Betas
|
65 |
+
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
|
66 |
+
# SGD momentum
|
67 |
+
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
|
68 |
+
|
69 |
+
# Criterion
|
70 |
+
_C.TRAIN.CRITERION = CN()
|
71 |
+
# Boundary loss (Horizon-Net)
|
72 |
+
_C.TRAIN.CRITERION.BOUNDARY = CN()
|
73 |
+
_C.TRAIN.CRITERION.BOUNDARY.NAME = 'boundary'
|
74 |
+
_C.TRAIN.CRITERION.BOUNDARY.LOSS = 'BoundaryLoss'
|
75 |
+
_C.TRAIN.CRITERION.BOUNDARY.WEIGHT = 0.0
|
76 |
+
_C.TRAIN.CRITERION.BOUNDARY.WEIGHTS = []
|
77 |
+
_C.TRAIN.CRITERION.BOUNDARY.NEED_ALL = True
|
78 |
+
# Up and Down depth loss (LED2-Net)
|
79 |
+
_C.TRAIN.CRITERION.LEDDepth = CN()
|
80 |
+
_C.TRAIN.CRITERION.LEDDepth.NAME = 'led_depth'
|
81 |
+
_C.TRAIN.CRITERION.LEDDepth.LOSS = 'LEDLoss'
|
82 |
+
_C.TRAIN.CRITERION.LEDDepth.WEIGHT = 0.0
|
83 |
+
_C.TRAIN.CRITERION.LEDDepth.WEIGHTS = []
|
84 |
+
_C.TRAIN.CRITERION.LEDDepth.NEED_ALL = True
|
85 |
+
# Depth loss
|
86 |
+
_C.TRAIN.CRITERION.DEPTH = CN()
|
87 |
+
_C.TRAIN.CRITERION.DEPTH.NAME = 'depth'
|
88 |
+
_C.TRAIN.CRITERION.DEPTH.LOSS = 'L1Loss'
|
89 |
+
_C.TRAIN.CRITERION.DEPTH.WEIGHT = 0.0
|
90 |
+
_C.TRAIN.CRITERION.DEPTH.WEIGHTS = []
|
91 |
+
_C.TRAIN.CRITERION.DEPTH.NEED_ALL = False
|
92 |
+
# Ratio(Room Height) loss
|
93 |
+
_C.TRAIN.CRITERION.RATIO = CN()
|
94 |
+
_C.TRAIN.CRITERION.RATIO.NAME = 'ratio'
|
95 |
+
_C.TRAIN.CRITERION.RATIO.LOSS = 'L1Loss'
|
96 |
+
_C.TRAIN.CRITERION.RATIO.WEIGHT = 0.0
|
97 |
+
_C.TRAIN.CRITERION.RATIO.WEIGHTS = []
|
98 |
+
_C.TRAIN.CRITERION.RATIO.NEED_ALL = False
|
99 |
+
# Grad(Normal) loss
|
100 |
+
_C.TRAIN.CRITERION.GRAD = CN()
|
101 |
+
_C.TRAIN.CRITERION.GRAD.NAME = 'grad'
|
102 |
+
_C.TRAIN.CRITERION.GRAD.LOSS = 'GradLoss'
|
103 |
+
_C.TRAIN.CRITERION.GRAD.WEIGHT = 0.0
|
104 |
+
_C.TRAIN.CRITERION.GRAD.WEIGHTS = [1.0, 1.0]
|
105 |
+
_C.TRAIN.CRITERION.GRAD.NEED_ALL = True
|
106 |
+
# Object loss
|
107 |
+
_C.TRAIN.CRITERION.OBJECT = CN()
|
108 |
+
_C.TRAIN.CRITERION.OBJECT.NAME = 'object'
|
109 |
+
_C.TRAIN.CRITERION.OBJECT.LOSS = 'ObjectLoss'
|
110 |
+
_C.TRAIN.CRITERION.OBJECT.WEIGHT = 0.0
|
111 |
+
_C.TRAIN.CRITERION.OBJECT.WEIGHTS = []
|
112 |
+
_C.TRAIN.CRITERION.OBJECT.NEED_ALL = True
|
113 |
+
# Heatmap loss
|
114 |
+
_C.TRAIN.CRITERION.CHM = CN()
|
115 |
+
_C.TRAIN.CRITERION.CHM.NAME = 'corner_heat_map'
|
116 |
+
_C.TRAIN.CRITERION.CHM.LOSS = 'HeatmapLoss'
|
117 |
+
_C.TRAIN.CRITERION.CHM.WEIGHT = 0.0
|
118 |
+
_C.TRAIN.CRITERION.CHM.WEIGHTS = []
|
119 |
+
_C.TRAIN.CRITERION.CHM.NEED_ALL = False
|
120 |
+
|
121 |
+
_C.TRAIN.VIS_MERGE = True
|
122 |
+
_C.TRAIN.VIS_WEIGHT = 1024
|
123 |
+
# -----------------------------------------------------------------------------
|
124 |
+
# Output settings
|
125 |
+
# -----------------------------------------------------------------------------
|
126 |
+
_C.CKPT = CN()
|
127 |
+
_C.CKPT.PYTORCH = './'
|
128 |
+
_C.CKPT.ROOT = "./checkpoints"
|
129 |
+
_C.CKPT.DIR = os.path.join(_C.CKPT.ROOT, _C.MODEL.NAME, _C.TAG)
|
130 |
+
_C.CKPT.RESULT_DIR = os.path.join(_C.CKPT.DIR, 'results', _C.MODE)
|
131 |
+
|
132 |
+
_C.LOGGER = CN()
|
133 |
+
_C.LOGGER.DIR = os.path.join(_C.CKPT.DIR, "logs")
|
134 |
+
_C.LOGGER.LEVEL = logging.DEBUG
|
135 |
+
|
136 |
+
# -----------------------------------------------------------------------------
|
137 |
+
# Misc
|
138 |
+
# -----------------------------------------------------------------------------
|
139 |
+
# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2'), Please confirm your device support FP16(Half).
|
140 |
+
# overwritten by command line argument
|
141 |
+
_C.AMP_OPT_LEVEL = 'O1'
|
142 |
+
# Path to output folder, overwritten by command line argument
|
143 |
+
_C.OUTPUT = ''
|
144 |
+
# Tag of experiment, overwritten by command line argument
|
145 |
+
_C.TAG = 'default'
|
146 |
+
# Frequency to save checkpoint
|
147 |
+
_C.SAVE_FREQ = 1
|
148 |
+
# Frequency to logging info
|
149 |
+
_C.PRINT_FREQ = 10
|
150 |
+
# Fixed random seed
|
151 |
+
_C.SEED = 0
|
152 |
+
# Perform evaluation only, overwritten by command line argument
|
153 |
+
_C.EVAL_MODE = False
|
154 |
+
# Test throughput only, overwritten by command line argument
|
155 |
+
_C.THROUGHPUT_MODE = False
|
156 |
+
|
157 |
+
# -----------------------------------------------------------------------------
|
158 |
+
# FIX
|
159 |
+
# -----------------------------------------------------------------------------
|
160 |
+
_C.LOCAL_RANK = 0
|
161 |
+
_C.WORLD_SIZE = 0
|
162 |
+
|
163 |
+
# -----------------------------------------------------------------------------
|
164 |
+
# Data settings
|
165 |
+
# -----------------------------------------------------------------------------
|
166 |
+
_C.DATA = CN()
|
167 |
+
# Sub dataset of pano_s2d3d
|
168 |
+
_C.DATA.SUBSET = None
|
169 |
+
# Dataset name
|
170 |
+
_C.DATA.DATASET = 'mp3d'
|
171 |
+
# Path to dataset, could be overwritten by command line argument
|
172 |
+
_C.DATA.DIR = ''
|
173 |
+
# Max wall number
|
174 |
+
_C.DATA.WALL_NUM = 0 # all
|
175 |
+
# Panorama image size
|
176 |
+
_C.DATA.SHAPE = [512, 1024]
|
177 |
+
# Really camera height
|
178 |
+
_C.DATA.CAMERA_HEIGHT = 1.6
|
179 |
+
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
|
180 |
+
_C.DATA.PIN_MEMORY = True
|
181 |
+
# Debug use, fast test performance of model
|
182 |
+
_C.DATA.FOR_TEST_INDEX = None
|
183 |
+
|
184 |
+
# Batch size for a single GPU, could be overwritten by command line argument
|
185 |
+
_C.DATA.BATCH_SIZE = 8
|
186 |
+
# Number of data loading threads
|
187 |
+
_C.DATA.NUM_WORKERS = 8
|
188 |
+
|
189 |
+
# Training augment
|
190 |
+
_C.DATA.AUG = CN()
|
191 |
+
# Flip the panorama horizontally
|
192 |
+
_C.DATA.AUG.FLIP = True
|
193 |
+
# Pano Stretch Data Augmentation by HorizonNet
|
194 |
+
_C.DATA.AUG.STRETCH = True
|
195 |
+
# Rotate the panorama horizontally
|
196 |
+
_C.DATA.AUG.ROTATE = True
|
197 |
+
# Gamma adjusting
|
198 |
+
_C.DATA.AUG.GAMMA = True
|
199 |
+
|
200 |
+
_C.DATA.KEYS = []
|
201 |
+
|
202 |
+
|
203 |
+
_C.EVAL = CN()
|
204 |
+
_C.EVAL.POST_PROCESSING = None
|
205 |
+
_C.EVAL.NEED_CPE = False
|
206 |
+
_C.EVAL.NEED_F1 = False
|
207 |
+
_C.EVAL.NEED_RMSE = False
|
208 |
+
_C.EVAL.FORCE_CUBE = False
|
209 |
+
|
210 |
+
|
211 |
+
def merge_from_file(cfg_path):
|
212 |
+
config = _C.clone()
|
213 |
+
config.merge_from_file(cfg_path)
|
214 |
+
return config
|
215 |
+
|
216 |
+
|
217 |
+
def get_config(args=None):
|
218 |
+
config = _C.clone()
|
219 |
+
if args:
|
220 |
+
if 'cfg' in args and args.cfg:
|
221 |
+
config.merge_from_file(args.cfg)
|
222 |
+
|
223 |
+
if 'mode' in args and args.mode:
|
224 |
+
config.MODE = args.mode
|
225 |
+
|
226 |
+
if 'debug' in args and args.debug:
|
227 |
+
config.DEBUG = args.debug
|
228 |
+
|
229 |
+
if 'hidden_bar' in args and args.hidden_bar:
|
230 |
+
config.SHOW_BAR = False
|
231 |
+
|
232 |
+
if 'bs' in args and args.bs:
|
233 |
+
config.DATA.BATCH_SIZE = args.bs
|
234 |
+
|
235 |
+
if 'save_eval' in args and args.save_eval:
|
236 |
+
config.SAVE_EVAL = True
|
237 |
+
|
238 |
+
if 'val_name' in args and args.val_name:
|
239 |
+
config.VAL_NAME = args.val_name
|
240 |
+
|
241 |
+
if 'post_processing' in args and args.post_processing:
|
242 |
+
config.EVAL.POST_PROCESSING = args.post_processing
|
243 |
+
|
244 |
+
if 'need_cpe' in args and args.need_cpe:
|
245 |
+
config.EVAL.NEED_CPE = args.need_cpe
|
246 |
+
|
247 |
+
if 'need_f1' in args and args.need_f1:
|
248 |
+
config.EVAL.NEED_F1 = args.need_f1
|
249 |
+
|
250 |
+
if 'need_rmse' in args and args.need_rmse:
|
251 |
+
config.EVAL.NEED_RMSE = args.need_rmse
|
252 |
+
|
253 |
+
if 'force_cube' in args and args.force_cube:
|
254 |
+
config.EVAL.FORCE_CUBE = args.force_cube
|
255 |
+
|
256 |
+
if 'wall_num' in args and args.wall_num:
|
257 |
+
config.DATA.WALL_NUM = args.wall_num
|
258 |
+
|
259 |
+
args = config.MODEL.ARGS[0]
|
260 |
+
config.CKPT.DIR = os.path.join(config.CKPT.ROOT, f"{args['decoder_name']}_{args['output_name']}_Net",
|
261 |
+
config.TAG, 'debug' if config.DEBUG else '')
|
262 |
+
config.CKPT.RESULT_DIR = os.path.join(config.CKPT.DIR, 'results', config.MODE)
|
263 |
+
config.LOGGER.DIR = os.path.join(config.CKPT.DIR, "logs")
|
264 |
+
|
265 |
+
core_number = os.popen("grep 'physical id' /proc/cpuinfo | sort | uniq | wc -l").read()
|
266 |
+
|
267 |
+
try:
|
268 |
+
config.DATA.NUM_WORKERS = int(core_number) * 2
|
269 |
+
print(f"System core number: {config.DATA.NUM_WORKERS}")
|
270 |
+
except ValueError:
|
271 |
+
print(f"Can't get system core number, will use config: { config.DATA.NUM_WORKERS}")
|
272 |
+
config.freeze()
|
273 |
+
return config
|
274 |
+
|
275 |
+
|
276 |
+
def get_rank_config(cfg, local_rank, world_size):
|
277 |
+
local_rank = 0 if local_rank is None else local_rank
|
278 |
+
config = cfg.clone()
|
279 |
+
config.defrost()
|
280 |
+
if world_size > 1:
|
281 |
+
ids = config.TRAIN.DEVICE.split(':')[-1].split(',') if ':' in config.TRAIN.DEVICE else range(world_size)
|
282 |
+
config.TRAIN.DEVICE = f'cuda:{ids[local_rank]}'
|
283 |
+
|
284 |
+
config.LOCAL_RANK = local_rank
|
285 |
+
config.WORLD_SIZE = world_size
|
286 |
+
config.SEED = config.SEED + local_rank
|
287 |
+
|
288 |
+
config.freeze()
|
289 |
+
return config
|
convert_ckpt.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@date: 2021/11/22
|
3 |
+
@description: Conversion training ckpt into inference ckpt
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from config.defaults import merge_from_file
|
11 |
+
|
12 |
+
|
13 |
+
def parse_option():
|
14 |
+
parser = argparse.ArgumentParser(description='Conversion training ckpt into inference ckpt')
|
15 |
+
parser.add_argument('--cfg',
|
16 |
+
type=str,
|
17 |
+
required=True,
|
18 |
+
metavar='FILE',
|
19 |
+
help='path of config file')
|
20 |
+
|
21 |
+
parser.add_argument('--output_path',
|
22 |
+
type=str,
|
23 |
+
help='path of output ckpt')
|
24 |
+
|
25 |
+
args = parser.parse_args()
|
26 |
+
|
27 |
+
print("arguments:")
|
28 |
+
for arg in vars(args):
|
29 |
+
print(arg, ":", getattr(args, arg))
|
30 |
+
print("-" * 50)
|
31 |
+
return args
|
32 |
+
|
33 |
+
|
34 |
+
def convert_ckpt():
|
35 |
+
args = parse_option()
|
36 |
+
config = merge_from_file(args.cfg)
|
37 |
+
ck_dir = os.path.join("checkpoints", f"{config.MODEL.ARGS[0]['decoder_name']}_{config.MODEL.ARGS[0]['output_name']}_Net",
|
38 |
+
config.TAG)
|
39 |
+
print(f"Processing {ck_dir}")
|
40 |
+
model_paths = [name for name in os.listdir(ck_dir) if '_best_' in name]
|
41 |
+
if len(model_paths) == 0:
|
42 |
+
print("Not find best ckpt")
|
43 |
+
return
|
44 |
+
model_path = os.path.join(ck_dir, model_paths[0])
|
45 |
+
print(f"Loading {model_path}")
|
46 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cuda:0'))
|
47 |
+
net = checkpoint['net']
|
48 |
+
output_path = None
|
49 |
+
if args.output_path is None:
|
50 |
+
output_path = os.path.join(ck_dir, 'best.pkl')
|
51 |
+
else:
|
52 |
+
output_path = args.output_path
|
53 |
+
if output_path is None:
|
54 |
+
print("Output path is invalid")
|
55 |
+
print(f"Save on: {output_path}")
|
56 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
57 |
+
torch.save(net, output_path)
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == '__main__':
|
61 |
+
convert_ckpt()
|
dataset/__init__.py
ADDED
File without changes
|
dataset/build.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/07/18
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import numpy as np
|
6 |
+
import torch.utils.data
|
7 |
+
from dataset.mp3d_dataset import MP3DDataset
|
8 |
+
from dataset.pano_s2d3d_dataset import PanoS2D3DDataset
|
9 |
+
from dataset.pano_s2d3d_mix_dataset import PanoS2D3DMixDataset
|
10 |
+
from dataset.zind_dataset import ZindDataset
|
11 |
+
|
12 |
+
|
13 |
+
def build_loader(config, logger):
|
14 |
+
name = config.DATA.DATASET
|
15 |
+
ddp = config.WORLD_SIZE > 1
|
16 |
+
train_dataset = None
|
17 |
+
train_data_loader = None
|
18 |
+
if config.MODE == 'train':
|
19 |
+
train_dataset = build_dataset(mode='train', config=config, logger=logger)
|
20 |
+
|
21 |
+
val_dataset = build_dataset(mode=config.VAL_NAME if config.MODE != 'test' else 'test', config=config, logger=logger)
|
22 |
+
|
23 |
+
train_sampler = None
|
24 |
+
val_sampler = None
|
25 |
+
if ddp:
|
26 |
+
if train_dataset:
|
27 |
+
train_sampler = torch.utils.data.DistributedSampler(train_dataset, shuffle=True)
|
28 |
+
val_sampler = torch.utils.data.DistributedSampler(val_dataset, shuffle=False)
|
29 |
+
|
30 |
+
batch_size = config.DATA.BATCH_SIZE
|
31 |
+
num_workers = 0 if config.DEBUG else config.DATA.NUM_WORKERS
|
32 |
+
pin_memory = config.DATA.PIN_MEMORY
|
33 |
+
if train_dataset:
|
34 |
+
logger.info(f'Train data loader batch size: {batch_size}')
|
35 |
+
train_data_loader = torch.utils.data.DataLoader(
|
36 |
+
train_dataset, sampler=train_sampler,
|
37 |
+
batch_size=batch_size,
|
38 |
+
shuffle=True,
|
39 |
+
num_workers=num_workers,
|
40 |
+
pin_memory=pin_memory,
|
41 |
+
drop_last=True,
|
42 |
+
)
|
43 |
+
batch_size = batch_size - (len(val_dataset) % np.arange(batch_size, 0, -1)).tolist().index(0)
|
44 |
+
logger.info(f'Val data loader batch size: {batch_size}')
|
45 |
+
val_data_loader = torch.utils.data.DataLoader(
|
46 |
+
val_dataset, sampler=val_sampler,
|
47 |
+
batch_size=batch_size,
|
48 |
+
shuffle=False,
|
49 |
+
num_workers=num_workers,
|
50 |
+
pin_memory=pin_memory,
|
51 |
+
drop_last=False
|
52 |
+
)
|
53 |
+
logger.info(f'Build data loader: num_workers:{num_workers} pin_memory:{pin_memory}')
|
54 |
+
return train_data_loader, val_data_loader
|
55 |
+
|
56 |
+
|
57 |
+
def build_dataset(mode, config, logger):
|
58 |
+
name = config.DATA.DATASET
|
59 |
+
if name == 'mp3d':
|
60 |
+
dataset = MP3DDataset(
|
61 |
+
root_dir=config.DATA.DIR,
|
62 |
+
mode=mode,
|
63 |
+
shape=config.DATA.SHAPE,
|
64 |
+
max_wall_num=config.DATA.WALL_NUM,
|
65 |
+
aug=config.DATA.AUG if mode == 'train' else None,
|
66 |
+
camera_height=config.DATA.CAMERA_HEIGHT,
|
67 |
+
logger=logger,
|
68 |
+
for_test_index=config.DATA.FOR_TEST_INDEX,
|
69 |
+
keys=config.DATA.KEYS
|
70 |
+
)
|
71 |
+
elif name == 'pano_s2d3d':
|
72 |
+
dataset = PanoS2D3DDataset(
|
73 |
+
root_dir=config.DATA.DIR,
|
74 |
+
mode=mode,
|
75 |
+
shape=config.DATA.SHAPE,
|
76 |
+
max_wall_num=config.DATA.WALL_NUM,
|
77 |
+
aug=config.DATA.AUG if mode == 'train' else None,
|
78 |
+
camera_height=config.DATA.CAMERA_HEIGHT,
|
79 |
+
logger=logger,
|
80 |
+
for_test_index=config.DATA.FOR_TEST_INDEX,
|
81 |
+
subset=config.DATA.SUBSET,
|
82 |
+
keys=config.DATA.KEYS
|
83 |
+
)
|
84 |
+
elif name == 'pano_s2d3d_mix':
|
85 |
+
dataset = PanoS2D3DMixDataset(
|
86 |
+
root_dir=config.DATA.DIR,
|
87 |
+
mode=mode,
|
88 |
+
shape=config.DATA.SHAPE,
|
89 |
+
max_wall_num=config.DATA.WALL_NUM,
|
90 |
+
aug=config.DATA.AUG if mode == 'train' else None,
|
91 |
+
camera_height=config.DATA.CAMERA_HEIGHT,
|
92 |
+
logger=logger,
|
93 |
+
for_test_index=config.DATA.FOR_TEST_INDEX,
|
94 |
+
subset=config.DATA.SUBSET,
|
95 |
+
keys=config.DATA.KEYS
|
96 |
+
)
|
97 |
+
elif name == 'zind':
|
98 |
+
dataset = ZindDataset(
|
99 |
+
root_dir=config.DATA.DIR,
|
100 |
+
mode=mode,
|
101 |
+
shape=config.DATA.SHAPE,
|
102 |
+
max_wall_num=config.DATA.WALL_NUM,
|
103 |
+
aug=config.DATA.AUG if mode == 'train' else None,
|
104 |
+
camera_height=config.DATA.CAMERA_HEIGHT,
|
105 |
+
logger=logger,
|
106 |
+
for_test_index=config.DATA.FOR_TEST_INDEX,
|
107 |
+
is_simple=True,
|
108 |
+
is_ceiling_flat=False,
|
109 |
+
keys=config.DATA.KEYS,
|
110 |
+
vp_align=config.EVAL.POST_PROCESSING is not None and 'manhattan' in config.EVAL.POST_PROCESSING
|
111 |
+
)
|
112 |
+
else:
|
113 |
+
raise NotImplementedError(f"Unknown dataset: {name}")
|
114 |
+
|
115 |
+
return dataset
|
dataset/communal/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/09/22
|
3 |
+
@description:
|
4 |
+
"""
|
dataset/communal/base_dataset.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/07/26
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from utils.boundary import corners2boundary, visibility_corners, get_heat_map
|
9 |
+
from utils.conversion import xyz2depth, uv2xyz, uv2pixel
|
10 |
+
from dataset.communal.data_augmentation import PanoDataAugmentation
|
11 |
+
|
12 |
+
|
13 |
+
class BaseDataset(torch.utils.data.Dataset):
|
14 |
+
def __init__(self, mode, shape=None, max_wall_num=999, aug=None, camera_height=1.6, patch_num=256, keys=None):
|
15 |
+
if keys is None or len(keys) == 0:
|
16 |
+
keys = ['image', 'depth', 'ratio', 'id', 'corners']
|
17 |
+
if shape is None:
|
18 |
+
shape = [512, 1024]
|
19 |
+
|
20 |
+
assert mode == 'train' or mode == 'val' or mode == 'test' or mode is None, 'unknown mode!'
|
21 |
+
self.mode = mode
|
22 |
+
self.keys = keys
|
23 |
+
self.shape = shape
|
24 |
+
self.pano_aug = None if aug is None or mode == 'val' else PanoDataAugmentation(aug)
|
25 |
+
self.camera_height = camera_height
|
26 |
+
self.max_wall_num = max_wall_num
|
27 |
+
self.patch_num = patch_num
|
28 |
+
self.data = None
|
29 |
+
|
30 |
+
def __len__(self):
|
31 |
+
return len(self.data)
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def get_depth(corners, plan_y=1, length=256, visible=True):
|
35 |
+
visible_floor_boundary = corners2boundary(corners, length=length, visible=visible)
|
36 |
+
# The horizon-depth relative to plan_y
|
37 |
+
visible_depth = xyz2depth(uv2xyz(visible_floor_boundary, plan_y), plan_y)
|
38 |
+
return visible_depth
|
39 |
+
|
40 |
+
def process_data(self, label, image, patch_num):
|
41 |
+
"""
|
42 |
+
:param label:
|
43 |
+
:param image:
|
44 |
+
:param patch_num:
|
45 |
+
:return:
|
46 |
+
"""
|
47 |
+
corners = label['corners']
|
48 |
+
if self.pano_aug is not None:
|
49 |
+
corners, image = self.pano_aug.execute_aug(corners, image if 'image' in self.keys else None)
|
50 |
+
eps = 1e-3
|
51 |
+
corners[:, 1] = np.clip(corners[:, 1], 0.5+eps, 1-eps)
|
52 |
+
|
53 |
+
output = {}
|
54 |
+
if 'image' in self.keys:
|
55 |
+
image = image.transpose(2, 0, 1)
|
56 |
+
output['image'] = image
|
57 |
+
|
58 |
+
visible_corners = None
|
59 |
+
if 'corner_class' in self.keys or 'depth' in self.keys:
|
60 |
+
visible_corners = visibility_corners(corners)
|
61 |
+
|
62 |
+
if 'depth' in self.keys:
|
63 |
+
depth = self.get_depth(visible_corners, length=patch_num, visible=False)
|
64 |
+
assert len(depth) == patch_num, f"{label['id']}, {len(depth)}, {self.pano_aug.parameters}, {corners}"
|
65 |
+
output['depth'] = depth
|
66 |
+
|
67 |
+
if 'ratio' in self.keys:
|
68 |
+
# Why use ratio? Because when floor_height =y_plan=1, we only need to predict ceil_height(ratio).
|
69 |
+
output['ratio'] = label['ratio']
|
70 |
+
|
71 |
+
if 'id' in self.keys:
|
72 |
+
output['id'] = label['id']
|
73 |
+
|
74 |
+
if 'corners' in self.keys:
|
75 |
+
# all corners for evaluating Full_IoU
|
76 |
+
assert len(label['corners']) <= 32, "len(label['corners']):"+len(label['corners'])
|
77 |
+
output['corners'] = np.zeros((32, 2), dtype=np.float32)
|
78 |
+
output['corners'][:len(label['corners'])] = label['corners']
|
79 |
+
|
80 |
+
if 'corner_heat_map' in self.keys:
|
81 |
+
output['corner_heat_map'] = get_heat_map(visible_corners[..., 0])
|
82 |
+
|
83 |
+
if 'object' in self.keys and 'objects' in label:
|
84 |
+
output[f'object_heat_map'] = np.zeros((3, patch_num), dtype=np.float32)
|
85 |
+
output['object_size'] = np.zeros((3, patch_num), dtype=np.float32) # width, height, bottom_height
|
86 |
+
for i, type in enumerate(label['objects']):
|
87 |
+
if len(label['objects'][type]) == 0:
|
88 |
+
continue
|
89 |
+
|
90 |
+
u_s = []
|
91 |
+
for obj in label['objects'][type]:
|
92 |
+
center_u = obj['center_u']
|
93 |
+
u_s.append(center_u)
|
94 |
+
center_pixel_u = uv2pixel(np.array([center_u]), w=patch_num, axis=0)[0]
|
95 |
+
output['object_size'][0, center_pixel_u] = obj['width_u']
|
96 |
+
output['object_size'][1, center_pixel_u] = obj['height_v']
|
97 |
+
output['object_size'][2, center_pixel_u] = obj['boundary_v']
|
98 |
+
output[f'object_heat_map'][i] = get_heat_map(np.array(u_s))
|
99 |
+
|
100 |
+
return output
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == '__main__':
|
104 |
+
from dataset.communal.read import read_image, read_label
|
105 |
+
from visualization.boundary import draw_boundaries
|
106 |
+
from utils.boundary import depth2boundaries
|
107 |
+
from tqdm import trange
|
108 |
+
|
109 |
+
# np.random.seed(0)
|
110 |
+
dataset = BaseDataset()
|
111 |
+
dataset.pano_aug = PanoDataAugmentation(aug={
|
112 |
+
'STRETCH': True,
|
113 |
+
'ROTATE': True,
|
114 |
+
'FLIP': True,
|
115 |
+
})
|
116 |
+
# pano_img = read_image("../src/demo.png")
|
117 |
+
# label = read_label("../src/demo.json")
|
118 |
+
pano_img_path = "../../src/dataset/mp3d/image/yqstnuAEVhm_6589ad7a5a0444b59adbf501c0f0fe53.png"
|
119 |
+
label_path = "../../src/dataset/mp3d/label/yqstnuAEVhm_6589ad7a5a0444b59adbf501c0f0fe53.json"
|
120 |
+
pano_img = read_image(pano_img_path)
|
121 |
+
label = read_label(label_path)
|
122 |
+
|
123 |
+
# batch test
|
124 |
+
for i in trange(1):
|
125 |
+
output = dataset.process_data(label, pano_img, 256)
|
126 |
+
boundary_list = depth2boundaries(output['ratio'], output['depth'], step=None)
|
127 |
+
draw_boundaries(output['image'].transpose(1, 2, 0), boundary_list=boundary_list, show=True)
|
dataset/communal/data_augmentation.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/07/27
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
import functools
|
8 |
+
|
9 |
+
from utils.conversion import pixel2lonlat, lonlat2pixel, uv2lonlat, lonlat2uv, pixel2uv
|
10 |
+
|
11 |
+
|
12 |
+
@functools.lru_cache()
|
13 |
+
def prepare_stretch(w, h):
|
14 |
+
lon = pixel2lonlat(np.array(range(w)), w=w, axis=0)
|
15 |
+
lat = pixel2lonlat(np.array(range(h)), h=h, axis=1)
|
16 |
+
sin_lon = np.sin(lon)
|
17 |
+
cos_lon = np.cos(lon)
|
18 |
+
tan_lat = np.tan(lat)
|
19 |
+
return sin_lon, cos_lon, tan_lat
|
20 |
+
|
21 |
+
|
22 |
+
def pano_stretch_image(pano_img, kx, ky, kz):
|
23 |
+
"""
|
24 |
+
Note that this is the inverse mapping, which refers to Equation 3 in HorizonNet paper (the coordinate system in
|
25 |
+
the paper is different from here, xz needs to be swapped)
|
26 |
+
:param pano_img: a panorama image, shape must be [h,w,c]
|
27 |
+
:param kx: stretching along left-right direction
|
28 |
+
:param ky: stretching along up-down direction
|
29 |
+
:param kz: stretching along front-back direction
|
30 |
+
:return:
|
31 |
+
"""
|
32 |
+
w = pano_img.shape[1]
|
33 |
+
h = pano_img.shape[0]
|
34 |
+
|
35 |
+
sin_lon, cos_lon, tan_lat = prepare_stretch(w, h)
|
36 |
+
|
37 |
+
n_lon = np.arctan2(sin_lon * kz / kx, cos_lon)
|
38 |
+
n_lat = np.arctan(tan_lat[..., None] * np.sin(n_lon) / sin_lon * kx / ky)
|
39 |
+
n_pu = lonlat2pixel(n_lon, w=w, axis=0, need_round=False)
|
40 |
+
n_pv = lonlat2pixel(n_lat, h=h, axis=1, need_round=False)
|
41 |
+
|
42 |
+
pixel_map = np.empty((h, w, 2), dtype=np.float32)
|
43 |
+
pixel_map[..., 0] = n_pu
|
44 |
+
pixel_map[..., 1] = n_pv
|
45 |
+
map1 = pixel_map[..., 0]
|
46 |
+
map2 = pixel_map[..., 1]
|
47 |
+
# using wrap mode because it is continues at left or right of panorama
|
48 |
+
new_img = cv2.remap(pano_img, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP)
|
49 |
+
return new_img
|
50 |
+
|
51 |
+
|
52 |
+
def pano_stretch_conner(corners, kx, ky, kz):
|
53 |
+
"""
|
54 |
+
:param corners:
|
55 |
+
:param kx: stretching along left-right direction
|
56 |
+
:param ky: stretching along up-down direction
|
57 |
+
:param kz: stretching along front-back direction
|
58 |
+
:return:
|
59 |
+
"""
|
60 |
+
|
61 |
+
lonlat = uv2lonlat(corners)
|
62 |
+
sin_lon = np.sin(lonlat[..., 0:1])
|
63 |
+
cos_lon = np.cos(lonlat[..., 0:1])
|
64 |
+
tan_lat = np.tan(lonlat[..., 1:2])
|
65 |
+
|
66 |
+
n_lon = np.arctan2(sin_lon * kx / kz, cos_lon)
|
67 |
+
|
68 |
+
a = np.bitwise_or(corners[..., 0] == 0.5, corners[..., 0] == 1)
|
69 |
+
b = np.bitwise_not(a)
|
70 |
+
w = np.zeros_like(n_lon)
|
71 |
+
w[b] = np.sin(n_lon[b]) / sin_lon[b]
|
72 |
+
w[a] = kx / kz
|
73 |
+
|
74 |
+
n_lat = np.arctan(tan_lat * w / kx * ky)
|
75 |
+
|
76 |
+
lst = [n_lon, n_lat]
|
77 |
+
lonlat = np.concatenate(lst, axis=-1)
|
78 |
+
new_corners = lonlat2uv(lonlat)
|
79 |
+
return new_corners
|
80 |
+
|
81 |
+
|
82 |
+
def pano_stretch(pano_img, corners, kx, ky, kz):
|
83 |
+
"""
|
84 |
+
:param pano_img: a panorama image, shape must be [h,w,c]
|
85 |
+
:param corners:
|
86 |
+
:param kx: stretching along left-right direction
|
87 |
+
:param ky: stretching along up-down direction
|
88 |
+
:param kz: stretching along front-back direction
|
89 |
+
:return:
|
90 |
+
"""
|
91 |
+
new_img = pano_stretch_image(pano_img, kx, ky, kz)
|
92 |
+
new_corners = pano_stretch_conner(corners, kx, ky, kz)
|
93 |
+
return new_img, new_corners
|
94 |
+
|
95 |
+
|
96 |
+
class PanoDataAugmentation:
|
97 |
+
def __init__(self, aug):
|
98 |
+
self.aug = aug
|
99 |
+
self.parameters = {}
|
100 |
+
|
101 |
+
def need_aug(self, name):
|
102 |
+
return name in self.aug and self.aug[name]
|
103 |
+
|
104 |
+
def execute_space_aug(self, corners, image):
|
105 |
+
if image is None:
|
106 |
+
return image
|
107 |
+
|
108 |
+
if self.aug is None:
|
109 |
+
return corners, image
|
110 |
+
w = image.shape[1]
|
111 |
+
h = image.shape[0]
|
112 |
+
|
113 |
+
if self.need_aug('STRETCH'):
|
114 |
+
kx = np.random.uniform(1, 2)
|
115 |
+
kx = 1 / kx if np.random.randint(2) == 0 else kx
|
116 |
+
# we found that the ky transform may cause IoU to drop (HorizonNet also only x and z transform)
|
117 |
+
# ky = np.random.uniform(1, 2)
|
118 |
+
# ky = 1 / ky if np.random.randint(2) == 0 else ky
|
119 |
+
ky = 1
|
120 |
+
kz = np.random.uniform(1, 2)
|
121 |
+
kz = 1 / kz if np.random.randint(2) == 0 else kz
|
122 |
+
image, corners = pano_stretch(image, corners, kx, ky, kz)
|
123 |
+
self.parameters['STRETCH'] = {'kx': kx, 'ky': ky, 'kz': kz}
|
124 |
+
else:
|
125 |
+
self.parameters['STRETCH'] = None
|
126 |
+
|
127 |
+
if self.need_aug('ROTATE'):
|
128 |
+
d_pu = np.random.randint(w)
|
129 |
+
image = np.roll(image, d_pu, axis=1)
|
130 |
+
corners[..., 0] = (corners[..., 0] + pixel2uv(np.array([d_pu]), w, h)) % pixel2uv(np.array([w]), w, h)
|
131 |
+
self.parameters['ROTATE'] = d_pu
|
132 |
+
else:
|
133 |
+
self.parameters['ROTATE'] = None
|
134 |
+
|
135 |
+
if self.need_aug('FLIP') and np.random.randint(2) == 0:
|
136 |
+
image = np.flip(image, axis=1).copy()
|
137 |
+
corners[..., 0] = pixel2uv(np.array([w]), w, h) - corners[..., 0]
|
138 |
+
corners = corners[::-1]
|
139 |
+
self.parameters['FLIP'] = True
|
140 |
+
else:
|
141 |
+
self.parameters['FLIP'] = None
|
142 |
+
|
143 |
+
return corners, image
|
144 |
+
|
145 |
+
def execute_visual_aug(self, image):
|
146 |
+
if self.need_aug('GAMMA'):
|
147 |
+
p = np.random.uniform(1, 2)
|
148 |
+
if np.random.randint(2) == 0:
|
149 |
+
p = 1 / p
|
150 |
+
image = image ** p
|
151 |
+
self.parameters['GAMMA'] = p
|
152 |
+
else:
|
153 |
+
self.parameters['GAMMA'] = None
|
154 |
+
|
155 |
+
# The following visual augmentation methods are only implemented but not tested
|
156 |
+
if self.need_aug('HUE') or self.need_aug('SATURATION'):
|
157 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
158 |
+
|
159 |
+
if self.need_aug('HUE') and np.random.randint(2) == 0:
|
160 |
+
p = np.random.uniform(-0.1, 0.1)
|
161 |
+
image[..., 0] = np.mod(image[..., 0] + p * 180, 180)
|
162 |
+
self.parameters['HUE'] = p
|
163 |
+
else:
|
164 |
+
self.parameters['HUE'] = None
|
165 |
+
|
166 |
+
if self.need_aug('SATURATION') and np.random.randint(2) == 0:
|
167 |
+
p = np.random.uniform(0.5, 1.5)
|
168 |
+
image[..., 1] = np.clip(image[..., 1] * p, 0, 1)
|
169 |
+
self.parameters['SATURATION'] = p
|
170 |
+
else:
|
171 |
+
self.parameters['SATURATION'] = None
|
172 |
+
|
173 |
+
image = cv2.cvtColor(image, cv2.COLOR_HSV2RGB)
|
174 |
+
|
175 |
+
if self.need_aug('CONTRAST') and np.random.randint(2) == 0:
|
176 |
+
p = np.random.uniform(0.9, 1.1)
|
177 |
+
mean = image.mean(axis=0).mean(axis=0)
|
178 |
+
image = (image - mean) * p + mean
|
179 |
+
image = np.clip(image, 0, 1)
|
180 |
+
self.parameters['CONTRAST'] = p
|
181 |
+
else:
|
182 |
+
self.parameters['CONTRAST'] = None
|
183 |
+
|
184 |
+
return image
|
185 |
+
|
186 |
+
def execute_aug(self, corners, image):
|
187 |
+
corners, image = self.execute_space_aug(corners, image)
|
188 |
+
if image is not None:
|
189 |
+
image = self.execute_visual_aug(image)
|
190 |
+
return corners, image
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == '__main__1':
|
194 |
+
from tqdm import trange
|
195 |
+
from visualization.floorplan import draw_floorplan
|
196 |
+
from dataset.communal.read import read_image, read_label
|
197 |
+
from utils.time_watch import TimeWatch
|
198 |
+
from utils.conversion import uv2xyz
|
199 |
+
from utils.boundary import corners2boundary
|
200 |
+
|
201 |
+
np.random.seed(123)
|
202 |
+
pano_img_path = "../../src/dataset/mp3d/image/TbHJrupSAjP_f320ae084f3a447da3e8ab11dd5f9320.png"
|
203 |
+
label_path = "../../src/dataset/mp3d/label/TbHJrupSAjP_f320ae084f3a447da3e8ab11dd5f9320.json"
|
204 |
+
pano_img = read_image(pano_img_path)
|
205 |
+
label = read_label(label_path)
|
206 |
+
|
207 |
+
corners = label['corners']
|
208 |
+
ratio = label['ratio']
|
209 |
+
|
210 |
+
pano_aug = PanoDataAugmentation(aug={
|
211 |
+
'STRETCH': True,
|
212 |
+
'ROTATE': True,
|
213 |
+
'FLIP': True,
|
214 |
+
'GAMMA': True,
|
215 |
+
# 'HUE': True,
|
216 |
+
# 'SATURATION': True,
|
217 |
+
# 'CONTRAST': True
|
218 |
+
})
|
219 |
+
|
220 |
+
# draw_floorplan(corners, show=True, marker_color=0.5, center_color=0.8, plan_y=1.6, show_radius=8)
|
221 |
+
# draw_boundaries(pano_img, corners_list=[corners], show=True, length=1024, ratio=ratio)
|
222 |
+
|
223 |
+
w = TimeWatch("test")
|
224 |
+
for i in trange(50000):
|
225 |
+
new_corners, new_pano_img = pano_aug.execute_aug(corners.copy(), pano_img.copy())
|
226 |
+
# draw_floorplan(uv2xyz(new_corners, plan_y=1.6)[..., ::2], show=True, marker_color=0.5, center_color=0.8,
|
227 |
+
# show_radius=8)
|
228 |
+
# draw_boundaries(new_pano_img, corners_list=[new_corners], show=True, length=1024, ratio=ratio)
|
229 |
+
|
230 |
+
|
231 |
+
if __name__ == '__main__':
|
232 |
+
from utils.boundary import corners2boundary
|
233 |
+
from visualization.floorplan import draw_floorplan
|
234 |
+
from utils.boundary import visibility_corners
|
235 |
+
|
236 |
+
corners = np.array([[0.7664539, 0.7416811],
|
237 |
+
[0.06641078, 0.6521386],
|
238 |
+
[0.30997428, 0.57855356],
|
239 |
+
[0.383300784, 0.58726823],
|
240 |
+
[0.383300775, 0.8005296],
|
241 |
+
[0.5062902, 0.74822706]])
|
242 |
+
corners = visibility_corners(corners)
|
243 |
+
print(corners)
|
244 |
+
# draw_floorplan(uv2xyz(corners, plan_y=1.6)[..., ::2], show=True, marker_color=0.5, center_color=0.8,
|
245 |
+
# show_radius=8)
|
246 |
+
visible_floor_boundary = corners2boundary(corners, length=256, visible=True)
|
247 |
+
# visible_depth = xyz2depth(uv2xyz(visible_floor_boundary, 1), 1)
|
248 |
+
print(len(visible_floor_boundary))
|
249 |
+
|
250 |
+
|
251 |
+
if __name__ == '__main__0':
|
252 |
+
from visualization.floorplan import draw_floorplan
|
253 |
+
|
254 |
+
from dataset.communal.read import read_image, read_label
|
255 |
+
from utils.time_watch import TimeWatch
|
256 |
+
from utils.conversion import uv2xyz
|
257 |
+
|
258 |
+
# np.random.seed(1234)
|
259 |
+
pano_img_path = "../../src/dataset/mp3d/image/VVfe2KiqLaN_35b41dcbfcf84f96878f6ca28c70e5af.png"
|
260 |
+
label_path = "../../src/dataset/mp3d/label/VVfe2KiqLaN_35b41dcbfcf84f96878f6ca28c70e5af.json"
|
261 |
+
pano_img = read_image(pano_img_path)
|
262 |
+
label = read_label(label_path)
|
263 |
+
|
264 |
+
corners = label['corners']
|
265 |
+
ratio = label['ratio']
|
266 |
+
|
267 |
+
# draw_floorplan(corners, show=True, marker_color=0.5, center_color=0.8, plan_y=1.6, show_radius=8)
|
268 |
+
|
269 |
+
w = TimeWatch()
|
270 |
+
for i in range(5):
|
271 |
+
kx = np.random.uniform(1, 2)
|
272 |
+
kx = 1 / kx if np.random.randint(2) == 0 else kx
|
273 |
+
ky = np.random.uniform(1, 2)
|
274 |
+
ky = 1 / ky if np.random.randint(2) == 0 else ky
|
275 |
+
kz = np.random.uniform(1, 2)
|
276 |
+
kz = 1 / kz if np.random.randint(2) == 0 else kz
|
277 |
+
new_corners = pano_stretch_conner(corners.copy(), kx, ky, kz)
|
278 |
+
draw_floorplan(uv2xyz(new_corners, plan_y=1.6)[..., ::2], show=True, marker_color=0.5, center_color=0.8,
|
279 |
+
show_radius=8)
|
dataset/communal/read.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/07/28
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
import json
|
9 |
+
from PIL import Image
|
10 |
+
from utils.conversion import xyz2uv, pixel2uv
|
11 |
+
from utils.height import calc_ceil_ratio
|
12 |
+
|
13 |
+
|
14 |
+
def read_image(image_path, shape=None):
|
15 |
+
if shape is None:
|
16 |
+
shape = [512, 1024]
|
17 |
+
img = np.array(Image.open(image_path)).astype(np.float32) / 255
|
18 |
+
if img.shape[0] != shape[0] or img.shape[1] != shape[1]:
|
19 |
+
img = cv2.resize(img, dsize=tuple(shape[::-1]), interpolation=cv2.INTER_AREA)
|
20 |
+
|
21 |
+
return np.array(img)
|
22 |
+
|
23 |
+
|
24 |
+
def read_label(label_path, data_type='MP3D'):
|
25 |
+
|
26 |
+
if data_type == 'MP3D':
|
27 |
+
with open(label_path, 'r') as f:
|
28 |
+
label = json.load(f)
|
29 |
+
point_idx = [one['pointsIdx'][0] for one in label['layoutWalls']['walls']]
|
30 |
+
camera_height = label['cameraHeight']
|
31 |
+
room_height = label['layoutHeight']
|
32 |
+
camera_ceiling_height = room_height - camera_height
|
33 |
+
ratio = camera_ceiling_height / camera_height
|
34 |
+
|
35 |
+
xyz = [one['xyz'] for one in label['layoutPoints']['points']]
|
36 |
+
assert len(xyz) == len(point_idx), "len(xyz) != len(point_idx)"
|
37 |
+
xyz = [xyz[i] for i in point_idx]
|
38 |
+
xyz = np.asarray(xyz, dtype=np.float32)
|
39 |
+
xyz[:, 2] *= -1
|
40 |
+
xyz[:, 1] = camera_height
|
41 |
+
corners = xyz2uv(xyz)
|
42 |
+
elif data_type == 'Pano_S2D3D':
|
43 |
+
with open(label_path, 'r') as f:
|
44 |
+
lines = [line for line in f.readlines() if
|
45 |
+
len([c for c in line.split(' ') if c[0].isnumeric()]) > 1]
|
46 |
+
|
47 |
+
corners_list = np.array([line.strip().split() for line in lines], np.float32)
|
48 |
+
uv_list = pixel2uv(corners_list)
|
49 |
+
ceil_uv = uv_list[::2]
|
50 |
+
floor_uv = uv_list[1::2]
|
51 |
+
ratio = calc_ceil_ratio([ceil_uv, floor_uv], mode='mean')
|
52 |
+
corners = floor_uv
|
53 |
+
else:
|
54 |
+
return None
|
55 |
+
|
56 |
+
output = {
|
57 |
+
'ratio': np.array([ratio], dtype=np.float32),
|
58 |
+
'corners': corners,
|
59 |
+
'id': os.path.basename(label_path).split('.')[0]
|
60 |
+
}
|
61 |
+
return output
|
62 |
+
|
63 |
+
|
64 |
+
def move_not_simple_image(data_dir, simple_panos):
|
65 |
+
import shutil
|
66 |
+
for house_index in os.listdir(data_dir):
|
67 |
+
house_path = os.path.join(data_dir, house_index)
|
68 |
+
if not os.path.isdir(house_path) or house_index == 'visualization':
|
69 |
+
continue
|
70 |
+
|
71 |
+
floor_plan_path = os.path.join(house_path, 'floor_plans')
|
72 |
+
if os.path.exists(floor_plan_path):
|
73 |
+
print(f'move:{floor_plan_path}')
|
74 |
+
dst_floor_plan_path = floor_plan_path.replace('zind', 'zind2')
|
75 |
+
os.makedirs(dst_floor_plan_path, exist_ok=True)
|
76 |
+
shutil.move(floor_plan_path, dst_floor_plan_path)
|
77 |
+
|
78 |
+
panos_path = os.path.join(house_path, 'panos')
|
79 |
+
for pano in os.listdir(panos_path):
|
80 |
+
pano_path = os.path.join(panos_path, pano)
|
81 |
+
pano_index = '_'.join(pano.split('.')[0].split('_')[-2:])
|
82 |
+
if f'{house_index}_{pano_index}' not in simple_panos and os.path.exists(pano_path):
|
83 |
+
print(f'move:{pano_path}')
|
84 |
+
dst_pano_path = pano_path.replace('zind', 'zind2')
|
85 |
+
os.makedirs(os.path.dirname(dst_pano_path), exist_ok=True)
|
86 |
+
shutil.move(pano_path, dst_pano_path)
|
87 |
+
|
88 |
+
|
89 |
+
def read_zind(partition_path, simplicity_path, data_dir, mode, is_simple=True,
|
90 |
+
layout_type='layout_raw', is_ceiling_flat=False, plan_y=1):
|
91 |
+
with open(simplicity_path, 'r') as f:
|
92 |
+
simple_tag = json.load(f)
|
93 |
+
simple_panos = {}
|
94 |
+
for k in simple_tag.keys():
|
95 |
+
if not simple_tag[k]:
|
96 |
+
continue
|
97 |
+
split = k.split('_')
|
98 |
+
house_index = split[0]
|
99 |
+
pano_index = '_'.join(split[-2:])
|
100 |
+
simple_panos[f'{house_index}_{pano_index}'] = True
|
101 |
+
|
102 |
+
# move_not_simple_image(data_dir, simple_panos)
|
103 |
+
|
104 |
+
pano_list = []
|
105 |
+
with open(partition_path, 'r') as f1:
|
106 |
+
house_list = json.load(f1)[mode]
|
107 |
+
|
108 |
+
for house_index in house_list:
|
109 |
+
with open(os.path.join(data_dir, house_index, f"zind_data.json"), 'r') as f2:
|
110 |
+
data = json.load(f2)
|
111 |
+
|
112 |
+
panos = []
|
113 |
+
merger = data['merger']
|
114 |
+
for floor in merger.values():
|
115 |
+
for complete_room in floor.values():
|
116 |
+
for partial_room in complete_room.values():
|
117 |
+
for pano_index in partial_room:
|
118 |
+
pano = partial_room[pano_index]
|
119 |
+
pano['index'] = pano_index
|
120 |
+
panos.append(pano)
|
121 |
+
|
122 |
+
for pano in panos:
|
123 |
+
if layout_type not in pano:
|
124 |
+
continue
|
125 |
+
pano_index = pano['index']
|
126 |
+
|
127 |
+
if is_simple and f'{house_index}_{pano_index}' not in simple_panos.keys():
|
128 |
+
continue
|
129 |
+
|
130 |
+
if is_ceiling_flat and not pano['is_ceiling_flat']:
|
131 |
+
continue
|
132 |
+
|
133 |
+
layout = pano[layout_type]
|
134 |
+
# corners
|
135 |
+
corner_xz = np.array(layout['vertices'])
|
136 |
+
corner_xz[..., 0] = -corner_xz[..., 0]
|
137 |
+
corner_xyz = np.insert(corner_xz, 1, pano['camera_height'], axis=1)
|
138 |
+
corners = xyz2uv(corner_xyz).astype(np.float32)
|
139 |
+
|
140 |
+
# ratio
|
141 |
+
ratio = np.array([(pano['ceiling_height'] - pano['camera_height']) / pano['camera_height']], dtype=np.float32)
|
142 |
+
|
143 |
+
# Ours future work: detection window, door, opening
|
144 |
+
objects = {
|
145 |
+
'windows': [],
|
146 |
+
'doors': [],
|
147 |
+
'openings': [],
|
148 |
+
}
|
149 |
+
for label_index, wdo_type in enumerate(["windows", "doors", "openings"]):
|
150 |
+
if wdo_type not in layout:
|
151 |
+
continue
|
152 |
+
|
153 |
+
wdo_vertices = np.array(layout[wdo_type])
|
154 |
+
if len(wdo_vertices) == 0:
|
155 |
+
continue
|
156 |
+
|
157 |
+
assert len(wdo_vertices) % 3 == 0
|
158 |
+
|
159 |
+
for i in range(0, len(wdo_vertices), 3):
|
160 |
+
# In the Zind dataset, the camera height is 1, and the default camera height in our code is also 1,
|
161 |
+
# so the xyz coordinate here can be used directly
|
162 |
+
# Since we're taking the opposite z-axis, we're changing the order of left and right
|
163 |
+
|
164 |
+
left_bottom_xyz = np.array(
|
165 |
+
[-wdo_vertices[i + 1][0], -wdo_vertices[i + 2][0], wdo_vertices[i + 1][1]])
|
166 |
+
right_bottom_xyz = np.array(
|
167 |
+
[-wdo_vertices[i][0], -wdo_vertices[i + 2][0], wdo_vertices[i][1]])
|
168 |
+
center_bottom_xyz = (left_bottom_xyz + right_bottom_xyz) / 2
|
169 |
+
|
170 |
+
center_top_xyz = center_bottom_xyz.copy()
|
171 |
+
center_top_xyz[1] = -wdo_vertices[i + 2][1]
|
172 |
+
|
173 |
+
center_boundary_xyz = center_bottom_xyz.copy()
|
174 |
+
center_boundary_xyz[1] = plan_y
|
175 |
+
|
176 |
+
uv = xyz2uv(np.array([left_bottom_xyz, right_bottom_xyz,
|
177 |
+
center_bottom_xyz, center_top_xyz,
|
178 |
+
center_boundary_xyz]))
|
179 |
+
|
180 |
+
left_bottom_uv = uv[0]
|
181 |
+
right_bottom_uv = uv[1]
|
182 |
+
width_u = abs(right_bottom_uv[0] - left_bottom_uv[0])
|
183 |
+
width_u = 1 - width_u if width_u > 0.5 else width_u
|
184 |
+
assert width_u > 0, width_u
|
185 |
+
|
186 |
+
center_bottom_uv = uv[2]
|
187 |
+
center_top_uv = uv[3]
|
188 |
+
height_v = center_bottom_uv[1] - center_top_uv[1]
|
189 |
+
|
190 |
+
if height_v < 0:
|
191 |
+
continue
|
192 |
+
|
193 |
+
center_boundary_uv = uv[4]
|
194 |
+
boundary_v = center_boundary_uv[1] - center_bottom_uv[1] if wdo_type == 'windows' else 0
|
195 |
+
boundary_v = 0 if boundary_v < 0 else boundary_v
|
196 |
+
|
197 |
+
center_u = center_bottom_uv[0]
|
198 |
+
|
199 |
+
objects[wdo_type].append({
|
200 |
+
'width_u': width_u,
|
201 |
+
'height_v': height_v,
|
202 |
+
'boundary_v': boundary_v,
|
203 |
+
'center_u': center_u
|
204 |
+
})
|
205 |
+
|
206 |
+
pano_list.append({
|
207 |
+
'img_path': os.path.join(data_dir, house_index, pano['image_path']),
|
208 |
+
'corners': corners,
|
209 |
+
'objects': objects,
|
210 |
+
'ratio': ratio,
|
211 |
+
'id': f'{house_index}_{pano_index}',
|
212 |
+
'is_inside': pano['is_inside']
|
213 |
+
})
|
214 |
+
return pano_list
|
dataset/mp3d_dataset.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@date: 2021/6/25
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
|
8 |
+
from dataset.communal.read import read_image, read_label
|
9 |
+
from dataset.communal.base_dataset import BaseDataset
|
10 |
+
from utils.logger import get_logger
|
11 |
+
|
12 |
+
|
13 |
+
class MP3DDataset(BaseDataset):
|
14 |
+
def __init__(self, root_dir, mode, shape=None, max_wall_num=0, aug=None, camera_height=1.6, logger=None,
|
15 |
+
split_list=None, patch_num=256, keys=None, for_test_index=None):
|
16 |
+
super().__init__(mode, shape, max_wall_num, aug, camera_height, patch_num, keys)
|
17 |
+
|
18 |
+
if logger is None:
|
19 |
+
logger = get_logger()
|
20 |
+
self.root_dir = root_dir
|
21 |
+
|
22 |
+
split_dir = os.path.join(root_dir, 'split')
|
23 |
+
label_dir = os.path.join(root_dir, 'label')
|
24 |
+
img_dir = os.path.join(root_dir, 'image')
|
25 |
+
|
26 |
+
if split_list is None:
|
27 |
+
with open(os.path.join(split_dir, f"{mode}.txt"), 'r') as f:
|
28 |
+
split_list = [x.rstrip().split() for x in f]
|
29 |
+
|
30 |
+
split_list.sort()
|
31 |
+
if for_test_index is not None:
|
32 |
+
split_list = split_list[:for_test_index]
|
33 |
+
|
34 |
+
self.data = []
|
35 |
+
invalid_num = 0
|
36 |
+
for name in split_list:
|
37 |
+
name = "_".join(name)
|
38 |
+
img_path = os.path.join(img_dir, f"{name}.png")
|
39 |
+
label_path = os.path.join(label_dir, f"{name}.json")
|
40 |
+
|
41 |
+
if not os.path.exists(img_path):
|
42 |
+
logger.warning(f"{img_path} not exists")
|
43 |
+
invalid_num += 1
|
44 |
+
continue
|
45 |
+
if not os.path.exists(label_path):
|
46 |
+
logger.warning(f"{label_path} not exists")
|
47 |
+
invalid_num += 1
|
48 |
+
continue
|
49 |
+
|
50 |
+
with open(label_path, 'r') as f:
|
51 |
+
label = json.load(f)
|
52 |
+
|
53 |
+
if self.max_wall_num >= 10:
|
54 |
+
if label['layoutWalls']['num'] < self.max_wall_num:
|
55 |
+
invalid_num += 1
|
56 |
+
continue
|
57 |
+
elif self.max_wall_num != 0 and label['layoutWalls']['num'] != self.max_wall_num:
|
58 |
+
invalid_num += 1
|
59 |
+
continue
|
60 |
+
|
61 |
+
# print(label['layoutWalls']['num'])
|
62 |
+
self.data.append([img_path, label_path])
|
63 |
+
|
64 |
+
logger.info(
|
65 |
+
f"Build dataset mode: {self.mode} max_wall_num: {self.max_wall_num} valid: {len(self.data)} invalid: {invalid_num}")
|
66 |
+
|
67 |
+
def __getitem__(self, idx):
|
68 |
+
rgb_path, label_path = self.data[idx]
|
69 |
+
label = read_label(label_path, data_type='MP3D')
|
70 |
+
image = read_image(rgb_path, self.shape)
|
71 |
+
output = self.process_data(label, image, self.patch_num)
|
72 |
+
return output
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
import numpy as np
|
77 |
+
from PIL import Image
|
78 |
+
|
79 |
+
from tqdm import tqdm
|
80 |
+
from visualization.boundary import draw_boundaries
|
81 |
+
from visualization.floorplan import draw_floorplan
|
82 |
+
from utils.boundary import depth2boundaries
|
83 |
+
from utils.conversion import uv2xyz
|
84 |
+
|
85 |
+
modes = ['test', 'val']
|
86 |
+
for i in range(1):
|
87 |
+
for mode in modes:
|
88 |
+
print(mode)
|
89 |
+
mp3d_dataset = MP3DDataset(root_dir='../src/dataset/mp3d', mode=mode, aug={
|
90 |
+
'STRETCH': True,
|
91 |
+
'ROTATE': True,
|
92 |
+
'FLIP': True,
|
93 |
+
'GAMMA': True
|
94 |
+
})
|
95 |
+
save_dir = f'../src/dataset/mp3d/visualization/{mode}'
|
96 |
+
if not os.path.isdir(save_dir):
|
97 |
+
os.makedirs(save_dir)
|
98 |
+
|
99 |
+
bar = tqdm(mp3d_dataset, ncols=100)
|
100 |
+
for data in bar:
|
101 |
+
bar.set_description(f"Processing {data['id']}")
|
102 |
+
boundary_list = depth2boundaries(data['ratio'], data['depth'], step=None)
|
103 |
+
pano_img = draw_boundaries(data['image'].transpose(1, 2, 0), boundary_list=boundary_list, show=True)
|
104 |
+
Image.fromarray((pano_img * 255).astype(np.uint8)).save(
|
105 |
+
os.path.join(save_dir, f"{data['id']}_boundary.png"))
|
106 |
+
|
107 |
+
floorplan = draw_floorplan(uv2xyz(boundary_list[0])[..., ::2], show=True,
|
108 |
+
marker_color=None, center_color=0.8, show_radius=None)
|
109 |
+
Image.fromarray((floorplan.squeeze() * 255).astype(np.uint8)).save(
|
110 |
+
os.path.join(save_dir, f"{data['id']}_floorplan.png"))
|
dataset/pano_s2d3d_dataset.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@date: 2021/6/16
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from dataset.communal.read import read_image, read_label
|
10 |
+
from dataset.communal.base_dataset import BaseDataset
|
11 |
+
from utils.logger import get_logger
|
12 |
+
|
13 |
+
|
14 |
+
class PanoS2D3DDataset(BaseDataset):
|
15 |
+
def __init__(self, root_dir, mode, shape=None, max_wall_num=0, aug=None, camera_height=1.6, logger=None,
|
16 |
+
split_list=None, patch_num=256, keys=None, for_test_index=None, subset=None):
|
17 |
+
super().__init__(mode, shape, max_wall_num, aug, camera_height, patch_num, keys)
|
18 |
+
|
19 |
+
if logger is None:
|
20 |
+
logger = get_logger()
|
21 |
+
self.root_dir = root_dir
|
22 |
+
|
23 |
+
if mode is None:
|
24 |
+
return
|
25 |
+
label_dir = os.path.join(root_dir, 'valid' if mode == 'val' else mode, 'label_cor')
|
26 |
+
img_dir = os.path.join(root_dir, 'valid' if mode == 'val' else mode, 'img')
|
27 |
+
|
28 |
+
if split_list is None:
|
29 |
+
split_list = [name.split('.')[0] for name in os.listdir(label_dir) if
|
30 |
+
not name.startswith('.') and name.endswith('txt')]
|
31 |
+
|
32 |
+
split_list.sort()
|
33 |
+
|
34 |
+
assert subset == 'pano' or subset == 's2d3d' or subset is None, 'error subset'
|
35 |
+
if subset == 'pano':
|
36 |
+
split_list = [name for name in split_list if 'pano_' in name]
|
37 |
+
logger.info(f"Use PanoContext Dataset")
|
38 |
+
elif subset == 's2d3d':
|
39 |
+
split_list = [name for name in split_list if 'camera_' in name]
|
40 |
+
logger.info(f"Use Stanford2D3D Dataset")
|
41 |
+
|
42 |
+
if for_test_index is not None:
|
43 |
+
split_list = split_list[:for_test_index]
|
44 |
+
|
45 |
+
self.data = []
|
46 |
+
invalid_num = 0
|
47 |
+
for name in split_list:
|
48 |
+
img_path = os.path.join(img_dir, f"{name}.png")
|
49 |
+
label_path = os.path.join(label_dir, f"{name}.txt")
|
50 |
+
|
51 |
+
if not os.path.exists(img_path):
|
52 |
+
logger.warning(f"{img_path} not exists")
|
53 |
+
invalid_num += 1
|
54 |
+
continue
|
55 |
+
if not os.path.exists(label_path):
|
56 |
+
logger.warning(f"{label_path} not exists")
|
57 |
+
invalid_num += 1
|
58 |
+
continue
|
59 |
+
|
60 |
+
with open(label_path, 'r') as f:
|
61 |
+
lines = [line for line in f.readlines() if
|
62 |
+
len([c for c in line.split(' ') if c[0].isnumeric()]) > 1]
|
63 |
+
if len(lines) % 2 != 0:
|
64 |
+
invalid_num += 1
|
65 |
+
continue
|
66 |
+
self.data.append([img_path, label_path])
|
67 |
+
|
68 |
+
logger.info(
|
69 |
+
f"Build dataset mode: {self.mode} valid: {len(self.data)} invalid: {invalid_num}")
|
70 |
+
|
71 |
+
def __getitem__(self, idx):
|
72 |
+
rgb_path, label_path = self.data[idx]
|
73 |
+
label = read_label(label_path, data_type='Pano_S2D3D')
|
74 |
+
image = read_image(rgb_path, self.shape)
|
75 |
+
output = self.process_data(label, image, self.patch_num)
|
76 |
+
return output
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == '__main__':
|
80 |
+
|
81 |
+
modes = ['test', 'val', 'train']
|
82 |
+
for i in range(1):
|
83 |
+
for mode in modes:
|
84 |
+
print(mode)
|
85 |
+
mp3d_dataset = PanoS2D3DDataset(root_dir='../src/dataset/pano_s2d3d', mode=mode, aug={
|
86 |
+
# 'STRETCH': True,
|
87 |
+
# 'ROTATE': True,
|
88 |
+
# 'FLIP': True,
|
89 |
+
# 'GAMMA': True
|
90 |
+
})
|
91 |
+
continue
|
92 |
+
save_dir = f'../src/dataset/pano_s2d3d/visualization/{mode}'
|
93 |
+
if not os.path.isdir(save_dir):
|
94 |
+
os.makedirs(save_dir)
|
95 |
+
|
96 |
+
bar = tqdm(mp3d_dataset, ncols=100)
|
97 |
+
for data in bar:
|
98 |
+
bar.set_description(f"Processing {data['id']}")
|
99 |
+
boundary_list = depth2boundaries(data['ratio'], data['depth'], step=None)
|
100 |
+
pano_img = draw_boundaries(data['image'].transpose(1, 2, 0), boundary_list=boundary_list, show=False)
|
101 |
+
Image.fromarray((pano_img * 255).astype(np.uint8)).save(
|
102 |
+
os.path.join(save_dir, f"{data['id']}_boundary.png"))
|
103 |
+
|
104 |
+
floorplan = draw_floorplan(uv2xyz(boundary_list[0])[..., ::2], show=False,
|
105 |
+
marker_color=None, center_color=0.8, show_radius=None)
|
106 |
+
Image.fromarray((floorplan.squeeze() * 255).astype(np.uint8)).save(
|
107 |
+
os.path.join(save_dir, f"{data['id']}_floorplan.png"))
|
dataset/pano_s2d3d_mix_dataset.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@date: 2021/6/16
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
|
8 |
+
from dataset.pano_s2d3d_dataset import PanoS2D3DDataset
|
9 |
+
from utils.logger import get_logger
|
10 |
+
|
11 |
+
|
12 |
+
class PanoS2D3DMixDataset(PanoS2D3DDataset):
|
13 |
+
def __init__(self, root_dir, mode, shape=None, max_wall_num=0, aug=None, camera_height=1.6, logger=None,
|
14 |
+
split_list=None, patch_num=256, keys=None, for_test_index=None, subset=None):
|
15 |
+
assert subset == 's2d3d' or subset == 'pano', 'error subset'
|
16 |
+
super().__init__(root_dir, None, shape, max_wall_num, aug, camera_height, logger,
|
17 |
+
split_list, patch_num, keys, None, subset)
|
18 |
+
if logger is None:
|
19 |
+
logger = get_logger()
|
20 |
+
self.mode = mode
|
21 |
+
if mode == 'train':
|
22 |
+
if subset == 'pano':
|
23 |
+
s2d3d_train_data = PanoS2D3DDataset(root_dir, 'train', shape, max_wall_num, aug, camera_height, logger,
|
24 |
+
split_list, patch_num, keys, None, 's2d3d').data
|
25 |
+
s2d3d_val_data = PanoS2D3DDataset(root_dir, 'val', shape, max_wall_num, aug, camera_height, logger,
|
26 |
+
split_list, patch_num, keys, None, 's2d3d').data
|
27 |
+
s2d3d_test_data = PanoS2D3DDataset(root_dir, 'test', shape, max_wall_num, aug, camera_height, logger,
|
28 |
+
split_list, patch_num, keys, None, 's2d3d').data
|
29 |
+
s2d3d_all_data = s2d3d_train_data + s2d3d_val_data + s2d3d_test_data
|
30 |
+
|
31 |
+
pano_train_data = PanoS2D3DDataset(root_dir, 'train', shape, max_wall_num, aug, camera_height, logger,
|
32 |
+
split_list, patch_num, keys, None, 'pano').data
|
33 |
+
self.data = s2d3d_all_data + pano_train_data
|
34 |
+
elif subset == 's2d3d':
|
35 |
+
pano_train_data = PanoS2D3DDataset(root_dir, 'train', shape, max_wall_num, aug, camera_height, logger,
|
36 |
+
split_list, patch_num, keys, None, 'pano').data
|
37 |
+
pano_val_data = PanoS2D3DDataset(root_dir, 'val', shape, max_wall_num, aug, camera_height, logger,
|
38 |
+
split_list, patch_num, keys, None, 'pano').data
|
39 |
+
pano_test_data = PanoS2D3DDataset(root_dir, 'test', shape, max_wall_num, aug, camera_height, logger,
|
40 |
+
split_list, patch_num, keys, None, 'pano').data
|
41 |
+
pano_all_data = pano_train_data + pano_val_data + pano_test_data
|
42 |
+
|
43 |
+
s2d3d_train_data = PanoS2D3DDataset(root_dir, 'train', shape, max_wall_num, aug, camera_height, logger,
|
44 |
+
split_list, patch_num, keys, None, 's2d3d').data
|
45 |
+
self.data = pano_all_data + s2d3d_train_data
|
46 |
+
else:
|
47 |
+
self.data = PanoS2D3DDataset(root_dir, mode, shape, max_wall_num, aug, camera_height, logger,
|
48 |
+
split_list, patch_num, keys, None, subset).data
|
49 |
+
|
50 |
+
if for_test_index is not None:
|
51 |
+
self.data = self.data[:for_test_index]
|
52 |
+
logger.info(f"Build dataset mode: {self.mode} valid: {len(self.data)}")
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
import numpy as np
|
57 |
+
from PIL import Image
|
58 |
+
|
59 |
+
from tqdm import tqdm
|
60 |
+
from visualization.boundary import draw_boundaries
|
61 |
+
from visualization.floorplan import draw_floorplan
|
62 |
+
from utils.boundary import depth2boundaries
|
63 |
+
from utils.conversion import uv2xyz
|
64 |
+
|
65 |
+
modes = ['test', 'val', 'train']
|
66 |
+
for i in range(1):
|
67 |
+
for mode in modes:
|
68 |
+
print(mode)
|
69 |
+
mp3d_dataset = PanoS2D3DMixDataset(root_dir='../src/dataset/pano_s2d3d', mode=mode, aug={
|
70 |
+
# 'STRETCH': True,
|
71 |
+
# 'ROTATE': True,
|
72 |
+
# 'FLIP': True,
|
73 |
+
# 'GAMMA': True
|
74 |
+
}, subset='pano')
|
75 |
+
continue
|
76 |
+
save_dir = f'../src/dataset/pano_s2d3d/visualization1/{mode}'
|
77 |
+
if not os.path.isdir(save_dir):
|
78 |
+
os.makedirs(save_dir)
|
79 |
+
|
80 |
+
bar = tqdm(mp3d_dataset, ncols=100)
|
81 |
+
for data in bar:
|
82 |
+
bar.set_description(f"Processing {data['id']}")
|
83 |
+
boundary_list = depth2boundaries(data['ratio'], data['depth'], step=None)
|
84 |
+
pano_img = draw_boundaries(data['image'].transpose(1, 2, 0), boundary_list=boundary_list, show=False)
|
85 |
+
Image.fromarray((pano_img * 255).astype(np.uint8)).save(
|
86 |
+
os.path.join(save_dir, f"{data['id']}_boundary.png"))
|
87 |
+
|
88 |
+
floorplan = draw_floorplan(uv2xyz(boundary_list[0])[..., ::2], show=False,
|
89 |
+
marker_color=None, center_color=0.8, show_radius=None)
|
90 |
+
Image.fromarray((floorplan.squeeze() * 255).astype(np.uint8)).save(
|
91 |
+
os.path.join(save_dir, f"{data['id']}_floorplan.png"))
|
dataset/zind_dataset.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/09/22
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from dataset.communal.read import read_image, read_label, read_zind
|
11 |
+
from dataset.communal.base_dataset import BaseDataset
|
12 |
+
from utils.logger import get_logger
|
13 |
+
from preprocessing.filter import filter_center, filter_boundary, filter_self_intersection
|
14 |
+
from utils.boundary import calc_rotation
|
15 |
+
|
16 |
+
|
17 |
+
class ZindDataset(BaseDataset):
|
18 |
+
def __init__(self, root_dir, mode, shape=None, max_wall_num=0, aug=None, camera_height=1.6, logger=None,
|
19 |
+
split_list=None, patch_num=256, keys=None, for_test_index=None,
|
20 |
+
is_simple=True, is_ceiling_flat=False, vp_align=False):
|
21 |
+
# if keys is None:
|
22 |
+
# keys = ['image', 'depth', 'ratio', 'id', 'corners', 'corner_heat_map', 'object']
|
23 |
+
super().__init__(mode, shape, max_wall_num, aug, camera_height, patch_num, keys)
|
24 |
+
if logger is None:
|
25 |
+
logger = get_logger()
|
26 |
+
self.root_dir = root_dir
|
27 |
+
self.vp_align = vp_align
|
28 |
+
|
29 |
+
data_dir = os.path.join(root_dir)
|
30 |
+
img_dir = os.path.join(root_dir, 'image')
|
31 |
+
|
32 |
+
pano_list = read_zind(partition_path=os.path.join(data_dir, f"zind_partition.json"),
|
33 |
+
simplicity_path=os.path.join(data_dir, f"room_shape_simplicity_labels.json"),
|
34 |
+
data_dir=data_dir, mode=mode, is_simple=is_simple, is_ceiling_flat=is_ceiling_flat)
|
35 |
+
|
36 |
+
if for_test_index is not None:
|
37 |
+
pano_list = pano_list[:for_test_index]
|
38 |
+
if split_list:
|
39 |
+
pano_list = [pano for pano in pano_list if pano['id'] in split_list]
|
40 |
+
self.data = []
|
41 |
+
invalid_num = 0
|
42 |
+
for pano in pano_list:
|
43 |
+
if not os.path.exists(pano['img_path']):
|
44 |
+
logger.warning(f"{pano['img_path']} not exists")
|
45 |
+
invalid_num += 1
|
46 |
+
continue
|
47 |
+
|
48 |
+
if not filter_center(pano['corners']):
|
49 |
+
# logger.warning(f"{pano['id']} camera center not in layout")
|
50 |
+
# invalid_num += 1
|
51 |
+
continue
|
52 |
+
|
53 |
+
if self.max_wall_num >= 10:
|
54 |
+
if len(pano['corners']) < self.max_wall_num:
|
55 |
+
invalid_num += 1
|
56 |
+
continue
|
57 |
+
elif self.max_wall_num != 0 and len(pano['corners']) != self.max_wall_num:
|
58 |
+
invalid_num += 1
|
59 |
+
continue
|
60 |
+
|
61 |
+
if not filter_boundary(pano['corners']):
|
62 |
+
logger.warning(f"{pano['id']} boundary cross")
|
63 |
+
invalid_num += 1
|
64 |
+
continue
|
65 |
+
|
66 |
+
if not filter_self_intersection(pano['corners']):
|
67 |
+
logger.warning(f"{pano['id']} self_intersection")
|
68 |
+
invalid_num += 1
|
69 |
+
continue
|
70 |
+
|
71 |
+
self.data.append(pano)
|
72 |
+
|
73 |
+
logger.info(
|
74 |
+
f"Build dataset mode: {self.mode} max_wall_num: {self.max_wall_num} valid: {len(self.data)} invalid: {invalid_num}")
|
75 |
+
|
76 |
+
def __getitem__(self, idx):
|
77 |
+
pano = self.data[idx]
|
78 |
+
rgb_path = pano['img_path']
|
79 |
+
label = pano
|
80 |
+
image = read_image(rgb_path, self.shape)
|
81 |
+
|
82 |
+
if self.vp_align:
|
83 |
+
# Equivalent to vanishing point alignment step
|
84 |
+
rotation = calc_rotation(corners=label['corners'])
|
85 |
+
shift = math.modf(rotation / (2 * np.pi) + 1)[0]
|
86 |
+
image = np.roll(image, round(shift * self.shape[1]), axis=1)
|
87 |
+
label['corners'][:, 0] = np.modf(label['corners'][:, 0] + shift)[0]
|
88 |
+
|
89 |
+
output = self.process_data(label, image, self.patch_num)
|
90 |
+
return output
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
import numpy as np
|
95 |
+
from PIL import Image
|
96 |
+
|
97 |
+
from tqdm import tqdm
|
98 |
+
from visualization.boundary import draw_boundaries, draw_object
|
99 |
+
from visualization.floorplan import draw_floorplan
|
100 |
+
from utils.boundary import depth2boundaries, calc_rotation
|
101 |
+
from utils.conversion import uv2xyz
|
102 |
+
from models.other.init_env import init_env
|
103 |
+
|
104 |
+
init_env(123)
|
105 |
+
|
106 |
+
modes = ['val']
|
107 |
+
for i in range(1):
|
108 |
+
for mode in modes:
|
109 |
+
print(mode)
|
110 |
+
mp3d_dataset = ZindDataset(root_dir='../src/dataset/zind', mode=mode, aug={
|
111 |
+
'STRETCH': False,
|
112 |
+
'ROTATE': False,
|
113 |
+
'FLIP': False,
|
114 |
+
'GAMMA': False
|
115 |
+
})
|
116 |
+
# continue
|
117 |
+
# save_dir = f'../src/dataset/zind/visualization/{mode}'
|
118 |
+
# if not os.path.isdir(save_dir):
|
119 |
+
# os.makedirs(save_dir)
|
120 |
+
|
121 |
+
bar = tqdm(mp3d_dataset, ncols=100)
|
122 |
+
for data in bar:
|
123 |
+
# if data['id'] != '1079_pano_18':
|
124 |
+
# continue
|
125 |
+
bar.set_description(f"Processing {data['id']}")
|
126 |
+
boundary_list = depth2boundaries(data['ratio'], data['depth'], step=None)
|
127 |
+
|
128 |
+
pano_img = draw_boundaries(data['image'].transpose(1, 2, 0), boundary_list=boundary_list, show=True)
|
129 |
+
# Image.fromarray((pano_img * 255).astype(np.uint8)).save(
|
130 |
+
# os.path.join(save_dir, f"{data['id']}_boundary.png"))
|
131 |
+
# draw_object(pano_img, heat_maps=data['object_heat_map'], depth=data['depth'],
|
132 |
+
# size=data['object_size'], show=True)
|
133 |
+
# pass
|
134 |
+
#
|
135 |
+
floorplan = draw_floorplan(uv2xyz(boundary_list[0])[..., ::2], show=True,
|
136 |
+
marker_color=None, center_color=0.2)
|
137 |
+
# Image.fromarray((floorplan.squeeze() * 255).astype(np.uint8)).save(
|
138 |
+
# os.path.join(save_dir, f"{data['id']}_floorplan.png"))
|
evaluation/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@date: 2021/6/29
|
3 |
+
@description:
|
4 |
+
"""
|
evaluation/accuracy.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@date: 2021/8/4
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
import scipy
|
8 |
+
|
9 |
+
from evaluation.f1_score import f1_score_2d
|
10 |
+
from loss import GradLoss
|
11 |
+
from utils.boundary import corners2boundaries, layout2depth
|
12 |
+
from utils.conversion import depth2xyz, uv2xyz, get_u, depth2uv, xyz2uv, uv2pixel
|
13 |
+
from utils.height import calc_ceil_ratio
|
14 |
+
from evaluation.iou import calc_IoU, calc_Iou_height
|
15 |
+
from visualization.boundary import draw_boundaries
|
16 |
+
from visualization.floorplan import draw_iou_floorplan
|
17 |
+
from visualization.grad import show_grad
|
18 |
+
|
19 |
+
|
20 |
+
def calc_accuracy(dt, gt, visualization=False, h=512):
|
21 |
+
visb_iou_2ds = []
|
22 |
+
visb_iou_3ds = []
|
23 |
+
full_iou_2ds = []
|
24 |
+
full_iou_3ds = []
|
25 |
+
iou_heights = []
|
26 |
+
|
27 |
+
visb_iou_floodplans = []
|
28 |
+
full_iou_floodplans = []
|
29 |
+
pano_bds = []
|
30 |
+
|
31 |
+
if 'depth' not in dt.keys():
|
32 |
+
dt['depth'] = gt['depth']
|
33 |
+
|
34 |
+
for i in range(len(gt['depth'])):
|
35 |
+
# print(i)
|
36 |
+
dt_xyz = dt['processed_xyz'][i] if 'processed_xyz' in dt else depth2xyz(np.abs(dt['depth'][i]))
|
37 |
+
visb_gt_xyz = depth2xyz(np.abs(gt['depth'][i]))
|
38 |
+
corners = gt['corners'][i]
|
39 |
+
full_gt_corners = corners[corners[..., 0] + corners[..., 1] != 0] # Take effective corners
|
40 |
+
full_gt_xyz = uv2xyz(full_gt_corners)
|
41 |
+
|
42 |
+
dt_xz = dt_xyz[..., ::2]
|
43 |
+
visb_gt_xz = visb_gt_xyz[..., ::2]
|
44 |
+
full_gt_xz = full_gt_xyz[..., ::2]
|
45 |
+
|
46 |
+
gt_ratio = gt['ratio'][i][0]
|
47 |
+
|
48 |
+
if 'ratio' not in dt.keys():
|
49 |
+
if 'boundary' in dt.keys():
|
50 |
+
w = len(dt['boundary'][i])
|
51 |
+
boundary = np.clip(dt['boundary'][i], 0.0001, 0.4999)
|
52 |
+
depth = np.clip(dt['depth'][i], 0.001, 9999)
|
53 |
+
dt_ceil_boundary = np.concatenate([get_u(w, is_np=True)[..., None], boundary], axis=-1)
|
54 |
+
dt_floor_boundary = depth2uv(depth)
|
55 |
+
dt_ratio = calc_ceil_ratio(boundaries=[dt_ceil_boundary, dt_floor_boundary])
|
56 |
+
else:
|
57 |
+
dt_ratio = gt_ratio
|
58 |
+
else:
|
59 |
+
dt_ratio = dt['ratio'][i][0]
|
60 |
+
|
61 |
+
visb_iou_2d, visb_iou_3d = calc_IoU(dt_xz, visb_gt_xz, dt_height=1 + dt_ratio, gt_height=1 + gt_ratio)
|
62 |
+
full_iou_2d, full_iou_3d = calc_IoU(dt_xz, full_gt_xz, dt_height=1 + dt_ratio, gt_height=1 + gt_ratio)
|
63 |
+
iou_height = calc_Iou_height(dt_height=1 + dt_ratio, gt_height=1 + gt_ratio)
|
64 |
+
|
65 |
+
visb_iou_2ds.append(visb_iou_2d)
|
66 |
+
visb_iou_3ds.append(visb_iou_3d)
|
67 |
+
full_iou_2ds.append(full_iou_2d)
|
68 |
+
full_iou_3ds.append(full_iou_3d)
|
69 |
+
iou_heights.append(iou_height)
|
70 |
+
|
71 |
+
if visualization:
|
72 |
+
pano_img = cv2.resize(gt['image'][i].transpose(1, 2, 0), (h*2, h))
|
73 |
+
# visb_iou_floodplans.append(draw_iou_floorplan(dt_xz, visb_gt_xz, iou_2d=visb_iou_2d, iou_3d=visb_iou_3d, side_l=h))
|
74 |
+
# full_iou_floodplans.append(draw_iou_floorplan(dt_xz, full_gt_xz, iou_2d=full_iou_2d, iou_3d=full_iou_3d, side_l=h))
|
75 |
+
visb_iou_floodplans.append(draw_iou_floorplan(dt_xz, visb_gt_xz, side_l=h))
|
76 |
+
full_iou_floodplans.append(draw_iou_floorplan(dt_xz, full_gt_xz, side_l=h))
|
77 |
+
gt_boundaries = corners2boundaries(gt_ratio, corners_xyz=full_gt_xyz, step=None, length=1024, visible=False)
|
78 |
+
dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt_xyz, step=None, visible=False,
|
79 |
+
length=1024)#visb_gt_xyz.shape[0] if dt_xyz.shape[0] != visb_gt_xyz.shape[0] else None)
|
80 |
+
|
81 |
+
pano_bd = draw_boundaries(pano_img, boundary_list=gt_boundaries, boundary_color=[0, 0, 1])
|
82 |
+
pano_bd = draw_boundaries(pano_bd, boundary_list=dt_boundaries, boundary_color=[0, 1, 0])
|
83 |
+
pano_bds.append(pano_bd)
|
84 |
+
|
85 |
+
visb_iou_2d = np.array(visb_iou_2ds).mean()
|
86 |
+
visb_iou_3d = np.array(visb_iou_3ds).mean()
|
87 |
+
full_iou_2d = np.array(full_iou_2ds).mean()
|
88 |
+
full_iou_3d = np.array(full_iou_3ds).mean()
|
89 |
+
iou_height = np.array(iou_heights).mean()
|
90 |
+
|
91 |
+
if visualization:
|
92 |
+
visb_iou_floodplans = np.array(visb_iou_floodplans).transpose(0, 3, 1, 2) # NCHW
|
93 |
+
full_iou_floodplans = np.array(full_iou_floodplans).transpose(0, 3, 1, 2) # NCHW
|
94 |
+
pano_bds = np.array(pano_bds).transpose(0, 3, 1, 2)
|
95 |
+
return [visb_iou_2d, visb_iou_3d, visb_iou_floodplans],\
|
96 |
+
[full_iou_2d, full_iou_3d, full_iou_floodplans], iou_height, pano_bds, full_iou_2ds
|
97 |
+
|
98 |
+
|
99 |
+
def calc_ce(dt, gt):
|
100 |
+
w = 1024
|
101 |
+
h = 512
|
102 |
+
ce_s = []
|
103 |
+
for i in range(len(gt['corners'])):
|
104 |
+
floor_gt_corners = gt['corners'][i]
|
105 |
+
# Take effective corners
|
106 |
+
floor_gt_corners = floor_gt_corners[floor_gt_corners[..., 0] + floor_gt_corners[..., 1] != 0]
|
107 |
+
floor_gt_corners = np.roll(floor_gt_corners, -np.argmin(floor_gt_corners[..., 0]), 0)
|
108 |
+
gt_ratio = gt['ratio'][i][0]
|
109 |
+
ceil_gt_corners = corners2boundaries(gt_ratio, corners_uv=floor_gt_corners, step=None)[1]
|
110 |
+
gt_corners = np.concatenate((floor_gt_corners, ceil_gt_corners))
|
111 |
+
gt_corners = uv2pixel(gt_corners, w, h)
|
112 |
+
|
113 |
+
floor_dt_corners = xyz2uv(dt['processed_xyz'][i])
|
114 |
+
floor_dt_corners = np.roll(floor_dt_corners, -np.argmin(floor_dt_corners[..., 0]), 0)
|
115 |
+
dt_ratio = dt['ratio'][i][0]
|
116 |
+
ceil_dt_corners = corners2boundaries(dt_ratio, corners_uv=floor_dt_corners, step=None)[1]
|
117 |
+
dt_corners = np.concatenate((floor_dt_corners, ceil_dt_corners))
|
118 |
+
dt_corners = uv2pixel(dt_corners, w, h)
|
119 |
+
|
120 |
+
mse = np.sqrt(((gt_corners - dt_corners) ** 2).sum(1)).mean()
|
121 |
+
ce = 100 * mse / np.sqrt(w ** 2 + h ** 2)
|
122 |
+
ce_s.append(ce)
|
123 |
+
|
124 |
+
return np.array(ce_s).mean()
|
125 |
+
|
126 |
+
|
127 |
+
def calc_pe(dt, gt):
|
128 |
+
w = 1024
|
129 |
+
h = 512
|
130 |
+
pe_s = []
|
131 |
+
for i in range(len(gt['corners'])):
|
132 |
+
floor_gt_corners = gt['corners'][i]
|
133 |
+
# Take effective corners
|
134 |
+
floor_gt_corners = floor_gt_corners[floor_gt_corners[..., 0] + floor_gt_corners[..., 1] != 0]
|
135 |
+
floor_gt_corners = np.roll(floor_gt_corners, -np.argmin(floor_gt_corners[..., 0]), 0)
|
136 |
+
gt_ratio = gt['ratio'][i][0]
|
137 |
+
gt_floor_boundary, gt_ceil_boundary = corners2boundaries(gt_ratio, corners_uv=floor_gt_corners, length=w)
|
138 |
+
gt_floor_boundary = uv2pixel(gt_floor_boundary, w, h)
|
139 |
+
gt_ceil_boundary = uv2pixel(gt_ceil_boundary, w, h)
|
140 |
+
|
141 |
+
floor_dt_corners = xyz2uv(dt['processed_xyz'][i])
|
142 |
+
floor_dt_corners = np.roll(floor_dt_corners, -np.argmin(floor_dt_corners[..., 0]), 0)
|
143 |
+
dt_ratio = dt['ratio'][i][0]
|
144 |
+
dt_floor_boundary, dt_ceil_boundary = corners2boundaries(dt_ratio, corners_uv=floor_dt_corners, length=w)
|
145 |
+
dt_floor_boundary = uv2pixel(dt_floor_boundary, w, h)
|
146 |
+
dt_ceil_boundary = uv2pixel(dt_ceil_boundary, w, h)
|
147 |
+
|
148 |
+
gt_surface = np.zeros((h, w), dtype=np.int32)
|
149 |
+
gt_surface[gt_ceil_boundary[..., 1], np.arange(w)] = 1
|
150 |
+
gt_surface[gt_floor_boundary[..., 1], np.arange(w)] = 1
|
151 |
+
gt_surface = np.cumsum(gt_surface, axis=0)
|
152 |
+
|
153 |
+
dt_surface = np.zeros((h, w), dtype=np.int32)
|
154 |
+
dt_surface[dt_ceil_boundary[..., 1], np.arange(w)] = 1
|
155 |
+
dt_surface[dt_floor_boundary[..., 1], np.arange(w)] = 1
|
156 |
+
dt_surface = np.cumsum(dt_surface, axis=0)
|
157 |
+
|
158 |
+
pe = 100 * (dt_surface != gt_surface).sum() / (h * w)
|
159 |
+
pe_s.append(pe)
|
160 |
+
return np.array(pe_s).mean()
|
161 |
+
|
162 |
+
|
163 |
+
def calc_rmse_delta_1(dt, gt):
|
164 |
+
rmse_s = []
|
165 |
+
delta_1_s = []
|
166 |
+
for i in range(len(gt['depth'])):
|
167 |
+
gt_boundaries = corners2boundaries(gt['ratio'][i], corners_xyz=depth2xyz(gt['depth'][i]), step=None,
|
168 |
+
visible=False)
|
169 |
+
dt_xyz = dt['processed_xyz'][i] if 'processed_xyz' in dt else depth2xyz(np.abs(dt['depth'][i]))
|
170 |
+
|
171 |
+
dt_boundaries = corners2boundaries(dt['ratio'][i], corners_xyz=dt_xyz, step=None,
|
172 |
+
length=256 if 'processed_xyz' in dt else None,
|
173 |
+
visible=True if 'processed_xyz' in dt else False)
|
174 |
+
gt_layout_depth = layout2depth(gt_boundaries, show=False)
|
175 |
+
dt_layout_depth = layout2depth(dt_boundaries, show=False)
|
176 |
+
|
177 |
+
rmse = ((gt_layout_depth - dt_layout_depth) ** 2).mean() ** 0.5
|
178 |
+
threshold = np.maximum(gt_layout_depth / dt_layout_depth, dt_layout_depth / gt_layout_depth)
|
179 |
+
delta_1 = (threshold < 1.25).mean()
|
180 |
+
rmse_s.append(rmse)
|
181 |
+
delta_1_s.append(delta_1)
|
182 |
+
return np.array(rmse_s).mean(), np.array(delta_1_s).mean()
|
183 |
+
|
184 |
+
|
185 |
+
def calc_f1_score(dt, gt, threshold=10):
|
186 |
+
w = 1024
|
187 |
+
h = 512
|
188 |
+
f1_s = []
|
189 |
+
precision_s = []
|
190 |
+
recall_s = []
|
191 |
+
for i in range(len(gt['corners'])):
|
192 |
+
floor_gt_corners = gt['corners'][i]
|
193 |
+
# Take effective corners
|
194 |
+
floor_gt_corners = floor_gt_corners[floor_gt_corners[..., 0] + floor_gt_corners[..., 1] != 0]
|
195 |
+
floor_gt_corners = np.roll(floor_gt_corners, -np.argmin(floor_gt_corners[..., 0]), 0)
|
196 |
+
gt_ratio = gt['ratio'][i][0]
|
197 |
+
ceil_gt_corners = corners2boundaries(gt_ratio, corners_uv=floor_gt_corners, step=None)[1]
|
198 |
+
gt_corners = np.concatenate((floor_gt_corners, ceil_gt_corners))
|
199 |
+
gt_corners = uv2pixel(gt_corners, w, h)
|
200 |
+
|
201 |
+
floor_dt_corners = xyz2uv(dt['processed_xyz'][i])
|
202 |
+
floor_dt_corners = np.roll(floor_dt_corners, -np.argmin(floor_dt_corners[..., 0]), 0)
|
203 |
+
dt_ratio = dt['ratio'][i][0]
|
204 |
+
ceil_dt_corners = corners2boundaries(dt_ratio, corners_uv=floor_dt_corners, step=None)[1]
|
205 |
+
dt_corners = np.concatenate((floor_dt_corners, ceil_dt_corners))
|
206 |
+
dt_corners = uv2pixel(dt_corners, w, h)
|
207 |
+
|
208 |
+
Fs, Ps, Rs = f1_score_2d(gt_corners, dt_corners, [threshold])
|
209 |
+
f1_s.append(Fs[0])
|
210 |
+
precision_s.append(Ps[0])
|
211 |
+
recall_s.append(Rs[0])
|
212 |
+
|
213 |
+
return np.array(f1_s).mean(), np.array(precision_s).mean(), np.array(recall_s).mean()
|
214 |
+
|
215 |
+
|
216 |
+
def show_heat_map(dt, gt, vis_w=1024):
|
217 |
+
dt_heat_map = dt['corner_heat_map'].detach().cpu().numpy()
|
218 |
+
gt_heat_map = gt['corner_heat_map'].detach().cpu().numpy()
|
219 |
+
dt_heat_map_imgs = []
|
220 |
+
gt_heat_map_imgs = []
|
221 |
+
for i in range(len(gt['depth'])):
|
222 |
+
dt_heat_map_img = dt_heat_map[..., np.newaxis].repeat(3, axis=-1).repeat(20, axis=0)
|
223 |
+
gt_heat_map_img = gt_heat_map[..., np.newaxis].repeat(3, axis=-1).repeat(20, axis=0)
|
224 |
+
dt_heat_map_imgs.append(cv2.resize(dt_heat_map_img, (vis_w, dt_heat_map_img.shape[0])).transpose(2, 0, 1))
|
225 |
+
gt_heat_map_imgs.append(cv2.resize(gt_heat_map_img, (vis_w, dt_heat_map_img.shape[0])).transpose(2, 0, 1))
|
226 |
+
return dt_heat_map_imgs, gt_heat_map_imgs
|
227 |
+
|
228 |
+
|
229 |
+
def show_depth_normal_grad(dt, gt, device, vis_w=1024):
|
230 |
+
grad_conv = GradLoss().to(device).grad_conv
|
231 |
+
gt_grad_imgs = []
|
232 |
+
dt_grad_imgs = []
|
233 |
+
|
234 |
+
if 'depth' not in dt.keys():
|
235 |
+
dt['depth'] = gt['depth']
|
236 |
+
|
237 |
+
if vis_w == 1024:
|
238 |
+
h = 5
|
239 |
+
else:
|
240 |
+
h = int(vis_w / (12 * 10))
|
241 |
+
|
242 |
+
for i in range(len(gt['depth'])):
|
243 |
+
gt_grad_img = show_grad(gt['depth'][i], grad_conv, h)
|
244 |
+
dt_grad_img = show_grad(dt['depth'][i], grad_conv, h)
|
245 |
+
vis_h = dt_grad_img.shape[0] * (vis_w // dt_grad_img.shape[1])
|
246 |
+
gt_grad_imgs.append(cv2.resize(gt_grad_img, (vis_w, vis_h), interpolation=cv2.INTER_NEAREST).transpose(2, 0, 1))
|
247 |
+
dt_grad_imgs.append(cv2.resize(dt_grad_img, (vis_w, vis_h), interpolation=cv2.INTER_NEAREST).transpose(2, 0, 1))
|
248 |
+
|
249 |
+
return gt_grad_imgs, dt_grad_imgs
|
evaluation/analyse_layout_type.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2022/01/31
|
3 |
+
@description:
|
4 |
+
ZInd:
|
5 |
+
{'test': {'mw': 2789, 'aw': 381}, 'train': {'mw': 21228, 'aw': 3654}, 'val': {'mw': 2647, 'aw': 433}}
|
6 |
+
|
7 |
+
"""
|
8 |
+
import numpy as np
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import json
|
11 |
+
|
12 |
+
from tqdm import tqdm
|
13 |
+
from evaluation.iou import calc_IoU_2D
|
14 |
+
from visualization.floorplan import draw_floorplan
|
15 |
+
from visualization.boundary import draw_boundaries
|
16 |
+
from utils.conversion import depth2xyz, uv2xyz
|
17 |
+
|
18 |
+
|
19 |
+
def analyse_layout_type(dataset, show=False):
|
20 |
+
bar = tqdm(dataset, total=len(dataset), ncols=100)
|
21 |
+
manhattan = 0
|
22 |
+
atlanta = 0
|
23 |
+
corner_type = {}
|
24 |
+
for data in bar:
|
25 |
+
bar.set_description(f"Processing {data['id']}")
|
26 |
+
corners = data['corners']
|
27 |
+
corners = corners[corners[..., 0] + corners[..., 1] != 0] # Take effective corners
|
28 |
+
corners_count = str(len(corners)) if len(corners) < 10 else "10"
|
29 |
+
if corners_count not in corner_type:
|
30 |
+
corner_type[corners_count] = 0
|
31 |
+
corner_type[corners_count] += 1
|
32 |
+
|
33 |
+
all_xz = uv2xyz(corners)[..., ::2]
|
34 |
+
|
35 |
+
c = len(all_xz)
|
36 |
+
flag = False
|
37 |
+
for i in range(c - 1):
|
38 |
+
l1 = all_xz[i + 1] - all_xz[i]
|
39 |
+
l2 = all_xz[(i + 2) % c] - all_xz[i + 1]
|
40 |
+
a = (np.linalg.norm(l1)*np.linalg.norm(l2))
|
41 |
+
if a == 0:
|
42 |
+
continue
|
43 |
+
dot = np.dot(l1, l2)/a
|
44 |
+
if 0.9 > abs(dot) > 0.1:
|
45 |
+
# cos-1(0.1)=84.26 > angle > cos-1(0.9)=25.84 or
|
46 |
+
# cos-1(-0.9)=154.16 > angle > cos-1(-0.1)=95.74
|
47 |
+
flag = True
|
48 |
+
break
|
49 |
+
if flag:
|
50 |
+
atlanta += 1
|
51 |
+
else:
|
52 |
+
manhattan += 1
|
53 |
+
|
54 |
+
if flag and show:
|
55 |
+
draw_floorplan(all_xz, show=True)
|
56 |
+
draw_boundaries(data['image'].transpose(1, 2, 0), [corners], ratio=data['ratio'], show=True)
|
57 |
+
|
58 |
+
corner_type = dict(sorted(corner_type.items(), key=lambda item: int(item[0])))
|
59 |
+
return {'manhattan': manhattan, "atlanta": atlanta, "corner_type": corner_type}
|
60 |
+
|
61 |
+
|
62 |
+
def execute_analyse_layout_type(root_dir, dataset, modes=None):
|
63 |
+
if modes is None:
|
64 |
+
modes = ["train", "val", "test"]
|
65 |
+
|
66 |
+
iou2d_d = {}
|
67 |
+
for mode in modes:
|
68 |
+
print("mode: {}".format(mode))
|
69 |
+
types = analyse_layout_type(dataset(root_dir, mode), show=False)
|
70 |
+
iou2d_d[mode] = types
|
71 |
+
print(json.dumps(types, indent=4))
|
72 |
+
return iou2d_d
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == '__main__':
|
76 |
+
from dataset.zind_dataset import ZindDataset
|
77 |
+
from dataset.mp3d_dataset import MP3DDataset
|
78 |
+
|
79 |
+
iou2d_d = execute_analyse_layout_type(root_dir='../src/dataset/mp3d',
|
80 |
+
dataset=MP3DDataset)
|
81 |
+
# iou2d_d = execute_analyse_layout_type(root_dir='../src/dataset/zind',
|
82 |
+
# dataset=ZindDataset)
|
83 |
+
print(json.dumps(iou2d_d, indent=4))
|
evaluation/eval_visible_iou.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/08/02
|
3 |
+
@description:
|
4 |
+
The 2DIoU for calculating the visible and full boundaries, such as the MP3D dataset,
|
5 |
+
has the following data: {'train': 0.9775843958583535, 'test': 0.9828616219607289, 'val': 0.9883810438132491},
|
6 |
+
indicating that our best performance is limited to below 98.29% 2DIoU using our approach.
|
7 |
+
"""
|
8 |
+
import numpy as np
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
from evaluation.iou import calc_IoU_2D
|
13 |
+
from visualization.floorplan import draw_iou_floorplan
|
14 |
+
from utils.conversion import depth2xyz, uv2xyz
|
15 |
+
|
16 |
+
|
17 |
+
def eval_dataset_visible_IoU(dataset, show=False):
|
18 |
+
bar = tqdm(dataset, total=len(dataset), ncols=100)
|
19 |
+
iou2ds = []
|
20 |
+
for data in bar:
|
21 |
+
bar.set_description(f"Processing {data['id']}")
|
22 |
+
corners = data['corners']
|
23 |
+
corners = corners[corners[..., 0] + corners[..., 1] != 0] # Take effective corners
|
24 |
+
all_xz = uv2xyz(corners)[..., ::2]
|
25 |
+
visible_xz = depth2xyz(data['depth'])[..., ::2]
|
26 |
+
iou2d = calc_IoU_2D(all_xz, visible_xz)
|
27 |
+
iou2ds.append(iou2d)
|
28 |
+
if show:
|
29 |
+
layout_floorplan = draw_iou_floorplan(all_xz, visible_xz, iou2d=iou2d)
|
30 |
+
plt.imshow(layout_floorplan)
|
31 |
+
plt.show()
|
32 |
+
|
33 |
+
mean_iou2d = np.array(iou2ds).mean()
|
34 |
+
return mean_iou2d
|
35 |
+
|
36 |
+
|
37 |
+
def execute_eval_dataset_visible_IoU(root_dir, dataset, modes=None):
|
38 |
+
if modes is None:
|
39 |
+
modes = ["train", "test", "valid"]
|
40 |
+
|
41 |
+
iou2d_d = {}
|
42 |
+
for mode in modes:
|
43 |
+
print("mode: {}".format(mode))
|
44 |
+
iou2d = eval_dataset_visible_IoU(dataset(root_dir, mode, patch_num=1024,
|
45 |
+
keys=['depth', 'visible_corners', 'corners', 'id']), show=False)
|
46 |
+
iou2d_d[mode] = iou2d
|
47 |
+
return iou2d_d
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == '__main__':
|
51 |
+
from dataset.mp3d_dataset import MP3DDataset
|
52 |
+
|
53 |
+
iou2d_d = execute_eval_dataset_visible_IoU(root_dir='../src/dataset/mp3d',
|
54 |
+
dataset=MP3DDataset,
|
55 |
+
modes=['train', 'test', 'val'])
|
56 |
+
print(iou2d_d)
|
evaluation/f1_score.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@author: Zhigang Jiang
|
3 |
+
@time: 2022/01/28
|
4 |
+
@description:
|
5 |
+
Holistic 3D Vision Challenge on General Room Layout Estimation Track Evaluation Package
|
6 |
+
Reference: https://github.com/bertjiazheng/indoor-layout-evaluation
|
7 |
+
"""
|
8 |
+
|
9 |
+
from scipy.optimize import linear_sum_assignment
|
10 |
+
import numpy as np
|
11 |
+
import scipy
|
12 |
+
|
13 |
+
HEIGHT, WIDTH = 512, 1024
|
14 |
+
MAX_DISTANCE = np.sqrt(HEIGHT**2 + WIDTH**2)
|
15 |
+
|
16 |
+
|
17 |
+
def f1_score_2d(gt_corners, dt_corners, thresholds):
|
18 |
+
distances = scipy.spatial.distance.cdist(gt_corners, dt_corners)
|
19 |
+
return eval_junctions(distances, thresholds=thresholds)
|
20 |
+
|
21 |
+
|
22 |
+
def eval_junctions(distances, thresholds=5):
|
23 |
+
thresholds = thresholds if isinstance(thresholds, tuple) or isinstance(
|
24 |
+
thresholds, list) else list([thresholds])
|
25 |
+
|
26 |
+
num_gts, num_preds = distances.shape
|
27 |
+
|
28 |
+
# filter the matches between ceiling-wall and floor-wall junctions
|
29 |
+
mask = np.zeros_like(distances, dtype=np.bool)
|
30 |
+
mask[:num_gts//2, :num_preds//2] = True
|
31 |
+
mask[num_gts//2:, num_preds//2:] = True
|
32 |
+
distances[~mask] = np.inf
|
33 |
+
|
34 |
+
# F-measure under different thresholds
|
35 |
+
Fs = []
|
36 |
+
Ps = []
|
37 |
+
Rs = []
|
38 |
+
for threshold in thresholds:
|
39 |
+
distances_temp = distances.copy()
|
40 |
+
|
41 |
+
# filter the mis-matched pairs
|
42 |
+
distances_temp[distances_temp > threshold] = np.inf
|
43 |
+
|
44 |
+
# remain the rows and columns that contain non-inf elements
|
45 |
+
distances_temp = distances_temp[:, np.any(np.isfinite(distances_temp), axis=0)]
|
46 |
+
|
47 |
+
if np.prod(distances_temp.shape) == 0:
|
48 |
+
Fs.append(0)
|
49 |
+
Ps.append(0)
|
50 |
+
Rs.append(0)
|
51 |
+
continue
|
52 |
+
|
53 |
+
distances_temp = distances_temp[np.any(np.isfinite(distances_temp), axis=1), :]
|
54 |
+
|
55 |
+
# solve the bipartite graph matching problem
|
56 |
+
row_ind, col_ind = linear_sum_assignment_with_inf(distances_temp)
|
57 |
+
true_positive = np.sum(np.isfinite(distances_temp[row_ind, col_ind]))
|
58 |
+
|
59 |
+
# compute precision and recall
|
60 |
+
precision = true_positive / num_preds
|
61 |
+
recall = true_positive / num_gts
|
62 |
+
|
63 |
+
# compute F measure
|
64 |
+
Fs.append(2 * precision * recall / (precision + recall))
|
65 |
+
Ps.append(precision)
|
66 |
+
Rs.append(recall)
|
67 |
+
|
68 |
+
return Fs, Ps, Rs
|
69 |
+
|
70 |
+
|
71 |
+
def linear_sum_assignment_with_inf(cost_matrix):
|
72 |
+
"""
|
73 |
+
Deal with linear_sum_assignment with inf according to
|
74 |
+
https://github.com/scipy/scipy/issues/6900#issuecomment-451735634
|
75 |
+
"""
|
76 |
+
cost_matrix = np.copy(cost_matrix)
|
77 |
+
cost_matrix[np.isinf(cost_matrix)] = MAX_DISTANCE
|
78 |
+
return linear_sum_assignment(cost_matrix)
|
evaluation/iou.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@date: 2021/6/29
|
3 |
+
@description:
|
4 |
+
The method with "_floorplan" suffix is only for comparison, which is used for calculation in LED2-net.
|
5 |
+
However, the floorplan is affected by show_radius. Setting too large will result in the decrease of accuracy,
|
6 |
+
and setting too small will result in the failure of calculation beyond the range.
|
7 |
+
"""
|
8 |
+
import numpy as np
|
9 |
+
from shapely.geometry import Polygon
|
10 |
+
|
11 |
+
|
12 |
+
def calc_inter_area(dt_xz, gt_xz):
|
13 |
+
"""
|
14 |
+
:param dt_xz: Prediction boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
|
15 |
+
:param gt_xz: Ground truth boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
|
16 |
+
:return:
|
17 |
+
"""
|
18 |
+
dt_polygon = Polygon(dt_xz)
|
19 |
+
gt_polygon = Polygon(gt_xz)
|
20 |
+
|
21 |
+
dt_area = dt_polygon.area
|
22 |
+
gt_area = gt_polygon.area
|
23 |
+
inter_area = dt_polygon.intersection(gt_polygon).area
|
24 |
+
return dt_area, gt_area, inter_area
|
25 |
+
|
26 |
+
|
27 |
+
def calc_IoU_2D(dt_xz, gt_xz):
|
28 |
+
"""
|
29 |
+
:param dt_xz: Prediction boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
|
30 |
+
:param gt_xz: Ground truth boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
|
31 |
+
:return:
|
32 |
+
"""
|
33 |
+
dt_area, gt_area, inter_area = calc_inter_area(dt_xz, gt_xz)
|
34 |
+
iou_2d = inter_area / (gt_area + dt_area - inter_area)
|
35 |
+
return iou_2d
|
36 |
+
|
37 |
+
|
38 |
+
def calc_IoU_3D(dt_xz, gt_xz, dt_height, gt_height):
|
39 |
+
"""
|
40 |
+
:param dt_xz: Prediction boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
|
41 |
+
:param gt_xz: Ground truth boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
|
42 |
+
:param dt_height:
|
43 |
+
:param gt_height:
|
44 |
+
:return:
|
45 |
+
"""
|
46 |
+
dt_area, gt_area, inter_area = calc_inter_area(dt_xz, gt_xz)
|
47 |
+
dt_volume = dt_area * dt_height
|
48 |
+
gt_volume = gt_area * gt_height
|
49 |
+
inter_volume = inter_area * min(dt_height, gt_height)
|
50 |
+
iou_3d = inter_volume / (dt_volume + gt_volume - inter_volume)
|
51 |
+
return iou_3d
|
52 |
+
|
53 |
+
|
54 |
+
def calc_IoU(dt_xz, gt_xz, dt_height, gt_height):
|
55 |
+
"""
|
56 |
+
:param dt_xz: Prediction boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
|
57 |
+
:param gt_xz: Ground truth boundaries can also be corners, format: [[x1, z1], [x2, z2], ...]
|
58 |
+
:param dt_height:
|
59 |
+
:param gt_height:
|
60 |
+
:return:
|
61 |
+
"""
|
62 |
+
dt_area, gt_area, inter_area = calc_inter_area(dt_xz, gt_xz)
|
63 |
+
iou_2d = inter_area / (gt_area + dt_area - inter_area)
|
64 |
+
|
65 |
+
dt_volume = dt_area * dt_height
|
66 |
+
gt_volume = gt_area * gt_height
|
67 |
+
inter_volume = inter_area * min(dt_height, gt_height)
|
68 |
+
iou_3d = inter_volume / (dt_volume + gt_volume - inter_volume)
|
69 |
+
|
70 |
+
return iou_2d, iou_3d
|
71 |
+
|
72 |
+
|
73 |
+
def calc_Iou_height(dt_height, gt_height):
|
74 |
+
return min(dt_height, gt_height) / max(dt_height, gt_height)
|
75 |
+
|
76 |
+
|
77 |
+
# the following is for testing only
|
78 |
+
def calc_inter_area_floorplan(dt_floorplan, gt_floorplan):
|
79 |
+
intersect = np.sum(np.logical_and(dt_floorplan, gt_floorplan))
|
80 |
+
dt_area = np.sum(dt_floorplan)
|
81 |
+
gt_area = np.sum(gt_floorplan)
|
82 |
+
return dt_area, gt_area, intersect
|
83 |
+
|
84 |
+
|
85 |
+
def calc_IoU_2D_floorplan(dt_floorplan, gt_floorplan):
|
86 |
+
dt_area, gt_area, inter_area = calc_inter_area_floorplan(dt_floorplan, gt_floorplan)
|
87 |
+
iou_2d = inter_area / (gt_area + dt_area - inter_area)
|
88 |
+
return iou_2d
|
89 |
+
|
90 |
+
|
91 |
+
def calc_IoU_3D_floorplan(dt_floorplan, gt_floorplan, dt_height, gt_height):
|
92 |
+
dt_area, gt_area, inter_area = calc_inter_area_floorplan(dt_floorplan, gt_floorplan)
|
93 |
+
dt_volume = dt_area * dt_height
|
94 |
+
gt_volume = gt_area * gt_height
|
95 |
+
inter_volume = inter_area * min(dt_height, gt_height)
|
96 |
+
iou_3d = inter_volume / (dt_volume + gt_volume - inter_volume)
|
97 |
+
return iou_3d
|
98 |
+
|
99 |
+
|
100 |
+
def calc_IoU_floorplan(dt_floorplan, gt_floorplan, dt_height, gt_height):
|
101 |
+
dt_area, gt_area, inter_area = calc_inter_area_floorplan(dt_floorplan, gt_floorplan)
|
102 |
+
iou_2d = inter_area / (gt_area + dt_area - inter_area)
|
103 |
+
|
104 |
+
dt_volume = dt_area * dt_height
|
105 |
+
gt_volume = gt_area * gt_height
|
106 |
+
inter_volume = inter_area * min(dt_height, gt_height)
|
107 |
+
iou_3d = inter_volume / (dt_volume + gt_volume - inter_volume)
|
108 |
+
return iou_2d, iou_3d
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == '__main__':
|
112 |
+
from visualization.floorplan import draw_floorplan, draw_iou_floorplan
|
113 |
+
from visualization.boundary import draw_boundaries, corners2boundaries
|
114 |
+
from utils.conversion import uv2xyz
|
115 |
+
from utils.height import height2ratio
|
116 |
+
|
117 |
+
# dummy data
|
118 |
+
dt_floor_corners = np.array([[0.2, 0.7],
|
119 |
+
[0.4, 0.7],
|
120 |
+
[0.6, 0.7],
|
121 |
+
[0.8, 0.7]])
|
122 |
+
dt_height = 2.8
|
123 |
+
|
124 |
+
gt_floor_corners = np.array([[0.3, 0.7],
|
125 |
+
[0.5, 0.7],
|
126 |
+
[0.7, 0.7],
|
127 |
+
[0.9, 0.7]])
|
128 |
+
gt_height = 3.2
|
129 |
+
|
130 |
+
dt_xz = uv2xyz(dt_floor_corners)[..., ::2]
|
131 |
+
gt_xz = uv2xyz(gt_floor_corners)[..., ::2]
|
132 |
+
|
133 |
+
dt_floorplan = draw_floorplan(dt_xz, show=False, show_radius=1)
|
134 |
+
gt_floorplan = draw_floorplan(gt_xz, show=False, show_radius=1)
|
135 |
+
# dt_floorplan = draw_floorplan(dt_xz, show=False, show_radius=2)
|
136 |
+
# gt_floorplan = draw_floorplan(gt_xz, show=False, show_radius=2)
|
137 |
+
|
138 |
+
iou_2d, iou_3d = calc_IoU_floorplan(dt_floorplan, gt_floorplan, dt_height, gt_height)
|
139 |
+
print('use floor plan image:', iou_2d, iou_3d)
|
140 |
+
|
141 |
+
iou_2d, iou_3d = calc_IoU(dt_xz, gt_xz, dt_height, gt_height)
|
142 |
+
print('use floor plan polygon:', iou_2d, iou_3d)
|
143 |
+
|
144 |
+
draw_iou_floorplan(dt_xz, gt_xz, show=True, iou_2d=iou_2d, iou_3d=iou_3d)
|
145 |
+
pano_bd = draw_boundaries(np.zeros([512, 1024, 3]), corners_list=[dt_floor_corners],
|
146 |
+
boundary_color=[0, 0, 1], ratio=height2ratio(dt_height), draw_corners=False)
|
147 |
+
pano_bd = draw_boundaries(pano_bd, corners_list=[gt_floor_corners],
|
148 |
+
boundary_color=[0, 1, 0], ratio=height2ratio(gt_height), show=True, draw_corners=False)
|
inference.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/09/19
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import argparse
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import glob
|
13 |
+
|
14 |
+
from tqdm import tqdm
|
15 |
+
from PIL import Image
|
16 |
+
from config.defaults import merge_from_file, get_config
|
17 |
+
from dataset.mp3d_dataset import MP3DDataset
|
18 |
+
from dataset.zind_dataset import ZindDataset
|
19 |
+
from models.build import build_model
|
20 |
+
from loss import GradLoss
|
21 |
+
from postprocessing.post_process import post_process
|
22 |
+
from preprocessing.pano_lsd_align import panoEdgeDetection, rotatePanorama
|
23 |
+
from utils.boundary import corners2boundaries, layout2depth
|
24 |
+
from utils.conversion import depth2xyz
|
25 |
+
from utils.logger import get_logger
|
26 |
+
from utils.misc import tensor2np_d, tensor2np
|
27 |
+
from evaluation.accuracy import show_grad
|
28 |
+
from models.lgt_net import LGT_Net
|
29 |
+
from utils.writer import xyz2json
|
30 |
+
from visualization.boundary import draw_boundaries
|
31 |
+
from visualization.floorplan import draw_floorplan, draw_iou_floorplan
|
32 |
+
from visualization.obj3d import create_3d_obj
|
33 |
+
|
34 |
+
|
35 |
+
def parse_option():
|
36 |
+
parser = argparse.ArgumentParser(description='Panorama Layout Transformer training and evaluation script')
|
37 |
+
parser.add_argument('--img_glob',
|
38 |
+
type=str,
|
39 |
+
required=True,
|
40 |
+
help='image glob path')
|
41 |
+
|
42 |
+
parser.add_argument('--cfg',
|
43 |
+
type=str,
|
44 |
+
required=True,
|
45 |
+
metavar='FILE',
|
46 |
+
help='path of config file')
|
47 |
+
|
48 |
+
parser.add_argument('--post_processing',
|
49 |
+
type=str,
|
50 |
+
default='manhattan',
|
51 |
+
choices=['manhattan', 'atalanta', 'original'],
|
52 |
+
help='post-processing type')
|
53 |
+
|
54 |
+
parser.add_argument('--output_dir',
|
55 |
+
type=str,
|
56 |
+
default='src/output',
|
57 |
+
help='path of output')
|
58 |
+
|
59 |
+
parser.add_argument('--visualize_3d', action='store_true',
|
60 |
+
help='visualize_3d')
|
61 |
+
|
62 |
+
parser.add_argument('--output_3d', action='store_true',
|
63 |
+
help='output_3d')
|
64 |
+
|
65 |
+
parser.add_argument('--device',
|
66 |
+
type=str,
|
67 |
+
default='cuda',
|
68 |
+
help='device')
|
69 |
+
|
70 |
+
args = parser.parse_args()
|
71 |
+
args.mode = 'test'
|
72 |
+
|
73 |
+
print("arguments:")
|
74 |
+
for arg in vars(args):
|
75 |
+
print(arg, ":", getattr(args, arg))
|
76 |
+
print("-" * 50)
|
77 |
+
return args
|
78 |
+
|
79 |
+
|
80 |
+
def visualize_2d(img, dt, show_depth=True, show_floorplan=True, show=False, save_path=None):
|
81 |
+
dt_np = tensor2np_d(dt)
|
82 |
+
dt_depth = dt_np['depth'][0]
|
83 |
+
dt_xyz = depth2xyz(np.abs(dt_depth))
|
84 |
+
dt_ratio = dt_np['ratio'][0][0]
|
85 |
+
dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt_xyz, step=None, visible=False, length=img.shape[1])
|
86 |
+
vis_img = draw_boundaries(img, boundary_list=dt_boundaries, boundary_color=[0, 1, 0])
|
87 |
+
|
88 |
+
if 'processed_xyz' in dt:
|
89 |
+
dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt['processed_xyz'][0], step=None, visible=False,
|
90 |
+
length=img.shape[1])
|
91 |
+
vis_img = draw_boundaries(vis_img, boundary_list=dt_boundaries, boundary_color=[1, 0, 0])
|
92 |
+
|
93 |
+
if show_depth:
|
94 |
+
dt_grad_img = show_depth_normal_grad(dt)
|
95 |
+
grad_h = dt_grad_img.shape[0]
|
96 |
+
vis_merge = [
|
97 |
+
vis_img[0:-grad_h, :, :],
|
98 |
+
dt_grad_img,
|
99 |
+
]
|
100 |
+
vis_img = np.concatenate(vis_merge, axis=0)
|
101 |
+
# vis_img = dt_grad_img.transpose(1, 2, 0)[100:]
|
102 |
+
|
103 |
+
if show_floorplan:
|
104 |
+
if 'processed_xyz' in dt:
|
105 |
+
floorplan = draw_iou_floorplan(dt['processed_xyz'][0][..., ::2], dt_xyz[..., ::2],
|
106 |
+
dt_board_color=[1, 0, 0, 1], gt_board_color=[0, 1, 0, 1])
|
107 |
+
else:
|
108 |
+
floorplan = show_alpha_floorplan(dt_xyz, border_color=[0, 1, 0, 1])
|
109 |
+
|
110 |
+
vis_img = np.concatenate([vis_img, floorplan[:, 60:-60, :]], axis=1)
|
111 |
+
if show:
|
112 |
+
plt.imshow(vis_img)
|
113 |
+
plt.show()
|
114 |
+
if save_path:
|
115 |
+
result = Image.fromarray((vis_img * 255).astype(np.uint8))
|
116 |
+
result.save(save_path)
|
117 |
+
return vis_img
|
118 |
+
|
119 |
+
|
120 |
+
def preprocess(img_ori, q_error=0.7, refine_iter=3, vp_cache_path=None):
|
121 |
+
# Align images with VP
|
122 |
+
if os.path.exists(vp_cache_path):
|
123 |
+
with open(vp_cache_path) as f:
|
124 |
+
vp = [[float(v) for v in line.rstrip().split(' ')] for line in f.readlines()]
|
125 |
+
vp = np.array(vp)
|
126 |
+
else:
|
127 |
+
# VP detection and line segment extraction
|
128 |
+
_, vp, _, _, _, _, _ = panoEdgeDetection(img_ori,
|
129 |
+
qError=q_error,
|
130 |
+
refineIter=refine_iter)
|
131 |
+
i_img = rotatePanorama(img_ori, vp[2::-1])
|
132 |
+
|
133 |
+
if vp_cache_path is not None:
|
134 |
+
with open(vp_cache_path, 'w') as f:
|
135 |
+
for i in range(3):
|
136 |
+
f.write('%.6f %.6f %.6f\n' % (vp[i, 0], vp[i, 1], vp[i, 2]))
|
137 |
+
|
138 |
+
return i_img, vp
|
139 |
+
|
140 |
+
|
141 |
+
def show_depth_normal_grad(dt):
|
142 |
+
grad_conv = GradLoss().to(dt['depth'].device).grad_conv
|
143 |
+
dt_grad_img = show_grad(dt['depth'][0], grad_conv, 50)
|
144 |
+
dt_grad_img = cv2.resize(dt_grad_img, (1024, 60), interpolation=cv2.INTER_NEAREST)
|
145 |
+
return dt_grad_img
|
146 |
+
|
147 |
+
|
148 |
+
def show_alpha_floorplan(dt_xyz, side_l=512, border_color=None):
|
149 |
+
if border_color is None:
|
150 |
+
border_color = [1, 0, 0, 1]
|
151 |
+
fill_color = [0.2, 0.2, 0.2, 0.2]
|
152 |
+
dt_floorplan = draw_floorplan(xz=dt_xyz[..., ::2], fill_color=fill_color,
|
153 |
+
border_color=border_color, side_l=side_l, show=False, center_color=[1, 0, 0, 1])
|
154 |
+
dt_floorplan = Image.fromarray((dt_floorplan * 255).astype(np.uint8), mode='RGBA')
|
155 |
+
back = np.zeros([side_l, side_l, len(fill_color)], dtype=np.float)
|
156 |
+
back[..., :] = [0.8, 0.8, 0.8, 1]
|
157 |
+
back = Image.fromarray((back * 255).astype(np.uint8), mode='RGBA')
|
158 |
+
iou_floorplan = Image.alpha_composite(back, dt_floorplan).convert("RGB")
|
159 |
+
dt_floorplan = np.array(iou_floorplan) / 255.0
|
160 |
+
return dt_floorplan
|
161 |
+
|
162 |
+
|
163 |
+
def save_pred_json(xyz, ration, save_path):
|
164 |
+
# xyz[..., -1] = -xyz[..., -1]
|
165 |
+
json_data = xyz2json(xyz, ration)
|
166 |
+
with open(save_path, 'w') as f:
|
167 |
+
f.write(json.dumps(json_data, indent=4) + '\n')
|
168 |
+
return json_data
|
169 |
+
|
170 |
+
|
171 |
+
def inference():
|
172 |
+
if len(img_paths) == 0:
|
173 |
+
logger.error('No images found')
|
174 |
+
return
|
175 |
+
|
176 |
+
bar = tqdm(img_paths, ncols=100)
|
177 |
+
for img_path in bar:
|
178 |
+
if not os.path.isfile(img_path):
|
179 |
+
logger.error(f'The {img_path} not is file')
|
180 |
+
continue
|
181 |
+
name = os.path.basename(img_path).split('.')[0]
|
182 |
+
bar.set_description(name)
|
183 |
+
img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3]
|
184 |
+
if args.post_processing is not None and 'manhattan' in args.post_processing:
|
185 |
+
bar.set_description("Preprocessing")
|
186 |
+
img, vp = preprocess(img, vp_cache_path=os.path.join(args.output_dir, f"{name}_vp.txt"))
|
187 |
+
|
188 |
+
img = (img / 255.0).astype(np.float32)
|
189 |
+
run_one_inference(img, model, args, name)
|
190 |
+
|
191 |
+
|
192 |
+
def inference_dataset(dataset):
|
193 |
+
bar = tqdm(dataset, ncols=100)
|
194 |
+
for data in bar:
|
195 |
+
bar.set_description(data['id'])
|
196 |
+
run_one_inference(data['image'].transpose(1, 2, 0), model, args, name=data['id'], logger=logger)
|
197 |
+
|
198 |
+
|
199 |
+
@torch.no_grad()
|
200 |
+
def run_one_inference(img, model, args, name, logger, show=True, show_depth=True,
|
201 |
+
show_floorplan=True, mesh_format='.gltf', mesh_resolution=512):
|
202 |
+
model.eval()
|
203 |
+
logger.info("model inference...")
|
204 |
+
dt = model(torch.from_numpy(img.transpose(2, 0, 1)[None]).to(args.device))
|
205 |
+
if args.post_processing != 'original':
|
206 |
+
logger.info(f"post-processing, type:{args.post_processing}...")
|
207 |
+
dt['processed_xyz'] = post_process(tensor2np(dt['depth']), type_name=args.post_processing)
|
208 |
+
|
209 |
+
visualize_2d(img, dt,
|
210 |
+
show_depth=show_depth,
|
211 |
+
show_floorplan=show_floorplan,
|
212 |
+
show=show,
|
213 |
+
save_path=os.path.join(args.output_dir, f"{name}_pred.png"))
|
214 |
+
output_xyz = dt['processed_xyz'][0] if 'processed_xyz' in dt else depth2xyz(tensor2np(dt['depth'][0]))
|
215 |
+
|
216 |
+
logger.info(f"saving predicted layout json...")
|
217 |
+
json_data = save_pred_json(output_xyz, tensor2np(dt['ratio'][0])[0],
|
218 |
+
save_path=os.path.join(args.output_dir, f"{name}_pred.json"))
|
219 |
+
# if args.visualize_3d:
|
220 |
+
# from visualization.visualizer.visualizer import visualize_3d
|
221 |
+
# visualize_3d(json_data, (img * 255).astype(np.uint8))
|
222 |
+
|
223 |
+
if args.visualize_3d or args.output_3d:
|
224 |
+
dt_boundaries = corners2boundaries(tensor2np(dt['ratio'][0])[0], corners_xyz=output_xyz, step=None,
|
225 |
+
length=mesh_resolution if 'processed_xyz' in dt else None,
|
226 |
+
visible=True if 'processed_xyz' in dt else False)
|
227 |
+
dt_layout_depth = layout2depth(dt_boundaries, show=False)
|
228 |
+
|
229 |
+
logger.info(f"creating 3d mesh ...")
|
230 |
+
create_3d_obj(cv2.resize(img, dt_layout_depth.shape[::-1]), dt_layout_depth,
|
231 |
+
save_path=os.path.join(args.output_dir, f"{name}_3d{mesh_format}") if args.output_3d else None,
|
232 |
+
mesh=True, show=args.visualize_3d)
|
233 |
+
|
234 |
+
|
235 |
+
if __name__ == '__main__':
|
236 |
+
logger = get_logger()
|
237 |
+
args = parse_option()
|
238 |
+
config = get_config(args)
|
239 |
+
|
240 |
+
if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available():
|
241 |
+
logger.info(f'The {args.device} is not available, will use cpu ...')
|
242 |
+
config.defrost()
|
243 |
+
args.device = "cpu"
|
244 |
+
config.TRAIN.DEVICE = "cpu"
|
245 |
+
config.freeze()
|
246 |
+
|
247 |
+
model, _, _, _ = build_model(config, logger)
|
248 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
249 |
+
img_paths = sorted(glob.glob(args.img_glob))
|
250 |
+
|
251 |
+
inference()
|
252 |
+
|
253 |
+
# dataset = MP3DDataset(root_dir='./src/dataset/mp3d', mode='test', split_list=[
|
254 |
+
# ['7y3sRwLe3Va', '155fac2d50764bf09feb6c8f33e8fb76'],
|
255 |
+
# ['e9zR4mvMWw7', 'c904c55a5d0e420bbd6e4e030b9fe5b4'],
|
256 |
+
# ])
|
257 |
+
# dataset = ZindDataset(root_dir='./src/dataset/zind', mode='test', split_list=[
|
258 |
+
# '1169_pano_21',
|
259 |
+
# '0583_pano_59',
|
260 |
+
# ], vp_align=True)
|
261 |
+
# inference_dataset(dataset)
|
loss/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@date: 2021/7/19
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
|
6 |
+
from torch.nn import L1Loss
|
7 |
+
from loss.led_loss import LEDLoss
|
8 |
+
from loss.grad_loss import GradLoss
|
9 |
+
from loss.boundary_loss import BoundaryLoss
|
10 |
+
from loss.object_loss import ObjectLoss, HeatmapLoss
|
loss/boundary_loss.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/08/12
|
3 |
+
@description: For HorizonNet, using latitudes to calculate loss.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from utils.conversion import depth2xyz, xyz2lonlat
|
8 |
+
|
9 |
+
|
10 |
+
class BoundaryLoss(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
self.loss = nn.L1Loss()
|
14 |
+
|
15 |
+
def forward(self, gt, dt):
|
16 |
+
gt_floor_xyz = depth2xyz(gt['depth'])
|
17 |
+
gt_ceil_xyz = gt_floor_xyz.clone()
|
18 |
+
gt_ceil_xyz[..., 1] = -gt['ratio']
|
19 |
+
|
20 |
+
gt_floor_boundary = xyz2lonlat(gt_floor_xyz)[..., -1:]
|
21 |
+
gt_ceil_boundary = xyz2lonlat(gt_ceil_xyz)[..., -1:]
|
22 |
+
|
23 |
+
gt_boundary = torch.cat([gt_floor_boundary, gt_ceil_boundary], dim=-1).permute(0, 2, 1)
|
24 |
+
dt_boundary = dt['boundary']
|
25 |
+
|
26 |
+
loss = self.loss(gt_boundary, dt_boundary)
|
27 |
+
return loss
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == '__main__':
|
31 |
+
import numpy as np
|
32 |
+
from dataset.mp3d_dataset import MP3DDataset
|
33 |
+
|
34 |
+
mp3d_dataset = MP3DDataset(root_dir='../src/dataset/mp3d', mode='train')
|
35 |
+
gt = mp3d_dataset.__getitem__(0)
|
36 |
+
|
37 |
+
gt['depth'] = torch.from_numpy(gt['depth'][np.newaxis]) # batch size is 1
|
38 |
+
gt['ratio'] = torch.from_numpy(gt['ratio'][np.newaxis]) # batch size is 1
|
39 |
+
|
40 |
+
dummy_dt = {
|
41 |
+
'depth': gt['depth'].clone(),
|
42 |
+
'boundary': torch.cat([
|
43 |
+
xyz2lonlat(depth2xyz(gt['depth']))[..., -1:],
|
44 |
+
xyz2lonlat(depth2xyz(gt['depth'], plan_y=-gt['ratio']))[..., -1:]
|
45 |
+
], dim=-1).permute(0, 2, 1)
|
46 |
+
}
|
47 |
+
# dummy_dt['boundary'][:, :, :20] /= 1.2 # some different
|
48 |
+
|
49 |
+
boundary_loss = BoundaryLoss()
|
50 |
+
loss = boundary_loss(gt, dummy_dt)
|
51 |
+
print(loss)
|
loss/grad_loss.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/08/12
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from visualization.grad import get_all
|
11 |
+
|
12 |
+
|
13 |
+
class GradLoss(nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
self.loss = nn.L1Loss()
|
17 |
+
self.cos = nn.CosineSimilarity(dim=-1, eps=0)
|
18 |
+
|
19 |
+
self.grad_conv = nn.Conv1d(1, 1, kernel_size=3, stride=1, padding=0, bias=False, padding_mode='circular')
|
20 |
+
self.grad_conv.weight = nn.Parameter(torch.tensor([[[1, 0, -1]]]).float())
|
21 |
+
self.grad_conv.weight.requires_grad = False
|
22 |
+
|
23 |
+
def forward(self, gt, dt):
|
24 |
+
gt_direction, _, gt_angle_grad = get_all(gt['depth'], self.grad_conv)
|
25 |
+
dt_direction, _, dt_angle_grad = get_all(dt['depth'], self.grad_conv)
|
26 |
+
|
27 |
+
normal_loss = (1 - self.cos(gt_direction, dt_direction)).mean()
|
28 |
+
grad_loss = self.loss(gt_angle_grad, dt_angle_grad)
|
29 |
+
return [normal_loss, grad_loss]
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == '__main__':
|
33 |
+
from dataset.mp3d_dataset import MP3DDataset
|
34 |
+
from utils.boundary import depth2boundaries
|
35 |
+
from utils.conversion import uv2xyz
|
36 |
+
from visualization.boundary import draw_boundaries
|
37 |
+
from visualization.floorplan import draw_floorplan
|
38 |
+
|
39 |
+
def show_boundary(image, depth, ratio):
|
40 |
+
boundary_list = depth2boundaries(ratio, depth, step=None)
|
41 |
+
draw_boundaries(image.transpose(1, 2, 0), boundary_list=boundary_list, show=True)
|
42 |
+
draw_floorplan(uv2xyz(boundary_list[0])[..., ::2], show=True, center_color=0.8)
|
43 |
+
|
44 |
+
mp3d_dataset = MP3DDataset(root_dir='../src/dataset/mp3d', mode='train', patch_num=256)
|
45 |
+
gt = mp3d_dataset.__getitem__(1)
|
46 |
+
gt['depth'] = torch.from_numpy(gt['depth'][np.newaxis]) # batch size is 1
|
47 |
+
dummy_dt = {
|
48 |
+
'depth': gt['depth'].clone(),
|
49 |
+
}
|
50 |
+
# dummy_dt['depth'][..., 20] *= 3 # some different
|
51 |
+
|
52 |
+
# show_boundary(gt['image'], gt['depth'][0].numpy(), gt['ratio'])
|
53 |
+
# show_boundary(gt['image'], dummy_dt['depth'][0].numpy(), gt['ratio'])
|
54 |
+
|
55 |
+
grad_loss = GradLoss()
|
56 |
+
loss = grad_loss(gt, dummy_dt)
|
57 |
+
print(loss)
|
loss/led_loss.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/08/12
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class LEDLoss(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
self.loss = nn.L1Loss()
|
13 |
+
|
14 |
+
def forward(self, gt, dt):
|
15 |
+
camera_height = 1.6
|
16 |
+
gt_depth = gt['depth'] * camera_height
|
17 |
+
|
18 |
+
dt_ceil_depth = dt['ceil_depth'] * camera_height * gt['ratio']
|
19 |
+
dt_floor_depth = dt['depth'] * camera_height
|
20 |
+
|
21 |
+
ceil_loss = self.loss(gt_depth, dt_ceil_depth)
|
22 |
+
floor_loss = self.loss(gt_depth, dt_floor_depth)
|
23 |
+
|
24 |
+
loss = floor_loss + ceil_loss
|
25 |
+
|
26 |
+
return loss
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
import numpy as np
|
31 |
+
from dataset.mp3d_dataset import MP3DDataset
|
32 |
+
|
33 |
+
mp3d_dataset = MP3DDataset(root_dir='../src/dataset/mp3d', mode='train')
|
34 |
+
gt = mp3d_dataset.__getitem__(0)
|
35 |
+
|
36 |
+
gt['depth'] = torch.from_numpy(gt['depth'][np.newaxis]) # batch size is 1
|
37 |
+
gt['ratio'] = torch.from_numpy(gt['ratio'][np.newaxis]) # batch size is 1
|
38 |
+
|
39 |
+
dummy_dt = {
|
40 |
+
'depth': gt['depth'].clone(),
|
41 |
+
'ceil_depth': gt['depth'] / gt['ratio']
|
42 |
+
}
|
43 |
+
# dummy_dt['depth'][..., :20] *= 3 # some different
|
44 |
+
|
45 |
+
led_loss = LEDLoss()
|
46 |
+
loss = led_loss(gt, dummy_dt)
|
47 |
+
print(loss)
|
loss/object_loss.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/08/12
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from loss.grad_loss import GradLoss
|
8 |
+
|
9 |
+
|
10 |
+
class ObjectLoss(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
self.heat_map_loss = HeatmapLoss(reduction='mean') # FocalLoss(reduction='mean')
|
14 |
+
self.l1_loss = nn.SmoothL1Loss()
|
15 |
+
|
16 |
+
def forward(self, gt, dt):
|
17 |
+
# TODO::
|
18 |
+
return 0
|
19 |
+
|
20 |
+
|
21 |
+
class HeatmapLoss(nn.Module):
|
22 |
+
def __init__(self, weight=None, alpha=2, beta=4, reduction='mean'):
|
23 |
+
super(HeatmapLoss, self).__init__()
|
24 |
+
self.alpha = alpha
|
25 |
+
self.beta = beta
|
26 |
+
self.reduction = reduction
|
27 |
+
|
28 |
+
def forward(self, targets, inputs):
|
29 |
+
center_id = (targets == 1.0).float()
|
30 |
+
other_id = (targets != 1.0).float()
|
31 |
+
center_loss = -center_id * (1.0 - inputs) ** self.alpha * torch.log(inputs + 1e-14)
|
32 |
+
other_loss = -other_id * (1 - targets) ** self.beta * inputs ** self.alpha * torch.log(1.0 - inputs + 1e-14)
|
33 |
+
loss = center_loss + other_loss
|
34 |
+
|
35 |
+
batch_size = loss.size(0)
|
36 |
+
if self.reduction == 'mean':
|
37 |
+
loss = torch.sum(loss) / batch_size
|
38 |
+
|
39 |
+
if self.reduction == 'sum':
|
40 |
+
loss = torch.sum(loss) / batch_size
|
41 |
+
|
42 |
+
return loss
|
main.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/07/17
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
import shutil
|
8 |
+
import argparse
|
9 |
+
import numpy as np
|
10 |
+
import json
|
11 |
+
import torch
|
12 |
+
import torch.nn.parallel
|
13 |
+
import torch.optim
|
14 |
+
import torch.multiprocessing as mp
|
15 |
+
import torch.utils.data
|
16 |
+
import torch.utils.data.distributed
|
17 |
+
import torch.cuda
|
18 |
+
|
19 |
+
from PIL import Image
|
20 |
+
from tqdm import tqdm
|
21 |
+
from torch.utils.tensorboard import SummaryWriter
|
22 |
+
from config.defaults import get_config, get_rank_config
|
23 |
+
from models.other.criterion import calc_criterion
|
24 |
+
from models.build import build_model
|
25 |
+
from models.other.init_env import init_env
|
26 |
+
from utils.logger import build_logger
|
27 |
+
from utils.misc import tensor2np_d, tensor2np
|
28 |
+
from dataset.build import build_loader
|
29 |
+
from evaluation.accuracy import calc_accuracy, show_heat_map, calc_ce, calc_pe, calc_rmse_delta_1, \
|
30 |
+
show_depth_normal_grad, calc_f1_score
|
31 |
+
from postprocessing.post_process import post_process
|
32 |
+
|
33 |
+
try:
|
34 |
+
from apex import amp
|
35 |
+
except ImportError:
|
36 |
+
amp = None
|
37 |
+
|
38 |
+
|
39 |
+
def parse_option():
|
40 |
+
debug = True if sys.gettrace() else False
|
41 |
+
parser = argparse.ArgumentParser(description='Panorama Layout Transformer training and evaluation script')
|
42 |
+
parser.add_argument('--cfg',
|
43 |
+
type=str,
|
44 |
+
metavar='FILE',
|
45 |
+
help='path to config file')
|
46 |
+
|
47 |
+
parser.add_argument('--mode',
|
48 |
+
type=str,
|
49 |
+
default='train',
|
50 |
+
choices=['train', 'val', 'test'],
|
51 |
+
help='train/val/test mode')
|
52 |
+
|
53 |
+
parser.add_argument('--val_name',
|
54 |
+
type=str,
|
55 |
+
choices=['val', 'test'],
|
56 |
+
help='val name')
|
57 |
+
|
58 |
+
parser.add_argument('--bs', type=int,
|
59 |
+
help='batch size')
|
60 |
+
|
61 |
+
parser.add_argument('--save_eval', action='store_true',
|
62 |
+
help='save eval result')
|
63 |
+
|
64 |
+
parser.add_argument('--post_processing', type=str,
|
65 |
+
choices=['manhattan', 'atalanta', 'manhattan_old'],
|
66 |
+
help='type of postprocessing ')
|
67 |
+
|
68 |
+
parser.add_argument('--need_cpe', action='store_true',
|
69 |
+
help='need to evaluate corner error and pixel error')
|
70 |
+
|
71 |
+
parser.add_argument('--need_f1', action='store_true',
|
72 |
+
help='need to evaluate f1-score of corners')
|
73 |
+
|
74 |
+
parser.add_argument('--need_rmse', action='store_true',
|
75 |
+
help='need to evaluate root mean squared error and delta error')
|
76 |
+
|
77 |
+
parser.add_argument('--force_cube', action='store_true',
|
78 |
+
help='force cube shape when eval')
|
79 |
+
|
80 |
+
parser.add_argument('--wall_num', type=int,
|
81 |
+
help='wall number')
|
82 |
+
|
83 |
+
args = parser.parse_args()
|
84 |
+
args.debug = debug
|
85 |
+
print("arguments:")
|
86 |
+
for arg in vars(args):
|
87 |
+
print(arg, ":", getattr(args, arg))
|
88 |
+
print("-" * 50)
|
89 |
+
return args
|
90 |
+
|
91 |
+
|
92 |
+
def main():
|
93 |
+
args = parse_option()
|
94 |
+
config = get_config(args)
|
95 |
+
|
96 |
+
if config.TRAIN.SCRATCH and os.path.exists(config.CKPT.DIR) and config.MODE == 'train':
|
97 |
+
print(f"Train from scratch, delete checkpoint dir: {config.CKPT.DIR}")
|
98 |
+
f = [int(f.split('_')[-1].split('.')[0]) for f in os.listdir(config.CKPT.DIR) if 'pkl' in f]
|
99 |
+
if len(f) > 0:
|
100 |
+
last_epoch = np.array(f).max()
|
101 |
+
if last_epoch > 10:
|
102 |
+
c = input(f"delete it (last_epoch: {last_epoch})?(Y/N)\n")
|
103 |
+
if c != 'y' and c != 'Y':
|
104 |
+
exit(0)
|
105 |
+
|
106 |
+
shutil.rmtree(config.CKPT.DIR, ignore_errors=True)
|
107 |
+
|
108 |
+
os.makedirs(config.CKPT.DIR, exist_ok=True)
|
109 |
+
os.makedirs(config.CKPT.RESULT_DIR, exist_ok=True)
|
110 |
+
os.makedirs(config.LOGGER.DIR, exist_ok=True)
|
111 |
+
|
112 |
+
if ':' in config.TRAIN.DEVICE:
|
113 |
+
nprocs = len(config.TRAIN.DEVICE.split(':')[-1].split(','))
|
114 |
+
if 'cuda' in config.TRAIN.DEVICE:
|
115 |
+
if not torch.cuda.is_available():
|
116 |
+
print(f"Cuda is not available(config is: {config.TRAIN.DEVICE}), will use cpu ...")
|
117 |
+
config.defrost()
|
118 |
+
config.TRAIN.DEVICE = "cpu"
|
119 |
+
config.freeze()
|
120 |
+
nprocs = 1
|
121 |
+
|
122 |
+
if config.MODE == 'train':
|
123 |
+
with open(os.path.join(config.CKPT.DIR, "config.yaml"), "w") as f:
|
124 |
+
f.write(config.dump(allow_unicode=True))
|
125 |
+
|
126 |
+
if config.TRAIN.DEVICE == 'cpu' or nprocs < 2:
|
127 |
+
print(f"Use single process, device:{config.TRAIN.DEVICE}")
|
128 |
+
main_worker(0, config, 1)
|
129 |
+
else:
|
130 |
+
print(f"Use {nprocs} processes ...")
|
131 |
+
mp.spawn(main_worker, nprocs=nprocs, args=(config, nprocs), join=True)
|
132 |
+
|
133 |
+
|
134 |
+
def main_worker(local_rank, cfg, world_size):
|
135 |
+
config = get_rank_config(cfg, local_rank, world_size)
|
136 |
+
logger = build_logger(config)
|
137 |
+
writer = SummaryWriter(config.CKPT.DIR)
|
138 |
+
logger.info(f"Comment: {config.COMMENT}")
|
139 |
+
cur_pid = os.getpid()
|
140 |
+
logger.info(f"Current process id: {cur_pid}")
|
141 |
+
torch.hub._hub_dir = config.CKPT.PYTORCH
|
142 |
+
logger.info(f"Pytorch hub dir: {torch.hub._hub_dir}")
|
143 |
+
init_env(config.SEED, config.TRAIN.DETERMINISTIC, config.DATA.NUM_WORKERS)
|
144 |
+
|
145 |
+
model, optimizer, criterion, scheduler = build_model(config, logger)
|
146 |
+
train_data_loader, val_data_loader = build_loader(config, logger)
|
147 |
+
|
148 |
+
if 'cuda' in config.TRAIN.DEVICE:
|
149 |
+
torch.cuda.set_device(config.TRAIN.DEVICE)
|
150 |
+
|
151 |
+
if config.MODE == 'train':
|
152 |
+
train(model, train_data_loader, val_data_loader, optimizer, criterion, config, logger, writer, scheduler)
|
153 |
+
else:
|
154 |
+
iou_results, other_results = val_an_epoch(model, val_data_loader,
|
155 |
+
criterion, config, logger, writer=None,
|
156 |
+
epoch=config.TRAIN.START_EPOCH)
|
157 |
+
results = dict(iou_results, **other_results)
|
158 |
+
if config.SAVE_EVAL:
|
159 |
+
save_path = os.path.join(config.CKPT.RESULT_DIR, f"result.json")
|
160 |
+
with open(save_path, 'w+') as f:
|
161 |
+
json.dump(results, f, indent=4)
|
162 |
+
|
163 |
+
|
164 |
+
def save(model, optimizer, epoch, iou_d, logger, writer, config):
|
165 |
+
model.save(optimizer, epoch, accuracy=iou_d['full_3d'], logger=logger, acc_d=iou_d, config=config)
|
166 |
+
for k in model.acc_d:
|
167 |
+
writer.add_scalar(f"BestACC/{k}", model.acc_d[k]['acc'], epoch)
|
168 |
+
|
169 |
+
|
170 |
+
def train(model, train_data_loader, val_data_loader, optimizer, criterion, config, logger, writer, scheduler):
|
171 |
+
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
|
172 |
+
logger.info("=" * 200)
|
173 |
+
train_an_epoch(model, train_data_loader, optimizer, criterion, config, logger, writer, epoch)
|
174 |
+
epoch_iou_d, _ = val_an_epoch(model, val_data_loader, criterion, config, logger, writer, epoch)
|
175 |
+
|
176 |
+
if config.LOCAL_RANK == 0:
|
177 |
+
ddp = config.WORLD_SIZE > 1
|
178 |
+
save(model.module if ddp else model, optimizer, epoch, epoch_iou_d, logger, writer, config)
|
179 |
+
|
180 |
+
if scheduler is not None:
|
181 |
+
if scheduler.min_lr is not None and optimizer.param_groups[0]['lr'] <= scheduler.min_lr:
|
182 |
+
continue
|
183 |
+
scheduler.step()
|
184 |
+
writer.close()
|
185 |
+
|
186 |
+
|
187 |
+
def train_an_epoch(model, train_data_loader, optimizer, criterion, config, logger, writer, epoch=0):
|
188 |
+
logger.info(f'Start Train Epoch {epoch}/{config.TRAIN.EPOCHS - 1}')
|
189 |
+
model.train()
|
190 |
+
|
191 |
+
if len(config.MODEL.FINE_TUNE) > 0:
|
192 |
+
model.feature_extractor.eval()
|
193 |
+
|
194 |
+
optimizer.zero_grad()
|
195 |
+
|
196 |
+
data_len = len(train_data_loader)
|
197 |
+
start_i = data_len * epoch * config.WORLD_SIZE
|
198 |
+
bar = enumerate(train_data_loader)
|
199 |
+
if config.LOCAL_RANK == 0 and config.SHOW_BAR:
|
200 |
+
bar = tqdm(bar, total=data_len, ncols=200)
|
201 |
+
|
202 |
+
device = config.TRAIN.DEVICE
|
203 |
+
epoch_loss_d = {}
|
204 |
+
for i, gt in bar:
|
205 |
+
imgs = gt['image'].to(device, non_blocking=True)
|
206 |
+
gt['depth'] = gt['depth'].to(device, non_blocking=True)
|
207 |
+
gt['ratio'] = gt['ratio'].to(device, non_blocking=True)
|
208 |
+
if 'corner_heat_map' in gt:
|
209 |
+
gt['corner_heat_map'] = gt['corner_heat_map'].to(device, non_blocking=True)
|
210 |
+
if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device:
|
211 |
+
imgs = imgs.type(torch.float16)
|
212 |
+
gt['depth'] = gt['depth'].type(torch.float16)
|
213 |
+
gt['ratio'] = gt['ratio'].type(torch.float16)
|
214 |
+
dt = model(imgs)
|
215 |
+
loss, batch_loss_d, epoch_loss_d = calc_criterion(criterion, gt, dt, epoch_loss_d)
|
216 |
+
if config.LOCAL_RANK == 0 and config.SHOW_BAR:
|
217 |
+
bar.set_postfix(batch_loss_d)
|
218 |
+
|
219 |
+
optimizer.zero_grad()
|
220 |
+
if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device:
|
221 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
222 |
+
scaled_loss.backward()
|
223 |
+
else:
|
224 |
+
loss.backward()
|
225 |
+
optimizer.step()
|
226 |
+
|
227 |
+
global_step = start_i + i * config.WORLD_SIZE + config.LOCAL_RANK
|
228 |
+
for key, val in batch_loss_d.items():
|
229 |
+
writer.add_scalar(f'TrainBatchLoss/{key}', val, global_step)
|
230 |
+
|
231 |
+
if config.LOCAL_RANK != 0:
|
232 |
+
return
|
233 |
+
|
234 |
+
epoch_loss_d = dict(zip(epoch_loss_d.keys(), [np.array(epoch_loss_d[k]).mean() for k in epoch_loss_d.keys()]))
|
235 |
+
s = 'TrainEpochLoss: '
|
236 |
+
for key, val in epoch_loss_d.items():
|
237 |
+
writer.add_scalar(f'TrainEpochLoss/{key}', val, epoch)
|
238 |
+
s += f" {key}={val}"
|
239 |
+
logger.info(s)
|
240 |
+
writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], epoch)
|
241 |
+
logger.info(f"LearningRate: {optimizer.param_groups[0]['lr']}")
|
242 |
+
|
243 |
+
|
244 |
+
@torch.no_grad()
|
245 |
+
def val_an_epoch(model, val_data_loader, criterion, config, logger, writer, epoch=0):
|
246 |
+
model.eval()
|
247 |
+
logger.info(f'Start Validate Epoch {epoch}/{config.TRAIN.EPOCHS - 1}')
|
248 |
+
data_len = len(val_data_loader)
|
249 |
+
start_i = data_len * epoch * config.WORLD_SIZE
|
250 |
+
bar = enumerate(val_data_loader)
|
251 |
+
if config.LOCAL_RANK == 0 and config.SHOW_BAR:
|
252 |
+
bar = tqdm(bar, total=data_len, ncols=200)
|
253 |
+
device = config.TRAIN.DEVICE
|
254 |
+
epoch_loss_d = {}
|
255 |
+
epoch_iou_d = {
|
256 |
+
'visible_2d': [],
|
257 |
+
'visible_3d': [],
|
258 |
+
'full_2d': [],
|
259 |
+
'full_3d': [],
|
260 |
+
'height': []
|
261 |
+
}
|
262 |
+
|
263 |
+
epoch_other_d = {
|
264 |
+
'ce': [],
|
265 |
+
'pe': [],
|
266 |
+
'f1': [],
|
267 |
+
'precision': [],
|
268 |
+
'recall': [],
|
269 |
+
'rmse': [],
|
270 |
+
'delta_1': []
|
271 |
+
}
|
272 |
+
|
273 |
+
show_index = np.random.randint(0, data_len)
|
274 |
+
for i, gt in bar:
|
275 |
+
imgs = gt['image'].to(device, non_blocking=True)
|
276 |
+
gt['depth'] = gt['depth'].to(device, non_blocking=True)
|
277 |
+
gt['ratio'] = gt['ratio'].to(device, non_blocking=True)
|
278 |
+
if 'corner_heat_map' in gt:
|
279 |
+
gt['corner_heat_map'] = gt['corner_heat_map'].to(device, non_blocking=True)
|
280 |
+
dt = model(imgs)
|
281 |
+
|
282 |
+
vis_w = config.TRAIN.VIS_WEIGHT
|
283 |
+
visualization = False # (config.LOCAL_RANK == 0 and i == show_index) or config.SAVE_EVAL
|
284 |
+
|
285 |
+
loss, batch_loss_d, epoch_loss_d = calc_criterion(criterion, gt, dt, epoch_loss_d)
|
286 |
+
|
287 |
+
if config.EVAL.POST_PROCESSING is not None:
|
288 |
+
depth = tensor2np(dt['depth'])
|
289 |
+
dt['processed_xyz'] = post_process(depth, type_name=config.EVAL.POST_PROCESSING,
|
290 |
+
need_cube=config.EVAL.FORCE_CUBE)
|
291 |
+
|
292 |
+
if config.EVAL.FORCE_CUBE and config.EVAL.NEED_CPE:
|
293 |
+
ce = calc_ce(tensor2np_d(dt), tensor2np_d(gt))
|
294 |
+
pe = calc_pe(tensor2np_d(dt), tensor2np_d(gt))
|
295 |
+
|
296 |
+
epoch_other_d['ce'].append(ce)
|
297 |
+
epoch_other_d['pe'].append(pe)
|
298 |
+
|
299 |
+
if config.EVAL.NEED_F1:
|
300 |
+
f1, precision, recall = calc_f1_score(tensor2np_d(dt), tensor2np_d(gt))
|
301 |
+
epoch_other_d['f1'].append(f1)
|
302 |
+
epoch_other_d['precision'].append(precision)
|
303 |
+
epoch_other_d['recall'].append(recall)
|
304 |
+
|
305 |
+
if config.EVAL.NEED_RMSE:
|
306 |
+
rmse, delta_1 = calc_rmse_delta_1(tensor2np_d(dt), tensor2np_d(gt))
|
307 |
+
epoch_other_d['rmse'].append(rmse)
|
308 |
+
epoch_other_d['delta_1'].append(delta_1)
|
309 |
+
|
310 |
+
visb_iou, full_iou, iou_height, pano_bds, full_iou_2ds = calc_accuracy(tensor2np_d(dt), tensor2np_d(gt),
|
311 |
+
visualization, h=vis_w // 2)
|
312 |
+
epoch_iou_d['visible_2d'].append(visb_iou[0])
|
313 |
+
epoch_iou_d['visible_3d'].append(visb_iou[1])
|
314 |
+
epoch_iou_d['full_2d'].append(full_iou[0])
|
315 |
+
epoch_iou_d['full_3d'].append(full_iou[1])
|
316 |
+
epoch_iou_d['height'].append(iou_height)
|
317 |
+
|
318 |
+
if config.LOCAL_RANK == 0 and config.SHOW_BAR:
|
319 |
+
bar.set_postfix(batch_loss_d)
|
320 |
+
|
321 |
+
global_step = start_i + i * config.WORLD_SIZE + config.LOCAL_RANK
|
322 |
+
|
323 |
+
if writer:
|
324 |
+
for key, val in batch_loss_d.items():
|
325 |
+
writer.add_scalar(f'ValBatchLoss/{key}', val, global_step)
|
326 |
+
|
327 |
+
if not visualization:
|
328 |
+
continue
|
329 |
+
|
330 |
+
gt_grad_imgs, dt_grad_imgs = show_depth_normal_grad(dt, gt, device, vis_w)
|
331 |
+
|
332 |
+
dt_heat_map_imgs = None
|
333 |
+
gt_heat_map_imgs = None
|
334 |
+
if 'corner_heat_map' in gt:
|
335 |
+
dt_heat_map_imgs, gt_heat_map_imgs = show_heat_map(dt, gt, vis_w)
|
336 |
+
|
337 |
+
if config.TRAIN.VIS_MERGE or config.SAVE_EVAL:
|
338 |
+
imgs = []
|
339 |
+
for j in range(len(pano_bds)):
|
340 |
+
# floorplan = np.concatenate([visb_iou[2][j], full_iou[2][j]], axis=-1)
|
341 |
+
floorplan = full_iou[2][j]
|
342 |
+
margin_w = int(floorplan.shape[-1] * (60/512))
|
343 |
+
floorplan = floorplan[:, :, margin_w:-margin_w]
|
344 |
+
|
345 |
+
grad_h = dt_grad_imgs[0].shape[1]
|
346 |
+
vis_merge = [
|
347 |
+
gt_grad_imgs[j],
|
348 |
+
pano_bds[j][:, grad_h:-grad_h],
|
349 |
+
dt_grad_imgs[j]
|
350 |
+
]
|
351 |
+
if 'corner_heat_map' in gt:
|
352 |
+
vis_merge = [dt_heat_map_imgs[j], gt_heat_map_imgs[j]] + vis_merge
|
353 |
+
img = np.concatenate(vis_merge, axis=-2)
|
354 |
+
|
355 |
+
img = np.concatenate([img, ], axis=-1)
|
356 |
+
# img = gt_grad_imgs[j]
|
357 |
+
imgs.append(img)
|
358 |
+
if writer:
|
359 |
+
writer.add_images('VIS/Merge', np.array(imgs), global_step)
|
360 |
+
|
361 |
+
if config.SAVE_EVAL:
|
362 |
+
for k in range(len(imgs)):
|
363 |
+
img = imgs[k] * 255.0
|
364 |
+
save_path = os.path.join(config.CKPT.RESULT_DIR, f"{gt['id'][k]}_{full_iou_2ds[k]:.5f}.png")
|
365 |
+
Image.fromarray(img.transpose(1, 2, 0).astype(np.uint8)).save(save_path)
|
366 |
+
|
367 |
+
elif writer:
|
368 |
+
writer.add_images('IoU/Visible_Floorplan', visb_iou[2], global_step)
|
369 |
+
writer.add_images('IoU/Full_Floorplan', full_iou[2], global_step)
|
370 |
+
writer.add_images('IoU/Boundary', pano_bds, global_step)
|
371 |
+
writer.add_images('Grad/gt', gt_grad_imgs, global_step)
|
372 |
+
writer.add_images('Grad/dt', dt_grad_imgs, global_step)
|
373 |
+
|
374 |
+
if config.LOCAL_RANK != 0:
|
375 |
+
return
|
376 |
+
|
377 |
+
epoch_loss_d = dict(zip(epoch_loss_d.keys(), [np.array(epoch_loss_d[k]).mean() for k in epoch_loss_d.keys()]))
|
378 |
+
s = 'ValEpochLoss: '
|
379 |
+
for key, val in epoch_loss_d.items():
|
380 |
+
if writer:
|
381 |
+
writer.add_scalar(f'ValEpochLoss/{key}', val, epoch)
|
382 |
+
s += f" {key}={val}"
|
383 |
+
logger.info(s)
|
384 |
+
|
385 |
+
epoch_iou_d = dict(zip(epoch_iou_d.keys(), [np.array(epoch_iou_d[k]).mean() for k in epoch_iou_d.keys()]))
|
386 |
+
s = 'ValEpochIoU: '
|
387 |
+
for key, val in epoch_iou_d.items():
|
388 |
+
if writer:
|
389 |
+
writer.add_scalar(f'ValEpochIoU/{key}', val, epoch)
|
390 |
+
s += f" {key}={val}"
|
391 |
+
logger.info(s)
|
392 |
+
epoch_other_d = dict(zip(epoch_other_d.keys(),
|
393 |
+
[np.array(epoch_other_d[k]).mean() if len(epoch_other_d[k]) > 0 else 0 for k in
|
394 |
+
epoch_other_d.keys()]))
|
395 |
+
|
396 |
+
logger.info(f'other acc: {epoch_other_d}')
|
397 |
+
return epoch_iou_d, epoch_other_d
|
398 |
+
|
399 |
+
|
400 |
+
if __name__ == '__main__':
|
401 |
+
main()
|
models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from models.lgt_net import LGT_Net
|
models/base_model.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/07/17
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import datetime
|
9 |
+
|
10 |
+
|
11 |
+
class BaseModule(nn.Module):
|
12 |
+
def __init__(self, ckpt_dir=None):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.ckpt_dir = ckpt_dir
|
16 |
+
|
17 |
+
if ckpt_dir:
|
18 |
+
if not os.path.exists(ckpt_dir):
|
19 |
+
os.makedirs(ckpt_dir)
|
20 |
+
else:
|
21 |
+
self.model_lst = [x for x in sorted(os.listdir(self.ckpt_dir)) if x.endswith('.pkl')]
|
22 |
+
|
23 |
+
self.last_model_path = None
|
24 |
+
self.best_model_path = None
|
25 |
+
self.best_accuracy = -float('inf')
|
26 |
+
self.acc_d = {}
|
27 |
+
|
28 |
+
def show_parameter_number(self, logger):
|
29 |
+
total = sum(p.numel() for p in self.parameters())
|
30 |
+
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
31 |
+
logger.info('{} parameter total:{:,}, trainable:{:,}'.format(self._get_name(), total, trainable))
|
32 |
+
|
33 |
+
def load(self, device, logger, optimizer=None, best=False):
|
34 |
+
if len(self.model_lst) == 0:
|
35 |
+
logger.info('*'*50)
|
36 |
+
logger.info("Empty model folder! Using initial weights")
|
37 |
+
logger.info('*'*50)
|
38 |
+
return 0
|
39 |
+
|
40 |
+
last_model_lst = list(filter(lambda n: '_last_' in n, self.model_lst))
|
41 |
+
best_model_lst = list(filter(lambda n: '_best_' in n, self.model_lst))
|
42 |
+
|
43 |
+
if len(last_model_lst) == 0 and len(best_model_lst) == 0:
|
44 |
+
logger.info('*'*50)
|
45 |
+
ckpt_path = os.path.join(self.ckpt_dir, self.model_lst[0])
|
46 |
+
logger.info(f"Load: {ckpt_path}")
|
47 |
+
checkpoint = torch.load(ckpt_path, map_location=torch.device(device))
|
48 |
+
self.load_state_dict(checkpoint, strict=False)
|
49 |
+
logger.info('*'*50)
|
50 |
+
return 0
|
51 |
+
|
52 |
+
checkpoint = None
|
53 |
+
if len(last_model_lst) > 0:
|
54 |
+
self.last_model_path = os.path.join(self.ckpt_dir, last_model_lst[-1])
|
55 |
+
checkpoint = torch.load(self.last_model_path, map_location=torch.device(device))
|
56 |
+
self.best_accuracy = checkpoint['accuracy']
|
57 |
+
self.acc_d = checkpoint['acc_d']
|
58 |
+
|
59 |
+
if len(best_model_lst) > 0:
|
60 |
+
self.best_model_path = os.path.join(self.ckpt_dir, best_model_lst[-1])
|
61 |
+
best_checkpoint = torch.load(self.best_model_path, map_location=torch.device(device))
|
62 |
+
self.best_accuracy = best_checkpoint['accuracy']
|
63 |
+
self.acc_d = best_checkpoint['acc_d']
|
64 |
+
if best:
|
65 |
+
checkpoint = best_checkpoint
|
66 |
+
|
67 |
+
for k in self.acc_d:
|
68 |
+
if isinstance(self.acc_d[k], float):
|
69 |
+
self.acc_d[k] = {
|
70 |
+
'acc': self.acc_d[k],
|
71 |
+
'epoch': checkpoint['epoch']
|
72 |
+
}
|
73 |
+
|
74 |
+
if checkpoint is None:
|
75 |
+
logger.error("Invalid checkpoint")
|
76 |
+
return
|
77 |
+
|
78 |
+
self.load_state_dict(checkpoint['net'], strict=False)
|
79 |
+
if optimizer and not best: # best的时候使用新的优化器比如从adam->sgd
|
80 |
+
logger.info('Load optimizer')
|
81 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
82 |
+
for state in optimizer.state.values():
|
83 |
+
for k, v in state.items():
|
84 |
+
if torch.is_tensor(v):
|
85 |
+
state[k] = v.to(device)
|
86 |
+
|
87 |
+
logger.info('*'*50)
|
88 |
+
if best:
|
89 |
+
logger.info(f"Lode best: {self.best_model_path}")
|
90 |
+
else:
|
91 |
+
logger.info(f"Lode last: {self.last_model_path}")
|
92 |
+
|
93 |
+
logger.info(f"Best accuracy: {self.best_accuracy}")
|
94 |
+
logger.info(f"Last epoch: {checkpoint['epoch'] + 1}")
|
95 |
+
logger.info('*'*50)
|
96 |
+
return checkpoint['epoch'] + 1
|
97 |
+
|
98 |
+
def update_acc(self, acc_d, epoch, logger):
|
99 |
+
logger.info("-" * 100)
|
100 |
+
for k in acc_d:
|
101 |
+
if k not in self.acc_d.keys() or acc_d[k] > self.acc_d[k]['acc']:
|
102 |
+
self.acc_d[k] = {
|
103 |
+
'acc': acc_d[k],
|
104 |
+
'epoch': epoch
|
105 |
+
}
|
106 |
+
logger.info(f"Update ACC: {k} {self.acc_d[k]['acc']:.4f}({self.acc_d[k]['epoch']}-{epoch})")
|
107 |
+
logger.info("-" * 100)
|
108 |
+
|
109 |
+
def save(self, optim, epoch, accuracy, logger, replace=True, acc_d=None, config=None):
|
110 |
+
"""
|
111 |
+
|
112 |
+
:param config:
|
113 |
+
:param optim:
|
114 |
+
:param epoch:
|
115 |
+
:param accuracy:
|
116 |
+
:param logger:
|
117 |
+
:param replace:
|
118 |
+
:param acc_d: 其他评估数据,visible_2/3d, full_2/3d, rmse...
|
119 |
+
:return:
|
120 |
+
"""
|
121 |
+
if acc_d:
|
122 |
+
self.update_acc(acc_d, epoch, logger)
|
123 |
+
name = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S_last_{:.4f}_{}'.format(accuracy, epoch))
|
124 |
+
name = f"model_{name}.pkl"
|
125 |
+
checkpoint = {
|
126 |
+
'net': self.state_dict(),
|
127 |
+
'optimizer': optim.state_dict(),
|
128 |
+
'epoch': epoch,
|
129 |
+
'accuracy': accuracy,
|
130 |
+
'acc_d': acc_d
|
131 |
+
}
|
132 |
+
# FIXME:: delete always true
|
133 |
+
if (True or config.MODEL.SAVE_LAST) and epoch % config.TRAIN.SAVE_FREQ == 0:
|
134 |
+
if replace and self.last_model_path and os.path.exists(self.last_model_path):
|
135 |
+
os.remove(self.last_model_path)
|
136 |
+
self.last_model_path = os.path.join(self.ckpt_dir, name)
|
137 |
+
torch.save(checkpoint, self.last_model_path)
|
138 |
+
logger.info(f"Saved last model: {self.last_model_path}")
|
139 |
+
|
140 |
+
if accuracy > self.best_accuracy:
|
141 |
+
self.best_accuracy = accuracy
|
142 |
+
# FIXME:: delete always true
|
143 |
+
if True or config.MODEL.SAVE_BEST:
|
144 |
+
if replace and self.best_model_path and os.path.exists(self.best_model_path):
|
145 |
+
os.remove(self.best_model_path)
|
146 |
+
self.best_model_path = os.path.join(self.ckpt_dir, name.replace('last', 'best'))
|
147 |
+
torch.save(checkpoint, self.best_model_path)
|
148 |
+
logger.info("#" * 100)
|
149 |
+
logger.info(f"Saved best model: {self.best_model_path}")
|
150 |
+
logger.info("#" * 100)
|
models/build.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/07/18
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import models
|
7 |
+
import torch.distributed as dist
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from torch.nn import init
|
11 |
+
from torch.optim import lr_scheduler
|
12 |
+
from utils.time_watch import TimeWatch
|
13 |
+
from models.other.optimizer import build_optimizer
|
14 |
+
from models.other.criterion import build_criterion
|
15 |
+
|
16 |
+
|
17 |
+
def build_model(config, logger):
|
18 |
+
name = config.MODEL.NAME
|
19 |
+
w = TimeWatch(f"Build model: {name}", logger)
|
20 |
+
|
21 |
+
ddp = config.WORLD_SIZE > 1
|
22 |
+
if ddp:
|
23 |
+
logger.info(f"use ddp")
|
24 |
+
dist.init_process_group("nccl", init_method='tcp://127.0.0.1:23456', rank=config.LOCAL_RANK,
|
25 |
+
world_size=config.WORLD_SIZE)
|
26 |
+
|
27 |
+
device = config.TRAIN.DEVICE
|
28 |
+
logger.info(f"Creating model: {name} to device:{device}, args:{config.MODEL.ARGS[0]}")
|
29 |
+
|
30 |
+
net = getattr(models, name)
|
31 |
+
ckpt_dir = os.path.abspath(os.path.join(config.CKPT.DIR, os.pardir)) if config.DEBUG else config.CKPT.DIR
|
32 |
+
if len(config.MODEL.ARGS) != 0:
|
33 |
+
model = net(ckpt_dir=ckpt_dir, **config.MODEL.ARGS[0])
|
34 |
+
else:
|
35 |
+
model = net(ckpt_dir=ckpt_dir)
|
36 |
+
logger.info(f'model dropout: {model.dropout_d}')
|
37 |
+
model = model.to(device)
|
38 |
+
optimizer = None
|
39 |
+
scheduler = None
|
40 |
+
|
41 |
+
if config.MODE == 'train':
|
42 |
+
optimizer = build_optimizer(config, model, logger)
|
43 |
+
|
44 |
+
config.defrost()
|
45 |
+
config.TRAIN.START_EPOCH = model.load(device, logger, optimizer, best=config.MODE != 'train' or not config.TRAIN.RESUME_LAST)
|
46 |
+
config.freeze()
|
47 |
+
|
48 |
+
if config.MODE == 'train' and len(config.MODEL.FINE_TUNE) > 0:
|
49 |
+
for param in model.parameters():
|
50 |
+
param.requires_grad = False
|
51 |
+
for layer in config.MODEL.FINE_TUNE:
|
52 |
+
logger.info(f'Fine-tune: {layer}')
|
53 |
+
getattr(model, layer).requires_grad_(requires_grad=True)
|
54 |
+
getattr(model, layer).reset_parameters()
|
55 |
+
|
56 |
+
model.show_parameter_number(logger)
|
57 |
+
|
58 |
+
if config.MODE == 'train':
|
59 |
+
if len(config.TRAIN.LR_SCHEDULER.NAME) > 0:
|
60 |
+
if 'last_epoch' not in config.TRAIN.LR_SCHEDULER.ARGS[0].keys():
|
61 |
+
config.TRAIN.LR_SCHEDULER.ARGS[0]['last_epoch'] = config.TRAIN.START_EPOCH - 1
|
62 |
+
|
63 |
+
scheduler = getattr(lr_scheduler, config.TRAIN.LR_SCHEDULER.NAME)(optimizer=optimizer,
|
64 |
+
**config.TRAIN.LR_SCHEDULER.ARGS[0])
|
65 |
+
logger.info(f"Use scheduler: name:{config.TRAIN.LR_SCHEDULER.NAME} args: {config.TRAIN.LR_SCHEDULER.ARGS[0]}")
|
66 |
+
logger.info(f"Current scheduler last lr: {scheduler.get_last_lr()}")
|
67 |
+
else:
|
68 |
+
scheduler = None
|
69 |
+
|
70 |
+
if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device:
|
71 |
+
import apex
|
72 |
+
logger.info(f"use amp:{config.AMP_OPT_LEVEL}")
|
73 |
+
model, optimizer = apex.amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL, verbosity=0)
|
74 |
+
if ddp:
|
75 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.TRAIN.DEVICE],
|
76 |
+
broadcast_buffers=True) # use rank:0 bn
|
77 |
+
|
78 |
+
criterion = build_criterion(config, logger)
|
79 |
+
if optimizer is not None:
|
80 |
+
logger.info(f"Finally lr: {optimizer.param_groups[0]['lr']}")
|
81 |
+
return model, optimizer, criterion, scheduler
|
models/lgt_net.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import models.modules as modules
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from models.base_model import BaseModule
|
8 |
+
from models.modules.horizon_net_feature_extractor import HorizonNetFeatureExtractor
|
9 |
+
from models.modules.patch_feature_extractor import PatchFeatureExtractor
|
10 |
+
from utils.conversion import uv2depth, get_u, lonlat2depth, get_lon, lonlat2uv
|
11 |
+
from utils.height import calc_ceil_ratio
|
12 |
+
from utils.misc import tensor2np
|
13 |
+
|
14 |
+
|
15 |
+
class LGT_Net(BaseModule):
|
16 |
+
def __init__(self, ckpt_dir=None, backbone='resnet50', dropout=0.0, output_name='LGT',
|
17 |
+
decoder_name='Transformer', win_size=8, depth=6,
|
18 |
+
ape=None, rpe=None, corner_heat_map=False, rpe_pos=1):
|
19 |
+
super().__init__(ckpt_dir)
|
20 |
+
|
21 |
+
self.patch_num = 256
|
22 |
+
self.patch_dim = 1024
|
23 |
+
self.decoder_name = decoder_name
|
24 |
+
self.output_name = output_name
|
25 |
+
self.corner_heat_map = corner_heat_map
|
26 |
+
self.dropout_d = dropout
|
27 |
+
|
28 |
+
if backbone == 'patch':
|
29 |
+
self.feature_extractor = PatchFeatureExtractor(patch_num=self.patch_num, input_shape=[3, 512, 1024])
|
30 |
+
else:
|
31 |
+
# feature extractor
|
32 |
+
self.feature_extractor = HorizonNetFeatureExtractor(backbone)
|
33 |
+
|
34 |
+
if 'Transformer' in self.decoder_name:
|
35 |
+
# transformer encoder
|
36 |
+
transformer_dim = self.patch_dim
|
37 |
+
transformer_layers = depth
|
38 |
+
transformer_heads = 8
|
39 |
+
transformer_head_dim = transformer_dim // transformer_heads
|
40 |
+
transformer_ff_dim = 2048
|
41 |
+
rpe = None if rpe == 'None' else rpe
|
42 |
+
self.transformer = getattr(modules, decoder_name)(dim=transformer_dim, depth=transformer_layers,
|
43 |
+
heads=transformer_heads, dim_head=transformer_head_dim,
|
44 |
+
mlp_dim=transformer_ff_dim, win_size=win_size,
|
45 |
+
dropout=self.dropout_d, patch_num=self.patch_num,
|
46 |
+
ape=ape, rpe=rpe, rpe_pos=rpe_pos)
|
47 |
+
elif self.decoder_name == 'LSTM':
|
48 |
+
self.bi_rnn = nn.LSTM(input_size=self.feature_extractor.c_last,
|
49 |
+
hidden_size=self.patch_dim // 2,
|
50 |
+
num_layers=2,
|
51 |
+
dropout=self.dropout_d,
|
52 |
+
batch_first=False,
|
53 |
+
bidirectional=True)
|
54 |
+
self.drop_out = nn.Dropout(self.dropout_d)
|
55 |
+
else:
|
56 |
+
raise NotImplementedError("Only support *Transformer and LSTM")
|
57 |
+
|
58 |
+
if self.output_name == 'LGT':
|
59 |
+
# omnidirectional-geometry aware output
|
60 |
+
self.linear_depth_output = nn.Linear(in_features=self.patch_dim, out_features=1)
|
61 |
+
self.linear_ratio = nn.Linear(in_features=self.patch_dim, out_features=1)
|
62 |
+
self.linear_ratio_output = nn.Linear(in_features=self.patch_num, out_features=1)
|
63 |
+
elif self.output_name == 'LED' or self.output_name == 'Horizon':
|
64 |
+
# horizon-depth or latitude output
|
65 |
+
self.linear = nn.Linear(in_features=self.patch_dim, out_features=2)
|
66 |
+
else:
|
67 |
+
raise NotImplementedError("Unknown output")
|
68 |
+
|
69 |
+
if self.corner_heat_map:
|
70 |
+
# corners heat map output
|
71 |
+
self.linear_corner_heat_map_output = nn.Linear(in_features=self.patch_dim, out_features=1)
|
72 |
+
|
73 |
+
self.name = f"{self.decoder_name}_{self.output_name}_Net"
|
74 |
+
|
75 |
+
def lgt_output(self, x):
|
76 |
+
"""
|
77 |
+
:param x: [ b, 256(patch_num), 1024(d)]
|
78 |
+
:return: {
|
79 |
+
'depth': [b, 256(patch_num & d)]
|
80 |
+
'ratio': [b, 1(d)]
|
81 |
+
}
|
82 |
+
"""
|
83 |
+
depth = self.linear_depth_output(x) # [b, 256(patch_num), 1(d)]
|
84 |
+
depth = depth.view(-1, self.patch_num) # [b, 256(patch_num & d)]
|
85 |
+
|
86 |
+
# ratio represent room height
|
87 |
+
ratio = self.linear_ratio(x) # [b, 256(patch_num), 1(d)]
|
88 |
+
ratio = ratio.view(-1, self.patch_num) # [b, 256(patch_num & d)]
|
89 |
+
ratio = self.linear_ratio_output(ratio) # [b, 1(d)]
|
90 |
+
output = {
|
91 |
+
'depth': depth,
|
92 |
+
'ratio': ratio
|
93 |
+
}
|
94 |
+
return output
|
95 |
+
|
96 |
+
def led_output(self, x):
|
97 |
+
"""
|
98 |
+
:param x: [ b, 256(patch_num), 1024(d)]
|
99 |
+
:return: {
|
100 |
+
'depth': [b, 256(patch_num)]
|
101 |
+
'ceil_depth': [b, 256(patch_num)]
|
102 |
+
'ratio': [b, 1(d)]
|
103 |
+
}
|
104 |
+
"""
|
105 |
+
bon = self.linear(x) # [b, 256(patch_num), 2(d)]
|
106 |
+
bon = bon.permute(0, 2, 1) # [b, 2(d), 256(patch_num)]
|
107 |
+
bon = torch.sigmoid(bon)
|
108 |
+
|
109 |
+
ceil_v = bon[:, 0, :] * -0.5 + 0.5 # [b, 256(patch_num)]
|
110 |
+
floor_v = bon[:, 1, :] * 0.5 + 0.5 # [b, 256(patch_num)]
|
111 |
+
u = get_u(w=self.patch_num, is_np=False, b=ceil_v.shape[0]).to(ceil_v.device)
|
112 |
+
ceil_boundary = torch.stack((u, ceil_v), axis=-1) # [b, 256(patch_num), 2]
|
113 |
+
floor_boundary = torch.stack((u, floor_v), axis=-1) # [b, 256(patch_num), 2]
|
114 |
+
output = {
|
115 |
+
'depth': uv2depth(floor_boundary), # [b, 256(patch_num)]
|
116 |
+
'ceil_depth': uv2depth(ceil_boundary), # [b, 256(patch_num)]
|
117 |
+
}
|
118 |
+
# print(output['depth'].mean())
|
119 |
+
if not self.training:
|
120 |
+
# [b, 1(d)]
|
121 |
+
output['ratio'] = calc_ceil_ratio([tensor2np(ceil_boundary), tensor2np(floor_boundary)], mode='lsq').reshape(-1, 1)
|
122 |
+
return output
|
123 |
+
|
124 |
+
def horizon_output(self, x):
|
125 |
+
"""
|
126 |
+
:param x: [ b, 256(patch_num), 1024(d)]
|
127 |
+
:return: {
|
128 |
+
'floor_boundary': [b, 256(patch_num)]
|
129 |
+
'ceil_boundary': [b, 256(patch_num)]
|
130 |
+
}
|
131 |
+
"""
|
132 |
+
bon = self.linear(x) # [b, 256(patch_num), 2(d)]
|
133 |
+
bon = bon.permute(0, 2, 1) # [b, 2(d), 256(patch_num)]
|
134 |
+
|
135 |
+
output = {
|
136 |
+
'boundary': bon
|
137 |
+
}
|
138 |
+
if not self.training:
|
139 |
+
lon = get_lon(w=self.patch_num, is_np=False, b=bon.shape[0]).to(bon.device)
|
140 |
+
floor_lat = torch.clip(bon[:, 0, :], 1e-4, np.pi / 2)
|
141 |
+
ceil_lat = torch.clip(bon[:, 1, :], -np.pi / 2, -1e-4)
|
142 |
+
floor_lonlat = torch.stack((lon, floor_lat), axis=-1) # [b, 256(patch_num), 2]
|
143 |
+
ceil_lonlat = torch.stack((lon, ceil_lat), axis=-1) # [b, 256(patch_num), 2]
|
144 |
+
output['depth'] = lonlat2depth(floor_lonlat)
|
145 |
+
output['ratio'] = calc_ceil_ratio([tensor2np(lonlat2uv(ceil_lonlat)),
|
146 |
+
tensor2np(lonlat2uv(floor_lonlat))], mode='mean').reshape(-1, 1)
|
147 |
+
return output
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
"""
|
151 |
+
:param x: [b, 3(d), 512(h), 1024(w)]
|
152 |
+
:return: {
|
153 |
+
'depth': [b, 256(patch_num & d)]
|
154 |
+
'ratio': [b, 1(d)]
|
155 |
+
}
|
156 |
+
"""
|
157 |
+
|
158 |
+
# feature extractor
|
159 |
+
x = self.feature_extractor(x) # [b 1024(d) 256(w)]
|
160 |
+
|
161 |
+
if 'Transformer' in self.decoder_name:
|
162 |
+
# transformer decoder
|
163 |
+
x = x.permute(0, 2, 1) # [b 256(patch_num) 1024(d)]
|
164 |
+
x = self.transformer(x) # [b 256(patch_num) 1024(d)]
|
165 |
+
elif self.decoder_name == 'LSTM':
|
166 |
+
# lstm decoder
|
167 |
+
x = x.permute(2, 0, 1) # [256(patch_num), b, 1024(d)]
|
168 |
+
self.bi_rnn.flatten_parameters()
|
169 |
+
x, _ = self.bi_rnn(x) # [256(patch_num & seq_len), b, 1024(d)]
|
170 |
+
x = x.permute(1, 0, 2) # [b, 256(patch_num), 1024(d)]
|
171 |
+
x = self.drop_out(x)
|
172 |
+
|
173 |
+
output = None
|
174 |
+
if self.output_name == 'LGT':
|
175 |
+
# plt output
|
176 |
+
output = self.lgt_output(x)
|
177 |
+
|
178 |
+
elif self.output_name == 'LED':
|
179 |
+
# led output
|
180 |
+
output = self.led_output(x)
|
181 |
+
|
182 |
+
elif self.output_name == 'Horizon':
|
183 |
+
# led output
|
184 |
+
output = self.horizon_output(x)
|
185 |
+
|
186 |
+
if self.corner_heat_map:
|
187 |
+
corner_heat_map = self.linear_corner_heat_map_output(x) # [b, 256(patch_num), 1]
|
188 |
+
corner_heat_map = corner_heat_map.view(-1, self.patch_num)
|
189 |
+
corner_heat_map = torch.sigmoid(corner_heat_map)
|
190 |
+
output['corner_heat_map'] = corner_heat_map
|
191 |
+
|
192 |
+
return output
|
193 |
+
|
194 |
+
|
195 |
+
if __name__ == '__main__':
|
196 |
+
from PIL import Image
|
197 |
+
import numpy as np
|
198 |
+
from models.other.init_env import init_env
|
199 |
+
|
200 |
+
init_env(0, deterministic=True)
|
201 |
+
|
202 |
+
net = LGT_Net()
|
203 |
+
|
204 |
+
total = sum(p.numel() for p in net.parameters())
|
205 |
+
trainable = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
206 |
+
print('parameter total:{:,}, trainable:{:,}'.format(total, trainable))
|
207 |
+
|
208 |
+
img = np.array(Image.open("../src/demo.png")).transpose((2, 0, 1))
|
209 |
+
input = torch.Tensor([img]) # 1 3 512 1024
|
210 |
+
output = net(input)
|
211 |
+
|
212 |
+
print(output['depth'].shape) # 1 256
|
213 |
+
print(output['ratio'].shape) # 1 1
|
models/modules/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/09/01
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
|
6 |
+
from models.modules.swin_transformer import Swin_Transformer
|
7 |
+
from models.modules.swg_transformer import SWG_Transformer
|
8 |
+
from models.modules.transformer import Transformer
|
models/modules/conv_transformer.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from torch import nn, einsum
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
|
8 |
+
class PreNorm(nn.Module):
|
9 |
+
def __init__(self, dim, fn):
|
10 |
+
super().__init__()
|
11 |
+
self.norm = nn.LayerNorm(dim)
|
12 |
+
self.fn = fn
|
13 |
+
|
14 |
+
def forward(self, x, **kwargs):
|
15 |
+
return self.fn(self.norm(x), **kwargs)
|
16 |
+
|
17 |
+
|
18 |
+
class GELU(nn.Module):
|
19 |
+
def forward(self, input):
|
20 |
+
return F.gelu(input)
|
21 |
+
|
22 |
+
|
23 |
+
class Attend(nn.Module):
|
24 |
+
|
25 |
+
def __init__(self, dim=None):
|
26 |
+
super().__init__()
|
27 |
+
self.dim = dim
|
28 |
+
|
29 |
+
def forward(self, input):
|
30 |
+
return F.softmax(input, dim=self.dim, dtype=input.dtype)
|
31 |
+
|
32 |
+
|
33 |
+
class FeedForward(nn.Module):
|
34 |
+
def __init__(self, dim, hidden_dim, dropout=0.):
|
35 |
+
super().__init__()
|
36 |
+
self.net = nn.Sequential(
|
37 |
+
nn.Linear(dim, hidden_dim),
|
38 |
+
GELU(),
|
39 |
+
nn.Dropout(dropout),
|
40 |
+
nn.Linear(hidden_dim, dim),
|
41 |
+
nn.Dropout(dropout)
|
42 |
+
)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
return self.net(x)
|
46 |
+
|
47 |
+
|
48 |
+
class Attention(nn.Module):
|
49 |
+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
50 |
+
super().__init__()
|
51 |
+
inner_dim = dim_head * heads
|
52 |
+
project_out = not (heads == 1 and dim_head == dim)
|
53 |
+
|
54 |
+
self.heads = heads
|
55 |
+
self.scale = dim_head ** -0.5
|
56 |
+
|
57 |
+
self.attend = Attend(dim=-1)
|
58 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
59 |
+
|
60 |
+
self.to_out = nn.Sequential(
|
61 |
+
nn.Linear(inner_dim, dim),
|
62 |
+
nn.Dropout(dropout)
|
63 |
+
) if project_out else nn.Identity()
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
b, n, _, h = *x.shape, self.heads
|
67 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
68 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
|
69 |
+
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
70 |
+
attn = self.attend(dots)
|
71 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
72 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
73 |
+
return self.to_out(out)
|
74 |
+
|
75 |
+
|
76 |
+
class Conv(nn.Module):
|
77 |
+
def __init__(self, dim, dropout=0.):
|
78 |
+
super().__init__()
|
79 |
+
self.dim = dim
|
80 |
+
self.net = nn.Sequential(
|
81 |
+
nn.Conv1d(dim, dim, kernel_size=3, stride=1, padding=0),
|
82 |
+
nn.Dropout(dropout)
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
x = x.transpose(1, 2)
|
87 |
+
x = torch.cat([x[..., -1:], x, x[..., :1]], dim=-1)
|
88 |
+
x = self.net(x)
|
89 |
+
return x.transpose(1, 2)
|
90 |
+
|
91 |
+
|
92 |
+
class ConvTransformer(nn.Module):
|
93 |
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
94 |
+
super().__init__()
|
95 |
+
self.layers = nn.ModuleList([])
|
96 |
+
for _ in range(depth):
|
97 |
+
self.layers.append(nn.ModuleList([
|
98 |
+
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
|
99 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)),
|
100 |
+
PreNorm(dim, Conv(dim, dropout=dropout))
|
101 |
+
]))
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
for attn, ff, cov in self.layers:
|
105 |
+
x = attn(x) + x
|
106 |
+
x = ff(x) + x
|
107 |
+
x = cov(x) + x
|
108 |
+
return x
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == '__main__':
|
112 |
+
token_dim = 1024
|
113 |
+
toke_len = 256
|
114 |
+
|
115 |
+
transformer = ConvTransformer(dim=token_dim,
|
116 |
+
depth=6,
|
117 |
+
heads=16,
|
118 |
+
dim_head=64,
|
119 |
+
mlp_dim=2048,
|
120 |
+
dropout=0.1)
|
121 |
+
|
122 |
+
total = sum(p.numel() for p in transformer.parameters())
|
123 |
+
trainable = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
|
124 |
+
print('parameter total:{:,}, trainable:{:,}'.format(total, trainable))
|
125 |
+
|
126 |
+
input = torch.randn(1, toke_len, token_dim)
|
127 |
+
output = transformer(input)
|
128 |
+
print(output.shape)
|
models/modules/horizon_net_feature_extractor.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@author:
|
3 |
+
@Date: 2021/07/17
|
4 |
+
@description: Use the feature extractor proposed by HorizonNet
|
5 |
+
"""
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torchvision.models as models
|
13 |
+
import functools
|
14 |
+
from models.base_model import BaseModule
|
15 |
+
|
16 |
+
ENCODER_RESNET = [
|
17 |
+
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
|
18 |
+
'resnext50_32x4d', 'resnext101_32x8d'
|
19 |
+
]
|
20 |
+
ENCODER_DENSENET = [
|
21 |
+
'densenet121', 'densenet169', 'densenet161', 'densenet201'
|
22 |
+
]
|
23 |
+
|
24 |
+
|
25 |
+
def lr_pad(x, padding=1):
|
26 |
+
''' Pad left/right-most to each other instead of zero padding '''
|
27 |
+
return torch.cat([x[..., -padding:], x, x[..., :padding]], dim=3)
|
28 |
+
|
29 |
+
|
30 |
+
class LR_PAD(nn.Module):
|
31 |
+
''' Pad left/right-most to each other instead of zero padding '''
|
32 |
+
|
33 |
+
def __init__(self, padding=1):
|
34 |
+
super(LR_PAD, self).__init__()
|
35 |
+
self.padding = padding
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return lr_pad(x, self.padding)
|
39 |
+
|
40 |
+
|
41 |
+
def wrap_lr_pad(net):
|
42 |
+
for name, m in net.named_modules():
|
43 |
+
if not isinstance(m, nn.Conv2d):
|
44 |
+
continue
|
45 |
+
if m.padding[1] == 0:
|
46 |
+
continue
|
47 |
+
w_pad = int(m.padding[1])
|
48 |
+
m.padding = (m.padding[0], 0) # weight padding is 0, LR_PAD then use valid padding will keep dim of weight
|
49 |
+
names = name.split('.')
|
50 |
+
|
51 |
+
root = functools.reduce(lambda o, i: getattr(o, i), [net] + names[:-1])
|
52 |
+
setattr(
|
53 |
+
root, names[-1],
|
54 |
+
nn.Sequential(LR_PAD(w_pad), m)
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
'''
|
59 |
+
Encoder
|
60 |
+
'''
|
61 |
+
|
62 |
+
|
63 |
+
class Resnet(nn.Module):
|
64 |
+
def __init__(self, backbone='resnet50', pretrained=True):
|
65 |
+
super(Resnet, self).__init__()
|
66 |
+
assert backbone in ENCODER_RESNET
|
67 |
+
self.encoder = getattr(models, backbone)(pretrained=pretrained)
|
68 |
+
del self.encoder.fc, self.encoder.avgpool
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
features = []
|
72 |
+
x = self.encoder.conv1(x)
|
73 |
+
x = self.encoder.bn1(x)
|
74 |
+
x = self.encoder.relu(x)
|
75 |
+
x = self.encoder.maxpool(x)
|
76 |
+
|
77 |
+
x = self.encoder.layer1(x)
|
78 |
+
features.append(x) # 1/4
|
79 |
+
x = self.encoder.layer2(x)
|
80 |
+
features.append(x) # 1/8
|
81 |
+
x = self.encoder.layer3(x)
|
82 |
+
features.append(x) # 1/16
|
83 |
+
x = self.encoder.layer4(x)
|
84 |
+
features.append(x) # 1/32
|
85 |
+
return features
|
86 |
+
|
87 |
+
def list_blocks(self):
|
88 |
+
lst = [m for m in self.encoder.children()]
|
89 |
+
block0 = lst[:4]
|
90 |
+
block1 = lst[4:5]
|
91 |
+
block2 = lst[5:6]
|
92 |
+
block3 = lst[6:7]
|
93 |
+
block4 = lst[7:8]
|
94 |
+
return block0, block1, block2, block3, block4
|
95 |
+
|
96 |
+
|
97 |
+
class Densenet(nn.Module):
|
98 |
+
def __init__(self, backbone='densenet169', pretrained=True):
|
99 |
+
super(Densenet, self).__init__()
|
100 |
+
assert backbone in ENCODER_DENSENET
|
101 |
+
self.encoder = getattr(models, backbone)(pretrained=pretrained)
|
102 |
+
self.final_relu = nn.ReLU(inplace=True)
|
103 |
+
del self.encoder.classifier
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
lst = []
|
107 |
+
for m in self.encoder.features.children():
|
108 |
+
x = m(x)
|
109 |
+
lst.append(x)
|
110 |
+
features = [lst[4], lst[6], lst[8], self.final_relu(lst[11])]
|
111 |
+
return features
|
112 |
+
|
113 |
+
def list_blocks(self):
|
114 |
+
lst = [m for m in self.encoder.features.children()]
|
115 |
+
block0 = lst[:4]
|
116 |
+
block1 = lst[4:6]
|
117 |
+
block2 = lst[6:8]
|
118 |
+
block3 = lst[8:10]
|
119 |
+
block4 = lst[10:]
|
120 |
+
return block0, block1, block2, block3, block4
|
121 |
+
|
122 |
+
|
123 |
+
'''
|
124 |
+
Decoder
|
125 |
+
'''
|
126 |
+
|
127 |
+
|
128 |
+
class ConvCompressH(nn.Module):
|
129 |
+
''' Reduce feature height by factor of two '''
|
130 |
+
|
131 |
+
def __init__(self, in_c, out_c, ks=3):
|
132 |
+
super(ConvCompressH, self).__init__()
|
133 |
+
assert ks % 2 == 1
|
134 |
+
self.layers = nn.Sequential(
|
135 |
+
nn.Conv2d(in_c, out_c, kernel_size=ks, stride=(2, 1), padding=ks // 2),
|
136 |
+
nn.BatchNorm2d(out_c),
|
137 |
+
nn.ReLU(inplace=True),
|
138 |
+
)
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
return self.layers(x)
|
142 |
+
|
143 |
+
|
144 |
+
class GlobalHeightConv(nn.Module):
|
145 |
+
def __init__(self, in_c, out_c):
|
146 |
+
super(GlobalHeightConv, self).__init__()
|
147 |
+
self.layer = nn.Sequential(
|
148 |
+
ConvCompressH(in_c, in_c // 2),
|
149 |
+
ConvCompressH(in_c // 2, in_c // 2),
|
150 |
+
ConvCompressH(in_c // 2, in_c // 4),
|
151 |
+
ConvCompressH(in_c // 4, out_c),
|
152 |
+
)
|
153 |
+
|
154 |
+
def forward(self, x, out_w):
|
155 |
+
x = self.layer(x)
|
156 |
+
|
157 |
+
factor = out_w // x.shape[3]
|
158 |
+
x = torch.cat([x[..., -1:], x, x[..., :1]], 3) # 先补左右,相当于warp模式,然后进行插值
|
159 |
+
d_type = x.dtype
|
160 |
+
x = F.interpolate(x, size=(x.shape[2], out_w + 2 * factor), mode='bilinear', align_corners=False)
|
161 |
+
# if x.dtype != d_type:
|
162 |
+
# x = x.type(d_type)
|
163 |
+
x = x[..., factor:-factor]
|
164 |
+
return x
|
165 |
+
|
166 |
+
|
167 |
+
class GlobalHeightStage(nn.Module):
|
168 |
+
def __init__(self, c1, c2, c3, c4, out_scale=8):
|
169 |
+
''' Process 4 blocks from encoder to single multiscale features '''
|
170 |
+
super(GlobalHeightStage, self).__init__()
|
171 |
+
self.cs = c1, c2, c3, c4
|
172 |
+
self.out_scale = out_scale
|
173 |
+
self.ghc_lst = nn.ModuleList([
|
174 |
+
GlobalHeightConv(c1, c1 // out_scale),
|
175 |
+
GlobalHeightConv(c2, c2 // out_scale),
|
176 |
+
GlobalHeightConv(c3, c3 // out_scale),
|
177 |
+
GlobalHeightConv(c4, c4 // out_scale),
|
178 |
+
])
|
179 |
+
|
180 |
+
def forward(self, conv_list, out_w):
|
181 |
+
assert len(conv_list) == 4
|
182 |
+
bs = conv_list[0].shape[0]
|
183 |
+
feature = torch.cat([
|
184 |
+
f(x, out_w).reshape(bs, -1, out_w)
|
185 |
+
for f, x, out_c in zip(self.ghc_lst, conv_list, self.cs)
|
186 |
+
], dim=1)
|
187 |
+
# conv_list:
|
188 |
+
# 0 [b, 256(d), 128(h), 256(w)] ->(4*{conv3*3 step2*1} : d/8 h/16)-> [b 32(d) 8(h) 256(w)]
|
189 |
+
# 1 [b, 512(d), 64(h), 128(w)] ->(4*{conv3*3 step2*1} : d/8 h/16)-> [b 64(d) 4(h) 128(w)]
|
190 |
+
# 2 [b, 1024(d), 32(h), 64(w)] ->(4*{conv3*3 step2*1} : d/8 h/16)-> [b 128(d) 2(h) 64(w)]
|
191 |
+
# 3 [b, 2048(d), 16(h), 32(w)] ->(4*{conv3*3 step2*1} : d/8 h/16)-> [b 256(d) 1(h) 32(w)]
|
192 |
+
# 0 ->(unsampledW256} : w=256)-> [b 32(d) 8(h) 256(w)] ->(reshapeH1} : h=1)-> [b 256(d) 1(h) 256(w)]
|
193 |
+
# 1 ->(unsampledW256} : w=256)-> [b 64(d) 4(h) 256(w)] ->(reshapeH1} : h=1)-> [b 256(d) 1(h) 256(w)]
|
194 |
+
# 2 ->(unsampledW256} : w=256)-> [b 128(d) 2(h) 256(w)] ->(reshapeH1} : h=1)-> [b 256(d) 1(h) 256(w)]
|
195 |
+
# 3 ->(unsampledW256} : w=256)-> [b 256(d) 1(h) 256(w)] ->(reshapeH1} : h=1)-> [b 256(d) 1(h) 256(w)]
|
196 |
+
# 0 --\
|
197 |
+
# 1 -- \
|
198 |
+
# ---- cat [b 1024(d) 1(h) 256(w)]
|
199 |
+
# 2 -- /
|
200 |
+
# 3 --/
|
201 |
+
return feature # [b 1024(d) 256(w)]
|
202 |
+
|
203 |
+
|
204 |
+
class HorizonNetFeatureExtractor(nn.Module):
|
205 |
+
x_mean = torch.FloatTensor(np.array([0.485, 0.456, 0.406])[None, :, None, None])
|
206 |
+
x_std = torch.FloatTensor(np.array([0.229, 0.224, 0.225])[None, :, None, None])
|
207 |
+
|
208 |
+
def __init__(self, backbone='resnet50'):
|
209 |
+
super(HorizonNetFeatureExtractor, self).__init__()
|
210 |
+
self.out_scale = 8
|
211 |
+
self.step_cols = 4
|
212 |
+
|
213 |
+
# Encoder
|
214 |
+
if backbone.startswith('res'):
|
215 |
+
self.feature_extractor = Resnet(backbone, pretrained=True)
|
216 |
+
elif backbone.startswith('dense'):
|
217 |
+
self.feature_extractor = Densenet(backbone, pretrained=True)
|
218 |
+
else:
|
219 |
+
raise NotImplementedError()
|
220 |
+
|
221 |
+
# Inference channels number from each block of the encoder
|
222 |
+
with torch.no_grad():
|
223 |
+
dummy = torch.zeros(1, 3, 512, 1024)
|
224 |
+
c1, c2, c3, c4 = [b.shape[1] for b in self.feature_extractor(dummy)]
|
225 |
+
self.c_last = (c1 * 8 + c2 * 4 + c3 * 2 + c4 * 1) // self.out_scale
|
226 |
+
|
227 |
+
# Convert features from 4 blocks of the encoder into B x C x 1 x W'
|
228 |
+
self.reduce_height_module = GlobalHeightStage(c1, c2, c3, c4, self.out_scale)
|
229 |
+
self.x_mean.requires_grad = False
|
230 |
+
self.x_std.requires_grad = False
|
231 |
+
wrap_lr_pad(self)
|
232 |
+
|
233 |
+
def _prepare_x(self, x):
|
234 |
+
x = x.clone()
|
235 |
+
if self.x_mean.device != x.device:
|
236 |
+
self.x_mean = self.x_mean.to(x.device)
|
237 |
+
self.x_std = self.x_std.to(x.device)
|
238 |
+
x[:, :3] = (x[:, :3] - self.x_mean) / self.x_std
|
239 |
+
|
240 |
+
return x
|
241 |
+
|
242 |
+
def forward(self, x):
|
243 |
+
# x [b 3 512 1024]
|
244 |
+
x = self._prepare_x(x) # [b 3 512 1024]
|
245 |
+
conv_list = self.feature_extractor(x)
|
246 |
+
# conv_list:
|
247 |
+
# 0 [b, 256(d), 128(h), 256(w)]
|
248 |
+
# 1 [b, 512(d), 64(h), 128(w)]
|
249 |
+
# 2 [b, 1024(d), 32(h), 64(w)]
|
250 |
+
# 3 [b, 2048(d), 16(h), 32(w)]
|
251 |
+
x = self.reduce_height_module(conv_list, x.shape[3] // self.step_cols) # [b 1024(d) 1(h) 256(w)]
|
252 |
+
# After reduce_Height_module, h becomes 1, the information is compressed to d,
|
253 |
+
# and w contains different resolutions
|
254 |
+
# 0 [b, 256(d), 128(h), 256(w)] -> [b, 256/8(d) * 128/16(h') = 256(d), 1(h) 256(w)]
|
255 |
+
# 1 [b, 512(d), 64(h), 128(w)] -> [b, 512/8(d) * 64/16(h') = 256(d), 1(h) 256(w)]
|
256 |
+
# 2 [b, 1024(d), 32(h), 64(w)] -> [b, 1024/8(d) * 32/16(h') = 256(d), 1(h) 256(w)]
|
257 |
+
# 3 [b, 2048(d), 16(h), 32(w)] -> [b, 2048/8(d) * 16/16(h') = 256(d), 1(h) 256(w)]
|
258 |
+
return x # [b 1024(d) 1(h) 256(w)]
|
259 |
+
|
260 |
+
|
261 |
+
if __name__ == '__main__':
|
262 |
+
from PIL import Image
|
263 |
+
extractor = HorizonNetFeatureExtractor()
|
264 |
+
img = np.array(Image.open("../../src/demo.png")).transpose((2, 0, 1))
|
265 |
+
input = torch.Tensor([img]) # 1 3 512 1024
|
266 |
+
feature = extractor(input)
|
267 |
+
print(feature.shape) # 1, 1024, 256 | 1024 = (out_c_0*h_0 +... + out_c_3*h_3) = 256 * 4
|
models/modules/patch_feature_extractor.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from einops.layers.torch import Rearrange
|
5 |
+
|
6 |
+
|
7 |
+
class PatchFeatureExtractor(nn.Module):
|
8 |
+
x_mean = torch.FloatTensor(np.array([0.485, 0.456, 0.406])[None, :, None, None])
|
9 |
+
x_std = torch.FloatTensor(np.array([0.229, 0.224, 0.225])[None, :, None, None])
|
10 |
+
|
11 |
+
def __init__(self, patch_num=256, input_shape=None):
|
12 |
+
super(PatchFeatureExtractor, self).__init__()
|
13 |
+
|
14 |
+
if input_shape is None:
|
15 |
+
input_shape = [3, 512, 1024]
|
16 |
+
self.patch_dim = 1024
|
17 |
+
self.patch_num = patch_num
|
18 |
+
|
19 |
+
img_channel = input_shape[0]
|
20 |
+
img_h = input_shape[1]
|
21 |
+
img_w = input_shape[2]
|
22 |
+
|
23 |
+
p_h, p_w = img_h, img_w // self.patch_num
|
24 |
+
p_dim = p_h * p_w * img_channel
|
25 |
+
|
26 |
+
self.patch_embedding = nn.Sequential(
|
27 |
+
Rearrange('b c h (p_n p_w) -> b p_n (h p_w c)', p_w=p_w),
|
28 |
+
nn.Linear(p_dim, self.patch_dim)
|
29 |
+
)
|
30 |
+
|
31 |
+
self.x_mean.requires_grad = False
|
32 |
+
self.x_std.requires_grad = False
|
33 |
+
|
34 |
+
def _prepare_x(self, x):
|
35 |
+
x = x.clone()
|
36 |
+
if self.x_mean.device != x.device:
|
37 |
+
self.x_mean = self.x_mean.to(x.device)
|
38 |
+
self.x_std = self.x_std.to(x.device)
|
39 |
+
x[:, :3] = (x[:, :3] - self.x_mean) / self.x_std
|
40 |
+
|
41 |
+
return x
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
# x [b 3 512 1024]
|
45 |
+
x = self._prepare_x(x) # [b 3 512 1024]
|
46 |
+
x = self.patch_embedding(x) # [b 256(patch_num) 1024(d)]
|
47 |
+
x = x.permute(0, 2, 1) # [b 1024(d) 256(patch_num)]
|
48 |
+
return x
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
from PIL import Image
|
53 |
+
extractor = PatchFeatureExtractor()
|
54 |
+
img = np.array(Image.open("../../src/demo.png")).transpose((2, 0, 1))
|
55 |
+
input = torch.Tensor([img]) # 1 3 512 1024
|
56 |
+
feature = extractor(input)
|
57 |
+
print(feature.shape) # 1, 1024, 256
|
models/modules/swg_transformer.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.modules.transformer_modules import *
|
2 |
+
|
3 |
+
|
4 |
+
class SWG_Transformer(nn.Module):
|
5 |
+
def __init__(self, dim, depth, heads, win_size, dim_head, mlp_dim,
|
6 |
+
dropout=0., patch_num=None, ape=None, rpe=None, rpe_pos=1):
|
7 |
+
super().__init__()
|
8 |
+
self.absolute_pos_embed = None if patch_num is None or ape is None else AbsolutePosition(dim, dropout,
|
9 |
+
patch_num, ape)
|
10 |
+
self.pos_dropout = nn.Dropout(dropout)
|
11 |
+
self.layers = nn.ModuleList([])
|
12 |
+
for i in range(depth):
|
13 |
+
if i % 2 == 0:
|
14 |
+
attention = WinAttention(dim, win_size=win_size, shift=0 if (i % 3 == 0) else win_size // 2,
|
15 |
+
heads=heads, dim_head=dim_head, dropout=dropout, rpe=rpe, rpe_pos=rpe_pos)
|
16 |
+
else:
|
17 |
+
attention = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout,
|
18 |
+
patch_num=patch_num, rpe=rpe, rpe_pos=rpe_pos)
|
19 |
+
|
20 |
+
self.layers.append(nn.ModuleList([
|
21 |
+
PreNorm(dim, attention),
|
22 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)),
|
23 |
+
]))
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
if self.absolute_pos_embed is not None:
|
27 |
+
x = self.absolute_pos_embed(x)
|
28 |
+
x = self.pos_dropout(x)
|
29 |
+
for attn, ff in self.layers:
|
30 |
+
x = attn(x) + x
|
31 |
+
x = ff(x) + x
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == '__main__':
|
36 |
+
token_dim = 1024
|
37 |
+
toke_len = 256
|
38 |
+
|
39 |
+
transformer = SWG_Transformer(dim=token_dim,
|
40 |
+
depth=6,
|
41 |
+
heads=16,
|
42 |
+
win_size=8,
|
43 |
+
dim_head=64,
|
44 |
+
mlp_dim=2048,
|
45 |
+
dropout=0.1)
|
46 |
+
|
47 |
+
input = torch.randn(1, toke_len, token_dim)
|
48 |
+
output = transformer(input)
|
49 |
+
print(output.shape)
|
models/modules/swin_transformer.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.modules.transformer_modules import *
|
2 |
+
|
3 |
+
|
4 |
+
class Swin_Transformer(nn.Module):
|
5 |
+
def __init__(self, dim, depth, heads, win_size, dim_head, mlp_dim,
|
6 |
+
dropout=0., patch_num=None, ape=None, rpe=None, rpe_pos=1):
|
7 |
+
super().__init__()
|
8 |
+
self.absolute_pos_embed = None if patch_num is None or ape is None else AbsolutePosition(dim, dropout,
|
9 |
+
patch_num, ape)
|
10 |
+
self.pos_dropout = nn.Dropout(dropout)
|
11 |
+
self.layers = nn.ModuleList([])
|
12 |
+
for i in range(depth):
|
13 |
+
self.layers.append(nn.ModuleList([
|
14 |
+
PreNorm(dim, WinAttention(dim, win_size=win_size, shift=0 if (i % 2 == 0) else win_size // 2,
|
15 |
+
heads=heads, dim_head=dim_head, dropout=dropout, rpe=rpe, rpe_pos=rpe_pos)),
|
16 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)),
|
17 |
+
]))
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
if self.absolute_pos_embed is not None:
|
21 |
+
x = self.absolute_pos_embed(x)
|
22 |
+
x = self.pos_dropout(x)
|
23 |
+
for attn, ff in self.layers:
|
24 |
+
x = attn(x) + x
|
25 |
+
x = ff(x) + x
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
token_dim = 1024
|
31 |
+
toke_len = 256
|
32 |
+
|
33 |
+
transformer = Swin_Transformer(dim=token_dim,
|
34 |
+
depth=6,
|
35 |
+
heads=16,
|
36 |
+
win_size=8,
|
37 |
+
dim_head=64,
|
38 |
+
mlp_dim=2048,
|
39 |
+
dropout=0.1)
|
40 |
+
|
41 |
+
input = torch.randn(1, toke_len, token_dim)
|
42 |
+
output = transformer(input)
|
43 |
+
print(output.shape)
|
models/modules/transformer.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.modules.transformer_modules import *
|
2 |
+
|
3 |
+
|
4 |
+
class Transformer(nn.Module):
|
5 |
+
def __init__(self, dim, depth, heads, win_size, dim_head, mlp_dim,
|
6 |
+
dropout=0., patch_num=None, ape=None, rpe=None, rpe_pos=1):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
self.absolute_pos_embed = None if patch_num is None or ape is None else AbsolutePosition(dim, dropout,
|
10 |
+
patch_num, ape)
|
11 |
+
self.pos_dropout = nn.Dropout(dropout)
|
12 |
+
self.layers = nn.ModuleList([])
|
13 |
+
for _ in range(depth):
|
14 |
+
self.layers.append(nn.ModuleList([
|
15 |
+
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout, patch_num=patch_num,
|
16 |
+
rpe=rpe, rpe_pos=rpe_pos)),
|
17 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
|
18 |
+
]))
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
if self.absolute_pos_embed is not None:
|
22 |
+
x = self.absolute_pos_embed(x)
|
23 |
+
x = self.pos_dropout(x)
|
24 |
+
for attn, ff in self.layers:
|
25 |
+
x = attn(x) + x
|
26 |
+
x = ff(x) + x
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == '__main__':
|
31 |
+
token_dim = 1024
|
32 |
+
toke_len = 256
|
33 |
+
|
34 |
+
transformer = Transformer(dim=token_dim, depth=6, heads=16,
|
35 |
+
dim_head=64, mlp_dim=2048, dropout=0.1,
|
36 |
+
patch_num=256, ape='lr_parameter', rpe='lr_parameter_mirror')
|
37 |
+
|
38 |
+
total = sum(p.numel() for p in transformer.parameters())
|
39 |
+
trainable = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
|
40 |
+
print('parameter total:{:,}, trainable:{:,}'.format(total, trainable))
|
41 |
+
|
42 |
+
input = torch.randn(1, toke_len, token_dim)
|
43 |
+
output = transformer(input)
|
44 |
+
print(output.shape)
|
models/modules/transformer_modules.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/09/01
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import warnings
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from torch import nn, einsum
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
|
14 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
15 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
16 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
17 |
+
def norm_cdf(x):
|
18 |
+
# Computes standard normal cumulative distribution function
|
19 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
20 |
+
|
21 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
22 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
23 |
+
"The distribution of values may be incorrect.",
|
24 |
+
stacklevel=2)
|
25 |
+
|
26 |
+
with torch.no_grad():
|
27 |
+
# Values are generated by using a truncated uniform distribution and
|
28 |
+
# then using the inverse CDF for the normal distribution.
|
29 |
+
# Get upper and lower cdf values
|
30 |
+
l = norm_cdf((a - mean) / std)
|
31 |
+
u = norm_cdf((b - mean) / std)
|
32 |
+
|
33 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
34 |
+
# [2l-1, 2u-1].
|
35 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
36 |
+
|
37 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
38 |
+
# standard normal
|
39 |
+
tensor.erfinv_()
|
40 |
+
|
41 |
+
# Transform to proper mean, std
|
42 |
+
tensor.mul_(std * math.sqrt(2.))
|
43 |
+
tensor.add_(mean)
|
44 |
+
|
45 |
+
# Clamp to ensure it's in the proper range
|
46 |
+
tensor.clamp_(min=a, max=b)
|
47 |
+
return tensor
|
48 |
+
|
49 |
+
|
50 |
+
class PreNorm(nn.Module):
|
51 |
+
def __init__(self, dim, fn):
|
52 |
+
super().__init__()
|
53 |
+
self.norm = nn.LayerNorm(dim)
|
54 |
+
self.fn = fn
|
55 |
+
|
56 |
+
def forward(self, x, **kwargs):
|
57 |
+
return self.fn(self.norm(x), **kwargs)
|
58 |
+
|
59 |
+
|
60 |
+
# compatibility pytorch < 1.4
|
61 |
+
class GELU(nn.Module):
|
62 |
+
def forward(self, input):
|
63 |
+
return F.gelu(input)
|
64 |
+
|
65 |
+
|
66 |
+
class Attend(nn.Module):
|
67 |
+
|
68 |
+
def __init__(self, dim=None):
|
69 |
+
super().__init__()
|
70 |
+
self.dim = dim
|
71 |
+
|
72 |
+
def forward(self, input):
|
73 |
+
return F.softmax(input, dim=self.dim, dtype=input.dtype)
|
74 |
+
|
75 |
+
|
76 |
+
class FeedForward(nn.Module):
|
77 |
+
def __init__(self, dim, hidden_dim, dropout=0.):
|
78 |
+
super().__init__()
|
79 |
+
self.net = nn.Sequential(
|
80 |
+
nn.Linear(dim, hidden_dim),
|
81 |
+
GELU(),
|
82 |
+
nn.Dropout(dropout),
|
83 |
+
nn.Linear(hidden_dim, dim),
|
84 |
+
nn.Dropout(dropout)
|
85 |
+
)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
return self.net(x)
|
89 |
+
|
90 |
+
|
91 |
+
class RelativePosition(nn.Module):
|
92 |
+
def __init__(self, heads, patch_num=None, rpe=None):
|
93 |
+
super().__init__()
|
94 |
+
self.rpe = rpe
|
95 |
+
self.heads = heads
|
96 |
+
self.patch_num = patch_num
|
97 |
+
|
98 |
+
if rpe == 'lr_parameter':
|
99 |
+
# -255 ~ 0 ~ 255 all count : patch * 2 - 1
|
100 |
+
count = patch_num * 2 - 1
|
101 |
+
self.rpe_table = nn.Parameter(torch.Tensor(count, heads))
|
102 |
+
nn.init.xavier_uniform_(self.rpe_table)
|
103 |
+
elif rpe == 'lr_parameter_mirror':
|
104 |
+
# 0 ~ 127 128 ~ 1 all count : patch_num // 2 + 1
|
105 |
+
count = patch_num // 2 + 1
|
106 |
+
self.rpe_table = nn.Parameter(torch.Tensor(count, heads))
|
107 |
+
nn.init.xavier_uniform_(self.rpe_table)
|
108 |
+
elif rpe == 'lr_parameter_half':
|
109 |
+
# -127 ~ 0 ~ 128 all count : patch
|
110 |
+
count = patch_num
|
111 |
+
self.rpe_table = nn.Parameter(torch.Tensor(count, heads))
|
112 |
+
nn.init.xavier_uniform_(self.rpe_table)
|
113 |
+
elif rpe == 'fix_angle':
|
114 |
+
# 0 ~ 127 128 ~ 1 all count : patch_num // 2 + 1
|
115 |
+
count = patch_num // 2 + 1
|
116 |
+
# we think that closer proximity should have stronger relationships
|
117 |
+
rpe_table = (torch.arange(count, 0, -1) / count)[..., None].repeat(1, heads)
|
118 |
+
self.register_buffer('rpe_table', rpe_table)
|
119 |
+
|
120 |
+
def get_relative_pos_embed(self):
|
121 |
+
range_vec = torch.arange(self.patch_num)
|
122 |
+
distance_mat = range_vec[None, :] - range_vec[:, None]
|
123 |
+
if self.rpe == 'lr_parameter':
|
124 |
+
# -255 ~ 0 ~ 255 -> 0 ~ 255 ~ 255 + 255
|
125 |
+
distance_mat += self.patch_num - 1 # remove negative
|
126 |
+
return self.rpe_table[distance_mat].permute(2, 0, 1)[None]
|
127 |
+
elif self.rpe == 'lr_parameter_mirror' or self.rpe == 'fix_angle':
|
128 |
+
distance_mat[distance_mat < 0] = -distance_mat[distance_mat < 0] # mirror
|
129 |
+
distance_mat[distance_mat > self.patch_num // 2] = self.patch_num - distance_mat[
|
130 |
+
distance_mat > self.patch_num // 2] # remove repeat
|
131 |
+
return self.rpe_table[distance_mat].permute(2, 0, 1)[None]
|
132 |
+
elif self.rpe == 'lr_parameter_half':
|
133 |
+
distance_mat[distance_mat > self.patch_num // 2] = distance_mat[
|
134 |
+
distance_mat > self.patch_num // 2] - self.patch_num # remove repeat > 128 exp: 129 -> -127
|
135 |
+
distance_mat[distance_mat < -self.patch_num // 2 + 1] = distance_mat[
|
136 |
+
distance_mat < -self.patch_num // 2 + 1] + self.patch_num # remove repeat < -127 exp: -128 -> 128
|
137 |
+
# -127 ~ 0 ~ 128 -> 0 ~ 0 ~ 127 + 127 + 128
|
138 |
+
distance_mat += self.patch_num//2 - 1 # remove negative
|
139 |
+
return self.rpe_table[distance_mat].permute(2, 0, 1)[None]
|
140 |
+
|
141 |
+
def forward(self, attn):
|
142 |
+
return attn + self.get_relative_pos_embed()
|
143 |
+
|
144 |
+
|
145 |
+
class Attention(nn.Module):
|
146 |
+
def __init__(self, dim, heads=8, dim_head=64, dropout=0., patch_num=None, rpe=None, rpe_pos=1):
|
147 |
+
"""
|
148 |
+
:param dim:
|
149 |
+
:param heads:
|
150 |
+
:param dim_head:
|
151 |
+
:param dropout:
|
152 |
+
:param patch_num:
|
153 |
+
:param rpe: relative position embedding
|
154 |
+
"""
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
self.relative_pos_embed = None if patch_num is None or rpe is None else RelativePosition(heads, patch_num, rpe)
|
158 |
+
inner_dim = dim_head * heads
|
159 |
+
project_out = not (heads == 1 and dim_head == dim)
|
160 |
+
|
161 |
+
self.heads = heads
|
162 |
+
self.scale = dim_head ** -0.5
|
163 |
+
self.rpe_pos = rpe_pos
|
164 |
+
|
165 |
+
self.attend = Attend(dim=-1)
|
166 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
167 |
+
|
168 |
+
self.to_out = nn.Sequential(
|
169 |
+
nn.Linear(inner_dim, dim),
|
170 |
+
nn.Dropout(dropout)
|
171 |
+
) if project_out else nn.Identity()
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
b, n, _, h = *x.shape, self.heads
|
175 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
176 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
|
177 |
+
|
178 |
+
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
179 |
+
|
180 |
+
if self.rpe_pos == 0:
|
181 |
+
if self.relative_pos_embed is not None:
|
182 |
+
dots = self.relative_pos_embed(dots)
|
183 |
+
|
184 |
+
attn = self.attend(dots)
|
185 |
+
|
186 |
+
if self.rpe_pos == 1:
|
187 |
+
if self.relative_pos_embed is not None:
|
188 |
+
attn = self.relative_pos_embed(attn)
|
189 |
+
|
190 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
191 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
192 |
+
return self.to_out(out)
|
193 |
+
|
194 |
+
|
195 |
+
class AbsolutePosition(nn.Module):
|
196 |
+
def __init__(self, dim, dropout=0., patch_num=None, ape=None):
|
197 |
+
super().__init__()
|
198 |
+
self.ape = ape
|
199 |
+
|
200 |
+
if ape == 'lr_parameter':
|
201 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, patch_num, dim))
|
202 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
203 |
+
|
204 |
+
elif ape == 'fix_angle':
|
205 |
+
angle = torch.arange(0, patch_num, dtype=torch.float) / patch_num * (math.pi * 2)
|
206 |
+
self.absolute_pos_embed = torch.sin(angle)[..., None].repeat(1, dim)[None]
|
207 |
+
|
208 |
+
def forward(self, x):
|
209 |
+
return x + self.absolute_pos_embed
|
210 |
+
|
211 |
+
|
212 |
+
class WinAttention(nn.Module):
|
213 |
+
def __init__(self, dim, win_size=8, shift=0, heads=8, dim_head=64, dropout=0., rpe=None, rpe_pos=1):
|
214 |
+
super().__init__()
|
215 |
+
|
216 |
+
self.win_size = win_size
|
217 |
+
self.shift = shift
|
218 |
+
self.attend = Attention(dim, heads=heads, dim_head=dim_head,
|
219 |
+
dropout=dropout, patch_num=win_size, rpe=None if rpe is None else 'lr_parameter',
|
220 |
+
rpe_pos=rpe_pos)
|
221 |
+
|
222 |
+
def forward(self, x):
|
223 |
+
b = x.shape[0]
|
224 |
+
if self.shift != 0:
|
225 |
+
x = torch.roll(x, shifts=self.shift, dims=-2)
|
226 |
+
x = rearrange(x, 'b (m w) d -> (b m) w d', w=self.win_size) # split windows
|
227 |
+
|
228 |
+
out = self.attend(x)
|
229 |
+
|
230 |
+
out = rearrange(out, '(b m) w d -> b (m w) d ', b=b) # recover windows
|
231 |
+
if self.shift != 0:
|
232 |
+
out = torch.roll(out, shifts=-self.shift, dims=-2)
|
233 |
+
|
234 |
+
return out
|
235 |
+
|
236 |
+
|
237 |
+
class Conv(nn.Module):
|
238 |
+
def __init__(self, dim, dropout=0.):
|
239 |
+
super().__init__()
|
240 |
+
self.dim = dim
|
241 |
+
self.net = nn.Sequential(
|
242 |
+
nn.Conv1d(dim, dim, kernel_size=3, stride=1, padding=0),
|
243 |
+
nn.Dropout(dropout)
|
244 |
+
)
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
x = x.transpose(1, 2)
|
248 |
+
x = torch.cat([x[..., -1:], x, x[..., :1]], dim=-1)
|
249 |
+
x = self.net(x)
|
250 |
+
return x.transpose(1, 2)
|
models/other/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/07/18
|
3 |
+
@description:
|
4 |
+
"""
|
models/other/criterion.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@date: 2021/7/19
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import loss
|
7 |
+
|
8 |
+
from utils.misc import tensor2np
|
9 |
+
|
10 |
+
|
11 |
+
def build_criterion(config, logger):
|
12 |
+
criterion = {}
|
13 |
+
device = config.TRAIN.DEVICE
|
14 |
+
|
15 |
+
for k in config.TRAIN.CRITERION.keys():
|
16 |
+
sc = config.TRAIN.CRITERION[k]
|
17 |
+
if sc.WEIGHT is None or float(sc.WEIGHT) == 0:
|
18 |
+
continue
|
19 |
+
criterion[sc.NAME] = {
|
20 |
+
'loss': getattr(loss, sc.LOSS)(),
|
21 |
+
'weight': float(sc.WEIGHT),
|
22 |
+
'sub_weights': sc.WEIGHTS,
|
23 |
+
'need_all': sc.NEED_ALL
|
24 |
+
}
|
25 |
+
|
26 |
+
criterion[sc.NAME]['loss'] = criterion[sc.NAME]['loss'].to(device)
|
27 |
+
if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device:
|
28 |
+
criterion[sc.NAME]['loss'] = criterion[sc.NAME]['loss'].type(torch.float16)
|
29 |
+
|
30 |
+
# logger.info(f"Build criterion:{sc.WEIGHT}_{sc.NAME}_{sc.LOSS}_{sc.WEIGHTS}")
|
31 |
+
return criterion
|
32 |
+
|
33 |
+
|
34 |
+
def calc_criterion(criterion, gt, dt, epoch_loss_d):
|
35 |
+
loss = None
|
36 |
+
postfix_d = {}
|
37 |
+
for k in criterion.keys():
|
38 |
+
if criterion[k]['need_all']:
|
39 |
+
single_loss = criterion[k]['loss'](gt, dt)
|
40 |
+
ws_loss = None
|
41 |
+
for i, sub_weight in enumerate(criterion[k]['sub_weights']):
|
42 |
+
if sub_weight == 0:
|
43 |
+
continue
|
44 |
+
if ws_loss is None:
|
45 |
+
ws_loss = single_loss[i] * sub_weight
|
46 |
+
else:
|
47 |
+
ws_loss = ws_loss + single_loss[i] * sub_weight
|
48 |
+
single_loss = ws_loss if ws_loss is not None else single_loss
|
49 |
+
else:
|
50 |
+
assert k in gt.keys(), "ground label is None:" + k
|
51 |
+
assert k in dt.keys(), "detection key is None:" + k
|
52 |
+
if k == 'ratio' and gt[k].shape[-1] != dt[k].shape[-1]:
|
53 |
+
gt[k] = gt[k].repeat(1, dt[k].shape[-1])
|
54 |
+
single_loss = criterion[k]['loss'](gt[k], dt[k])
|
55 |
+
|
56 |
+
postfix_d[k] = tensor2np(single_loss)
|
57 |
+
if k not in epoch_loss_d.keys():
|
58 |
+
epoch_loss_d[k] = []
|
59 |
+
epoch_loss_d[k].append(postfix_d[k])
|
60 |
+
|
61 |
+
single_loss = single_loss * criterion[k]['weight']
|
62 |
+
if loss is None:
|
63 |
+
loss = single_loss
|
64 |
+
else:
|
65 |
+
loss = loss + single_loss
|
66 |
+
|
67 |
+
k = 'loss'
|
68 |
+
postfix_d[k] = tensor2np(loss)
|
69 |
+
if k not in epoch_loss_d.keys():
|
70 |
+
epoch_loss_d[k] = []
|
71 |
+
epoch_loss_d[k].append(postfix_d[k])
|
72 |
+
return loss, postfix_d, epoch_loss_d
|
models/other/init_env.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/08/15
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
import torch.backends.cudnn as cudnn
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
|
13 |
+
def init_env(seed, deterministic=False, loader_work_num=0):
|
14 |
+
# Fix seed
|
15 |
+
# Python & NumPy
|
16 |
+
np.random.seed(seed)
|
17 |
+
random.seed(seed)
|
18 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
19 |
+
|
20 |
+
# PyTorch
|
21 |
+
torch.manual_seed(seed) # 为CPU设置随机种子
|
22 |
+
if torch.cuda.is_available():
|
23 |
+
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
|
24 |
+
torch.cuda.manual_seed_all(seed) # 为所有GPU设置随机种子
|
25 |
+
|
26 |
+
# cuDNN
|
27 |
+
if deterministic:
|
28 |
+
# 复现
|
29 |
+
torch.backends.cudnn.benchmark = False
|
30 |
+
torch.backends.cudnn.deterministic = True # 将这个 flag 置为 True 的话,每次返回的卷积算法将是确定的,即默认算法
|
31 |
+
else:
|
32 |
+
cudnn.benchmark = True # 如果网络的输入数据维度或类型上变化不大,设置true
|
33 |
+
torch.backends.cudnn.deterministic = False
|
34 |
+
|
35 |
+
# Using multiple threads in Opencv can cause deadlocks
|
36 |
+
if loader_work_num != 0:
|
37 |
+
cv2.setNumThreads(0)
|
models/other/optimizer.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/07/18
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
from torch import optim as optim
|
6 |
+
|
7 |
+
|
8 |
+
def build_optimizer(config, model, logger):
|
9 |
+
name = config.TRAIN.OPTIMIZER.NAME.lower()
|
10 |
+
|
11 |
+
optimizer = None
|
12 |
+
if name == 'sgd':
|
13 |
+
optimizer = optim.SGD(model.parameters(), momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
|
14 |
+
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
|
15 |
+
elif name == 'adamw':
|
16 |
+
optimizer = optim.AdamW(model.parameters(), eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
|
17 |
+
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
|
18 |
+
elif name == 'adam':
|
19 |
+
optimizer = optim.Adam(model.parameters(), eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
|
20 |
+
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
|
21 |
+
|
22 |
+
logger.info(f"Build optimizer: {name}, lr:{config.TRAIN.BASE_LR}")
|
23 |
+
|
24 |
+
return optimizer
|
models/other/scheduler.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/09/14
|
3 |
+
@description:
|
4 |
+
"""
|
5 |
+
|
6 |
+
|
7 |
+
class WarmupScheduler:
|
8 |
+
def __init__(self, optimizer, lr_pow, init_lr, warmup_lr, warmup_step, max_step, **kwargs):
|
9 |
+
self.lr_pow = lr_pow
|
10 |
+
self.init_lr = init_lr
|
11 |
+
self.running_lr = init_lr
|
12 |
+
self.warmup_lr = warmup_lr
|
13 |
+
self.warmup_step = warmup_step
|
14 |
+
self.max_step = max_step
|
15 |
+
self.optimizer = optimizer
|
16 |
+
|
17 |
+
def step_update(self, cur_step):
|
18 |
+
if cur_step < self.warmup_step:
|
19 |
+
frac = cur_step / self.warmup_step
|
20 |
+
step = self.warmup_lr - self.init_lr
|
21 |
+
self.running_lr = self.init_lr + step * frac
|
22 |
+
else:
|
23 |
+
frac = (float(cur_step) - self.warmup_step) / (self.max_step - self.warmup_step)
|
24 |
+
scale_running_lr = max((1. - frac), 0.) ** self.lr_pow
|
25 |
+
self.running_lr = self.warmup_lr * scale_running_lr
|
26 |
+
|
27 |
+
if self.optimizer is not None:
|
28 |
+
for param_group in self.optimizer.param_groups:
|
29 |
+
param_group['lr'] = self.running_lr
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == '__main__':
|
33 |
+
import matplotlib.pyplot as plt
|
34 |
+
|
35 |
+
scheduler = WarmupScheduler(optimizer=None,
|
36 |
+
lr_pow=4,
|
37 |
+
init_lr=0.0000003,
|
38 |
+
warmup_lr=0.00003,
|
39 |
+
warmup_step=10000,
|
40 |
+
max_step=100000)
|
41 |
+
|
42 |
+
x = []
|
43 |
+
y = []
|
44 |
+
for i in range(100000):
|
45 |
+
if i == 10000-1:
|
46 |
+
print()
|
47 |
+
scheduler.step_update(i)
|
48 |
+
x.append(i)
|
49 |
+
y.append(scheduler.running_lr)
|
50 |
+
plt.plot(x, y, linewidth=1)
|
51 |
+
plt.show()
|
postprocessing/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/10/06
|
3 |
+
@description:
|
4 |
+
"""
|
postprocessing/dula/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/10/06
|
3 |
+
@description:
|
4 |
+
"""
|
postprocessing/dula/layout.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Date: 2021/10/06
|
3 |
+
@description: Use the approach proposed by DuLa-Net
|
4 |
+
"""
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import math
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from visualization.floorplan import draw_floorplan
|
11 |
+
|
12 |
+
|
13 |
+
def merge_near(lst, diag):
|
14 |
+
group = [[0, ]]
|
15 |
+
for i in range(1, len(lst)):
|
16 |
+
if lst[i][1] == 0 and lst[i][0] - np.mean(group[-1]) < diag * 0.02:
|
17 |
+
group[-1].append(lst[i][0])
|
18 |
+
else:
|
19 |
+
group.append([lst[i][0], ])
|
20 |
+
if len(group) == 1:
|
21 |
+
group = [lst[0][0], lst[-1][0]]
|
22 |
+
else:
|
23 |
+
group = [int(np.mean(x)) for x in group]
|
24 |
+
return group
|
25 |
+
|
26 |
+
|
27 |
+
def fit_layout(floor_xz, need_cube=False, show=False, block_eps=0.2):
|
28 |
+
show_radius = np.linalg.norm(floor_xz, axis=-1).max()
|
29 |
+
side_l = 512
|
30 |
+
floorplan = draw_floorplan(xz=floor_xz, show_radius=show_radius, show=show, scale=1, side_l=side_l).astype(np.uint8)
|
31 |
+
center = np.array([side_l / 2, side_l / 2])
|
32 |
+
polys = cv2.findContours(floorplan, 1, 2)
|
33 |
+
if isinstance(polys, tuple):
|
34 |
+
if len(polys) == 3:
|
35 |
+
# opencv 3
|
36 |
+
polys = list(polys[1])
|
37 |
+
else:
|
38 |
+
polys = list(polys[0])
|
39 |
+
polys.sort(key=lambda x: cv2.contourArea(x), reverse=True)
|
40 |
+
poly = polys[0]
|
41 |
+
sub_x, sub_y, w, h = cv2.boundingRect(poly)
|
42 |
+
floorplan_sub = floorplan[sub_y:sub_y + h, sub_x:sub_x + w]
|
43 |
+
sub_center = center - np.array([sub_x, sub_y])
|
44 |
+
polys = cv2.findContours(floorplan_sub, 1, 2)
|
45 |
+
if isinstance(polys, tuple):
|
46 |
+
if len(polys) == 3:
|
47 |
+
polys = polys[1]
|
48 |
+
else:
|
49 |
+
polys = polys[0]
|
50 |
+
poly = polys[0]
|
51 |
+
epsilon = 0.005 * cv2.arcLength(poly, True)
|
52 |
+
poly = cv2.approxPolyDP(poly, epsilon, True)
|
53 |
+
|
54 |
+
x_lst = [[0, 0], ]
|
55 |
+
y_lst = [[0, 0], ]
|
56 |
+
|
57 |
+
ans = np.zeros((floorplan_sub.shape[0], floorplan_sub.shape[1]))
|
58 |
+
|
59 |
+
for i in range(len(poly)):
|
60 |
+
p1 = poly[i][0]
|
61 |
+
p2 = poly[(i + 1) % len(poly)][0]
|
62 |
+
# We added occlusion detection
|
63 |
+
cp1 = p1 - sub_center
|
64 |
+
cp2 = p2 - sub_center
|
65 |
+
p12 = p2 - p1
|
66 |
+
l1 = np.linalg.norm(cp1)
|
67 |
+
l2 = np.linalg.norm(cp2)
|
68 |
+
l3 = np.linalg.norm(p12)
|
69 |
+
# We added occlusion detection
|
70 |
+
is_block1 = abs(np.cross(cp1/l1, cp2/l2)) < block_eps
|
71 |
+
is_block2 = abs(np.cross(cp2/l2, p12/l3)) < block_eps*2
|
72 |
+
is_block = is_block1 and is_block2
|
73 |
+
|
74 |
+
if (p2[0] - p1[0]) == 0:
|
75 |
+
slope = 10
|
76 |
+
else:
|
77 |
+
slope = abs((p2[1] - p1[1]) / (p2[0] - p1[0]))
|
78 |
+
|
79 |
+
if is_block:
|
80 |
+
s = p1[1] if l1 < l2 else p2[1]
|
81 |
+
y_lst.append([s, 1])
|
82 |
+
s = p1[0] if l1 < l2 else p2[0]
|
83 |
+
x_lst.append([s, 1])
|
84 |
+
|
85 |
+
left = p1[0] if p1[0] < p2[0] else p2[0]
|
86 |
+
right = p1[0] if p1[0] > p2[0] else p2[0]
|
87 |
+
top = p1[1] if p1[1] < p2[1] else p2[1]
|
88 |
+
bottom = p1[1] if p1[1] > p2[1] else p2[1]
|
89 |
+
sample = floorplan_sub[top:bottom, left:right]
|
90 |
+
score = 0 if sample.size == 0 else sample.mean()
|
91 |
+
if score >= 0.3:
|
92 |
+
ans[top:bottom, left:right] = 1
|
93 |
+
|
94 |
+
else:
|
95 |
+
if slope <= 1:
|
96 |
+
s = int((p1[1] + p2[1]) / 2)
|
97 |
+
y_lst.append([s, 0])
|
98 |
+
elif slope > 1:
|
99 |
+
s = int((p1[0] + p2[0]) / 2)
|
100 |
+
x_lst.append([s, 0])
|
101 |
+
|
102 |
+
debug_show = False
|
103 |
+
if debug_show:
|
104 |
+
plt.figure(dpi=300)
|
105 |
+
plt.axis('off')
|
106 |
+
a = cv2.drawMarker(floorplan_sub.copy()*0.5, tuple([floorplan_sub.shape[1] // 2, floorplan_sub.shape[0] // 2]), [1], markerType=0, markerSize=10, thickness=2)
|
107 |
+
plt.imshow(cv2.drawContours(a, [poly], 0, 1, 1))
|
108 |
+
plt.savefig('src/1.png', bbox_inches='tight', transparent=True, pad_inches=0)
|
109 |
+
plt.show()
|
110 |
+
|
111 |
+
plt.figure(dpi=300)
|
112 |
+
plt.axis('off')
|
113 |
+
a = cv2.drawMarker(ans.copy()*0.5, tuple([floorplan_sub.shape[1] // 2, floorplan_sub.shape[0] // 2]), [1], markerType=0, markerSize=10, thickness=2)
|
114 |
+
plt.imshow(cv2.drawContours(a, [poly], 0, 1, 1))
|
115 |
+
# plt.show()
|
116 |
+
plt.savefig('src/2.png', bbox_inches='tight', transparent=True, pad_inches=0)
|
117 |
+
plt.show()
|
118 |
+
|
119 |
+
x_lst.append([floorplan_sub.shape[1], 0])
|
120 |
+
y_lst.append([floorplan_sub.shape[0], 0])
|
121 |
+
x_lst.sort(key=lambda x: x[0])
|
122 |
+
y_lst.sort(key=lambda x: x[0])
|
123 |
+
|
124 |
+
diag = math.sqrt(math.pow(floorplan_sub.shape[1], 2) + math.pow(floorplan_sub.shape[0], 2))
|
125 |
+
x_lst = merge_near(x_lst, diag)
|
126 |
+
y_lst = merge_near(y_lst, diag)
|
127 |
+
if need_cube and len(x_lst) > 2:
|
128 |
+
x_lst = [x_lst[0], x_lst[-1]]
|
129 |
+
if need_cube and len(y_lst) > 2:
|
130 |
+
y_lst = [y_lst[0], y_lst[-1]]
|
131 |
+
|
132 |
+
for i in range(len(x_lst) - 1):
|
133 |
+
for j in range(len(y_lst) - 1):
|
134 |
+
sample = floorplan_sub[y_lst[j]:y_lst[j + 1], x_lst[i]:x_lst[i + 1]]
|
135 |
+
score = 0 if sample.size == 0 else sample.mean()
|
136 |
+
if score >= 0.3:
|
137 |
+
ans[y_lst[j]:y_lst[j + 1], x_lst[i]:x_lst[i + 1]] = 1
|
138 |
+
|
139 |
+
if debug_show:
|
140 |
+
plt.figure(dpi=300)
|
141 |
+
plt.axis('off')
|
142 |
+
a = cv2.drawMarker(ans.copy() * 0.5, tuple([floorplan_sub.shape[1] // 2, floorplan_sub.shape[0] // 2]), [1],
|
143 |
+
markerType=0, markerSize=10, thickness=2)
|
144 |
+
plt.imshow(cv2.drawContours(a, [poly], 0, 1, 1))
|
145 |
+
# plt.show()
|
146 |
+
plt.savefig('src/3.png', bbox_inches='tight', transparent=True, pad_inches=0)
|
147 |
+
plt.show()
|
148 |
+
|
149 |
+
pred = np.uint8(ans)
|
150 |
+
pred_polys = cv2.findContours(pred, 1, 3)
|
151 |
+
if isinstance(pred_polys, tuple):
|
152 |
+
if len(pred_polys) == 3:
|
153 |
+
pred_polys = pred_polys[1]
|
154 |
+
else:
|
155 |
+
pred_polys = pred_polys[0]
|
156 |
+
|
157 |
+
pred_polys.sort(key=lambda x: cv2.contourArea(x), reverse=True)
|
158 |
+
pred_polys = pred_polys[0]
|
159 |
+
|
160 |
+
if debug_show:
|
161 |
+
plt.figure(dpi=300)
|
162 |
+
plt.axis('off')
|
163 |
+
a = cv2.drawMarker(ans.copy() * 0.5, tuple([floorplan_sub.shape[1] // 2, floorplan_sub.shape[0] // 2]), [1],
|
164 |
+
markerType=0, markerSize=10, thickness=2)
|
165 |
+
a = cv2.drawContours(a, [poly], 0, 0.8, 1)
|
166 |
+
a = cv2.drawContours(a, [pred_polys], 0, 1, 1)
|
167 |
+
plt.imshow(a)
|
168 |
+
# plt.show()
|
169 |
+
plt.savefig('src/4.png', bbox_inches='tight', transparent=True, pad_inches=0)
|
170 |
+
plt.show()
|
171 |
+
|
172 |
+
polygon = [(p[0][1], p[0][0]) for p in pred_polys[::-1]]
|
173 |
+
|
174 |
+
v = np.array([p[0] + sub_y for p in polygon])
|
175 |
+
u = np.array([p[1] + sub_x for p in polygon])
|
176 |
+
# side_l
|
177 |
+
# v<-----------|o
|
178 |
+
# | | |
|
179 |
+
# | ----|----z | side_l
|
180 |
+
# | | |
|
181 |
+
# | x \|/
|
182 |
+
# |------------u
|
183 |
+
side_l = floorplan.shape[0]
|
184 |
+
pred_xz = np.concatenate((u[:, np.newaxis] - side_l // 2, side_l // 2 - v[:, np.newaxis]), axis=1)
|
185 |
+
|
186 |
+
pred_xz = pred_xz * show_radius / (side_l // 2)
|
187 |
+
if show:
|
188 |
+
draw_floorplan(pred_xz, show_radius=show_radius, show=show)
|
189 |
+
|
190 |
+
show_process = False
|
191 |
+
if show_process:
|
192 |
+
img = np.zeros((floorplan_sub.shape[0], floorplan_sub.shape[1], 3))
|
193 |
+
for x in x_lst:
|
194 |
+
cv2.line(img, (x, 0), (x, floorplan_sub.shape[0]), (0, 255, 0), 1)
|
195 |
+
for y in y_lst:
|
196 |
+
cv2.line(img, (0, y), (floorplan_sub.shape[1], y), (255, 0, 0), 1)
|
197 |
+
|
198 |
+
fig = plt.figure()
|
199 |
+
plt.axis('off')
|
200 |
+
ax1 = fig.add_subplot(2, 2, 1)
|
201 |
+
ax1.imshow(floorplan)
|
202 |
+
ax3 = fig.add_subplot(2, 2, 2)
|
203 |
+
ax3.imshow(floorplan_sub)
|
204 |
+
ax4 = fig.add_subplot(2, 2, 3)
|
205 |
+
ax4.imshow(img)
|
206 |
+
ax5 = fig.add_subplot(2, 2, 4)
|
207 |
+
ax5.imshow(ans)
|
208 |
+
plt.show()
|
209 |
+
|
210 |
+
return pred_xz
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == '__main__':
|
214 |
+
from utils.conversion import uv2xyz
|
215 |
+
|
216 |
+
pano_img = np.zeros([512, 1024, 3])
|
217 |
+
corners = np.array([[0.1, 0.7],
|
218 |
+
[0.4, 0.7],
|
219 |
+
[0.3, 0.6],
|
220 |
+
[0.6, 0.6],
|
221 |
+
[0.8, 0.7]])
|
222 |
+
xz = uv2xyz(corners)[..., ::2]
|
223 |
+
draw_floorplan(xz, show=True, marker_color=None, center_color=0.8)
|
224 |
+
|
225 |
+
xz = fit_layout(xz)
|
226 |
+
draw_floorplan(xz, show=True, marker_color=None, center_color=0.8)
|