import gradio as gr import sys import torch from omegaconf import OmegaConf from PIL import Image from diffusers import StableDiffusionInpaintPipeline from model.clip_away import CLIPAway import cv2 import numpy as np import argparse # Parse command line arguments parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="config/inference_config.yaml", help="Path to the config file") parser.add_argument("--share", action="store_true", help="Share the interface if provided") args = parser.parse_args() # Load configuration and models config = OmegaConf.load(args.config) sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", safety_checker=None, torch_dtype=torch.float32 ) clipaway = CLIPAway( sd_pipe=sd_pipeline, image_encoder_path=config.image_encoder_path, ip_ckpt=config.ip_adapter_ckpt_path, alpha_clip_path=config.alpha_clip_ckpt_pth, config=config, alpha_clip_id=config.alpha_clip_id, device=config.device, num_tokens=4 ) def dilate_mask(mask, kernel_size=5, iterations=5): mask = mask.convert("L") kernel = np.ones((kernel_size, kernel_size), np.uint8) mask = cv2.dilate(np.array(mask), kernel, iterations=iterations) return Image.fromarray(mask) def combine_masks(uploaded_mask, sketched_mask): if uploaded_mask is not None: return uploaded_mask elif sketched_mask is not None: return sketched_mask else: raise ValueError("Please provide a mask") def remove_obj(image, uploaded_mask, seed): image_pil, sketched_mask = image["image"], image["mask"] mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask)) seed = int(seed) latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to("cuda") final_image = clipaway.generate( prompt=[""], scale=1, seed=seed, pil_image=[image_pil], alpha=[mask], strength=1, latents=latents )[0] return final_image # Define example data examples = [ ["assets/gradio_examples/images/1.jpg", "assets/gradio_examples/masks/1.png", 42], ["assets/gradio_examples/images/2.jpg", "assets/gradio_examples/masks/2.png", 42], ["assets/gradio_examples/images/3.jpg", "assets/gradio_examples/masks/3.png", 464], ["assets/gradio_examples/images/4.jpg", "assets/gradio_examples/masks/4.png", 2024], ] # Define the Gradio interface with gr.Blocks() as demo: gr.Markdown("