|
""" |
|
Code adapted from TokenCut: https://github.com/YangtaoWANG95/TokenCut |
|
""" |
|
|
|
import PIL.Image as Image |
|
import numpy as np |
|
from scipy import ndimage |
|
from scipy.sparse import diags, csr_matrix |
|
from scipy.sparse.linalg import cg |
|
|
|
RGB_TO_YUV = np.array( |
|
[[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]] |
|
) |
|
YUV_TO_RGB = np.array([[1.0, 0.0, 1.402], [1.0, -0.34414, -0.71414], [1.0, 1.772, 0.0]]) |
|
YUV_OFFSET = np.array([0, 128.0, 128.0]).reshape(1, 1, -1) |
|
MAX_VAL = 255.0 |
|
|
|
|
|
def rgb2yuv(im): |
|
return np.tensordot(im, RGB_TO_YUV, ([2], [1])) + YUV_OFFSET |
|
|
|
|
|
def yuv2rgb(im): |
|
return np.tensordot(im.astype(float) - YUV_OFFSET, YUV_TO_RGB, ([2], [1])) |
|
|
|
|
|
def get_valid_idx(valid, candidates): |
|
"""Find which values are present in a list and where they are located""" |
|
locs = np.searchsorted(valid, candidates) |
|
|
|
locs = np.clip(locs, 0, len(valid) - 1) |
|
|
|
valid_idx = np.flatnonzero(valid[locs] == candidates) |
|
locs = locs[valid_idx] |
|
return valid_idx, locs |
|
|
|
|
|
class BilateralGrid(object): |
|
def __init__(self, im, sigma_spatial=32, sigma_luma=8, sigma_chroma=8): |
|
im_yuv = rgb2yuv(im) |
|
|
|
Iy, Ix = np.mgrid[: im.shape[0], : im.shape[1]] |
|
x_coords = (Ix / sigma_spatial).astype(int) |
|
y_coords = (Iy / sigma_spatial).astype(int) |
|
luma_coords = (im_yuv[..., 0] / sigma_luma).astype(int) |
|
chroma_coords = (im_yuv[..., 1:] / sigma_chroma).astype(int) |
|
coords = np.dstack((x_coords, y_coords, luma_coords, chroma_coords)) |
|
coords_flat = coords.reshape(-1, coords.shape[-1]) |
|
self.npixels, self.dim = coords_flat.shape |
|
|
|
|
|
self.hash_vec = MAX_VAL ** np.arange(self.dim) |
|
|
|
self._compute_factorization(coords_flat) |
|
|
|
def _compute_factorization(self, coords_flat): |
|
|
|
hashed_coords = self._hash_coords(coords_flat) |
|
unique_hashes, unique_idx, idx = np.unique( |
|
hashed_coords, return_index=True, return_inverse=True |
|
) |
|
|
|
unique_coords = coords_flat[unique_idx] |
|
self.nvertices = len(unique_coords) |
|
|
|
self.S = csr_matrix((np.ones(self.npixels), (idx, np.arange(self.npixels)))) |
|
|
|
|
|
self.blurs = [] |
|
for d in range(self.dim): |
|
blur = 0.0 |
|
for offset in (-1, 1): |
|
offset_vec = np.zeros((1, self.dim)) |
|
offset_vec[:, d] = offset |
|
neighbor_hash = self._hash_coords(unique_coords + offset_vec) |
|
valid_coord, idx = get_valid_idx(unique_hashes, neighbor_hash) |
|
blur = blur + csr_matrix( |
|
(np.ones((len(valid_coord),)), (valid_coord, idx)), |
|
shape=(self.nvertices, self.nvertices), |
|
) |
|
self.blurs.append(blur) |
|
|
|
def _hash_coords(self, coord): |
|
"""Hacky function to turn a coordinate into a unique value""" |
|
return np.dot(coord.reshape(-1, self.dim), self.hash_vec) |
|
|
|
def splat(self, x): |
|
return self.S.dot(x) |
|
|
|
def slice(self, y): |
|
return self.S.T.dot(y) |
|
|
|
def blur(self, x): |
|
"""Blur a bilateral-space vector with a 1 2 1 kernel in each dimension""" |
|
assert x.shape[0] == self.nvertices |
|
out = 2 * self.dim * x |
|
for blur in self.blurs: |
|
out = out + blur.dot(x) |
|
return out |
|
|
|
def filter(self, x): |
|
"""Apply bilateral filter to an input x""" |
|
return self.slice(self.blur(self.splat(x))) / self.slice( |
|
self.blur(self.splat(np.ones_like(x))) |
|
) |
|
|
|
|
|
def bistochastize(grid, maxiter=10): |
|
"""Compute diagonal matrices to bistochastize a bilateral grid""" |
|
m = grid.splat(np.ones(grid.npixels)) |
|
n = np.ones(grid.nvertices) |
|
for i in range(maxiter): |
|
n = np.sqrt(n * m / grid.blur(n)) |
|
|
|
|
|
m = n * grid.blur(n) |
|
Dm = diags(m, 0) |
|
Dn = diags(n, 0) |
|
return Dn, Dm |
|
|
|
|
|
class BilateralSolver(object): |
|
def __init__(self, grid, params): |
|
self.grid = grid |
|
self.params = params |
|
self.Dn, self.Dm = bistochastize(grid) |
|
|
|
def solve(self, x, w): |
|
|
|
if w.ndim == 2: |
|
assert w.shape[1] == 1 |
|
elif w.dim == 1: |
|
w = w.reshape(w.shape[0], 1) |
|
A_smooth = self.Dm - self.Dn.dot(self.grid.blur(self.Dn)) |
|
w_splat = self.grid.splat(w) |
|
A_data = diags(w_splat[:, 0], 0) |
|
A = self.params["lam"] * A_smooth + A_data |
|
xw = x * w |
|
b = self.grid.splat(xw) |
|
|
|
A_diag = np.maximum(A.diagonal(), self.params["A_diag_min"]) |
|
M = diags(1 / A_diag, 0) |
|
|
|
y0 = self.grid.splat(xw) / w_splat |
|
yhat = np.empty_like(y0) |
|
for d in range(x.shape[-1]): |
|
yhat[..., d], info = cg( |
|
A, |
|
b[..., d], |
|
x0=y0[..., d], |
|
M=M, |
|
maxiter=self.params["cg_maxiter"], |
|
tol=self.params["cg_tol"], |
|
) |
|
xhat = self.grid.slice(yhat) |
|
return xhat |
|
|
|
|
|
def bilateral_solver_output( |
|
img_pth, |
|
target, |
|
img=None, |
|
sigma_spatial=24, |
|
sigma_luma=4, |
|
sigma_chroma=4, |
|
get_all_cc=False |
|
): |
|
if img is None: |
|
reference = np.array(Image.open(img_pth).convert("RGB")) |
|
else: |
|
reference = np.array(img) |
|
|
|
h, w = target.shape |
|
confidence = np.ones((h, w)) * 0.999 |
|
|
|
grid_params = { |
|
"sigma_luma": sigma_luma, |
|
"sigma_chroma": sigma_chroma, |
|
"sigma_spatial": sigma_spatial, |
|
} |
|
|
|
bs_params = { |
|
"lam": 256, |
|
"A_diag_min": 1e-5, |
|
"cg_tol": 1e-5, |
|
"cg_maxiter": 25, |
|
} |
|
|
|
grid = BilateralGrid(reference, **grid_params) |
|
|
|
t = target.reshape(-1, 1).astype(np.double) |
|
c = confidence.reshape(-1, 1).astype(np.double) |
|
|
|
|
|
output_solver = BilateralSolver(grid, bs_params).solve(t, c).reshape((h, w)) |
|
|
|
binary_solver = ndimage.binary_fill_holes(output_solver > 0.5) |
|
labeled, nr_objects = ndimage.label(binary_solver) |
|
|
|
nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)] |
|
pixel_order = np.argsort(nb_pixel) |
|
|
|
if get_all_cc: |
|
|
|
pixel_descending_order = pixel_order[::-1] |
|
|
|
binary_solver = (labeled[None,:,:] == pixel_descending_order[1:,None,None]).astype(int).sum(0) |
|
else: |
|
try: |
|
binary_solver = labeled == pixel_order[-2] |
|
except: |
|
binary_solver = np.ones((h, w), dtype=bool) |
|
|
|
return output_solver, binary_solver |