""" Code adapted from TokenCut: https://github.com/YangtaoWANG95/TokenCut """ import PIL.Image as Image import numpy as np from scipy import ndimage from scipy.sparse import diags, csr_matrix from scipy.sparse.linalg import cg RGB_TO_YUV = np.array( [[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]] ) YUV_TO_RGB = np.array([[1.0, 0.0, 1.402], [1.0, -0.34414, -0.71414], [1.0, 1.772, 0.0]]) YUV_OFFSET = np.array([0, 128.0, 128.0]).reshape(1, 1, -1) MAX_VAL = 255.0 def rgb2yuv(im): return np.tensordot(im, RGB_TO_YUV, ([2], [1])) + YUV_OFFSET def yuv2rgb(im): return np.tensordot(im.astype(float) - YUV_OFFSET, YUV_TO_RGB, ([2], [1])) def get_valid_idx(valid, candidates): """Find which values are present in a list and where they are located""" locs = np.searchsorted(valid, candidates) # Handle edge case where the candidate is larger than all valid values locs = np.clip(locs, 0, len(valid) - 1) # Identify which values are actually present valid_idx = np.flatnonzero(valid[locs] == candidates) locs = locs[valid_idx] return valid_idx, locs class BilateralGrid(object): def __init__(self, im, sigma_spatial=32, sigma_luma=8, sigma_chroma=8): im_yuv = rgb2yuv(im) # Compute 5-dimensional XYLUV bilateral-space coordinates Iy, Ix = np.mgrid[: im.shape[0], : im.shape[1]] x_coords = (Ix / sigma_spatial).astype(int) y_coords = (Iy / sigma_spatial).astype(int) luma_coords = (im_yuv[..., 0] / sigma_luma).astype(int) chroma_coords = (im_yuv[..., 1:] / sigma_chroma).astype(int) coords = np.dstack((x_coords, y_coords, luma_coords, chroma_coords)) coords_flat = coords.reshape(-1, coords.shape[-1]) self.npixels, self.dim = coords_flat.shape # Hacky "hash vector" for coordinates, # Requires all scaled coordinates be < MAX_VAL self.hash_vec = MAX_VAL ** np.arange(self.dim) # Construct S and B matrix self._compute_factorization(coords_flat) def _compute_factorization(self, coords_flat): # Hash each coordinate in grid to a unique value hashed_coords = self._hash_coords(coords_flat) unique_hashes, unique_idx, idx = np.unique( hashed_coords, return_index=True, return_inverse=True ) # Identify unique set of vertices unique_coords = coords_flat[unique_idx] self.nvertices = len(unique_coords) # Construct sparse splat matrix that maps from pixels to vertices self.S = csr_matrix((np.ones(self.npixels), (idx, np.arange(self.npixels)))) # Construct sparse blur matrices. # Note that these represent [1 0 1] blurs, excluding the central element self.blurs = [] for d in range(self.dim): blur = 0.0 for offset in (-1, 1): offset_vec = np.zeros((1, self.dim)) offset_vec[:, d] = offset neighbor_hash = self._hash_coords(unique_coords + offset_vec) valid_coord, idx = get_valid_idx(unique_hashes, neighbor_hash) blur = blur + csr_matrix( (np.ones((len(valid_coord),)), (valid_coord, idx)), shape=(self.nvertices, self.nvertices), ) self.blurs.append(blur) def _hash_coords(self, coord): """Hacky function to turn a coordinate into a unique value""" return np.dot(coord.reshape(-1, self.dim), self.hash_vec) def splat(self, x): return self.S.dot(x) def slice(self, y): return self.S.T.dot(y) def blur(self, x): """Blur a bilateral-space vector with a 1 2 1 kernel in each dimension""" assert x.shape[0] == self.nvertices out = 2 * self.dim * x for blur in self.blurs: out = out + blur.dot(x) return out def filter(self, x): """Apply bilateral filter to an input x""" return self.slice(self.blur(self.splat(x))) / self.slice( self.blur(self.splat(np.ones_like(x))) ) def bistochastize(grid, maxiter=10): """Compute diagonal matrices to bistochastize a bilateral grid""" m = grid.splat(np.ones(grid.npixels)) n = np.ones(grid.nvertices) for i in range(maxiter): n = np.sqrt(n * m / grid.blur(n)) # Correct m to satisfy the assumption of bistochastization regardless # of how many iterations have been run. m = n * grid.blur(n) Dm = diags(m, 0) Dn = diags(n, 0) return Dn, Dm class BilateralSolver(object): def __init__(self, grid, params): self.grid = grid self.params = params self.Dn, self.Dm = bistochastize(grid) def solve(self, x, w): # Check that w is a vector or a nx1 matrix if w.ndim == 2: assert w.shape[1] == 1 elif w.dim == 1: w = w.reshape(w.shape[0], 1) A_smooth = self.Dm - self.Dn.dot(self.grid.blur(self.Dn)) w_splat = self.grid.splat(w) A_data = diags(w_splat[:, 0], 0) A = self.params["lam"] * A_smooth + A_data xw = x * w b = self.grid.splat(xw) # Use simple Jacobi preconditioner A_diag = np.maximum(A.diagonal(), self.params["A_diag_min"]) M = diags(1 / A_diag, 0) # Flat initialization y0 = self.grid.splat(xw) / w_splat yhat = np.empty_like(y0) for d in range(x.shape[-1]): yhat[..., d], info = cg( A, b[..., d], x0=y0[..., d], M=M, maxiter=self.params["cg_maxiter"], tol=self.params["cg_tol"], ) xhat = self.grid.slice(yhat) return xhat def bilateral_solver_output( img_pth, target, img=None, sigma_spatial=24, sigma_luma=4, sigma_chroma=4, get_all_cc=False, ): if img is None: reference = np.array(Image.open(img_pth).convert("RGB")) else: reference = np.array(img) h, w = target.shape confidence = np.ones((h, w)) * 0.999 grid_params = { "sigma_luma": sigma_luma, # Brightness bandwidth "sigma_chroma": sigma_chroma, # Color bandwidth "sigma_spatial": sigma_spatial, # Spatial bandwidth } bs_params = { "lam": 256, # The strength of the smoothness parameter "A_diag_min": 1e-5, # Clamp the diagonal of the A diagonal in the Jacobi preconditioner. "cg_tol": 1e-5, # The tolerance on the convergence in PCG "cg_maxiter": 25, # The number of PCG iterations } grid = BilateralGrid(reference, **grid_params) t = target.reshape(-1, 1).astype(np.double) c = confidence.reshape(-1, 1).astype(np.double) # output solver, which is a soft value output_solver = BilateralSolver(grid, bs_params).solve(t, c).reshape((h, w)) binary_solver = ndimage.binary_fill_holes(output_solver > 0.5) labeled, nr_objects = ndimage.label(binary_solver) nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)] pixel_order = np.argsort(nb_pixel) if get_all_cc: # Remove known bakground pixel_descending_order = pixel_order[::-1] # Get all CC expect biggest one, may consider it as background, try and change here binary_solver = ( (labeled[None, :, :] == pixel_descending_order[1:, None, None]) .astype(int) .sum(0) ) else: try: binary_solver = labeled == pixel_order[-2] except: binary_solver = np.ones((h, w), dtype=bool) return output_solver, binary_solver