from diffusers.utils import load_image, make_image_grid import gradio as gr import cv2 import numpy as np from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler import torch from PIL import Image from Unet import UNet from torchvision import transforms device = torch.device("cuda" if torch.cuda.is_available() else "cpu") human_segment = UNet(n_classes=2, img_channels=3) human_segment.load_state_dict(torch.load("./unet_weights.pth", map_location=device)) human_segment.to(device) human_segment.eval() controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", use_safetensors=True) pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, use_safetensors=True ).to(device) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) def create_mask(img_path, invert=True): """Generate a binary mask using the custom segmentation model. If invert=True, the mask will target the background instead of the human.""" img = Image.open(img_path).convert("RGB") img_tensor = transform(img).unsqueeze(0).to(device) with torch.no_grad(): pred = human_segment(img_tensor)[0] pred_class = torch.argmax(pred, dim=0).cpu().numpy() if invert: mask = (pred_class == 0).astype(np.uint8) * 255 else: mask = (pred_class > 0).astype(np.uint8) * 255 mask = Image.fromarray(mask) return mask def load_and_resize_images(image, target_size=(512, 512)): """Load and resize images for inpainting.""" if isinstance(image, str): init_image = Image.open(image).convert("RGB").resize(target_size) else: init_image = image.convert("RGB").resize(target_size) return init_image def make_inpaint_condition(image, image_mask): """Prepare the condition image for inpainting.""" image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 assert image.shape[:2] == image_mask.shape[:2] image[image_mask > 0.5] = -1.0 image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) return torch.from_numpy(image) def generate_inpainting(init_image, mask_image, prompt, negative_prompt=None): """Generate the inpainted image.""" control_image = make_inpaint_condition(init_image, mask_image) output = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50, eta=1.0, image=init_image, mask_image=mask_image, control_image=control_image, ).images[0] return output def process_with_auto_mask(image_path, prompt, negative_prompt=None, invert_mask=True): """Process input image with automatic mask generation and inpainting. If invert_mask=True, the background will be inpainted instead of the human.""" mask_image = create_mask(image_path, invert=invert_mask) init_image = load_and_resize_images(image_path) mask_image = mask_image.resize(init_image.size) output_image = generate_inpainting(init_image, mask_image, prompt, negative_prompt) return mask_image, output_image demo = gr.Interface( fn=process_with_auto_mask, inputs=[ gr.Image(type='filepath', label="Original Image"), gr.Textbox(label="Prompt", placeholder="Describe what should replace the masked area..."), gr.Textbox(label="Negative Prompt", placeholder="Elements to avoid in the generated image...", value="low quality, bad anatomy, blurry, pixelated") ], outputs=[ gr.Image(label="Generated Mask"), gr.Image(label="Inpainted Result") ], title="Automatic Mask & Inpainting Tool", description="Upload an image, and the system will automatically create a mask and perform inpainting based on your prompt.", allow_flagging="never" ) demo.launch(share=True)