lemonaddie
commited on
Commit
•
2e23827
1
Parent(s):
87f795e
Upload 11 files
Browse files- utils/batch_size.py +63 -0
- utils/colormap.py +45 -0
- utils/common.py +42 -0
- utils/dataset_configuration.py +81 -0
- utils/de_normalized.py +33 -0
- utils/depth2normal.py +186 -0
- utils/depth_ensemble.py +115 -0
- utils/image_util.py +83 -0
- utils/normal_ensemble.py +22 -0
- utils/seed_all.py +33 -0
- utils/surface_normal.py +213 -0
utils/batch_size.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
# Search table for suggested max. inference batch size
|
8 |
+
bs_search_table = [
|
9 |
+
# tested on A100-PCIE-80GB
|
10 |
+
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
11 |
+
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
12 |
+
# tested on A100-PCIE-40GB
|
13 |
+
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
14 |
+
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
15 |
+
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
16 |
+
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
17 |
+
# tested on RTX3090, RTX4090
|
18 |
+
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
19 |
+
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
20 |
+
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
21 |
+
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
22 |
+
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
23 |
+
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
24 |
+
# tested on GTX1080Ti
|
25 |
+
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
26 |
+
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
27 |
+
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
28 |
+
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
29 |
+
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
30 |
+
]
|
31 |
+
|
32 |
+
|
33 |
+
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
34 |
+
"""
|
35 |
+
Automatically search for suitable operating batch size.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
ensemble_size (`int`):
|
39 |
+
Number of predictions to be ensembled.
|
40 |
+
input_res (`int`):
|
41 |
+
Operating resolution of the input image.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
`int`: Operating batch size.
|
45 |
+
"""
|
46 |
+
if not torch.cuda.is_available():
|
47 |
+
return 1
|
48 |
+
|
49 |
+
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
50 |
+
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
51 |
+
for settings in sorted(
|
52 |
+
filtered_bs_search_table,
|
53 |
+
key=lambda k: (k["res"], -k["total_vram"]),
|
54 |
+
):
|
55 |
+
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
56 |
+
bs = settings["bs"]
|
57 |
+
if bs > ensemble_size:
|
58 |
+
bs = ensemble_size
|
59 |
+
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
60 |
+
bs = math.ceil(ensemble_size / 2)
|
61 |
+
return bs
|
62 |
+
|
63 |
+
return 1
|
utils/colormap.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
def kitti_colormap(disparity, maxval=-1):
|
7 |
+
"""
|
8 |
+
A utility function to reproduce KITTI fake colormap
|
9 |
+
Arguments:
|
10 |
+
- disparity: numpy float32 array of dimension HxW
|
11 |
+
- maxval: maximum disparity value for normalization (if equal to -1, the maximum value in disparity will be used)
|
12 |
+
|
13 |
+
Returns a numpy uint8 array of shape HxWx3.
|
14 |
+
"""
|
15 |
+
if maxval < 0:
|
16 |
+
maxval = np.max(disparity)
|
17 |
+
|
18 |
+
colormap = np.asarray([[0,0,0,114],[0,0,1,185],[1,0,0,114],[1,0,1,174],[0,1,0,114],[0,1,1,185],[1,1,0,114],[1,1,1,0]])
|
19 |
+
weights = np.asarray([8.771929824561404,5.405405405405405,8.771929824561404,5.747126436781609,8.771929824561404,5.405405405405405,8.771929824561404,0])
|
20 |
+
cumsum = np.asarray([0,0.114,0.299,0.413,0.587,0.701,0.8859999999999999,0.9999999999999999])
|
21 |
+
|
22 |
+
colored_disp = np.zeros([disparity.shape[0], disparity.shape[1], 3])
|
23 |
+
values = np.expand_dims(np.minimum(np.maximum(disparity/maxval, 0.), 1.), -1)
|
24 |
+
bins = np.repeat(np.repeat(np.expand_dims(np.expand_dims(cumsum,axis=0),axis=0), disparity.shape[1], axis=1), disparity.shape[0], axis=0)
|
25 |
+
diffs = np.where((np.repeat(values, 8, axis=-1) - bins) > 0, -1000, (np.repeat(values, 8, axis=-1) - bins))
|
26 |
+
index = np.argmax(diffs, axis=-1)-1
|
27 |
+
|
28 |
+
w = 1-(values[:,:,0]-cumsum[index])*np.asarray(weights)[index]
|
29 |
+
|
30 |
+
|
31 |
+
colored_disp[:,:,2] = (w*colormap[index][:,:,0] + (1.-w)*colormap[index+1][:,:,0])
|
32 |
+
colored_disp[:,:,1] = (w*colormap[index][:,:,1] + (1.-w)*colormap[index+1][:,:,1])
|
33 |
+
colored_disp[:,:,0] = (w*colormap[index][:,:,2] + (1.-w)*colormap[index+1][:,:,2])
|
34 |
+
|
35 |
+
return (colored_disp*np.expand_dims((disparity>0),-1)*255).astype(np.uint8)
|
36 |
+
|
37 |
+
def read_16bit_gt(path):
|
38 |
+
"""
|
39 |
+
A utility function to read KITTI 16bit gt
|
40 |
+
Arguments:
|
41 |
+
- path: filepath
|
42 |
+
Returns a numpy float32 array of shape HxW.
|
43 |
+
"""
|
44 |
+
gt = cv2.imread(path,-1).astype(np.float32)/256.
|
45 |
+
return gt
|
utils/common.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
+
|
3 |
+
import json
|
4 |
+
import yaml
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
import sys
|
9 |
+
|
10 |
+
def load_loss_scheme(loss_config):
|
11 |
+
with open(loss_config, 'r') as f:
|
12 |
+
loss_json = yaml.safe_load(f)
|
13 |
+
return loss_json
|
14 |
+
|
15 |
+
|
16 |
+
DEBUG =0
|
17 |
+
logger = logging.getLogger()
|
18 |
+
|
19 |
+
|
20 |
+
if DEBUG:
|
21 |
+
#coloredlogs.install(level='DEBUG')
|
22 |
+
logger.setLevel(logging.DEBUG)
|
23 |
+
else:
|
24 |
+
#coloredlogs.install(level='INFO')
|
25 |
+
logger.setLevel(logging.INFO)
|
26 |
+
|
27 |
+
|
28 |
+
strhdlr = logging.StreamHandler()
|
29 |
+
logger.addHandler(strhdlr)
|
30 |
+
formatter = logging.Formatter('%(asctime)s [%(filename)s:%(lineno)d] %(levelname)s %(message)s')
|
31 |
+
strhdlr.setFormatter(formatter)
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
def count_parameters(model):
|
36 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
37 |
+
|
38 |
+
def check_path(path):
|
39 |
+
if not os.path.exists(path):
|
40 |
+
os.makedirs(path, exist_ok=True)
|
41 |
+
|
42 |
+
|
utils/dataset_configuration.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
import sys
|
8 |
+
sys.path.append("..")
|
9 |
+
|
10 |
+
from dataloader.mix_loader import MixDataset
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from dataloader import transforms
|
13 |
+
import os
|
14 |
+
|
15 |
+
|
16 |
+
# Get Dataset Here
|
17 |
+
def prepare_dataset(data_dir=None,
|
18 |
+
batch_size=1,
|
19 |
+
test_batch=1,
|
20 |
+
datathread=4,
|
21 |
+
logger=None):
|
22 |
+
|
23 |
+
# set the config parameters
|
24 |
+
dataset_config_dict = dict()
|
25 |
+
|
26 |
+
train_dataset = MixDataset(data_dir=data_dir)
|
27 |
+
|
28 |
+
img_height, img_width = train_dataset.get_img_size()
|
29 |
+
|
30 |
+
datathread = datathread
|
31 |
+
if os.environ.get('datathread') is not None:
|
32 |
+
datathread = int(os.environ.get('datathread'))
|
33 |
+
|
34 |
+
if logger is not None:
|
35 |
+
logger.info("Use %d processes to load data..." % datathread)
|
36 |
+
|
37 |
+
train_loader = DataLoader(train_dataset, batch_size = batch_size, \
|
38 |
+
shuffle = True, num_workers = datathread, \
|
39 |
+
pin_memory = True)
|
40 |
+
|
41 |
+
num_batches_per_epoch = len(train_loader)
|
42 |
+
|
43 |
+
dataset_config_dict['num_batches_per_epoch'] = num_batches_per_epoch
|
44 |
+
dataset_config_dict['img_size'] = (img_height,img_width)
|
45 |
+
|
46 |
+
return train_loader, dataset_config_dict
|
47 |
+
|
48 |
+
def depth_scale_shift_normalization(depth):
|
49 |
+
|
50 |
+
bsz = depth.shape[0]
|
51 |
+
|
52 |
+
depth_ = depth[:,0,:,:].reshape(bsz,-1).cpu().numpy()
|
53 |
+
min_value = torch.from_numpy(np.percentile(a=depth_,q=2,axis=1)).to(depth)[...,None,None,None]
|
54 |
+
max_value = torch.from_numpy(np.percentile(a=depth_,q=98,axis=1)).to(depth)[...,None,None,None]
|
55 |
+
|
56 |
+
normalized_depth = ((depth - min_value)/(max_value-min_value+1e-5) - 0.5) * 2
|
57 |
+
normalized_depth = torch.clip(normalized_depth, -1., 1.)
|
58 |
+
|
59 |
+
return normalized_depth
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
def resize_max_res_tensor(input_tensor, mode, recom_resolution=768):
|
64 |
+
assert input_tensor.shape[1]==3
|
65 |
+
original_H, original_W = input_tensor.shape[2:]
|
66 |
+
downscale_factor = min(recom_resolution/original_H, recom_resolution/original_W)
|
67 |
+
|
68 |
+
if mode == 'normal':
|
69 |
+
resized_input_tensor = F.interpolate(input_tensor,
|
70 |
+
scale_factor=downscale_factor,
|
71 |
+
mode='nearest')
|
72 |
+
else:
|
73 |
+
resized_input_tensor = F.interpolate(input_tensor,
|
74 |
+
scale_factor=downscale_factor,
|
75 |
+
mode='bilinear',
|
76 |
+
align_corners=False)
|
77 |
+
|
78 |
+
if mode == 'depth':
|
79 |
+
return resized_input_tensor / downscale_factor
|
80 |
+
else:
|
81 |
+
return resized_input_tensor
|
utils/de_normalized.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from scipy.optimize import least_squares
|
5 |
+
import torch
|
6 |
+
|
7 |
+
def align_scale_shift(pred, target, clip_max):
|
8 |
+
mask = (target > 0) & (target < clip_max)
|
9 |
+
if mask.sum() > 10:
|
10 |
+
target_mask = target[mask]
|
11 |
+
pred_mask = pred[mask]
|
12 |
+
scale, shift = np.polyfit(pred_mask, target_mask, deg=1)
|
13 |
+
return scale, shift
|
14 |
+
else:
|
15 |
+
return 1, 0
|
16 |
+
|
17 |
+
def align_scale(pred: torch.tensor, target: torch.tensor):
|
18 |
+
mask = target > 0
|
19 |
+
if torch.sum(mask) > 10:
|
20 |
+
scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8)
|
21 |
+
else:
|
22 |
+
scale = 1
|
23 |
+
pred_scale = pred * scale
|
24 |
+
return pred_scale, scale
|
25 |
+
|
26 |
+
def align_shift(pred: torch.tensor, target: torch.tensor):
|
27 |
+
mask = target > 0
|
28 |
+
if torch.sum(mask) > 10:
|
29 |
+
shift = torch.median(target[mask]) - (torch.median(pred[mask]) + 1e-8)
|
30 |
+
else:
|
31 |
+
shift = 0
|
32 |
+
pred_shift = pred + shift
|
33 |
+
return pred_shift, shift
|
utils/depth2normal.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
+
|
3 |
+
import pickle
|
4 |
+
import os
|
5 |
+
import h5py
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import glob
|
11 |
+
|
12 |
+
|
13 |
+
def init_image_coor(height, width):
|
14 |
+
x_row = np.arange(0, width)
|
15 |
+
x = np.tile(x_row, (height, 1))
|
16 |
+
x = x[np.newaxis, :, :]
|
17 |
+
x = x.astype(np.float32)
|
18 |
+
x = torch.from_numpy(x.copy()).cuda()
|
19 |
+
u_u0 = x - width/2.0
|
20 |
+
|
21 |
+
y_col = np.arange(0, height) # y_col = np.arange(0, height)
|
22 |
+
y = np.tile(y_col, (width, 1)).T
|
23 |
+
y = y[np.newaxis, :, :]
|
24 |
+
y = y.astype(np.float32)
|
25 |
+
y = torch.from_numpy(y.copy()).cuda()
|
26 |
+
v_v0 = y - height/2.0
|
27 |
+
return u_u0, v_v0
|
28 |
+
|
29 |
+
|
30 |
+
def depth_to_xyz(depth, focal_length):
|
31 |
+
b, c, h, w = depth.shape
|
32 |
+
u_u0, v_v0 = init_image_coor(h, w)
|
33 |
+
x = u_u0 * depth / focal_length[0]
|
34 |
+
y = v_v0 * depth / focal_length[1]
|
35 |
+
z = depth
|
36 |
+
pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
|
37 |
+
return pw
|
38 |
+
|
39 |
+
|
40 |
+
def get_surface_normal(xyz, patch_size=5):
|
41 |
+
# xyz: [1, h, w, 3]
|
42 |
+
x, y, z = torch.unbind(xyz, dim=3)
|
43 |
+
x = torch.unsqueeze(x, 0)
|
44 |
+
y = torch.unsqueeze(y, 0)
|
45 |
+
z = torch.unsqueeze(z, 0)
|
46 |
+
|
47 |
+
xx = x * x
|
48 |
+
yy = y * y
|
49 |
+
zz = z * z
|
50 |
+
xy = x * y
|
51 |
+
xz = x * z
|
52 |
+
yz = y * z
|
53 |
+
patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda()
|
54 |
+
xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2))
|
55 |
+
yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2))
|
56 |
+
zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2))
|
57 |
+
xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2))
|
58 |
+
xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2))
|
59 |
+
yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2))
|
60 |
+
ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch],
|
61 |
+
dim=4)
|
62 |
+
ATA = torch.squeeze(ATA)
|
63 |
+
ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3))
|
64 |
+
eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1])
|
65 |
+
ATA = ATA + eps_identity
|
66 |
+
x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2))
|
67 |
+
y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2))
|
68 |
+
z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2))
|
69 |
+
AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4)
|
70 |
+
AT1 = torch.squeeze(AT1)
|
71 |
+
AT1 = torch.unsqueeze(AT1, 3)
|
72 |
+
|
73 |
+
patch_num = 4
|
74 |
+
patch_x = int(AT1.size(1) / patch_num)
|
75 |
+
patch_y = int(AT1.size(0) / patch_num)
|
76 |
+
n_img = torch.randn(AT1.shape).cuda()
|
77 |
+
overlap = patch_size // 2 + 1
|
78 |
+
for x in range(int(patch_num)):
|
79 |
+
for y in range(int(patch_num)):
|
80 |
+
left_flg = 0 if x == 0 else 1
|
81 |
+
right_flg = 0 if x == patch_num -1 else 1
|
82 |
+
top_flg = 0 if y == 0 else 1
|
83 |
+
btm_flg = 0 if y == patch_num - 1 else 1
|
84 |
+
at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
|
85 |
+
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
|
86 |
+
ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
|
87 |
+
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
|
88 |
+
# n_img_tmp, _ = torch.solve(at1, ata)
|
89 |
+
n_img_tmp = torch.linalg.solve(ata, at1)
|
90 |
+
|
91 |
+
n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :]
|
92 |
+
n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select
|
93 |
+
|
94 |
+
n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True))
|
95 |
+
n_img_norm = n_img / n_img_L2
|
96 |
+
|
97 |
+
# re-orient normals consistently
|
98 |
+
orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0
|
99 |
+
n_img_norm[orient_mask] *= -1
|
100 |
+
return n_img_norm
|
101 |
+
|
102 |
+
def get_surface_normalv2(xyz, patch_size=5):
|
103 |
+
"""
|
104 |
+
xyz: xyz coordinates
|
105 |
+
patch: [p1, p2, p3,
|
106 |
+
p4, p5, p6,
|
107 |
+
p7, p8, p9]
|
108 |
+
surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)]
|
109 |
+
return: normal [h, w, 3, b]
|
110 |
+
"""
|
111 |
+
b, h, w, c = xyz.shape
|
112 |
+
half_patch = patch_size // 2
|
113 |
+
xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device)
|
114 |
+
xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz
|
115 |
+
|
116 |
+
# xyz_left_top = xyz_pad[:, :h, :w, :] # p1
|
117 |
+
# xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9
|
118 |
+
# xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7
|
119 |
+
# xyz_right_top = xyz_pad[:, :h, -w:, :] # p3
|
120 |
+
# xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9
|
121 |
+
# xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3
|
122 |
+
|
123 |
+
xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4
|
124 |
+
xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6
|
125 |
+
xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2
|
126 |
+
xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8
|
127 |
+
xyz_horizon = xyz_left - xyz_right # p4p6
|
128 |
+
xyz_vertical = xyz_top - xyz_bottom # p2p8
|
129 |
+
|
130 |
+
xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4
|
131 |
+
xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6
|
132 |
+
xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2
|
133 |
+
xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8
|
134 |
+
xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6
|
135 |
+
xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8
|
136 |
+
|
137 |
+
n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3)
|
138 |
+
n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3)
|
139 |
+
|
140 |
+
# re-orient normals consistently
|
141 |
+
orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0
|
142 |
+
n_img_1[orient_mask] *= -1
|
143 |
+
orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0
|
144 |
+
n_img_2[orient_mask] *= -1
|
145 |
+
|
146 |
+
n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True))
|
147 |
+
n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8)
|
148 |
+
|
149 |
+
n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True))
|
150 |
+
n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8)
|
151 |
+
|
152 |
+
# average 2 norms
|
153 |
+
n_img_aver = n_img1_norm + n_img2_norm
|
154 |
+
n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True))
|
155 |
+
n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8)
|
156 |
+
# re-orient normals consistently
|
157 |
+
orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0
|
158 |
+
n_img_aver_norm[orient_mask] *= -1
|
159 |
+
n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b]
|
160 |
+
|
161 |
+
# a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze()
|
162 |
+
# plt.imshow(np.abs(a), cmap='rainbow')
|
163 |
+
# plt.show()
|
164 |
+
return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0))
|
165 |
+
|
166 |
+
def surface_normal_from_depth(depth, focal_length, valid_mask=None):
|
167 |
+
# para depth: depth map, [b, c, h, w]
|
168 |
+
b, c, h, w = depth.shape
|
169 |
+
focal_length = focal_length[:, None, None, None]
|
170 |
+
depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1)
|
171 |
+
#depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1)
|
172 |
+
xyz = depth_to_xyz(depth_filter, focal_length)
|
173 |
+
sn_batch = []
|
174 |
+
for i in range(b):
|
175 |
+
xyz_i = xyz[i, :][None, :, :, :]
|
176 |
+
#normal = get_surface_normalv2(xyz_i)
|
177 |
+
normal = get_surface_normal(xyz_i)
|
178 |
+
sn_batch.append(normal)
|
179 |
+
sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w]
|
180 |
+
|
181 |
+
if valid_mask != None:
|
182 |
+
mask_invalid = (~valid_mask).repeat(1, 3, 1, 1)
|
183 |
+
sn_batch[mask_invalid] = 0.0
|
184 |
+
|
185 |
+
return sn_batch
|
186 |
+
|
utils/depth_ensemble.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from scipy.optimize import minimize
|
7 |
+
|
8 |
+
def inter_distances(tensors: torch.Tensor):
|
9 |
+
"""
|
10 |
+
To calculate the distance between each two depth maps.
|
11 |
+
"""
|
12 |
+
distances = []
|
13 |
+
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
|
14 |
+
arr1 = tensors[i : i + 1]
|
15 |
+
arr2 = tensors[j : j + 1]
|
16 |
+
distances.append(arr1 - arr2)
|
17 |
+
dist = torch.concat(distances, dim=0)
|
18 |
+
return dist
|
19 |
+
|
20 |
+
|
21 |
+
def ensemble_depths(input_images:torch.Tensor,
|
22 |
+
regularizer_strength: float =0.02,
|
23 |
+
max_iter: int =2,
|
24 |
+
tol:float =1e-3,
|
25 |
+
reduction: str='median',
|
26 |
+
max_res: int=None):
|
27 |
+
"""
|
28 |
+
To ensemble multiple affine-invariant depth images (up to scale and shift),
|
29 |
+
by aligning estimating the scale and shift
|
30 |
+
"""
|
31 |
+
|
32 |
+
device = input_images.device
|
33 |
+
dtype = input_images.dtype
|
34 |
+
np_dtype = np.float32
|
35 |
+
|
36 |
+
|
37 |
+
original_input = input_images.clone()
|
38 |
+
n_img = input_images.shape[0]
|
39 |
+
ori_shape = input_images.shape
|
40 |
+
|
41 |
+
if max_res is not None:
|
42 |
+
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
|
43 |
+
if scale_factor < 1:
|
44 |
+
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
|
45 |
+
input_images = downscaler(torch.from_numpy(input_images)).numpy()
|
46 |
+
|
47 |
+
# init guess
|
48 |
+
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) # get the min value of each possible depth
|
49 |
+
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) # get the max value of each possible depth
|
50 |
+
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1)) #(10,1,1) : re-scale'f scale
|
51 |
+
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1)) #(10,1,1)
|
52 |
+
|
53 |
+
x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype) #(20,)
|
54 |
+
|
55 |
+
input_images = input_images.to(device)
|
56 |
+
|
57 |
+
# objective function
|
58 |
+
def closure(x):
|
59 |
+
l = len(x)
|
60 |
+
s = x[: int(l / 2)]
|
61 |
+
t = x[int(l / 2) :]
|
62 |
+
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
63 |
+
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
64 |
+
|
65 |
+
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
|
66 |
+
dists = inter_distances(transformed_arrays)
|
67 |
+
sqrt_dist = torch.sqrt(torch.mean(dists**2))
|
68 |
+
|
69 |
+
if "mean" == reduction:
|
70 |
+
pred = torch.mean(transformed_arrays, dim=0)
|
71 |
+
elif "median" == reduction:
|
72 |
+
pred = torch.median(transformed_arrays, dim=0).values
|
73 |
+
else:
|
74 |
+
raise ValueError
|
75 |
+
|
76 |
+
near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
|
77 |
+
far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
|
78 |
+
|
79 |
+
err = sqrt_dist + (near_err + far_err) * regularizer_strength
|
80 |
+
err = err.detach().cpu().numpy().astype(np_dtype)
|
81 |
+
return err
|
82 |
+
|
83 |
+
res = minimize(
|
84 |
+
closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False}
|
85 |
+
)
|
86 |
+
x = res.x
|
87 |
+
l = len(x)
|
88 |
+
s = x[: int(l / 2)]
|
89 |
+
t = x[int(l / 2) :]
|
90 |
+
|
91 |
+
# Prediction
|
92 |
+
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
93 |
+
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
94 |
+
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1) #[10,H,W]
|
95 |
+
|
96 |
+
|
97 |
+
if "mean" == reduction:
|
98 |
+
aligned_images = torch.mean(transformed_arrays, dim=0)
|
99 |
+
std = torch.std(transformed_arrays, dim=0)
|
100 |
+
uncertainty = std
|
101 |
+
|
102 |
+
elif "median" == reduction:
|
103 |
+
aligned_images = torch.median(transformed_arrays, dim=0).values
|
104 |
+
# MAD (median absolute deviation) as uncertainty indicator
|
105 |
+
abs_dev = torch.abs(transformed_arrays - aligned_images)
|
106 |
+
mad = torch.median(abs_dev, dim=0).values
|
107 |
+
uncertainty = mad
|
108 |
+
|
109 |
+
# Scale and shift to [0, 1]
|
110 |
+
_min = torch.min(aligned_images)
|
111 |
+
_max = torch.max(aligned_images)
|
112 |
+
aligned_images = (aligned_images - _min) / (_max - _min)
|
113 |
+
uncertainty /= _max - _min
|
114 |
+
|
115 |
+
return aligned_images, uncertainty
|
utils/image_util.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
+
|
3 |
+
import matplotlib
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
12 |
+
"""
|
13 |
+
Resize image to limit maximum edge length while keeping aspect ratio.
|
14 |
+
Args:
|
15 |
+
img (`Image.Image`):
|
16 |
+
Image to be resized.
|
17 |
+
max_edge_resolution (`int`):
|
18 |
+
Maximum edge length (pixel).
|
19 |
+
Returns:
|
20 |
+
`Image.Image`: Resized image.
|
21 |
+
"""
|
22 |
+
|
23 |
+
original_width, original_height = img.size
|
24 |
+
|
25 |
+
downscale_factor = min(
|
26 |
+
max_edge_resolution / original_width, max_edge_resolution / original_height
|
27 |
+
)
|
28 |
+
|
29 |
+
new_width = int(original_width * downscale_factor)
|
30 |
+
new_height = int(original_height * downscale_factor)
|
31 |
+
|
32 |
+
resized_img = img.resize((new_width, new_height))
|
33 |
+
return resized_img
|
34 |
+
|
35 |
+
|
36 |
+
def colorize_depth_maps(
|
37 |
+
depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
Colorize depth maps.
|
41 |
+
"""
|
42 |
+
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
43 |
+
|
44 |
+
if isinstance(depth_map, torch.Tensor):
|
45 |
+
depth = depth_map.detach().clone().squeeze().numpy()
|
46 |
+
elif isinstance(depth_map, np.ndarray):
|
47 |
+
depth = depth_map.copy().squeeze()
|
48 |
+
# reshape to [ (B,) H, W ]
|
49 |
+
if depth.ndim < 3:
|
50 |
+
depth = depth[np.newaxis, :, :]
|
51 |
+
|
52 |
+
# colorize
|
53 |
+
cm = matplotlib.colormaps[cmap]
|
54 |
+
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
55 |
+
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
|
56 |
+
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
57 |
+
|
58 |
+
if valid_mask is not None:
|
59 |
+
if isinstance(depth_map, torch.Tensor):
|
60 |
+
valid_mask = valid_mask.detach().numpy()
|
61 |
+
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
|
62 |
+
if valid_mask.ndim < 3:
|
63 |
+
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
64 |
+
else:
|
65 |
+
valid_mask = valid_mask[:, np.newaxis, :, :]
|
66 |
+
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
67 |
+
img_colored_np[~valid_mask] = 0
|
68 |
+
|
69 |
+
if isinstance(depth_map, torch.Tensor):
|
70 |
+
img_colored = torch.from_numpy(img_colored_np).float()
|
71 |
+
elif isinstance(depth_map, np.ndarray):
|
72 |
+
img_colored = img_colored_np
|
73 |
+
|
74 |
+
return img_colored
|
75 |
+
|
76 |
+
|
77 |
+
def chw2hwc(chw):
|
78 |
+
assert 3 == len(chw.shape)
|
79 |
+
if isinstance(chw, torch.Tensor):
|
80 |
+
hwc = torch.permute(chw, (1, 2, 0))
|
81 |
+
elif isinstance(chw, np.ndarray):
|
82 |
+
hwc = np.moveaxis(chw, 0, -1)
|
83 |
+
return hwc
|
utils/normal_ensemble.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def ensemble_normals(input_images:torch.Tensor):
|
7 |
+
normal_preds = input_images
|
8 |
+
|
9 |
+
bsz, d, h, w = normal_preds.shape
|
10 |
+
normal_preds = normal_preds / (torch.norm(normal_preds, p=2, dim=1).unsqueeze(1)+1e-5)
|
11 |
+
|
12 |
+
phi = torch.atan2(normal_preds[:,1,:,:], normal_preds[:,0,:,:]).mean(dim=0)
|
13 |
+
theta = torch.atan2(torch.norm(normal_preds[:,:2,:,:], p=2, dim=1), normal_preds[:,2,:,:]).mean(dim=0)
|
14 |
+
normal_pred = torch.zeros((d,h,w)).to(normal_preds)
|
15 |
+
normal_pred[0,:,:] = torch.sin(theta) * torch.cos(phi)
|
16 |
+
normal_pred[1,:,:] = torch.sin(theta) * torch.sin(phi)
|
17 |
+
normal_pred[2,:,:] = torch.cos(theta)
|
18 |
+
|
19 |
+
angle_error = torch.acos(torch.cosine_similarity(normal_pred[None], normal_preds, dim=1))
|
20 |
+
normal_idx = torch.argmin(angle_error.reshape(bsz,-1).sum(-1))
|
21 |
+
|
22 |
+
return normal_preds[normal_idx]
|
utils/seed_all.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------------------------
|
15 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
+
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
+
# --------------------------------------------------------------------------
|
19 |
+
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import random
|
23 |
+
import torch
|
24 |
+
|
25 |
+
|
26 |
+
def seed_all(seed: int = 0):
|
27 |
+
"""
|
28 |
+
Set random seeds of all components.
|
29 |
+
"""
|
30 |
+
random.seed(seed)
|
31 |
+
np.random.seed(seed)
|
32 |
+
torch.manual_seed(seed)
|
33 |
+
torch.cuda.manual_seed_all(seed)
|
utils/surface_normal.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
def init_image_coor(height, width):
|
9 |
+
x_row = np.arange(0, width)
|
10 |
+
x = np.tile(x_row, (height, 1))
|
11 |
+
x = x[np.newaxis, :, :]
|
12 |
+
x = x.astype(np.float32)
|
13 |
+
x = torch.from_numpy(x.copy()).cuda()
|
14 |
+
u_u0 = x - width/2.0
|
15 |
+
|
16 |
+
y_col = np.arange(0, height) # y_col = np.arange(0, height)
|
17 |
+
y = np.tile(y_col, (width, 1)).T
|
18 |
+
y = y[np.newaxis, :, :]
|
19 |
+
y = y.astype(np.float32)
|
20 |
+
y = torch.from_numpy(y.copy()).cuda()
|
21 |
+
v_v0 = y - height/2.0
|
22 |
+
return u_u0, v_v0
|
23 |
+
|
24 |
+
|
25 |
+
def depth_to_xyz(depth, focal_length):
|
26 |
+
b, c, h, w = depth.shape
|
27 |
+
u_u0, v_v0 = init_image_coor(h, w)
|
28 |
+
x = u_u0 * depth / focal_length
|
29 |
+
y = v_v0 * depth / focal_length
|
30 |
+
z = depth
|
31 |
+
pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
|
32 |
+
return pw
|
33 |
+
|
34 |
+
|
35 |
+
def get_surface_normal(xyz, patch_size=3):
|
36 |
+
# xyz: [1, h, w, 3]
|
37 |
+
x, y, z = torch.unbind(xyz, dim=3)
|
38 |
+
x = torch.unsqueeze(x, 0)
|
39 |
+
y = torch.unsqueeze(y, 0)
|
40 |
+
z = torch.unsqueeze(z, 0)
|
41 |
+
|
42 |
+
xx = x * x
|
43 |
+
yy = y * y
|
44 |
+
zz = z * z
|
45 |
+
xy = x * y
|
46 |
+
xz = x * z
|
47 |
+
yz = y * z
|
48 |
+
patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda()
|
49 |
+
xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2))
|
50 |
+
yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2))
|
51 |
+
zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2))
|
52 |
+
xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2))
|
53 |
+
xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2))
|
54 |
+
yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2))
|
55 |
+
ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch],
|
56 |
+
dim=4)
|
57 |
+
ATA = torch.squeeze(ATA)
|
58 |
+
ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3))
|
59 |
+
eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1])
|
60 |
+
ATA = ATA + eps_identity
|
61 |
+
x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2))
|
62 |
+
y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2))
|
63 |
+
z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2))
|
64 |
+
AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4)
|
65 |
+
AT1 = torch.squeeze(AT1)
|
66 |
+
AT1 = torch.unsqueeze(AT1, 3)
|
67 |
+
|
68 |
+
patch_num = 4
|
69 |
+
patch_x = int(AT1.size(1) / patch_num)
|
70 |
+
patch_y = int(AT1.size(0) / patch_num)
|
71 |
+
n_img = torch.randn(AT1.shape).cuda()
|
72 |
+
overlap = patch_size // 2 + 1
|
73 |
+
for x in range(int(patch_num)):
|
74 |
+
for y in range(int(patch_num)):
|
75 |
+
left_flg = 0 if x == 0 else 1
|
76 |
+
right_flg = 0 if x == patch_num -1 else 1
|
77 |
+
top_flg = 0 if y == 0 else 1
|
78 |
+
btm_flg = 0 if y == patch_num - 1 else 1
|
79 |
+
at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
|
80 |
+
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
|
81 |
+
ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
|
82 |
+
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
|
83 |
+
n_img_tmp, _ = torch.solve(at1, ata)
|
84 |
+
|
85 |
+
n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :]
|
86 |
+
n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select
|
87 |
+
|
88 |
+
n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True))
|
89 |
+
n_img_norm = n_img / n_img_L2
|
90 |
+
|
91 |
+
# re-orient normals consistently
|
92 |
+
orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0
|
93 |
+
n_img_norm[orient_mask] *= -1
|
94 |
+
return n_img_norm
|
95 |
+
|
96 |
+
def get_surface_normalv2(xyz, patch_size=3):
|
97 |
+
"""
|
98 |
+
xyz: xyz coordinates
|
99 |
+
patch: [p1, p2, p3,
|
100 |
+
p4, p5, p6,
|
101 |
+
p7, p8, p9]
|
102 |
+
surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)]
|
103 |
+
return: normal [h, w, 3, b]
|
104 |
+
"""
|
105 |
+
b, h, w, c = xyz.shape
|
106 |
+
half_patch = patch_size // 2
|
107 |
+
xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device)
|
108 |
+
xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz
|
109 |
+
|
110 |
+
# xyz_left_top = xyz_pad[:, :h, :w, :] # p1
|
111 |
+
# xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9
|
112 |
+
# xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7
|
113 |
+
# xyz_right_top = xyz_pad[:, :h, -w:, :] # p3
|
114 |
+
# xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9
|
115 |
+
# xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3
|
116 |
+
|
117 |
+
xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4
|
118 |
+
xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6
|
119 |
+
xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2
|
120 |
+
xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8
|
121 |
+
xyz_horizon = xyz_left - xyz_right # p4p6
|
122 |
+
xyz_vertical = xyz_top - xyz_bottom # p2p8
|
123 |
+
|
124 |
+
xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4
|
125 |
+
xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6
|
126 |
+
xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2
|
127 |
+
xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8
|
128 |
+
xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6
|
129 |
+
xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8
|
130 |
+
|
131 |
+
n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3)
|
132 |
+
n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3)
|
133 |
+
|
134 |
+
# re-orient normals consistently
|
135 |
+
orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0
|
136 |
+
n_img_1[orient_mask] *= -1
|
137 |
+
orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0
|
138 |
+
n_img_2[orient_mask] *= -1
|
139 |
+
|
140 |
+
n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True))
|
141 |
+
n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8)
|
142 |
+
|
143 |
+
n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True))
|
144 |
+
n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8)
|
145 |
+
|
146 |
+
# average 2 norms
|
147 |
+
n_img_aver = n_img1_norm + n_img2_norm
|
148 |
+
n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True))
|
149 |
+
n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8)
|
150 |
+
# re-orient normals consistently
|
151 |
+
orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0
|
152 |
+
n_img_aver_norm[orient_mask] *= -1
|
153 |
+
n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b]
|
154 |
+
|
155 |
+
# a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze()
|
156 |
+
# plt.imshow(np.abs(a), cmap='rainbow')
|
157 |
+
# plt.show()
|
158 |
+
return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0))
|
159 |
+
|
160 |
+
def surface_normal_from_depth(depth, focal_length, valid_mask=None):
|
161 |
+
# para depth: depth map, [b, c, h, w]
|
162 |
+
b, c, h, w = depth.shape
|
163 |
+
focal_length = focal_length[:, None, None, None]
|
164 |
+
depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1)
|
165 |
+
depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1)
|
166 |
+
xyz = depth_to_xyz(depth_filter, focal_length)
|
167 |
+
sn_batch = []
|
168 |
+
for i in range(b):
|
169 |
+
xyz_i = xyz[i, :][None, :, :, :]
|
170 |
+
normal = get_surface_normalv2(xyz_i)
|
171 |
+
sn_batch.append(normal)
|
172 |
+
sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w]
|
173 |
+
mask_invalid = (~valid_mask).repeat(1, 3, 1, 1)
|
174 |
+
sn_batch[mask_invalid] = 0.0
|
175 |
+
|
176 |
+
return sn_batch
|
177 |
+
|
178 |
+
|
179 |
+
def vis_normal(normal):
|
180 |
+
"""
|
181 |
+
Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255]
|
182 |
+
@para normal: surface normal, [h, w, 3], numpy.array
|
183 |
+
"""
|
184 |
+
n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True))
|
185 |
+
n_img_norm = normal / (n_img_L2 + 1e-8)
|
186 |
+
normal_vis = n_img_norm * 127
|
187 |
+
normal_vis += 128
|
188 |
+
normal_vis = normal_vis.astype(np.uint8)
|
189 |
+
return normal_vis
|
190 |
+
|
191 |
+
def vis_normal2(normals):
|
192 |
+
'''
|
193 |
+
Montage of normal maps. Vectors are unit length and backfaces thresholded.
|
194 |
+
'''
|
195 |
+
x = normals[:, :, 0] # horizontal; pos right
|
196 |
+
y = normals[:, :, 1] # depth; pos far
|
197 |
+
z = normals[:, :, 2] # vertical; pos up
|
198 |
+
backfacing = (z > 0)
|
199 |
+
norm = np.sqrt(np.sum(normals**2, axis=2))
|
200 |
+
zero = (norm < 1e-5)
|
201 |
+
x += 1.0; x *= 0.5
|
202 |
+
y += 1.0; y *= 0.5
|
203 |
+
z = np.abs(z)
|
204 |
+
x[zero] = 0.0
|
205 |
+
y[zero] = 0.0
|
206 |
+
z[zero] = 0.0
|
207 |
+
normals[:, :, 0] = x # horizontal; pos right
|
208 |
+
normals[:, :, 1] = y # depth; pos far
|
209 |
+
normals[:, :, 2] = z # vertical; pos up
|
210 |
+
return normals
|
211 |
+
|
212 |
+
if __name__ == '__main__':
|
213 |
+
import cv2, os
|