peekaboo-demo / bilateral_solver.py
hasibzunair's picture
add files
1803579
raw
history blame
7.7 kB
"""
Code adapted from TokenCut: https://github.com/YangtaoWANG95/TokenCut
"""
import PIL.Image as Image
import numpy as np
from scipy import ndimage
from scipy.sparse import diags, csr_matrix
from scipy.sparse.linalg import cg
RGB_TO_YUV = np.array(
[[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]]
)
YUV_TO_RGB = np.array([[1.0, 0.0, 1.402], [1.0, -0.34414, -0.71414], [1.0, 1.772, 0.0]])
YUV_OFFSET = np.array([0, 128.0, 128.0]).reshape(1, 1, -1)
MAX_VAL = 255.0
def rgb2yuv(im):
return np.tensordot(im, RGB_TO_YUV, ([2], [1])) + YUV_OFFSET
def yuv2rgb(im):
return np.tensordot(im.astype(float) - YUV_OFFSET, YUV_TO_RGB, ([2], [1]))
def get_valid_idx(valid, candidates):
"""Find which values are present in a list and where they are located"""
locs = np.searchsorted(valid, candidates)
# Handle edge case where the candidate is larger than all valid values
locs = np.clip(locs, 0, len(valid) - 1)
# Identify which values are actually present
valid_idx = np.flatnonzero(valid[locs] == candidates)
locs = locs[valid_idx]
return valid_idx, locs
class BilateralGrid(object):
def __init__(self, im, sigma_spatial=32, sigma_luma=8, sigma_chroma=8):
im_yuv = rgb2yuv(im)
# Compute 5-dimensional XYLUV bilateral-space coordinates
Iy, Ix = np.mgrid[: im.shape[0], : im.shape[1]]
x_coords = (Ix / sigma_spatial).astype(int)
y_coords = (Iy / sigma_spatial).astype(int)
luma_coords = (im_yuv[..., 0] / sigma_luma).astype(int)
chroma_coords = (im_yuv[..., 1:] / sigma_chroma).astype(int)
coords = np.dstack((x_coords, y_coords, luma_coords, chroma_coords))
coords_flat = coords.reshape(-1, coords.shape[-1])
self.npixels, self.dim = coords_flat.shape
# Hacky "hash vector" for coordinates,
# Requires all scaled coordinates be < MAX_VAL
self.hash_vec = MAX_VAL ** np.arange(self.dim)
# Construct S and B matrix
self._compute_factorization(coords_flat)
def _compute_factorization(self, coords_flat):
# Hash each coordinate in grid to a unique value
hashed_coords = self._hash_coords(coords_flat)
unique_hashes, unique_idx, idx = np.unique(
hashed_coords, return_index=True, return_inverse=True
)
# Identify unique set of vertices
unique_coords = coords_flat[unique_idx]
self.nvertices = len(unique_coords)
# Construct sparse splat matrix that maps from pixels to vertices
self.S = csr_matrix((np.ones(self.npixels), (idx, np.arange(self.npixels))))
# Construct sparse blur matrices.
# Note that these represent [1 0 1] blurs, excluding the central element
self.blurs = []
for d in range(self.dim):
blur = 0.0
for offset in (-1, 1):
offset_vec = np.zeros((1, self.dim))
offset_vec[:, d] = offset
neighbor_hash = self._hash_coords(unique_coords + offset_vec)
valid_coord, idx = get_valid_idx(unique_hashes, neighbor_hash)
blur = blur + csr_matrix(
(np.ones((len(valid_coord),)), (valid_coord, idx)),
shape=(self.nvertices, self.nvertices),
)
self.blurs.append(blur)
def _hash_coords(self, coord):
"""Hacky function to turn a coordinate into a unique value"""
return np.dot(coord.reshape(-1, self.dim), self.hash_vec)
def splat(self, x):
return self.S.dot(x)
def slice(self, y):
return self.S.T.dot(y)
def blur(self, x):
"""Blur a bilateral-space vector with a 1 2 1 kernel in each dimension"""
assert x.shape[0] == self.nvertices
out = 2 * self.dim * x
for blur in self.blurs:
out = out + blur.dot(x)
return out
def filter(self, x):
"""Apply bilateral filter to an input x"""
return self.slice(self.blur(self.splat(x))) / self.slice(
self.blur(self.splat(np.ones_like(x)))
)
def bistochastize(grid, maxiter=10):
"""Compute diagonal matrices to bistochastize a bilateral grid"""
m = grid.splat(np.ones(grid.npixels))
n = np.ones(grid.nvertices)
for i in range(maxiter):
n = np.sqrt(n * m / grid.blur(n))
# Correct m to satisfy the assumption of bistochastization regardless
# of how many iterations have been run.
m = n * grid.blur(n)
Dm = diags(m, 0)
Dn = diags(n, 0)
return Dn, Dm
class BilateralSolver(object):
def __init__(self, grid, params):
self.grid = grid
self.params = params
self.Dn, self.Dm = bistochastize(grid)
def solve(self, x, w):
# Check that w is a vector or a nx1 matrix
if w.ndim == 2:
assert w.shape[1] == 1
elif w.dim == 1:
w = w.reshape(w.shape[0], 1)
A_smooth = self.Dm - self.Dn.dot(self.grid.blur(self.Dn))
w_splat = self.grid.splat(w)
A_data = diags(w_splat[:, 0], 0)
A = self.params["lam"] * A_smooth + A_data
xw = x * w
b = self.grid.splat(xw)
# Use simple Jacobi preconditioner
A_diag = np.maximum(A.diagonal(), self.params["A_diag_min"])
M = diags(1 / A_diag, 0)
# Flat initialization
y0 = self.grid.splat(xw) / w_splat
yhat = np.empty_like(y0)
for d in range(x.shape[-1]):
yhat[..., d], info = cg(
A,
b[..., d],
x0=y0[..., d],
M=M,
maxiter=self.params["cg_maxiter"],
tol=self.params["cg_tol"],
)
xhat = self.grid.slice(yhat)
return xhat
def bilateral_solver_output(
img_pth,
target,
img=None,
sigma_spatial=24,
sigma_luma=4,
sigma_chroma=4,
get_all_cc=False,
):
if img is None:
reference = np.array(Image.open(img_pth).convert("RGB"))
else:
reference = np.array(img)
h, w = target.shape
confidence = np.ones((h, w)) * 0.999
grid_params = {
"sigma_luma": sigma_luma, # Brightness bandwidth
"sigma_chroma": sigma_chroma, # Color bandwidth
"sigma_spatial": sigma_spatial, # Spatial bandwidth
}
bs_params = {
"lam": 256, # The strength of the smoothness parameter
"A_diag_min": 1e-5, # Clamp the diagonal of the A diagonal in the Jacobi preconditioner.
"cg_tol": 1e-5, # The tolerance on the convergence in PCG
"cg_maxiter": 25, # The number of PCG iterations
}
grid = BilateralGrid(reference, **grid_params)
t = target.reshape(-1, 1).astype(np.double)
c = confidence.reshape(-1, 1).astype(np.double)
# output solver, which is a soft value
output_solver = BilateralSolver(grid, bs_params).solve(t, c).reshape((h, w))
binary_solver = ndimage.binary_fill_holes(output_solver > 0.5)
labeled, nr_objects = ndimage.label(binary_solver)
nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)]
pixel_order = np.argsort(nb_pixel)
if get_all_cc:
# Remove known bakground
pixel_descending_order = pixel_order[::-1]
# Get all CC expect biggest one, may consider it as background, try and change here
binary_solver = (
(labeled[None, :, :] == pixel_descending_order[1:, None, None])
.astype(int)
.sum(0)
)
else:
try:
binary_solver = labeled == pixel_order[-2]
except:
binary_solver = np.ones((h, w), dtype=bool)
return output_solver, binary_solver