import base64 import imghdr import io import os import sys from typing import List, Optional, Dict, Tuple from urllib.parse import urlparse import cv2 from PIL import Image, ImageOps, PngImagePlugin import numpy as np import torch from iopaint.const import MPS_UNSUPPORT_MODELS from loguru import logger from torch.hub import download_url_to_file, get_dir import hashlib def md5sum(filename): md5 = hashlib.md5() with open(filename, "rb") as f: for chunk in iter(lambda: f.read(128 * md5.block_size), b""): md5.update(chunk) return md5.hexdigest() def switch_mps_device(model_name, device): if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps": logger.info(f"{model_name} not support mps, switch to cpu") return torch.device("cpu") return device def get_cache_path_by_url(url): parts = urlparse(url) hub_dir = get_dir() model_dir = os.path.join(hub_dir, "checkpoints") if not os.path.isdir(model_dir): os.makedirs(model_dir) filename = os.path.basename(parts.path) cached_file = os.path.join(model_dir, filename) return cached_file def get_cache_path_by_local(url): root_path = os.getcwd() model_path = os.path.join(root_path, 'pretrained-model', 'big-lama.pt') return model_path def download_model(url, model_md5: str = None): cached_file = get_cache_path_by_url(url) # cached_file = get_cache_path_by_local(url) if not os.path.exists(cached_file): sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) hash_prefix = None download_url_to_file(url, cached_file, hash_prefix, progress=True) if model_md5: _md5 = md5sum(cached_file) if model_md5 == _md5: logger.info(f"Download model success, md5: {_md5}") else: try: os.remove(cached_file) logger.error( f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint." f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n" ) except: logger.error( f"Model md5: {_md5}, expected md5: {model_md5}, please delete {cached_file} and restart iopaint." ) exit(-1) return cached_file def ceil_modulo(x, mod): if x % mod == 0: return x return (x // mod + 1) * mod def handle_error(model_path, model_md5, e): _md5 = md5sum(model_path) if _md5 != model_md5: try: os.remove(model_path) logger.error( f"Model md5: {_md5}, expected md5: {model_md5}, wrong model deleted. Please restart iopaint." f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n" ) except: logger.error( f"Model md5: {_md5}, expected md5: {model_md5}, please delete {model_path} and restart iopaint." ) else: logger.error( f"Failed to load model {model_path}," f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}" ) exit(-1) def load_jit_model(url_or_path, device, model_md5: str): if os.path.exists(url_or_path): model_path = url_or_path else: model_path = download_model(url_or_path, model_md5) logger.info(f"Loading model from: {model_path}") try: model = torch.jit.load(model_path, map_location="cpu").to(device) except Exception as e: handle_error(model_path, model_md5, e) model.eval() return model def load_model(model: torch.nn.Module, url_or_path, device, model_md5): if os.path.exists(url_or_path): model_path = url_or_path else: model_path = download_model(url_or_path, model_md5) try: logger.info(f"Loading model from: {model_path}") state_dict = torch.load(model_path, map_location="cpu") model.load_state_dict(state_dict, strict=True) model.to(device) except Exception as e: handle_error(model_path, model_md5, e) model.eval() return model def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: data = cv2.imencode( f".{ext}", image_numpy, [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], )[1] image_bytes = data.tobytes() return image_bytes def pil_to_bytes(pil_img, ext: str, quality: int = 95, infos={}) -> bytes: with io.BytesIO() as output: kwargs = {k: v for k, v in infos.items() if v is not None} if ext == "jpg": ext = "jpeg" if "png" == ext.lower() and "parameters" in kwargs: pnginfo_data = PngImagePlugin.PngInfo() pnginfo_data.add_text("parameters", kwargs["parameters"]) kwargs["pnginfo"] = pnginfo_data pil_img.save(output, format=ext, quality=quality, **kwargs) image_bytes = output.getvalue() return image_bytes def pil_to_bytes_single(pil_img, ext: str, quality: int = 95, infos=None) -> bytes: infos = infos or {} # Use an empty dictionary if infos is None with io.BytesIO() as output: kwargs = {k: v for k, v in infos.items() if v is not None} if ext == "jpg": ext = "jpeg" if "png" == ext.lower() and "parameters" in kwargs: pnginfo_data = PngImagePlugin.PngInfo() pnginfo_data.add_text("parameters", kwargs["parameters"]) kwargs["pnginfo"] = pnginfo_data pil_img.save(output, format=ext, quality=quality, **kwargs) image_bytes = output.getvalue() return image_bytes def load_img(img_bytes, gray: bool = False, return_info: bool = False): alpha_channel = None image = Image.open(io.BytesIO(img_bytes)) if return_info: infos = image.info try: image = ImageOps.exif_transpose(image) except: pass if gray: image = image.convert("L") np_img = np.array(image) else: if image.mode == "RGBA": np_img = np.array(image) alpha_channel = np_img[:, :, -1] np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) else: image = image.convert("RGB") np_img = np.array(image) if return_info: return np_img, alpha_channel, infos return np_img, alpha_channel def norm_img(np_img): if len(np_img.shape) == 2: np_img = np_img[:, :, np.newaxis] np_img = np.transpose(np_img, (2, 0, 1)) np_img = np_img.astype("float32") / 255 return np_img def resize_max_size( np_img, size_limit: int, interpolation=cv2.INTER_CUBIC ) -> np.ndarray: # Resize image's longer size to size_limit if longer size larger than size_limit h, w = np_img.shape[:2] if max(h, w) > size_limit: ratio = size_limit / max(h, w) new_w = int(w * ratio + 0.5) new_h = int(h * ratio + 0.5) return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation) else: return np_img def pad_img_to_modulo( img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None ): """ Args: img: [H, W, C] mod: square: 是否为正方形 min_size: Returns: """ if len(img.shape) == 2: img = img[:, :, np.newaxis] height, width = img.shape[:2] out_height = ceil_modulo(height, mod) out_width = ceil_modulo(width, mod) if min_size is not None: assert min_size % mod == 0 out_width = max(min_size, out_width) out_height = max(min_size, out_height) if square: max_size = max(out_height, out_width) out_height = max_size out_width = max_size return np.pad( img, ((0, out_height - height), (0, out_width - width), (0, 0)), mode="symmetric", ) def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]: """ Args: mask: (h, w, 1) 0~255 Returns: """ height, width = mask.shape[:2] _, thresh = cv2.threshold(mask, 127, 255, 0) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) boxes = [] for cnt in contours: x, y, w, h = cv2.boundingRect(cnt) box = np.array([x, y, x + w, y + h]).astype(int) box[::2] = np.clip(box[::2], 0, width) box[1::2] = np.clip(box[1::2], 0, height) boxes.append(box) return boxes def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]: """ Args: mask: (h, w) 0~255 Returns: """ _, thresh = cv2.threshold(mask, 127, 255, 0) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) max_area = 0 max_index = -1 for i, cnt in enumerate(contours): area = cv2.contourArea(cnt) if area > max_area: max_area = area max_index = i if max_index != -1: new_mask = np.zeros_like(mask) return cv2.drawContours(new_mask, contours, max_index, 255, -1) else: return mask def is_mac(): return sys.platform == "darwin" def get_image_ext(img_bytes): w = imghdr.what("", img_bytes) if w is None: w = "jpeg" return w def decode_base64_to_image( encoding: str, gray=False ) -> Tuple[np.array, Optional[np.array], Dict]: if encoding.startswith("data:image/") or encoding.startswith( "data:application/octet-stream;base64," ): encoding = encoding.split(";")[1].split(",")[1] image = Image.open(io.BytesIO(base64.b64decode(encoding))) alpha_channel = None try: image = ImageOps.exif_transpose(image) except: pass # exif_transpose will remove exif rotate info,we must call image.info after exif_transpose infos = image.info if gray: image = image.convert("L") np_img = np.array(image) else: if image.mode == "RGBA": np_img = np.array(image) alpha_channel = np_img[:, :, -1] np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) else: image = image.convert("RGB") np_img = np.array(image) return np_img, alpha_channel, infos def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes: img_bytes = pil_to_bytes( image, "png", quality=quality, infos=infos, ) return base64.b64encode(img_bytes) def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray: if alpha_channel is not None: if alpha_channel.shape[:2] != rgb_np_img.shape[:2]: alpha_channel = cv2.resize( alpha_channel, dsize=(rgb_np_img.shape[1], rgb_np_img.shape[0]) ) rgb_np_img = np.concatenate( (rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1 ) return rgb_np_img def adjust_mask(mask: np.ndarray, kernel_size: int, operate): # fronted brush color "ffcc00bb" # kernel_size = kernel_size*2+1 mask[mask >= 127] = 255 mask[mask < 127] = 0 if operate == "reverse": mask = 255 - mask else: kernel = cv2.getStructuringElement( cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1) ) if operate == "expand": mask = cv2.dilate( mask, kernel, iterations=1, ) else: mask = cv2.erode( mask, kernel, iterations=1, ) res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8) res_mask[mask > 128] = [255, 203, 0, int(255 * 0.73)] res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA) return res_mask def gen_frontend_mask(bgr_or_gray_mask): if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1: bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY) # fronted brush color "ffcc00bb" # TODO: how to set kernel size? kernel_size = 9 bgr_or_gray_mask = cv2.dilate( bgr_or_gray_mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=1, ) res_mask = np.zeros( (bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8 ) res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)] res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA) return res_mask