Spaces:
Build error
Build error
import os | |
try: | |
import detectron2 | |
except: | |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git@d1e04565d3bec8719335b88be9e9b961bf3ec464') | |
import gradio as gr | |
import codecs | |
import cv2 | |
from io import BytesIO | |
import numpy as np | |
from PIL import Image | |
import requests | |
import torch | |
# CLIPSeg imports | |
from torchvision import transforms | |
from clipseg.models.clipseg import CLIPDensePredT | |
# MaskFormer imports | |
from detectron2.config import get_cfg | |
from detectron2.projects.deeplab import add_deeplab_config | |
from detectron2.engine.defaults import DefaultPredictor | |
from mask_former.mask_former_model import MaskFormer | |
from mask_former.config import add_mask_former_config | |
from mask_former.data.datasets.register_coco_stuff_10k import COCO_CATEGORIES | |
import sys | |
sys.path.append("shape-guided-diffusion") | |
from shape_guided_diffusion import shape_guided_diffusion, init_models, init_safety_checker, check_image | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
chosen_preset = "dog" | |
chosen_hparam = "default" | |
image_height = 512 | |
# ========================== | |
# User Options Init | |
# ========================== | |
radio_options = [ | |
"Upload mask", | |
"Draw mask above", | |
"Infer mask with MaskFormer", | |
# "Infer mask with CLIPSeg" | |
] | |
presets = { | |
"custom": [None, None, None, None], | |
"dog": [Image.open(f"assets/dog.png"), Image.open(f"assets/dog_mask.png"), "dog", "dog wearing a floral jacket"], | |
"truck": [Image.open(f"assets/truck.png"), Image.open(f"assets/truck_mask.png"), "truck", "lego truck"], | |
} | |
hparams_presets = { | |
"default": [True, True, True, True, 3.5, 2.5, 0], | |
"halve the runtime": [True, False, True, True, 3.5, 2.5, 0], | |
"make non-deterministic": [True, True, True, True, 3.5, 2.5, 0.5], | |
"make more text aligned": [True, True, True, True, 7.5, 2.5, 0] | |
} | |
# ========================== | |
# CLIPSeg Init | |
# ========================== | |
# clip_seg_model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64) | |
# clip_seg_model.eval() | |
# clip_seg_model.load_state_dict(torch.load('clipseg/weights/rd64-uni.pth'), strict=False) | |
# clip_seg_model.to(device) | |
# clip_seg_transform = transforms.Compose([ | |
# transforms.ToTensor(), | |
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
# transforms.Resize((512, 512)), | |
# ]) | |
# ========================== | |
# MaskFormer Init | |
# ========================== | |
def setup_cfg(cfg_path): | |
cfg = get_cfg() | |
add_deeplab_config(cfg) | |
add_mask_former_config(cfg) | |
cfg.merge_from_file(cfg_path) | |
cfg.freeze() | |
return cfg | |
maskformer_cfg = setup_cfg("mask_former/configs/maskformer_R50_bs32_60k.yaml") | |
maskformer_model = DefaultPredictor(maskformer_cfg) | |
# ========================== | |
# Stable Diffusion Init | |
# ========================== | |
huggingface_access_token = os.environ.get("API_TOKEN") or True | |
unet, vae, clip, clip_tokenizer = init_models(huggingface_access_token, device=device) | |
feature_extractor, safety_checker = init_safety_checker(device=device) | |
seed = 98374234 | |
generator = torch.cuda.manual_seed(seed) | |
noise = torch.randn( | |
(1, unet.in_channels, 512 // 8, 512 // 8), | |
device=device, | |
generator=torch.cuda.manual_seed(seed) | |
) | |
# ========================== | |
# Helper Functions | |
# ========================== | |
def download_image(url): | |
response = requests.get(url) | |
return Image.open(BytesIO(response.content)).convert("RGB") | |
def reset(): | |
reset_options = [chosen_preset, chosen_hparam, radio_options[0]] | |
reset_options += presets[chosen_preset] + hparams_presets[chosen_hparam] | |
return reset_options | |
def preset_mask_upload(dropdown, radio): | |
if radio == radio_options[0] and dropdown != "custom": | |
return presets[dropdown][1] | |
else: | |
return None | |
def preset_dropdown(dropdown): | |
return [radio_options[0]] + presets[dropdown] | |
def preset_hparams_dropdown(hparams_dropdown): | |
return hparams_presets[hparams_dropdown] | |
def activate_mask_upload(radio): | |
is_upload = radio == radio_options[0] | |
return gr.update(interactive=is_upload) | |
# def infer_clip_seg(img, source_prompt, threshold=100): | |
# img = clip_seg_transform(img).unsqueeze(0) | |
# with torch.no_grad(): | |
# preds = clip_seg_model(img, [source_prompt])[0] | |
# mask_preds = torch.sigmoid(preds[0][0]).detach().cpu().numpy() | |
# mask_preds = mask_preds * 255.0 | |
# mask_preds = np.where(mask_preds > threshold, 255, 0) | |
# mask_preds = np.uint8(mask_preds) | |
# mask_preds = Image.fromarray(mask_preds, "L") | |
# return mask_preds | |
def infer_maskformer(img, source_prompt): | |
category_mapping = {cat["name"]: i for i, cat in enumerate(COCO_CATEGORIES)} | |
if source_prompt not in category_mapping: | |
raise gr.Error(f"When using MaskFormer, source prompt must be a category in the COCO-Stuff dataset.") | |
category_idx = category_mapping[source_prompt] | |
img = np.array(img) | |
with torch.no_grad(): | |
mask_preds = maskformer_model(img)["sem_seg"] | |
mask_preds = mask_preds.detach().cpu().numpy() | |
mask_preds = mask_preds.argmax(axis=0) | |
mask_preds = np.where(mask_preds == category_idx, 255, 0) | |
mask_preds = np.uint8(mask_preds) | |
mask_preds = Image.fromarray(mask_preds, "L") | |
return mask_preds | |
def get_mask(radio, image_upload, mask_upload, source_prompt, threshold=100): | |
if radio == radio_options[0]: | |
mask = mask_upload | |
elif radio == radio_options[1]: | |
mask = image_upload["mask"] | |
elif radio == radio_options[2]: | |
mask = infer_maskformer(image_upload["image"], source_prompt) | |
# elif radio == radio_options[3]: | |
# mask = infer_clip_seg(image_upload["image"], source_prompt) | |
if mask is None: | |
raise gr.Error("Missing input mask. Try running Update Mask again.") | |
mask = mask.convert("RGB") | |
mask = mask.resize((512, 512)) | |
return mask | |
def predict( | |
image_upload, | |
mask_upload, | |
source_prompt, | |
edit_prompt, | |
run_safety_checker, | |
run_inversion, | |
run_cross_attn_mask, | |
run_self_attn_mask, | |
guidance_scale, | |
cross_attn_sched, | |
noise_mixing | |
): | |
if source_prompt not in edit_prompt: | |
raise gr.Error("Source prompt must be a substring of edit prompt.") | |
if not image_upload: | |
raise gr.Error("Missing input image.") | |
if not mask_upload: | |
raise gr.Error("Missing input mask.") | |
init_image = image_upload["image"] | |
init_image = init_image.convert("RGB") | |
init_image = init_image.resize((512, 512)) | |
with torch.autocast("cuda"): | |
edit_image = shape_guided_diffusion( | |
unet, | |
vae, | |
clip_tokenizer, | |
clip, | |
init_image=init_image, | |
mask_image=mask_upload, | |
# Prompt params | |
prompt_inversion_inside=source_prompt, | |
prompt_inversion_outside="background", | |
prompt_inside=edit_prompt, | |
prompt_outside="background", | |
# Generation params | |
guidance_scale=guidance_scale, | |
# Inside-Outside Attention params | |
run_cross_attention_mask=run_cross_attn_mask, | |
run_self_attention_mask=run_self_attn_mask, | |
self_attn_schedule=1.0, | |
cross_attn_schedule=cross_attn_sched, | |
# DDIM Inversion params | |
run_inversion=run_inversion, | |
noise_mixing=noise_mixing, | |
# Random seed params | |
noise=noise, | |
generator=generator, | |
) | |
if run_safety_checker: | |
edit_image, has_nsfw_concept = check_image(feature_extractor, safety_checker, edit_image) | |
if has_nsfw_concept: | |
raise gr.Error("Safety checker has filtered potentially NSFW result.") | |
return edit_image | |
# ========================== | |
# User Interface | |
# ========================== | |
css = f""" | |
.container {{max-width: 1150px; margin: auto; padding-top: 1.5rem;}} | |
#image_upload{{min-height: {image_height}px;}} | |
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{{min-height: {image_height}px;}} | |
#mask_radio .gr-form{{background:transparent; border: none;}} | |
#mask_upload{{min-height: {image_height}px;}} | |
#mask_upload [data-testid="image"], #mask_upload [data-testid="image"] > div{{min-height: {image_height}px;}} | |
#mask_btn {{width: 100%; margin: 10px 0;}} | |
.footer {{margin-bottom: 45px; margin-top: 35px; text-align: center; border-bottom: 1px solid #e5e5e5;}} | |
.footer>p {{font-size: .8rem; display: inline-block; padding: 0 10px; transform: translateY(10px); background: white;}} | |
.dark .footer {{border-color: #303030;}} | |
.dark .footer>p {{background: #0b0f19;}} | |
.acknowledgments h4{{margin: 1.25em 0 .25em 0; font-weight: bold;font-size: 115%;}} | |
#image_upload .touch-none{{display: flex;}} | |
""" | |
chosen_image, chosen_mask, chosen_source_prompt, chosen_edit_prompt = presets[chosen_preset] | |
image_blocks = gr.Blocks(css=css) | |
with image_blocks as demo: | |
# Components | |
gr.HTML(codecs.open("html/header.html", "r").read()) | |
with gr.Row(): | |
with gr.Column(): | |
dropdown = gr.Dropdown(list(presets.keys()), value=chosen_preset, label="Preset Image, Prompt, Mask") | |
image_upload = gr.Image(value=chosen_image, source="upload", tool="sketch", elem_id="image_upload", type="pil", label="Image").style(height=image_height) | |
source_prompt = gr.Textbox(value=chosen_source_prompt, label="Source Prompt") | |
edit_prompt = gr.Textbox(value=chosen_edit_prompt, label="Edit Prompt") | |
with gr.Box(elem_id="mask_radio").style(border=False): | |
radio = gr.Radio(radio_options, value=radio_options[0], show_label=False).style(container=False) | |
mask_btn = gr.Button("Update Mask", elem_id="mask_btn") | |
mask_upload = gr.Image(value=chosen_mask, source="upload", tool="editor", elem_id="mask_upload", type="pil", label="Mask").style(height=image_height) | |
with gr.Accordion("Advanced Settings", open=False): | |
hparams_dropdown = gr.Dropdown(list(hparams_presets.keys()), value=chosen_hparam, label="Preset Hyperparameters") | |
run_safety_checker = gr.Checkbox( | |
label="Filter NSFW results", value=hparams_presets[chosen_hparam][0] | |
) | |
run_inversion = gr.Checkbox( | |
label="Run DDIM inversion", value=hparams_presets[chosen_hparam][1] | |
) | |
run_cross_attn_mask = gr.Checkbox( | |
label="Apply Inside-Outside Attention to cross-attention layers", value=hparams_presets[chosen_hparam][2] | |
) | |
run_self_attn_mask = gr.Checkbox( | |
label="Apply Inside-Outside Attention to self-attention layers", value=hparams_presets[chosen_hparam][3] | |
) | |
guidance_scale = gr.Slider( | |
label="Guidance scale", minimum=0, maximum=15, step=0.1, value=hparams_presets[chosen_hparam][4] | |
) | |
cross_attn_sched = gr.Slider( | |
label="Up-weight factor of new tokens in edit prompt", minimum=0.0, maximum=10, step=0.5, value=hparams_presets[chosen_hparam][5] | |
) | |
noise_mixing = gr.Slider( | |
label="Level of random noise to interpolate with initial latent", minimum=0.0, maximum=1.0, step=0.1, value=hparams_presets[chosen_hparam][6] | |
) | |
with gr.Row(): | |
reset_btn = gr.Button("Reset", variant="secondary") | |
run_btn = gr.Button("Run", variant="primary") | |
with gr.Column(): | |
result = gr.Image(label="Result").style(height=image_height) | |
gr.HTML(codecs.open("html/footer.html", "r").read()) | |
# ========================== | |
# Event Listeners | |
# ========================== | |
# Mask selection interactions | |
radio.change(fn=activate_mask_upload, inputs=radio, outputs=mask_upload, show_progress=False) | |
radio.change(fn=preset_mask_upload, inputs=[dropdown, radio], outputs=mask_upload, show_progress=False) | |
mask_btn.click(fn=get_mask, inputs=[radio, image_upload, mask_upload, source_prompt], outputs=mask_upload, show_progress=True) | |
# Preset interactions | |
dropdown.change(fn=preset_dropdown, inputs=dropdown, outputs=[radio, image_upload, mask_upload, source_prompt, edit_prompt], show_progress=False) | |
hparams_dropdown.change( | |
fn=preset_hparams_dropdown, | |
inputs=hparams_dropdown, | |
outputs=[ | |
run_safety_checker, | |
run_inversion, | |
run_cross_attn_mask, | |
run_self_attn_mask, | |
guidance_scale, | |
cross_attn_sched, | |
noise_mixing | |
], | |
show_progress=False | |
) | |
# Global interactions | |
reset_btn.click( | |
fn=reset, | |
outputs=[ | |
dropdown, | |
hparams_dropdown, | |
radio, | |
image_upload, | |
mask_upload, | |
source_prompt, | |
edit_prompt, | |
run_safety_checker, | |
run_inversion, | |
run_cross_attn_mask, | |
run_self_attn_mask, | |
guidance_scale, | |
cross_attn_sched, | |
noise_mixing | |
], | |
show_progress=True | |
) | |
run_btn.click( | |
fn=predict, | |
inputs=[ | |
image_upload, | |
mask_upload, | |
source_prompt, | |
edit_prompt, | |
run_safety_checker, | |
run_inversion, | |
run_cross_attn_mask, | |
run_self_attn_mask, | |
guidance_scale, | |
cross_attn_sched, | |
noise_mixing | |
], | |
outputs=result, | |
show_progress=True | |
) | |
# Demo Launch, server_port=7067 | |
demo.launch() |