import numpy as np from functools import reduce def sparse_bilateral_filtering( depth, image, config, HR=False, mask=None, gsHR=True, edge_id=None, num_iter=None, num_gs_iter=None, spdb=False ): """ config: - filter_size """ import time save_images = [] save_depths = [] save_discontinuities = [] vis_depth = depth.copy() backup_vis_depth = vis_depth.copy() depth_max = vis_depth.max() depth_min = vis_depth.min() vis_image = image.copy() for i in range(num_iter): if isinstance(config["filter_size"], list): window_size = config["filter_size"][i] else: window_size = config["filter_size"] vis_image = image.copy() save_images.append(vis_image) save_depths.append(vis_depth) u_over, b_over, l_over, r_over = vis_depth_discontinuity(vis_depth, config, mask=mask) vis_image[u_over > 0] = np.array([0, 0, 0]) vis_image[b_over > 0] = np.array([0, 0, 0]) vis_image[l_over > 0] = np.array([0, 0, 0]) vis_image[r_over > 0] = np.array([0, 0, 0]) discontinuity_map = (u_over + b_over + l_over + r_over).clip(0.0, 1.0) discontinuity_map[depth == 0] = 1 save_discontinuities.append(discontinuity_map) if mask is not None: discontinuity_map[mask == 0] = 0 vis_depth = bilateral_filter( vis_depth, config, discontinuity_map=discontinuity_map, HR=HR, mask=mask, window_size=window_size ) return save_images, save_depths def vis_depth_discontinuity(depth, config, vis_diff=False, label=False, mask=None): """ config: - """ if label == False: disp = 1./depth u_diff = (disp[1:, :] - disp[:-1, :])[:-1, 1:-1] b_diff = (disp[:-1, :] - disp[1:, :])[1:, 1:-1] l_diff = (disp[:, 1:] - disp[:, :-1])[1:-1, :-1] r_diff = (disp[:, :-1] - disp[:, 1:])[1:-1, 1:] if mask is not None: u_mask = (mask[1:, :] * mask[:-1, :])[:-1, 1:-1] b_mask = (mask[:-1, :] * mask[1:, :])[1:, 1:-1] l_mask = (mask[:, 1:] * mask[:, :-1])[1:-1, :-1] r_mask = (mask[:, :-1] * mask[:, 1:])[1:-1, 1:] u_diff = u_diff * u_mask b_diff = b_diff * b_mask l_diff = l_diff * l_mask r_diff = r_diff * r_mask u_over = (np.abs(u_diff) > config['depth_threshold']).astype(np.float32) b_over = (np.abs(b_diff) > config['depth_threshold']).astype(np.float32) l_over = (np.abs(l_diff) > config['depth_threshold']).astype(np.float32) r_over = (np.abs(r_diff) > config['depth_threshold']).astype(np.float32) else: disp = depth u_diff = (disp[1:, :] * disp[:-1, :])[:-1, 1:-1] b_diff = (disp[:-1, :] * disp[1:, :])[1:, 1:-1] l_diff = (disp[:, 1:] * disp[:, :-1])[1:-1, :-1] r_diff = (disp[:, :-1] * disp[:, 1:])[1:-1, 1:] if mask is not None: u_mask = (mask[1:, :] * mask[:-1, :])[:-1, 1:-1] b_mask = (mask[:-1, :] * mask[1:, :])[1:, 1:-1] l_mask = (mask[:, 1:] * mask[:, :-1])[1:-1, :-1] r_mask = (mask[:, :-1] * mask[:, 1:])[1:-1, 1:] u_diff = u_diff * u_mask b_diff = b_diff * b_mask l_diff = l_diff * l_mask r_diff = r_diff * r_mask u_over = (np.abs(u_diff) > 0).astype(np.float32) b_over = (np.abs(b_diff) > 0).astype(np.float32) l_over = (np.abs(l_diff) > 0).astype(np.float32) r_over = (np.abs(r_diff) > 0).astype(np.float32) u_over = np.pad(u_over, 1, mode='constant') b_over = np.pad(b_over, 1, mode='constant') l_over = np.pad(l_over, 1, mode='constant') r_over = np.pad(r_over, 1, mode='constant') u_diff = np.pad(u_diff, 1, mode='constant') b_diff = np.pad(b_diff, 1, mode='constant') l_diff = np.pad(l_diff, 1, mode='constant') r_diff = np.pad(r_diff, 1, mode='constant') if vis_diff: return [u_over, b_over, l_over, r_over], [u_diff, b_diff, l_diff, r_diff] else: return [u_over, b_over, l_over, r_over] def bilateral_filter(depth, config, discontinuity_map=None, HR=False, mask=None, window_size=False): sort_time = 0 replace_time = 0 filter_time = 0 init_time = 0 filtering_time = 0 sigma_s = config['sigma_s'] sigma_r = config['sigma_r'] if window_size == False: window_size = config['filter_size'] midpt = window_size//2 ax = np.arange(-midpt, midpt+1.) xx, yy = np.meshgrid(ax, ax) if discontinuity_map is not None: spatial_term = np.exp(-(xx**2 + yy**2) / (2. * sigma_s**2)) # padding depth = depth[1:-1, 1:-1] depth = np.pad(depth, ((1,1), (1,1)), 'edge') pad_depth = np.pad(depth, (midpt,midpt), 'edge') if discontinuity_map is not None: discontinuity_map = discontinuity_map[1:-1, 1:-1] discontinuity_map = np.pad(discontinuity_map, ((1,1), (1,1)), 'edge') pad_discontinuity_map = np.pad(discontinuity_map, (midpt,midpt), 'edge') pad_discontinuity_hole = 1 - pad_discontinuity_map # filtering output = depth.copy() pad_depth_patches = rolling_window(pad_depth, [window_size, window_size], [1,1]) if discontinuity_map is not None: pad_discontinuity_patches = rolling_window(pad_discontinuity_map, [window_size, window_size], [1,1]) pad_discontinuity_hole_patches = rolling_window(pad_discontinuity_hole, [window_size, window_size], [1,1]) if mask is not None: pad_mask = np.pad(mask, (midpt,midpt), 'constant') pad_mask_patches = rolling_window(pad_mask, [window_size, window_size], [1,1]) from itertools import product if discontinuity_map is not None: pH, pW = pad_depth_patches.shape[:2] for pi in range(pH): for pj in range(pW): if mask is not None and mask[pi, pj] == 0: continue if discontinuity_map is not None: if bool(pad_discontinuity_patches[pi, pj].any()) is False: continue discontinuity_patch = pad_discontinuity_patches[pi, pj] discontinuity_holes = pad_discontinuity_hole_patches[pi, pj] depth_patch = pad_depth_patches[pi, pj] depth_order = depth_patch.ravel().argsort() patch_midpt = depth_patch[window_size//2, window_size//2] if discontinuity_map is not None: coef = discontinuity_holes.astype(np.float32) if mask is not None: coef = coef * pad_mask_patches[pi, pj] else: range_term = np.exp(-(depth_patch-patch_midpt)**2 / (2. * sigma_r**2)) coef = spatial_term * range_term if coef.max() == 0: output[pi, pj] = patch_midpt continue if discontinuity_map is not None and (coef.max() == 0): output[pi, pj] = patch_midpt else: coef = coef/(coef.sum()) coef_order = coef.ravel()[depth_order] cum_coef = np.cumsum(coef_order) ind = np.digitize(0.5, cum_coef) output[pi, pj] = depth_patch.ravel()[depth_order][ind] else: pH, pW = pad_depth_patches.shape[:2] for pi in range(pH): for pj in range(pW): if discontinuity_map is not None: if pad_discontinuity_patches[pi, pj][window_size//2, window_size//2] == 1: continue discontinuity_patch = pad_discontinuity_patches[pi, pj] discontinuity_holes = (1. - discontinuity_patch) depth_patch = pad_depth_patches[pi, pj] depth_order = depth_patch.ravel().argsort() patch_midpt = depth_patch[window_size//2, window_size//2] range_term = np.exp(-(depth_patch-patch_midpt)**2 / (2. * sigma_r**2)) if discontinuity_map is not None: coef = spatial_term * range_term * discontinuity_holes else: coef = spatial_term * range_term if coef.sum() == 0: output[pi, pj] = patch_midpt continue if discontinuity_map is not None and (coef.sum() == 0): output[pi, pj] = patch_midpt else: coef = coef/(coef.sum()) coef_order = coef.ravel()[depth_order] cum_coef = np.cumsum(coef_order) ind = np.digitize(0.5, cum_coef) output[pi, pj] = depth_patch.ravel()[depth_order][ind] return output def rolling_window(a, window, strides): assert len(a.shape)==len(window)==len(strides), "\'a\', \'window\', \'strides\' dimension mismatch" shape_fn = lambda i,w,s: (a.shape[i]-w)//s + 1 shape = [shape_fn(i,w,s) for i,(w,s) in enumerate(zip(window, strides))] + list(window) def acc_shape(i): if i+1>=len(a.shape): return 1 else: return reduce(lambda x,y:x*y, a.shape[i+1:]) _strides = [acc_shape(i)*s*a.itemsize for i,s in enumerate(strides)] + list(a.strides) return np.lib.stride_tricks.as_strided(a, shape=shape, strides=_strides)