Spaces:
Running
Running
import gradio as gr | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import torch | |
from torchvision import transforms | |
from diffusers import AutoencoderKL, LCMScheduler | |
from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline | |
from controlnet import ControlNetModel | |
# Define helper functions | |
def download_image(url): | |
response = requests.get(url) | |
return Image.open(BytesIO(response.content)).convert("RGB") | |
def load_model(): | |
# Load model components | |
controlnet = ControlNetModel().from_pretrained("briaai/DEV-ControlNetInpaintingFast", torch_dtype=torch.float16) | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
pipe = StableDiffusionXLControlNetPipeline.from_pretrained("briaai/BRIA-2.3", controlnet=controlnet.to(dtype=torch.float16), torch_dtype=torch.float16, vae=vae) | |
pipe.to('cuda') | |
return pipe | |
pipe = load_model() | |
# Define the inpainting function | |
def inpaint(image, mask): | |
# Process image and mask | |
image = image.resize((1024, 1024)).convert("RGB") | |
mask = mask.resize((1024, 1024)).convert("L") | |
# Transform to tensor | |
image_transform = transforms.ToTensor() | |
image_tensor = image_transform(image).unsqueeze(0).to('cuda') | |
mask_tensor = image_transform(mask).unsqueeze(0).to('cuda') | |
mask_tensor = (mask_tensor > 0.5).float() # binarize mask | |
# Generate image | |
with torch.no_grad(): | |
result = pipe(prompt="A park bench", init_image=image_tensor, mask_image=mask_tensor, num_inference_steps=50).images[0] | |
return transforms.ToPILImage()(result.squeeze(0)) | |
# Define the interface | |
interface = gr.Interface(fn=inpaint, | |
inputs=[gr.inputs.Image(type="pil", label="Original Image"), gr.inputs.Image(type="pil", label="Mask Image")], | |
outputs=gr.outputs.Image(type="pil", label="Inpainted Image"), | |
title="Stable Diffusion XL ControlNet Inpainting", | |
description="Upload an image and its corresponding mask to inpaint the specified area.") | |
if __name__ == "__main__": | |
interface.launch() |