#! /usr/bin/env python3 # -*- coding: utf-8 -*- # File : patch_match.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 01/09/2020 # # Distributed under terms of the MIT license. import ctypes import os.path as osp from typing import Optional, Union import numpy as np from PIL import Image import os if os.name!="nt": # Otherwise, fall back to the subprocess. import subprocess print('Compiling and loading c extensions from "{}".'.format(osp.realpath(osp.dirname(__file__)))) # subprocess.check_call(['./travis.sh'], cwd=osp.dirname(__file__)) subprocess.check_call("make clean && make", cwd=osp.dirname(__file__), shell=True) __all__ = ['set_random_seed', 'set_verbose', 'inpaint', 'inpaint_regularity'] class CShapeT(ctypes.Structure): _fields_ = [ ('width', ctypes.c_int), ('height', ctypes.c_int), ('channels', ctypes.c_int), ] class CMatT(ctypes.Structure): _fields_ = [ ('data_ptr', ctypes.c_void_p), ('shape', CShapeT), ('dtype', ctypes.c_int) ] import tempfile from urllib.request import urlopen, Request import shutil from pathlib import Path from tqdm import tqdm def download_url_to_file(url, dst, hash_prefix=None, progress=True): r"""Download object at the given URL to a local path. Args: url (string): URL of the object to download dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file`` hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. Default: None progress (bool, optional): whether or not to display a progress bar to stderr Default: True https://pytorch.org/docs/stable/_modules/torch/hub.html#load_state_dict_from_url """ file_size = None req = Request(url) u = urlopen(req) meta = u.info() if hasattr(meta, 'getheaders'): content_length = meta.getheaders("Content-Length") else: content_length = meta.get_all("Content-Length") if content_length is not None and len(content_length) > 0: file_size = int(content_length[0]) # We deliberately save it in a temp file and move it after # download is complete. This prevents a local working checkpoint # being overridden by a broken download. dst = os.path.expanduser(dst) dst_dir = os.path.dirname(dst) f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) try: with tqdm(total=file_size, disable=not progress, unit='B', unit_scale=True, unit_divisor=1024) as pbar: while True: buffer = u.read(8192) if len(buffer) == 0: break f.write(buffer) pbar.update(len(buffer)) f.close() shutil.move(f.name, dst) finally: f.close() if os.path.exists(f.name): os.remove(f.name) if os.name!="nt": PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.so')) else: if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')): download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll",dst=osp.join(osp.dirname(__file__), 'libpatchmatch.dll')) if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')): download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll",dst=osp.join(osp.dirname(__file__), 'opencv_world460.dll')) if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')): print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll and put it into the PyPatchMatch folder") if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')): print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll and put it into the PyPatchMatch folder") PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')) PMLIB.PM_set_random_seed.argtypes = [ctypes.c_uint] PMLIB.PM_set_verbose.argtypes = [ctypes.c_int] PMLIB.PM_free_pymat.argtypes = [CMatT] PMLIB.PM_inpaint.argtypes = [CMatT, CMatT, ctypes.c_int] PMLIB.PM_inpaint.restype = CMatT PMLIB.PM_inpaint_regularity.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float] PMLIB.PM_inpaint_regularity.restype = CMatT PMLIB.PM_inpaint2.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int] PMLIB.PM_inpaint2.restype = CMatT PMLIB.PM_inpaint2_regularity.argtypes = [CMatT, CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float] PMLIB.PM_inpaint2_regularity.restype = CMatT def set_random_seed(seed: int): PMLIB.PM_set_random_seed(ctypes.c_uint(seed)) def set_verbose(verbose: bool): PMLIB.PM_set_verbose(ctypes.c_int(verbose)) def inpaint( image: Union[np.ndarray, Image.Image], mask: Optional[Union[np.ndarray, Image.Image]] = None, *, global_mask: Optional[Union[np.ndarray, Image.Image]] = None, patch_size: int = 15 ) -> np.ndarray: """ PatchMatch based inpainting proposed in: PatchMatch : A Randomized Correspondence Algorithm for Structural Image Editing C.Barnes, E.Shechtman, A.Finkelstein and Dan B.Goldman SIGGRAPH 2009 Args: image (Union[np.ndarray, Image.Image]): the input image, should be 3-channel RGB/BGR. mask (Union[np.array, Image.Image], optional): the mask of the hole(s) to be filled, should be 1-channel. If not provided (None), the algorithm will treat all purely white pixels as the holes (255, 255, 255). global_mask (Union[np.array, Image.Image], optional): the target mask of the output image. patch_size (int): the patch size for the inpainting algorithm. Return: result (np.ndarray): the repaired image, of the same size as the input image. """ if isinstance(image, Image.Image): image = np.array(image) image = np.ascontiguousarray(image) assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8' if mask is None: mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8') mask = np.ascontiguousarray(mask) else: mask = _canonize_mask_array(mask) if global_mask is None: ret_pymat = PMLIB.PM_inpaint(np_to_pymat(image), np_to_pymat(mask), ctypes.c_int(patch_size)) else: global_mask = _canonize_mask_array(global_mask) ret_pymat = PMLIB.PM_inpaint2(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), ctypes.c_int(patch_size)) ret_npmat = pymat_to_np(ret_pymat) PMLIB.PM_free_pymat(ret_pymat) return ret_npmat def inpaint_regularity( image: Union[np.ndarray, Image.Image], mask: Optional[Union[np.ndarray, Image.Image]], ijmap: np.ndarray, *, global_mask: Optional[Union[np.ndarray, Image.Image]] = None, patch_size: int = 15, guide_weight: float = 0.25 ) -> np.ndarray: if isinstance(image, Image.Image): image = np.array(image) image = np.ascontiguousarray(image) assert isinstance(ijmap, np.ndarray) and ijmap.ndim == 3 and ijmap.shape[2] == 3 and ijmap.dtype == 'float32' ijmap = np.ascontiguousarray(ijmap) assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8' if mask is None: mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8') mask = np.ascontiguousarray(mask) else: mask = _canonize_mask_array(mask) if global_mask is None: ret_pymat = PMLIB.PM_inpaint_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight)) else: global_mask = _canonize_mask_array(global_mask) ret_pymat = PMLIB.PM_inpaint2_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight)) ret_npmat = pymat_to_np(ret_pymat) PMLIB.PM_free_pymat(ret_pymat) return ret_npmat def _canonize_mask_array(mask): if isinstance(mask, Image.Image): mask = np.array(mask) if mask.ndim == 2 and mask.dtype == 'uint8': mask = mask[..., np.newaxis] assert mask.ndim == 3 and mask.shape[2] == 1 and mask.dtype == 'uint8' return np.ascontiguousarray(mask) dtype_pymat_to_ctypes = [ ctypes.c_uint8, ctypes.c_int8, ctypes.c_uint16, ctypes.c_int16, ctypes.c_int32, ctypes.c_float, ctypes.c_double, ] dtype_np_to_pymat = { 'uint8': 0, 'int8': 1, 'uint16': 2, 'int16': 3, 'int32': 4, 'float32': 5, 'float64': 6, } def np_to_pymat(npmat): assert npmat.ndim == 3 return CMatT( ctypes.cast(npmat.ctypes.data, ctypes.c_void_p), CShapeT(npmat.shape[1], npmat.shape[0], npmat.shape[2]), dtype_np_to_pymat[str(npmat.dtype)] ) def pymat_to_np(pymat): npmat = np.ctypeslib.as_array( ctypes.cast(pymat.data_ptr, ctypes.POINTER(dtype_pymat_to_ctypes[pymat.dtype])), (pymat.shape.height, pymat.shape.width, pymat.shape.channels) ) ret = np.empty(npmat.shape, npmat.dtype) ret[:] = npmat return ret