Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import pickle | |
from time import perf_counter | |
import tempfile | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
from diffusers import AutoPipelineForInpainting, AutoencoderTiny, LCMScheduler | |
from utils.drag import bi_warp | |
__all__ = [ | |
'clear_all', 'resize', | |
'visualize_user_drag', 'preview_out_image', 'inpaint', | |
'add_point', 'undo_point', 'clear_point', | |
] | |
# Global variables for lazy loading | |
pipe = None | |
# UI functions | |
def clear_all(length): | |
"""Reset UI by clearing all input images and parameters.""" | |
return (gr.Image(value=None, height=length, width=length),) * 3 + ([], 5, None) | |
def resize(canvas, gen_length, canvas_length): | |
"""Resize canvas while maintaining aspect ratio.""" | |
if not canvas: | |
return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3 | |
result = process_canvas(canvas) | |
if result[0] is None: # Check if image is None | |
return (gr.Image(value=None, width=canvas_length, height=canvas_length),) * 3 | |
image = result[0] | |
aspect_ratio = image.shape[1] / image.shape[0] | |
is_landscape = aspect_ratio >= 1 | |
new_dims = ( | |
(gen_length, round(gen_length / aspect_ratio / 8) * 8) if is_landscape | |
else (round(gen_length * aspect_ratio / 8) * 8, gen_length) | |
) | |
canvas_dims = ( | |
(canvas_length, round(canvas_length / aspect_ratio)) if is_landscape | |
else (round(canvas_length * aspect_ratio), canvas_length) | |
) | |
return (gr.Image(value=cv2.resize(image, new_dims), width=canvas_dims[0], height=canvas_dims[1]),) * 3 | |
def process_canvas(canvas): | |
"""Extracts the image (H, W, 3) and the mask (H, W) from a Gradio canvas object.""" | |
# Handle None canvas | |
if canvas is None: | |
return None, None | |
# Handle new ImageEditor format | |
if isinstance(canvas, dict): | |
if 'background' in canvas and 'layers' in canvas: | |
# New ImageEditor format | |
if canvas["background"] is None: | |
return None, None | |
image = canvas["background"].copy() | |
# Ensure image is 3-channel RGB | |
if len(image.shape) == 3 and image.shape[2] == 4: | |
image = image[:, :, :3] # Remove alpha channel | |
elif len(image.shape) == 2: | |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
# Try to extract mask from layers | |
mask = np.zeros(image.shape[:2], dtype=np.uint8) | |
if canvas["layers"]: | |
for layer in canvas["layers"]: | |
if isinstance(layer, np.ndarray) and len(layer.shape) >= 2: | |
layer_mask = np.uint8(layer[:, :, 0] > 0) if len(layer.shape) == 3 else np.uint8(layer > 0) | |
mask = np.logical_or(mask, layer_mask).astype(np.uint8) | |
elif 'image' in canvas and 'mask' in canvas: | |
# Old format | |
if canvas["image"] is None: | |
return None, None | |
image = canvas["image"].copy() | |
# Ensure image is 3-channel RGB | |
if len(image.shape) == 3 and image.shape[2] == 4: | |
image = image[:, :, :3] # Remove alpha channel | |
elif len(image.shape) == 2: | |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
mask = np.uint8(canvas["mask"][:, :, 0] > 0).copy() if canvas["mask"] is not None else np.zeros(image.shape[:2], dtype=np.uint8) | |
else: | |
# Fallback | |
return None, None | |
else: | |
# Direct numpy array | |
if canvas is None: | |
return None, None | |
image = canvas.copy() if isinstance(canvas, np.ndarray) else np.array(canvas) | |
# Ensure image is 3-channel RGB | |
if len(image.shape) == 3 and image.shape[2] == 4: | |
image = image[:, :, :3] # Remove alpha channel | |
elif len(image.shape) == 2: | |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
mask = np.zeros(image.shape[:2], dtype=np.uint8) | |
return image, mask | |
# Point manipulation functions | |
def add_point(canvas, points, inpaint_ks, evt: gr.SelectData): | |
"""Add selected point to points list and update image.""" | |
if canvas is None: | |
return None | |
points.append(evt.index) | |
return visualize_user_drag(canvas, points) | |
def undo_point(canvas, points, inpaint_ks): | |
"""Remove last point and update image.""" | |
if canvas is None: | |
return None | |
if len(points) > 0: | |
points.pop() | |
return visualize_user_drag(canvas, points) | |
def clear_point(canvas, points, inpaint_ks): | |
"""Clear all points and update image.""" | |
if canvas is None: | |
return None | |
points.clear() | |
return visualize_user_drag(canvas, points) | |
# Visualization tools | |
def visualize_user_drag(canvas, points): | |
"""Visualize control points and motion vectors on the input image.""" | |
if canvas is None: | |
return None | |
result = process_canvas(canvas) | |
if result[0] is None: # Check if image is None | |
return None | |
image, mask = result | |
# Ensure image is uint8 and 3-channel | |
if image.dtype != np.uint8: | |
image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) | |
if len(image.shape) != 3 or image.shape[2] != 3: | |
return None | |
# Apply colored mask overlay | |
result_img = image.copy() | |
if np.any(mask == 1): | |
result_img[mask == 1] = [255, 0, 0] # Red color | |
image = cv2.addWeighted(result_img, 0.3, image, 0.7, 0) | |
# Draw mask outline | |
if np.any(mask > 0): | |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
cv2.drawContours(image, contours, -1, (255, 255, 255), 2) | |
# Draw control points and motion vectors | |
prev_point = None | |
for idx, point in enumerate(points, 1): | |
if idx % 2 == 0: | |
cv2.circle(image, tuple(point), 10, (0, 0, 255), -1) # End point | |
if prev_point is not None: | |
cv2.arrowedLine(image, prev_point, point, (255, 255, 255), 4, tipLength=0.5) | |
else: | |
cv2.circle(image, tuple(point), 10, (255, 0, 0), -1) # Start point | |
prev_point = point | |
return image | |
def preview_out_image(canvas, points, inpaint_ks): | |
"""Preview warped image result and generate inpainting mask.""" | |
if canvas is None: | |
return None, None | |
result = process_canvas(canvas) | |
if result[0] is None: # Check if image is None | |
return None, None | |
image, mask = result | |
# Ensure image is uint8 and 3-channel | |
if image.dtype != np.uint8: | |
image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) | |
if len(image.shape) != 3 or image.shape[2] != 3: | |
return image, None | |
if len(points) < 2: | |
return image, None | |
# ensure H, W divisible by 8 and longer edge 512 | |
shapes_valid = all(s % 8 == 0 for s in mask.shape + image.shape[:2]) | |
size_valid = all(max(x.shape[:2] if len(x.shape) > 2 else x.shape) == 512 for x in (image, mask)) | |
if not (shapes_valid and size_valid): | |
gr.Warning('Click Resize Image Button first.') | |
return image, None | |
try: | |
handle_pts, target_pts, inpaint_mask = bi_warp(mask, points, inpaint_ks) | |
image[target_pts[:, 1], target_pts[:, 0]] = image[handle_pts[:, 1], handle_pts[:, 0]] | |
# Add grid pattern to highlight inpainting regions | |
background = np.ones_like(mask) * 255 | |
background[::10] = background[:, ::10] = 0 | |
image = np.where(inpaint_mask[..., np.newaxis]==1, background[..., np.newaxis], image) | |
return image, (inpaint_mask * 255).astype(np.uint8) | |
except Exception as e: | |
gr.Warning(f"Preview failed: {str(e)}") | |
return image, None | |
# Inpaint tools | |
def setup_pipeline(device='cuda', model_version='v1-5'): | |
"""Initialize optimized inpainting pipeline with specified model configuration.""" | |
MODEL_CONFIGS = { | |
'v1-5': ('runwayml/stable-diffusion-inpainting', 'latent-consistency/lcm-lora-sdv1-5', 'madebyollin/taesd'), | |
'xl': ('diffusers/stable-diffusion-xl-1.0-inpainting-0.1', 'latent-consistency/lcm-lora-sdxl', 'madebyollin/taesdxl') | |
} | |
model_id, lora_id, vae_id = MODEL_CONFIGS[model_version] | |
# Check if CUDA is available, fallback to CPU | |
if not torch.cuda.is_available(): | |
device = 'cpu' | |
torch_dtype = torch.float32 | |
variant = None | |
else: | |
torch_dtype = torch.float16 | |
variant = "fp16" | |
gr.Info('Loading inpainting pipeline...') | |
pipe = AutoPipelineForInpainting.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
variant=variant, | |
safety_checker=None | |
) | |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) | |
pipe.load_lora_weights(lora_id) | |
pipe.fuse_lora() | |
pipe.vae = AutoencoderTiny.from_pretrained(vae_id, torch_dtype=torch_dtype) | |
pipe = pipe.to(device) | |
# Pre-compute prompt embeddings during setup | |
if model_version == 'v1-5': | |
pipe.cached_prompt_embeds = pipe.encode_prompt( | |
'', device=device, num_images_per_prompt=1, | |
do_classifier_free_guidance=False)[0] | |
else: | |
pipe.cached_prompt_embeds, pipe.cached_pooled_prompt_embeds = pipe.encode_prompt( | |
'', device=device, num_images_per_prompt=1, | |
do_classifier_free_guidance=False)[0::2] | |
return pipe | |
def get_pipeline(): | |
"""Lazy load pipeline only when needed.""" | |
global pipe | |
if pipe is None: | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
pipe = setup_pipeline(device=device, model_version='v1-5') | |
if device == 'cuda': | |
pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cuda', 1, False)[0] | |
else: | |
pipe.cached_prompt_embeds = pipe.encode_prompt('', 'cpu', 1, False)[0] | |
return pipe | |
def inpaint(image, inpaint_mask): | |
"""Perform efficient inpainting on masked regions using Stable Diffusion.""" | |
if image is None: | |
return None | |
if inpaint_mask is None: | |
return image | |
start = perf_counter() | |
# Get pipeline (lazy loading) | |
pipe = get_pipeline() | |
pipe_id = 'xl' if 'xl' in pipe.config._name_or_path else 'v1-5' | |
inpaint_strength = 0.99 if pipe_id == 'xl' else 1.0 | |
# Convert inputs to PIL | |
image_pil = Image.fromarray(image) | |
inpaint_mask_pil = Image.fromarray(inpaint_mask) | |
width, height = inpaint_mask_pil.size | |
if width % 8 != 0 or height % 8 != 0: | |
width, height = round(width / 8) * 8, round(height / 8) * 8 | |
image_pil = image_pil.resize((width, height)) | |
image = np.array(image_pil) | |
inpaint_mask_pil = inpaint_mask_pil.resize((width, height), Image.NEAREST) | |
inpaint_mask = np.array(inpaint_mask_pil) | |
# Common pipeline parameters | |
common_params = { | |
'image': image_pil, | |
'mask_image': inpaint_mask_pil, | |
'height': height, | |
'width': width, | |
'guidance_scale': 1.0, | |
'num_inference_steps': 8, | |
'strength': inpaint_strength, | |
'output_type': 'np' | |
} | |
# Run pipeline | |
try: | |
if pipe_id == 'v1-5': | |
inpainted = pipe( | |
prompt_embeds=pipe.cached_prompt_embeds, | |
**common_params | |
).images[0] | |
else: | |
inpainted = pipe( | |
prompt_embeds=pipe.cached_prompt_embeds, | |
pooled_prompt_embeds=pipe.cached_pooled_prompt_embeds, | |
**common_params | |
).images[0] | |
except Exception as e: | |
gr.Warning(f"Inpainting failed: {str(e)}") | |
return image | |
# Post-process results | |
inpaint_mask = (inpaint_mask[..., np.newaxis] / 255).astype(np.uint8) | |
return (inpainted * 255).astype(np.uint8) * inpaint_mask + image * (1 - inpaint_mask) |