import json from pathlib import Path from typing import Dict, Optional import cv2 import psutil from PIL import Image from loguru import logger from rich.console import Console from rich.progress import ( Progress, SpinnerColumn, TimeElapsedColumn, MofNCompleteColumn, TextColumn, BarColumn, TaskProgressColumn, ) from iopaint.helper import pil_to_bytes from iopaint.model.utils import torch_gc from iopaint.model_manager import ModelManager from iopaint.schema import InpaintRequest def glob_images(path: Path) -> Dict[str, Path]: # png/jpg/jpeg if path.is_file(): return {path.stem: path} elif path.is_dir(): res = {} for it in path.glob("*.*"): if it.suffix.lower() in [".png", ".jpg", ".jpeg"]: res[it.stem] = it return res def batch_inpaint( model: str, device, image: Path, mask: Path, output: Path, config: Optional[Path] = None, concat: bool = False, ): if image.is_dir() and output.is_file(): logger.error( f"invalid --output: when image is a directory, output should be a directory" ) exit(-1) output.mkdir(parents=True, exist_ok=True) image_paths = glob_images(image) mask_paths = glob_images(mask) if len(image_paths) == 0: logger.error(f"invalid --image: empty image folder") exit(-1) if len(mask_paths) == 0: logger.error(f"invalid --mask: empty mask folder") exit(-1) if config is None: inpaint_request = InpaintRequest() logger.info(f"Using default config: {inpaint_request}") else: with open(config, "r", encoding="utf-8") as f: inpaint_request = InpaintRequest(**json.load(f)) model_manager = ModelManager(name=model, device=device) first_mask = list(mask_paths.values())[0] console = Console() with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), TimeElapsedColumn(), console=console, transient=False, ) as progress: task = progress.add_task("Batch processing...", total=len(image_paths)) for stem, image_p in image_paths.items(): if stem not in mask_paths and mask.is_dir(): progress.log(f"mask for {image_p} not found") progress.update(task, advance=1) continue mask_p = mask_paths.get(stem, first_mask) infos = Image.open(image_p).info img = cv2.imread(str(image_p)) img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) mask_img = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE) if mask_img.shape[:2] != img.shape[:2]: progress.log( f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}" ) mask_img = cv2.resize( mask_img, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST, ) mask_img[mask_img >= 127] = 255 mask_img[mask_img < 127] = 0 # bgr inpaint_result = model_manager(img, mask_img, inpaint_request) inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB) if concat: mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB) inpaint_result = cv2.hconcat([img, mask_img, inpaint_result]) img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos) save_p = output / f"{stem}.png" with open(save_p, "wb") as fw: fw.write(img_bytes) progress.update(task, advance=1) torch_gc() # pid = psutil.Process().pid # memory_info = psutil.Process(pid).memory_info() # memory_in_mb = memory_info.rss / (1024 * 1024) # print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB")