Spaces:
Runtime error
Runtime error
import io | |
import requests | |
import numpy as np | |
import torch | |
from PIL import Image | |
from skimage.measure import block_reduce | |
from typing import List, Optional | |
from functools import reduce | |
import gradio as gr | |
from transformers import DetrFeatureExtractor, DetrForSegmentation, DetrConfig | |
from transformers.models.detr.feature_extraction_detr import rgb_to_id | |
from diffusers import StableDiffusionInpaintPipeline | |
torch.inference_mode() | |
torch.no_grad() | |
# Load segmentation models | |
def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'): | |
feature_extractor = DetrFeatureExtractor.from_pretrained(model_name) | |
model = DetrForSegmentation.from_pretrained(model_name) | |
cfg = DetrConfig.from_pretrained(model_name) | |
return feature_extractor, model, cfg | |
# Load diffusion pipeline | |
def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpainting'): | |
return StableDiffusionInpaintPipeline.from_pretrained( | |
model_name, | |
revision='fp16', | |
torch_dtype=torch.float16 | |
) | |
# Device helper | |
def get_device(try_cuda=True): | |
return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu') | |
def min_pool(x: torch.Tensor, kernel_size: int): | |
pad_size = (kernel_size - 1) // 2 | |
return -torch.nn.functional.max_pool2d(-x, kernel_size, (1, 1), padding=pad_size) | |
def max_pool(x: torch.Tensor, kernel_size: int): | |
pad_size = (kernel_size - 1) // 2 | |
return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size) | |
# Apply min-max pooling to clean up mask | |
def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5): | |
mask = torch.Tensor(mask[None, None]).float() | |
mask = min_pool(mask, min_kernel) | |
mask = max_pool(mask, max_kernel) | |
mask = mask.bool().squeeze().numpy() | |
return mask | |
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models() | |
pipe = load_diffusion_pipeline() | |
device = get_device() | |
pipe = pipe.to(device) | |
# Callback function that runs segmentation and updates CheckboxGroup | |
def fn_segmentation(image, max_kernel, min_kernel): | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
outputs = segmentation_model(**inputs) | |
processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0) | |
result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0] | |
panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height)) | |
panoptic_seg = np.array(panoptic_seg, dtype=np.uint8) | |
panoptic_seg_id = rgb_to_id(panoptic_seg) | |
raw_masks = [] | |
for s in result['segments_info']: | |
m = panoptic_seg_id == s['id'] | |
raw_masks.append(m.astype(np.uint8) * 255) | |
checkbox_choices = [f"{s['id']}:{segmentation_cfg.id2label[s['category_id']]}" for s in result['segments_info']] | |
checkbox_group = gr.CheckboxGroup.update( | |
choices=checkbox_choices | |
) | |
return raw_masks, checkbox_group, gr.Image.update(value=np.zeros((image.height, image.width))), gr.Image.update(value=image) | |
# Callback function that updates the displayed mask based on selected checkboxes | |
def fn_update_mask( | |
image: Image, | |
masks: List[np.array], | |
masks_enabled: List[int], | |
max_kernel: int, | |
min_kernel: int, | |
): | |
masks_enabled = [int(m.split(':')[0]) for m in masks_enabled] | |
combined_mask = reduce(lambda x, y: x | y, [masks[i] for i in masks_enabled], np.zeros_like(masks[0], dtype=bool)) | |
combined_mask = clean_mask(combined_mask, max_kernel, min_kernel) | |
masked_image = np.array(image).copy() | |
masked_image[combined_mask] = 0.0 | |
return combined_mask.astype(np.uint8) * 255, Image.fromarray(masked_image) | |
# Callback function that runs diffusion given the current image, mask and prompt. | |
def fn_diffusion( | |
prompt: str, | |
masked_image: Image, | |
mask: Image, | |
num_diffusion_steps: int, | |
guidance_scale: float, | |
negative_prompt: Optional[str] = None, | |
): | |
if len(negative_prompt) == 0: | |
negative_prompt = None | |
# Resize image to a more stable diffusion friendly format. | |
# TODO: remove magic number | |
STABLE_DIFFUSION_SMALL_EDGE = 512 | |
w, h = masked_image.size | |
is_width_larger = w > h | |
resize_ratio = STABLE_DIFFUSION_SMALL_EDGE / (h if is_width_larger else w) | |
new_width = int(w * resize_ratio) if is_width_larger else STABLE_DIFFUSION_SMALL_EDGE | |
new_height = STABLE_DIFFUSION_SMALL_EDGE if is_width_larger else int(h * resize_ratio) | |
new_width += 8 - (new_width % 8) if is_width_larger else 0 | |
new_height += 0 if is_width_larger else 8 - (new_height % 8) | |
mask = Image.fromarray(mask).convert("RGB").resize((new_width, new_height)) | |
masked_image = masked_image.convert("RGB").resize((new_width, new_height)) | |
# Run diffusion | |
inpainted_image = pipe( | |
height=new_height, | |
width=new_width, | |
prompt=prompt, | |
image=masked_image, | |
mask_image=mask, | |
num_inference_steps=num_diffusion_steps, | |
guidance_scale=guidance_scale, | |
negative_prompt=negative_prompt | |
).images[0] | |
# Resize back to the original size | |
inpainted_image = inpainted_image.resize((w, h)) | |
return inpainted_image | |
demo = gr.Blocks() | |
with demo: | |
# Input image control | |
input_image = gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil', label="Input Image") | |
# Combined mask controls | |
bt_masks = gr.Button("Compute Masks") | |
with gr.Row(): | |
mask_image = gr.Image(type='numpy', label="Diffusion Mask") | |
masked_image = gr.Image(type='pil', label="Masked Image") | |
mask_storage = gr.State() | |
# Mask editing controls | |
with gr.Row(): | |
max_slider = gr.Slider(minimum=1, maximum=99, value=23, step=2, label="Mask Overflow") | |
min_slider = gr.Slider(minimum=1, maximum=99, value=5, step=2, label="Mask Denoising") | |
mask_checkboxes = gr.CheckboxGroup(interactive=True, label="Mask Selection") | |
# Diffusion controls and output | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox("Two ginger cats lying together on a pink sofa. There are two TV remotes. High definition.", label="Prompt") | |
negative_prompt = gr.Textbox(label="Negative Prompt") | |
with gr.Column(): | |
steps_slider = gr.Slider(minimum=1, maximum=100, value=50, label="Inference Steps") | |
guidance_slider = gr.Slider(minimum=0.0, maximum=50.0, value=7.5, step=0.1, label="Guidance Scale") | |
bt_diffusion = gr.Button("Run Diffusion") | |
inpainted_image = gr.Image(type='pil', label="Inpainted Image") | |
update_mask_inputs = [input_image, mask_storage, mask_checkboxes, max_slider, min_slider] | |
update_mask_outputs = [mask_image, masked_image] | |
# Clear checkbox group on input image change | |
input_image.change(lambda: gr.CheckboxGroup.update(choices=[], value=[]), outputs=mask_checkboxes) | |
# Segmentation button callback | |
bt_masks.click(fn_segmentation, inputs=[input_image, max_slider, min_slider], outputs=[mask_storage, mask_checkboxes, mask_image, masked_image]) | |
# Update mask callbacks | |
# TODO: can we replace this with `mask_image.change`? Not sure if it will actively update. | |
max_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs) | |
min_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs) | |
mask_checkboxes.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs) | |
# Diffusion button callback | |
bt_diffusion.click(fn_diffusion, inputs=[ | |
prompt, | |
masked_image, | |
mask_image, | |
steps_slider, | |
guidance_slider, | |
negative_prompt | |
], outputs=inpainted_image) | |
demo.launch() | |