|
import base64 |
|
import io |
|
import json |
|
from pathlib import Path |
|
from typing import Dict, Optional |
|
|
|
import cv2 |
|
import psutil |
|
from PIL import Image |
|
from loguru import logger |
|
from rich.console import Console |
|
from rich.progress import ( |
|
Progress, |
|
SpinnerColumn, |
|
TimeElapsedColumn, |
|
MofNCompleteColumn, |
|
TextColumn, |
|
BarColumn, |
|
TaskProgressColumn, |
|
) |
|
|
|
from iopaint.helper import pil_to_bytes_single |
|
from iopaint.model.utils import torch_gc |
|
from iopaint.model_manager import ModelManager |
|
from iopaint.schema import InpaintRequest |
|
import numpy as np |
|
|
|
|
|
def glob_images(path: Path) -> Dict[str, Path]: |
|
|
|
if path.is_file(): |
|
return {path.stem: path} |
|
elif path.is_dir(): |
|
res = {} |
|
for it in path.glob("*.*"): |
|
if it.suffix.lower() in [".png", ".jpg", ".jpeg"]: |
|
res[it.stem] = it |
|
return res |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch_inpaint( |
|
model: str, |
|
device, |
|
input_base64: str, |
|
mask_base64: str, |
|
config_base64: Optional[str] = None, |
|
concat: bool = False, |
|
): |
|
if config_base64 is None: |
|
inpaint_request = InpaintRequest() |
|
else: |
|
config_json = base64.b64decode(config_base64) |
|
inpaint_request = InpaintRequest(**json.loads(config_json)) |
|
|
|
model_manager = ModelManager(name=model, device=device) |
|
|
|
|
|
input_image_data = base64.b64decode(input_base64) |
|
input_image = cv2.imdecode(np.frombuffer(input_image_data, np.uint8), cv2.IMREAD_COLOR) |
|
|
|
|
|
mask_image_data = base64.b64decode(mask_base64) |
|
mask_image = cv2.imdecode(np.frombuffer(mask_image_data, np.uint8), cv2.IMREAD_GRAYSCALE) |
|
|
|
if mask_image.shape[:2] != input_image.shape[:2]: |
|
mask_image = cv2.resize( |
|
mask_image, |
|
(input_image.shape[1], input_image.shape[0]), |
|
interpolation=cv2.INTER_NEAREST, |
|
) |
|
|
|
mask_image[mask_image >= 127] = 255 |
|
mask_image[mask_image < 127] = 0 |
|
|
|
|
|
inpaint_result = model_manager(input_image, mask_image, inpaint_request) |
|
|
|
if concat: |
|
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_GRAY2RGB) |
|
inpaint_result = cv2.hconcat([input_image, mask_image, inpaint_result]) |
|
|
|
|
|
pil_image = Image.fromarray(inpaint_result) |
|
|
|
|
|
with io.BytesIO() as output_buffer: |
|
pil_image.save(output_buffer, format='PNG') |
|
base64_image = base64.b64encode(output_buffer.getvalue()).decode('utf-8') |
|
|
|
return base64_image |
|
|
|
|
|
def batch_inpaint_cv2( |
|
model: str, |
|
device, |
|
input_base: str, |
|
mask_base: str, |
|
config_base64: Optional[str] = None, |
|
concat: bool = False, |
|
): |
|
if config_base64 is None: |
|
inpaint_request = InpaintRequest() |
|
else: |
|
config_json = base64.b64decode(config_base64) |
|
inpaint_request = InpaintRequest(**json.loads(config_json)) |
|
|
|
model_manager = ModelManager(name=model, device=device) |
|
|
|
|
|
input_image = input_base |
|
|
|
mask_image = mask_base |
|
|
|
if mask_image.shape[:2] != input_image.shape[:2]: |
|
mask_image = cv2.resize( |
|
mask_image, |
|
(input_image.shape[1], input_image.shape[0]), |
|
interpolation=cv2.INTER_NEAREST, |
|
) |
|
|
|
mask_image[mask_image >= 127] = 255 |
|
mask_image[mask_image < 127] = 0 |
|
|
|
|
|
inpaint_result = model_manager(input_image, mask_image, inpaint_request) |
|
|
|
if concat: |
|
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_GRAY2RGB) |
|
inpaint_result = cv2.hconcat([input_image, mask_image, inpaint_result]) |
|
|
|
return inpaint_result |