Spaces:
Running
Running
from diffusers.utils import load_image, make_image_grid | |
import gradio as gr | |
import cv2 | |
import numpy as np | |
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler | |
import torch | |
from PIL import Image | |
from Unet import UNet | |
from torchvision import transforms | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
human_segment = UNet(n_classes=2, img_channels=3) | |
human_segment.load_state_dict(torch.load("./unet_weights.pth", map_location=device)) | |
human_segment.to(device) | |
human_segment.eval() | |
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", use_safetensors=True) | |
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, use_safetensors=True | |
).to(device) | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
]) | |
def create_mask(img_path, invert=True): | |
"""Generate a binary mask using the custom segmentation model. | |
If invert=True, the mask will target the background instead of the human.""" | |
img = Image.open(img_path).convert("RGB") | |
img_tensor = transform(img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
pred = human_segment(img_tensor)[0] | |
pred_class = torch.argmax(pred, dim=0).cpu().numpy() | |
if invert: | |
mask = (pred_class == 0).astype(np.uint8) * 255 | |
else: | |
mask = (pred_class > 0).astype(np.uint8) * 255 | |
mask = Image.fromarray(mask) | |
return mask | |
def load_and_resize_images(image, target_size=(512, 512)): | |
"""Load and resize images for inpainting.""" | |
if isinstance(image, str): | |
init_image = Image.open(image).convert("RGB").resize(target_size) | |
else: | |
init_image = image.convert("RGB").resize(target_size) | |
return init_image | |
def make_inpaint_condition(image, image_mask): | |
"""Prepare the condition image for inpainting.""" | |
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 | |
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 | |
assert image.shape[:2] == image_mask.shape[:2] | |
image[image_mask > 0.5] = -1.0 | |
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) | |
return torch.from_numpy(image) | |
def generate_inpainting(init_image, mask_image, prompt, negative_prompt=None): | |
"""Generate the inpainted image.""" | |
control_image = make_inpaint_condition(init_image, mask_image) | |
output = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=50, | |
eta=1.0, | |
image=init_image, | |
mask_image=mask_image, | |
control_image=control_image, | |
).images[0] | |
return output | |
def process_with_auto_mask(image_path, prompt, negative_prompt=None, invert_mask=True): | |
"""Process input image with automatic mask generation and inpainting. | |
If invert_mask=True, the background will be inpainted instead of the human.""" | |
mask_image = create_mask(image_path, invert=invert_mask) | |
init_image = load_and_resize_images(image_path) | |
mask_image = mask_image.resize(init_image.size) | |
output_image = generate_inpainting(init_image, mask_image, prompt, negative_prompt) | |
return mask_image, output_image | |
demo = gr.Interface( | |
fn=process_with_auto_mask, | |
inputs=[ | |
gr.Image(type='filepath', label="Original Image"), | |
gr.Textbox(label="Prompt", placeholder="Describe what should replace the masked area..."), | |
gr.Textbox(label="Negative Prompt", placeholder="Elements to avoid in the generated image...", | |
value="low quality, bad anatomy, blurry, pixelated") | |
], | |
outputs=[ | |
gr.Image(label="Generated Mask"), | |
gr.Image(label="Inpainted Result") | |
], | |
title="Automatic Mask & Inpainting Tool", | |
description="Upload an image, and the system will automatically create a mask and perform inpainting based on your prompt.", | |
allow_flagging="never" | |
) | |
demo.launch(share=True) | |