3D_Photo_Inpainting / bilateral_filtering.py
Saini
init
0b9f920
raw history blame
No virus
9.48 kB
import numpy as np
from functools import reduce
def sparse_bilateral_filtering(
depth, image, config, HR=False, mask=None, gsHR=True, edge_id=None, num_iter=None, num_gs_iter=None, spdb=False
):
"""
config:
- filter_size
"""
import time
save_images = []
save_depths = []
save_discontinuities = []
vis_depth = depth.copy()
backup_vis_depth = vis_depth.copy()
depth_max = vis_depth.max()
depth_min = vis_depth.min()
vis_image = image.copy()
for i in range(num_iter):
if isinstance(config["filter_size"], list):
window_size = config["filter_size"][i]
else:
window_size = config["filter_size"]
vis_image = image.copy()
save_images.append(vis_image)
save_depths.append(vis_depth)
u_over, b_over, l_over, r_over = vis_depth_discontinuity(vis_depth, config, mask=mask)
vis_image[u_over > 0] = np.array([0, 0, 0])
vis_image[b_over > 0] = np.array([0, 0, 0])
vis_image[l_over > 0] = np.array([0, 0, 0])
vis_image[r_over > 0] = np.array([0, 0, 0])
discontinuity_map = (u_over + b_over + l_over + r_over).clip(0.0, 1.0)
discontinuity_map[depth == 0] = 1
save_discontinuities.append(discontinuity_map)
if mask is not None:
discontinuity_map[mask == 0] = 0
vis_depth = bilateral_filter(
vis_depth, config, discontinuity_map=discontinuity_map, HR=HR, mask=mask, window_size=window_size
)
return save_images, save_depths
def vis_depth_discontinuity(depth, config, vis_diff=False, label=False, mask=None):
"""
config:
-
"""
if label == False:
disp = 1./depth
u_diff = (disp[1:, :] - disp[:-1, :])[:-1, 1:-1]
b_diff = (disp[:-1, :] - disp[1:, :])[1:, 1:-1]
l_diff = (disp[:, 1:] - disp[:, :-1])[1:-1, :-1]
r_diff = (disp[:, :-1] - disp[:, 1:])[1:-1, 1:]
if mask is not None:
u_mask = (mask[1:, :] * mask[:-1, :])[:-1, 1:-1]
b_mask = (mask[:-1, :] * mask[1:, :])[1:, 1:-1]
l_mask = (mask[:, 1:] * mask[:, :-1])[1:-1, :-1]
r_mask = (mask[:, :-1] * mask[:, 1:])[1:-1, 1:]
u_diff = u_diff * u_mask
b_diff = b_diff * b_mask
l_diff = l_diff * l_mask
r_diff = r_diff * r_mask
u_over = (np.abs(u_diff) > config['depth_threshold']).astype(np.float32)
b_over = (np.abs(b_diff) > config['depth_threshold']).astype(np.float32)
l_over = (np.abs(l_diff) > config['depth_threshold']).astype(np.float32)
r_over = (np.abs(r_diff) > config['depth_threshold']).astype(np.float32)
else:
disp = depth
u_diff = (disp[1:, :] * disp[:-1, :])[:-1, 1:-1]
b_diff = (disp[:-1, :] * disp[1:, :])[1:, 1:-1]
l_diff = (disp[:, 1:] * disp[:, :-1])[1:-1, :-1]
r_diff = (disp[:, :-1] * disp[:, 1:])[1:-1, 1:]
if mask is not None:
u_mask = (mask[1:, :] * mask[:-1, :])[:-1, 1:-1]
b_mask = (mask[:-1, :] * mask[1:, :])[1:, 1:-1]
l_mask = (mask[:, 1:] * mask[:, :-1])[1:-1, :-1]
r_mask = (mask[:, :-1] * mask[:, 1:])[1:-1, 1:]
u_diff = u_diff * u_mask
b_diff = b_diff * b_mask
l_diff = l_diff * l_mask
r_diff = r_diff * r_mask
u_over = (np.abs(u_diff) > 0).astype(np.float32)
b_over = (np.abs(b_diff) > 0).astype(np.float32)
l_over = (np.abs(l_diff) > 0).astype(np.float32)
r_over = (np.abs(r_diff) > 0).astype(np.float32)
u_over = np.pad(u_over, 1, mode='constant')
b_over = np.pad(b_over, 1, mode='constant')
l_over = np.pad(l_over, 1, mode='constant')
r_over = np.pad(r_over, 1, mode='constant')
u_diff = np.pad(u_diff, 1, mode='constant')
b_diff = np.pad(b_diff, 1, mode='constant')
l_diff = np.pad(l_diff, 1, mode='constant')
r_diff = np.pad(r_diff, 1, mode='constant')
if vis_diff:
return [u_over, b_over, l_over, r_over], [u_diff, b_diff, l_diff, r_diff]
else:
return [u_over, b_over, l_over, r_over]
def bilateral_filter(depth, config, discontinuity_map=None, HR=False, mask=None, window_size=False):
sort_time = 0
replace_time = 0
filter_time = 0
init_time = 0
filtering_time = 0
sigma_s = config['sigma_s']
sigma_r = config['sigma_r']
if window_size == False:
window_size = config['filter_size']
midpt = window_size//2
ax = np.arange(-midpt, midpt+1.)
xx, yy = np.meshgrid(ax, ax)
if discontinuity_map is not None:
spatial_term = np.exp(-(xx**2 + yy**2) / (2. * sigma_s**2))
# padding
depth = depth[1:-1, 1:-1]
depth = np.pad(depth, ((1,1), (1,1)), 'edge')
pad_depth = np.pad(depth, (midpt,midpt), 'edge')
if discontinuity_map is not None:
discontinuity_map = discontinuity_map[1:-1, 1:-1]
discontinuity_map = np.pad(discontinuity_map, ((1,1), (1,1)), 'edge')
pad_discontinuity_map = np.pad(discontinuity_map, (midpt,midpt), 'edge')
pad_discontinuity_hole = 1 - pad_discontinuity_map
# filtering
output = depth.copy()
pad_depth_patches = rolling_window(pad_depth, [window_size, window_size], [1,1])
if discontinuity_map is not None:
pad_discontinuity_patches = rolling_window(pad_discontinuity_map, [window_size, window_size], [1,1])
pad_discontinuity_hole_patches = rolling_window(pad_discontinuity_hole, [window_size, window_size], [1,1])
if mask is not None:
pad_mask = np.pad(mask, (midpt,midpt), 'constant')
pad_mask_patches = rolling_window(pad_mask, [window_size, window_size], [1,1])
from itertools import product
if discontinuity_map is not None:
pH, pW = pad_depth_patches.shape[:2]
for pi in range(pH):
for pj in range(pW):
if mask is not None and mask[pi, pj] == 0:
continue
if discontinuity_map is not None:
if bool(pad_discontinuity_patches[pi, pj].any()) is False:
continue
discontinuity_patch = pad_discontinuity_patches[pi, pj]
discontinuity_holes = pad_discontinuity_hole_patches[pi, pj]
depth_patch = pad_depth_patches[pi, pj]
depth_order = depth_patch.ravel().argsort()
patch_midpt = depth_patch[window_size//2, window_size//2]
if discontinuity_map is not None:
coef = discontinuity_holes.astype(np.float32)
if mask is not None:
coef = coef * pad_mask_patches[pi, pj]
else:
range_term = np.exp(-(depth_patch-patch_midpt)**2 / (2. * sigma_r**2))
coef = spatial_term * range_term
if coef.max() == 0:
output[pi, pj] = patch_midpt
continue
if discontinuity_map is not None and (coef.max() == 0):
output[pi, pj] = patch_midpt
else:
coef = coef/(coef.sum())
coef_order = coef.ravel()[depth_order]
cum_coef = np.cumsum(coef_order)
ind = np.digitize(0.5, cum_coef)
output[pi, pj] = depth_patch.ravel()[depth_order][ind]
else:
pH, pW = pad_depth_patches.shape[:2]
for pi in range(pH):
for pj in range(pW):
if discontinuity_map is not None:
if pad_discontinuity_patches[pi, pj][window_size//2, window_size//2] == 1:
continue
discontinuity_patch = pad_discontinuity_patches[pi, pj]
discontinuity_holes = (1. - discontinuity_patch)
depth_patch = pad_depth_patches[pi, pj]
depth_order = depth_patch.ravel().argsort()
patch_midpt = depth_patch[window_size//2, window_size//2]
range_term = np.exp(-(depth_patch-patch_midpt)**2 / (2. * sigma_r**2))
if discontinuity_map is not None:
coef = spatial_term * range_term * discontinuity_holes
else:
coef = spatial_term * range_term
if coef.sum() == 0:
output[pi, pj] = patch_midpt
continue
if discontinuity_map is not None and (coef.sum() == 0):
output[pi, pj] = patch_midpt
else:
coef = coef/(coef.sum())
coef_order = coef.ravel()[depth_order]
cum_coef = np.cumsum(coef_order)
ind = np.digitize(0.5, cum_coef)
output[pi, pj] = depth_patch.ravel()[depth_order][ind]
return output
def rolling_window(a, window, strides):
assert len(a.shape)==len(window)==len(strides), "\'a\', \'window\', \'strides\' dimension mismatch"
shape_fn = lambda i,w,s: (a.shape[i]-w)//s + 1
shape = [shape_fn(i,w,s) for i,(w,s) in enumerate(zip(window, strides))] + list(window)
def acc_shape(i):
if i+1>=len(a.shape):
return 1
else:
return reduce(lambda x,y:x*y, a.shape[i+1:])
_strides = [acc_shape(i)*s*a.itemsize for i,s in enumerate(strides)] + list(a.strides)
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=_strides)