Spaces:
Runtime error
Runtime error
import numpy as np | |
from functools import reduce | |
def sparse_bilateral_filtering( | |
depth, image, config, HR=False, mask=None, gsHR=True, edge_id=None, num_iter=None, num_gs_iter=None, spdb=False | |
): | |
""" | |
config: | |
- filter_size | |
""" | |
import time | |
save_images = [] | |
save_depths = [] | |
save_discontinuities = [] | |
vis_depth = depth.copy() | |
backup_vis_depth = vis_depth.copy() | |
depth_max = vis_depth.max() | |
depth_min = vis_depth.min() | |
vis_image = image.copy() | |
for i in range(num_iter): | |
if isinstance(config["filter_size"], list): | |
window_size = config["filter_size"][i] | |
else: | |
window_size = config["filter_size"] | |
vis_image = image.copy() | |
save_images.append(vis_image) | |
save_depths.append(vis_depth) | |
u_over, b_over, l_over, r_over = vis_depth_discontinuity(vis_depth, config, mask=mask) | |
vis_image[u_over > 0] = np.array([0, 0, 0]) | |
vis_image[b_over > 0] = np.array([0, 0, 0]) | |
vis_image[l_over > 0] = np.array([0, 0, 0]) | |
vis_image[r_over > 0] = np.array([0, 0, 0]) | |
discontinuity_map = (u_over + b_over + l_over + r_over).clip(0.0, 1.0) | |
discontinuity_map[depth == 0] = 1 | |
save_discontinuities.append(discontinuity_map) | |
if mask is not None: | |
discontinuity_map[mask == 0] = 0 | |
vis_depth = bilateral_filter( | |
vis_depth, config, discontinuity_map=discontinuity_map, HR=HR, mask=mask, window_size=window_size | |
) | |
return save_images, save_depths | |
def vis_depth_discontinuity(depth, config, vis_diff=False, label=False, mask=None): | |
""" | |
config: | |
- | |
""" | |
if label == False: | |
disp = 1./depth | |
u_diff = (disp[1:, :] - disp[:-1, :])[:-1, 1:-1] | |
b_diff = (disp[:-1, :] - disp[1:, :])[1:, 1:-1] | |
l_diff = (disp[:, 1:] - disp[:, :-1])[1:-1, :-1] | |
r_diff = (disp[:, :-1] - disp[:, 1:])[1:-1, 1:] | |
if mask is not None: | |
u_mask = (mask[1:, :] * mask[:-1, :])[:-1, 1:-1] | |
b_mask = (mask[:-1, :] * mask[1:, :])[1:, 1:-1] | |
l_mask = (mask[:, 1:] * mask[:, :-1])[1:-1, :-1] | |
r_mask = (mask[:, :-1] * mask[:, 1:])[1:-1, 1:] | |
u_diff = u_diff * u_mask | |
b_diff = b_diff * b_mask | |
l_diff = l_diff * l_mask | |
r_diff = r_diff * r_mask | |
u_over = (np.abs(u_diff) > config['depth_threshold']).astype(np.float32) | |
b_over = (np.abs(b_diff) > config['depth_threshold']).astype(np.float32) | |
l_over = (np.abs(l_diff) > config['depth_threshold']).astype(np.float32) | |
r_over = (np.abs(r_diff) > config['depth_threshold']).astype(np.float32) | |
else: | |
disp = depth | |
u_diff = (disp[1:, :] * disp[:-1, :])[:-1, 1:-1] | |
b_diff = (disp[:-1, :] * disp[1:, :])[1:, 1:-1] | |
l_diff = (disp[:, 1:] * disp[:, :-1])[1:-1, :-1] | |
r_diff = (disp[:, :-1] * disp[:, 1:])[1:-1, 1:] | |
if mask is not None: | |
u_mask = (mask[1:, :] * mask[:-1, :])[:-1, 1:-1] | |
b_mask = (mask[:-1, :] * mask[1:, :])[1:, 1:-1] | |
l_mask = (mask[:, 1:] * mask[:, :-1])[1:-1, :-1] | |
r_mask = (mask[:, :-1] * mask[:, 1:])[1:-1, 1:] | |
u_diff = u_diff * u_mask | |
b_diff = b_diff * b_mask | |
l_diff = l_diff * l_mask | |
r_diff = r_diff * r_mask | |
u_over = (np.abs(u_diff) > 0).astype(np.float32) | |
b_over = (np.abs(b_diff) > 0).astype(np.float32) | |
l_over = (np.abs(l_diff) > 0).astype(np.float32) | |
r_over = (np.abs(r_diff) > 0).astype(np.float32) | |
u_over = np.pad(u_over, 1, mode='constant') | |
b_over = np.pad(b_over, 1, mode='constant') | |
l_over = np.pad(l_over, 1, mode='constant') | |
r_over = np.pad(r_over, 1, mode='constant') | |
u_diff = np.pad(u_diff, 1, mode='constant') | |
b_diff = np.pad(b_diff, 1, mode='constant') | |
l_diff = np.pad(l_diff, 1, mode='constant') | |
r_diff = np.pad(r_diff, 1, mode='constant') | |
if vis_diff: | |
return [u_over, b_over, l_over, r_over], [u_diff, b_diff, l_diff, r_diff] | |
else: | |
return [u_over, b_over, l_over, r_over] | |
def bilateral_filter(depth, config, discontinuity_map=None, HR=False, mask=None, window_size=False): | |
sort_time = 0 | |
replace_time = 0 | |
filter_time = 0 | |
init_time = 0 | |
filtering_time = 0 | |
sigma_s = config['sigma_s'] | |
sigma_r = config['sigma_r'] | |
if window_size == False: | |
window_size = config['filter_size'] | |
midpt = window_size//2 | |
ax = np.arange(-midpt, midpt+1.) | |
xx, yy = np.meshgrid(ax, ax) | |
if discontinuity_map is not None: | |
spatial_term = np.exp(-(xx**2 + yy**2) / (2. * sigma_s**2)) | |
# padding | |
depth = depth[1:-1, 1:-1] | |
depth = np.pad(depth, ((1,1), (1,1)), 'edge') | |
pad_depth = np.pad(depth, (midpt,midpt), 'edge') | |
if discontinuity_map is not None: | |
discontinuity_map = discontinuity_map[1:-1, 1:-1] | |
discontinuity_map = np.pad(discontinuity_map, ((1,1), (1,1)), 'edge') | |
pad_discontinuity_map = np.pad(discontinuity_map, (midpt,midpt), 'edge') | |
pad_discontinuity_hole = 1 - pad_discontinuity_map | |
# filtering | |
output = depth.copy() | |
pad_depth_patches = rolling_window(pad_depth, [window_size, window_size], [1,1]) | |
if discontinuity_map is not None: | |
pad_discontinuity_patches = rolling_window(pad_discontinuity_map, [window_size, window_size], [1,1]) | |
pad_discontinuity_hole_patches = rolling_window(pad_discontinuity_hole, [window_size, window_size], [1,1]) | |
if mask is not None: | |
pad_mask = np.pad(mask, (midpt,midpt), 'constant') | |
pad_mask_patches = rolling_window(pad_mask, [window_size, window_size], [1,1]) | |
from itertools import product | |
if discontinuity_map is not None: | |
pH, pW = pad_depth_patches.shape[:2] | |
for pi in range(pH): | |
for pj in range(pW): | |
if mask is not None and mask[pi, pj] == 0: | |
continue | |
if discontinuity_map is not None: | |
if bool(pad_discontinuity_patches[pi, pj].any()) is False: | |
continue | |
discontinuity_patch = pad_discontinuity_patches[pi, pj] | |
discontinuity_holes = pad_discontinuity_hole_patches[pi, pj] | |
depth_patch = pad_depth_patches[pi, pj] | |
depth_order = depth_patch.ravel().argsort() | |
patch_midpt = depth_patch[window_size//2, window_size//2] | |
if discontinuity_map is not None: | |
coef = discontinuity_holes.astype(np.float32) | |
if mask is not None: | |
coef = coef * pad_mask_patches[pi, pj] | |
else: | |
range_term = np.exp(-(depth_patch-patch_midpt)**2 / (2. * sigma_r**2)) | |
coef = spatial_term * range_term | |
if coef.max() == 0: | |
output[pi, pj] = patch_midpt | |
continue | |
if discontinuity_map is not None and (coef.max() == 0): | |
output[pi, pj] = patch_midpt | |
else: | |
coef = coef/(coef.sum()) | |
coef_order = coef.ravel()[depth_order] | |
cum_coef = np.cumsum(coef_order) | |
ind = np.digitize(0.5, cum_coef) | |
output[pi, pj] = depth_patch.ravel()[depth_order][ind] | |
else: | |
pH, pW = pad_depth_patches.shape[:2] | |
for pi in range(pH): | |
for pj in range(pW): | |
if discontinuity_map is not None: | |
if pad_discontinuity_patches[pi, pj][window_size//2, window_size//2] == 1: | |
continue | |
discontinuity_patch = pad_discontinuity_patches[pi, pj] | |
discontinuity_holes = (1. - discontinuity_patch) | |
depth_patch = pad_depth_patches[pi, pj] | |
depth_order = depth_patch.ravel().argsort() | |
patch_midpt = depth_patch[window_size//2, window_size//2] | |
range_term = np.exp(-(depth_patch-patch_midpt)**2 / (2. * sigma_r**2)) | |
if discontinuity_map is not None: | |
coef = spatial_term * range_term * discontinuity_holes | |
else: | |
coef = spatial_term * range_term | |
if coef.sum() == 0: | |
output[pi, pj] = patch_midpt | |
continue | |
if discontinuity_map is not None and (coef.sum() == 0): | |
output[pi, pj] = patch_midpt | |
else: | |
coef = coef/(coef.sum()) | |
coef_order = coef.ravel()[depth_order] | |
cum_coef = np.cumsum(coef_order) | |
ind = np.digitize(0.5, cum_coef) | |
output[pi, pj] = depth_patch.ravel()[depth_order][ind] | |
return output | |
def rolling_window(a, window, strides): | |
assert len(a.shape)==len(window)==len(strides), "\'a\', \'window\', \'strides\' dimension mismatch" | |
shape_fn = lambda i,w,s: (a.shape[i]-w)//s + 1 | |
shape = [shape_fn(i,w,s) for i,(w,s) in enumerate(zip(window, strides))] + list(window) | |
def acc_shape(i): | |
if i+1>=len(a.shape): | |
return 1 | |
else: | |
return reduce(lambda x,y:x*y, a.shape[i+1:]) | |
_strides = [acc_shape(i)*s*a.itemsize for i,s in enumerate(strides)] + list(a.strides) | |
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=_strides) | |