sphinxsolution's picture
updated app.py
a5b6fbf verified
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import gradio as gr
from PIL import Image
import numpy as np
# Ensure we use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model():
"""Load ControlNet Inpainting model."""
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_inpaint",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=controlnet,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
return pipe
pipe = load_model()
def inpaint(image_editor_output, prompt):
"""Perform inpainting using the uploaded image and mask."""
print(f"image_editor_output:{image_editor_output}")
if not isinstance(image_editor_output, dict) or "background" not in image_editor_output or "layers" not in image_editor_output:
raise ValueError("Invalid input. Please upload an image and paint a mask.")
# Extract the base image (background or composite)
image = Image.fromarray(image_editor_output["composite"]).convert("RGB") # Ensure image is in RGB format
# Extract the first layer (mask) and convert to grayscale
mask_array = np.array(image_editor_output["layers"][0], dtype=np.uint8)
mask = Image.fromarray(mask_array[:, :, 3]).convert("L") # Use the alpha channel
# Resize mask to match image size if needed
if mask.size != image.size:
mask = mask.resize(image.size, Image.LANCZOS)
# Perform inpainting
result = pipe(prompt=prompt, image=image, control_image=mask).images[0]
return result
# Gradio UI with ImageEditor for inpainting
demo = gr.Interface(
fn=inpaint,
inputs=[
gr.ImageEditor(label="Upload & Paint Image (White areas will be inpainted)"),
gr.Textbox(label="Prompt for Inpainting"),
],
outputs=gr.Image(type="pil", label="Generated Image"),
title="ControlNet Inpainting",
description="Upload an image, paint directly on it to define mask areas, and provide a text prompt.",
)
demo.launch()