Inpaint4Drag / utils /ui_utils.py
LuJingyi-John
Add title and project link to interface
6fce8cc
import spaces
import os
import pickle
from time import perf_counter
import tempfile
import cv2
import gradio as gr
import numpy as np
import torch
from PIL import Image
from diffusers import AutoPipelineForInpainting, AutoencoderTiny, LCMScheduler
from utils.drag import bi_warp
__all__ = [
'clear_all', 'resize',
'visualize_user_drag', 'preview_out_image', 'inpaint',
'add_point', 'undo_point', 'clear_point',
]
# Global variables for lazy loading
pipe = None
# UI functions
def clear_all(length):
"""Reset UI by clearing all input images and parameters."""
return (gr.Image(value=None, height=length, width=length),) * 3 + ([], 5, None)
def resize(canvas, gen_length, canvas_length):
"""Resize canvas while maintaining aspect ratio."""
if not canvas:
return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3
result = process_canvas(canvas)
if result[0] is None: # Check if image is None
return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3
image = result[0]
aspect_ratio = image.shape[1] / image.shape[0]
is_landscape = aspect_ratio >= 1
new_dims = (
(gen_length, round(gen_length / aspect_ratio / 8) * 8) if is_landscape
else (round(gen_length * aspect_ratio / 8) * 8, gen_length)
)
canvas_dims = (
(canvas_length, round(canvas_length / aspect_ratio)) if is_landscape
else (round(canvas_length * aspect_ratio), canvas_length)
)
return (gr.Image(value=cv2.resize(image, new_dims), width=canvas_dims[0], height=canvas_dims[1]),) * 3
def process_canvas(canvas):
"""Extracts the image (H, W, 3) and the mask (H, W) from a Gradio canvas object."""
# Handle None canvas
if canvas is None:
return None, None
# Handle new ImageEditor format
if isinstance(canvas, dict):
if 'background' in canvas and 'layers' in canvas:
# New ImageEditor format
if canvas["background"] is None:
return None, None
image = canvas["background"].copy()
# Ensure image is 3-channel RGB
if len(image.shape) == 3 and image.shape[2] == 4:
image = image[:, :, :3] # Remove alpha channel
elif len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# Try to extract mask from layers
mask = np.zeros(image.shape[:2], dtype=np.uint8)
if canvas["layers"]:
for layer in canvas["layers"]:
if isinstance(layer, np.ndarray) and len(layer.shape) >= 2:
layer_mask = np.uint8(layer[:, :, 0] > 0) if len(layer.shape) == 3 else np.uint8(layer > 0)
mask = np.logical_or(mask, layer_mask).astype(np.uint8)
elif 'image' in canvas and 'mask' in canvas:
# Old format
if canvas["image"] is None:
return None, None
image = canvas["image"].copy()
# Ensure image is 3-channel RGB
if len(image.shape) == 3 and image.shape[2] == 4:
image = image[:, :, :3] # Remove alpha channel
elif len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
mask = np.uint8(canvas["mask"][:, :, 0] > 0).copy() if canvas["mask"] is not None else np.zeros(image.shape[:2], dtype=np.uint8)
else:
# Fallback
return None, None
else:
# Direct numpy array
if canvas is None:
return None, None
image = canvas.copy() if isinstance(canvas, np.ndarray) else np.array(canvas)
# Ensure image is 3-channel RGB
if len(image.shape) == 3 and image.shape[2] == 4:
image = image[:, :, :3] # Remove alpha channel
elif len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
mask = np.zeros(image.shape[:2], dtype=np.uint8)
return image, mask
# Point manipulation functions
def add_point(canvas, points, inpaint_ks, evt: gr.SelectData):
"""Add selected point to points list and update image."""
if canvas is None:
return None
points.append(evt.index)
return visualize_user_drag(canvas, points)
def undo_point(canvas, points, inpaint_ks):
"""Remove last point and update image."""
if canvas is None:
return None
if len(points) > 0:
points.pop()
return visualize_user_drag(canvas, points)
def clear_point(canvas, points, inpaint_ks):
"""Clear all points and update image."""
if canvas is None:
return None
points.clear()
return visualize_user_drag(canvas, points)
# Visualization tools
def visualize_user_drag(canvas, points):
"""Visualize control points and motion vectors on the input image."""
if canvas is None:
return None
result = process_canvas(canvas)
if result[0] is None: # Check if image is None
return None
image, mask = result
# Ensure image is uint8 and 3-channel
if image.dtype != np.uint8:
image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
if len(image.shape) != 3 or image.shape[2] != 3:
return None
# Apply colored mask overlay
result_img = image.copy()
if np.any(mask == 1):
result_img[mask == 1] = [255, 0, 0] # Red color
image = cv2.addWeighted(result_img, 0.3, image, 0.7, 0)
# Draw mask outline
if np.any(mask > 0):
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(image, contours, -1, (255, 255, 255), 2)
# Draw control points and motion vectors
prev_point = None
for idx, point in enumerate(points, 1):
if idx % 2 == 0:
cv2.circle(image, tuple(point), 10, (0, 0, 255), -1) # End point
if prev_point is not None:
cv2.arrowedLine(image, prev_point, point, (255, 255, 255), 4, tipLength=0.5)
else:
cv2.circle(image, tuple(point), 10, (255, 0, 0), -1) # Start point
prev_point = point
return image
def preview_out_image(canvas, points, inpaint_ks):
"""Preview warped image result and generate inpainting mask."""
if canvas is None:
return None, None
result = process_canvas(canvas)
if result[0] is None: # Check if image is None
return None, None
image, mask = result
# Ensure image is uint8 and 3-channel
if image.dtype != np.uint8:
image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
if len(image.shape) != 3 or image.shape[2] != 3:
return image, None
if len(points) < 2:
return image, None
# ensure H, W divisible by 8 and longer edge 512
shapes_valid = all(s % 8 == 0 for s in mask.shape + image.shape[:2])
size_valid = all(max(x.shape[:2] if len(x.shape) > 2 else x.shape) == 512 for x in (image, mask))
if not (shapes_valid and size_valid):
gr.Warning('Click Resize Image Button first.')
return image, None
try:
handle_pts, target_pts, inpaint_mask = bi_warp(mask, points, inpaint_ks)
image[target_pts[:, 1], target_pts[:, 0]] = image[handle_pts[:, 1], handle_pts[:, 0]]
# Add grid pattern to highlight inpainting regions
background = np.ones_like(mask) * 255
background[::10] = background[:, ::10] = 0
image = np.where(inpaint_mask[..., np.newaxis]==1, background[..., np.newaxis], image)
return image, (inpaint_mask * 255).astype(np.uint8)
except Exception as e:
gr.Warning(f"Preview failed: {str(e)}")
return image, None
# Inpaint tools
@spaces.GPU
def setup_pipeline(device='cuda', model_version='v1-5'):
"""Initialize optimized inpainting pipeline with specified model configuration."""
MODEL_CONFIGS = {
'v1-5': ('runwayml/stable-diffusion-inpainting', 'latent-consistency/lcm-lora-sdv1-5', 'madebyollin/taesd'),
'xl': ('diffusers/stable-diffusion-xl-1.0-inpainting-0.1', 'latent-consistency/lcm-lora-sdxl', 'madebyollin/taesdxl')
}
model_id, lora_id, vae_id = MODEL_CONFIGS[model_version]
# Check if CUDA is available, fallback to CPU
if not torch.cuda.is_available():
device = 'cpu'
torch_dtype = torch.float32
variant = None
else:
torch_dtype = torch.float16
variant = "fp16"
gr.Info('Loading inpainting pipeline...')
pipe = AutoPipelineForInpainting.from_pretrained(
model_id,
torch_dtype=torch_dtype,
variant=variant,
safety_checker=None
)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights(lora_id)
pipe.fuse_lora()
pipe.vae = AutoencoderTiny.from_pretrained(vae_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
# Pre-compute prompt embeddings during setup
if model_version == 'v1-5':
pipe.cached_prompt_embeds = pipe.encode_prompt(
'', device=device, num_images_per_prompt=1,
do_classifier_free_guidance=False)[0]
else:
pipe.cached_prompt_embeds, pipe.cached_pooled_prompt_embeds = pipe.encode_prompt(
'', device=device, num_images_per_prompt=1,
do_classifier_free_guidance=False)[0::2]
return pipe
def get_pipeline():
"""Lazy load pipeline only when needed."""
global pipe
if pipe is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = setup_pipeline(device=device, model_version='v1-5')
if device == 'cuda':
pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cuda', 1, False)[0]
else:
pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cpu', 1, False)[0]
return pipe
@spaces.GPU
def inpaint(image, inpaint_mask):
"""Perform efficient inpainting on masked regions using Stable Diffusion."""
if image is None:
return None
if inpaint_mask is None:
return image
start = perf_counter()
# Get pipeline (lazy loading)
pipe = get_pipeline()
pipe_id = 'xl' if 'xl' in pipe.config._name_or_path else 'v1-5'
inpaint_strength = 0.99 if pipe_id == 'xl' else 1.0
# Convert inputs to PIL
image_pil = Image.fromarray(image)
inpaint_mask_pil = Image.fromarray(inpaint_mask)
width, height = inpaint_mask_pil.size
if width % 8 != 0 or height % 8 != 0:
width, height = round(width / 8) * 8, round(height / 8) * 8
image_pil = image_pil.resize((width, height))
image = np.array(image_pil)
inpaint_mask_pil = inpaint_mask_pil.resize((width, height), Image.NEAREST)
inpaint_mask = np.array(inpaint_mask_pil)
# Common pipeline parameters
common_params = {
'image': image_pil,
'mask_image': inpaint_mask_pil,
'height': height,
'width': width,
'guidance_scale': 1.0,
'num_inference_steps': 8,
'strength': inpaint_strength,
'output_type': 'np'
}
# Run pipeline
try:
if pipe_id == 'v1-5':
inpainted = pipe(
prompt_embeds=pipe.cached_prompt_embeds,
**common_params
).images[0]
else:
inpainted = pipe(
prompt_embeds=pipe.cached_prompt_embeds,
pooled_prompt_embeds=pipe.cached_pooled_prompt_embeds,
**common_params
).images[0]
except Exception as e:
gr.Warning(f"Inpainting failed: {str(e)}")
return image
# Post-process results
inpaint_mask = (inpaint_mask[..., np.newaxis] / 255).astype(np.uint8)
return (inpainted * 255).astype(np.uint8) * inpaint_mask + image * (1 - inpaint_mask)