Edit_background / app.py
sharmaarush's picture
added inverse mask
c2d9880
from diffusers.utils import load_image, make_image_grid
import gradio as gr
import cv2
import numpy as np
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
import torch
from PIL import Image
from Unet import UNet
from torchvision import transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
human_segment = UNet(n_classes=2, img_channels=3)
human_segment.load_state_dict(torch.load("./unet_weights.pth", map_location=device))
human_segment.to(device)
human_segment.eval()
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", use_safetensors=True)
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, use_safetensors=True
).to(device)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
def create_mask(img_path, invert=True):
"""Generate a binary mask using the custom segmentation model.
If invert=True, the mask will target the background instead of the human."""
img = Image.open(img_path).convert("RGB")
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
pred = human_segment(img_tensor)[0]
pred_class = torch.argmax(pred, dim=0).cpu().numpy()
if invert:
mask = (pred_class == 0).astype(np.uint8) * 255
else:
mask = (pred_class > 0).astype(np.uint8) * 255
mask = Image.fromarray(mask)
return mask
def load_and_resize_images(image, target_size=(512, 512)):
"""Load and resize images for inpainting."""
if isinstance(image, str):
init_image = Image.open(image).convert("RGB").resize(target_size)
else:
init_image = image.convert("RGB").resize(target_size)
return init_image
def make_inpaint_condition(image, image_mask):
"""Prepare the condition image for inpainting."""
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
assert image.shape[:2] == image_mask.shape[:2]
image[image_mask > 0.5] = -1.0
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
return torch.from_numpy(image)
def generate_inpainting(init_image, mask_image, prompt, negative_prompt=None):
"""Generate the inpainted image."""
control_image = make_inpaint_condition(init_image, mask_image)
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=50,
eta=1.0,
image=init_image,
mask_image=mask_image,
control_image=control_image,
).images[0]
return output
def process_with_auto_mask(image_path, prompt, negative_prompt=None, invert_mask=True):
"""Process input image with automatic mask generation and inpainting.
If invert_mask=True, the background will be inpainted instead of the human."""
mask_image = create_mask(image_path, invert=invert_mask)
init_image = load_and_resize_images(image_path)
mask_image = mask_image.resize(init_image.size)
output_image = generate_inpainting(init_image, mask_image, prompt, negative_prompt)
return mask_image, output_image
demo = gr.Interface(
fn=process_with_auto_mask,
inputs=[
gr.Image(type='filepath', label="Original Image"),
gr.Textbox(label="Prompt", placeholder="Describe what should replace the masked area..."),
gr.Textbox(label="Negative Prompt", placeholder="Elements to avoid in the generated image...",
value="low quality, bad anatomy, blurry, pixelated")
],
outputs=[
gr.Image(label="Generated Mask"),
gr.Image(label="Inpainted Result")
],
title="Automatic Mask & Inpainting Tool",
description="Upload an image, and the system will automatically create a mask and perform inpainting based on your prompt.",
allow_flagging="never"
)
demo.launch(share=True)