g-luo's picture
Fix ClipSeg error
428226a
import os
try:
import detectron2
except:
os.system('pip install git+https://github.com/facebookresearch/detectron2.git@d1e04565d3bec8719335b88be9e9b961bf3ec464')
import gradio as gr
import codecs
import cv2
from io import BytesIO
import numpy as np
from PIL import Image
import requests
import torch
# CLIPSeg imports
from torchvision import transforms
from clipseg.models.clipseg import CLIPDensePredT
# MaskFormer imports
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.engine.defaults import DefaultPredictor
from mask_former.mask_former_model import MaskFormer
from mask_former.config import add_mask_former_config
from mask_former.data.datasets.register_coco_stuff_10k import COCO_CATEGORIES
import sys
sys.path.append("shape-guided-diffusion")
from shape_guided_diffusion import shape_guided_diffusion, init_models, init_safety_checker, check_image
device = "cuda" if torch.cuda.is_available() else "cpu"
chosen_preset = "dog"
chosen_hparam = "default"
image_height = 512
# ==========================
# User Options Init
# ==========================
radio_options = [
"Upload mask",
"Draw mask above",
"Infer mask with MaskFormer",
# "Infer mask with CLIPSeg"
]
presets = {
"custom": [None, None, None, None],
"dog": [Image.open(f"assets/dog.png"), Image.open(f"assets/dog_mask.png"), "dog", "dog wearing a floral jacket"],
"truck": [Image.open(f"assets/truck.png"), Image.open(f"assets/truck_mask.png"), "truck", "lego truck"],
}
hparams_presets = {
"default": [True, True, True, True, 3.5, 2.5, 0],
"halve the runtime": [True, False, True, True, 3.5, 2.5, 0],
"make non-deterministic": [True, True, True, True, 3.5, 2.5, 0.5],
"make more text aligned": [True, True, True, True, 7.5, 2.5, 0]
}
# ==========================
# CLIPSeg Init
# ==========================
# clip_seg_model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
# clip_seg_model.eval()
# clip_seg_model.load_state_dict(torch.load('clipseg/weights/rd64-uni.pth'), strict=False)
# clip_seg_model.to(device)
# clip_seg_transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# transforms.Resize((512, 512)),
# ])
# ==========================
# MaskFormer Init
# ==========================
def setup_cfg(cfg_path):
cfg = get_cfg()
add_deeplab_config(cfg)
add_mask_former_config(cfg)
cfg.merge_from_file(cfg_path)
cfg.freeze()
return cfg
maskformer_cfg = setup_cfg("mask_former/configs/maskformer_R50_bs32_60k.yaml")
maskformer_model = DefaultPredictor(maskformer_cfg)
# ==========================
# Stable Diffusion Init
# ==========================
huggingface_access_token = os.environ.get("API_TOKEN") or True
unet, vae, clip, clip_tokenizer = init_models(huggingface_access_token, device=device)
feature_extractor, safety_checker = init_safety_checker(device=device)
seed = 98374234
generator = torch.cuda.manual_seed(seed)
noise = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
device=device,
generator=torch.cuda.manual_seed(seed)
)
# ==========================
# Helper Functions
# ==========================
def download_image(url):
response = requests.get(url)
return Image.open(BytesIO(response.content)).convert("RGB")
def reset():
reset_options = [chosen_preset, chosen_hparam, radio_options[0]]
reset_options += presets[chosen_preset] + hparams_presets[chosen_hparam]
return reset_options
def preset_mask_upload(dropdown, radio):
if radio == radio_options[0] and dropdown != "custom":
return presets[dropdown][1]
else:
return None
def preset_dropdown(dropdown):
return [radio_options[0]] + presets[dropdown]
def preset_hparams_dropdown(hparams_dropdown):
return hparams_presets[hparams_dropdown]
def activate_mask_upload(radio):
is_upload = radio == radio_options[0]
return gr.update(interactive=is_upload)
# def infer_clip_seg(img, source_prompt, threshold=100):
# img = clip_seg_transform(img).unsqueeze(0)
# with torch.no_grad():
# preds = clip_seg_model(img, [source_prompt])[0]
# mask_preds = torch.sigmoid(preds[0][0]).detach().cpu().numpy()
# mask_preds = mask_preds * 255.0
# mask_preds = np.where(mask_preds > threshold, 255, 0)
# mask_preds = np.uint8(mask_preds)
# mask_preds = Image.fromarray(mask_preds, "L")
# return mask_preds
def infer_maskformer(img, source_prompt):
category_mapping = {cat["name"]: i for i, cat in enumerate(COCO_CATEGORIES)}
if source_prompt not in category_mapping:
raise gr.Error(f"When using MaskFormer, source prompt must be a category in the COCO-Stuff dataset.")
category_idx = category_mapping[source_prompt]
img = np.array(img)
with torch.no_grad():
mask_preds = maskformer_model(img)["sem_seg"]
mask_preds = mask_preds.detach().cpu().numpy()
mask_preds = mask_preds.argmax(axis=0)
mask_preds = np.where(mask_preds == category_idx, 255, 0)
mask_preds = np.uint8(mask_preds)
mask_preds = Image.fromarray(mask_preds, "L")
return mask_preds
def get_mask(radio, image_upload, mask_upload, source_prompt, threshold=100):
if radio == radio_options[0]:
mask = mask_upload
elif radio == radio_options[1]:
mask = image_upload["mask"]
elif radio == radio_options[2]:
mask = infer_maskformer(image_upload["image"], source_prompt)
# elif radio == radio_options[3]:
# mask = infer_clip_seg(image_upload["image"], source_prompt)
if mask is None:
raise gr.Error("Missing input mask. Try running Update Mask again.")
mask = mask.convert("RGB")
mask = mask.resize((512, 512))
return mask
def predict(
image_upload,
mask_upload,
source_prompt,
edit_prompt,
run_safety_checker,
run_inversion,
run_cross_attn_mask,
run_self_attn_mask,
guidance_scale,
cross_attn_sched,
noise_mixing
):
if source_prompt not in edit_prompt:
raise gr.Error("Source prompt must be a substring of edit prompt.")
if not image_upload:
raise gr.Error("Missing input image.")
if not mask_upload:
raise gr.Error("Missing input mask.")
init_image = image_upload["image"]
init_image = init_image.convert("RGB")
init_image = init_image.resize((512, 512))
with torch.autocast("cuda"):
edit_image = shape_guided_diffusion(
unet,
vae,
clip_tokenizer,
clip,
init_image=init_image,
mask_image=mask_upload,
# Prompt params
prompt_inversion_inside=source_prompt,
prompt_inversion_outside="background",
prompt_inside=edit_prompt,
prompt_outside="background",
# Generation params
guidance_scale=guidance_scale,
# Inside-Outside Attention params
run_cross_attention_mask=run_cross_attn_mask,
run_self_attention_mask=run_self_attn_mask,
self_attn_schedule=1.0,
cross_attn_schedule=cross_attn_sched,
# DDIM Inversion params
run_inversion=run_inversion,
noise_mixing=noise_mixing,
# Random seed params
noise=noise,
generator=generator,
)
if run_safety_checker:
edit_image, has_nsfw_concept = check_image(feature_extractor, safety_checker, edit_image)
if has_nsfw_concept:
raise gr.Error("Safety checker has filtered potentially NSFW result.")
return edit_image
# ==========================
# User Interface
# ==========================
css = f"""
.container {{max-width: 1150px; margin: auto; padding-top: 1.5rem;}}
#image_upload{{min-height: {image_height}px;}}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{{min-height: {image_height}px;}}
#mask_radio .gr-form{{background:transparent; border: none;}}
#mask_upload{{min-height: {image_height}px;}}
#mask_upload [data-testid="image"], #mask_upload [data-testid="image"] > div{{min-height: {image_height}px;}}
#mask_btn {{width: 100%; margin: 10px 0;}}
.footer {{margin-bottom: 45px; margin-top: 35px; text-align: center; border-bottom: 1px solid #e5e5e5;}}
.footer>p {{font-size: .8rem; display: inline-block; padding: 0 10px; transform: translateY(10px); background: white;}}
.dark .footer {{border-color: #303030;}}
.dark .footer>p {{background: #0b0f19;}}
.acknowledgments h4{{margin: 1.25em 0 .25em 0; font-weight: bold;font-size: 115%;}}
#image_upload .touch-none{{display: flex;}}
"""
chosen_image, chosen_mask, chosen_source_prompt, chosen_edit_prompt = presets[chosen_preset]
image_blocks = gr.Blocks(css=css)
with image_blocks as demo:
# Components
gr.HTML(codecs.open("html/header.html", "r").read())
with gr.Row():
with gr.Column():
dropdown = gr.Dropdown(list(presets.keys()), value=chosen_preset, label="Preset Image, Prompt, Mask")
image_upload = gr.Image(value=chosen_image, source="upload", tool="sketch", elem_id="image_upload", type="pil", label="Image").style(height=image_height)
source_prompt = gr.Textbox(value=chosen_source_prompt, label="Source Prompt")
edit_prompt = gr.Textbox(value=chosen_edit_prompt, label="Edit Prompt")
with gr.Box(elem_id="mask_radio").style(border=False):
radio = gr.Radio(radio_options, value=radio_options[0], show_label=False).style(container=False)
mask_btn = gr.Button("Update Mask", elem_id="mask_btn")
mask_upload = gr.Image(value=chosen_mask, source="upload", tool="editor", elem_id="mask_upload", type="pil", label="Mask").style(height=image_height)
with gr.Accordion("Advanced Settings", open=False):
hparams_dropdown = gr.Dropdown(list(hparams_presets.keys()), value=chosen_hparam, label="Preset Hyperparameters")
run_safety_checker = gr.Checkbox(
label="Filter NSFW results", value=hparams_presets[chosen_hparam][0]
)
run_inversion = gr.Checkbox(
label="Run DDIM inversion", value=hparams_presets[chosen_hparam][1]
)
run_cross_attn_mask = gr.Checkbox(
label="Apply Inside-Outside Attention to cross-attention layers", value=hparams_presets[chosen_hparam][2]
)
run_self_attn_mask = gr.Checkbox(
label="Apply Inside-Outside Attention to self-attention layers", value=hparams_presets[chosen_hparam][3]
)
guidance_scale = gr.Slider(
label="Guidance scale", minimum=0, maximum=15, step=0.1, value=hparams_presets[chosen_hparam][4]
)
cross_attn_sched = gr.Slider(
label="Up-weight factor of new tokens in edit prompt", minimum=0.0, maximum=10, step=0.5, value=hparams_presets[chosen_hparam][5]
)
noise_mixing = gr.Slider(
label="Level of random noise to interpolate with initial latent", minimum=0.0, maximum=1.0, step=0.1, value=hparams_presets[chosen_hparam][6]
)
with gr.Row():
reset_btn = gr.Button("Reset", variant="secondary")
run_btn = gr.Button("Run", variant="primary")
with gr.Column():
result = gr.Image(label="Result").style(height=image_height)
gr.HTML(codecs.open("html/footer.html", "r").read())
# ==========================
# Event Listeners
# ==========================
# Mask selection interactions
radio.change(fn=activate_mask_upload, inputs=radio, outputs=mask_upload, show_progress=False)
radio.change(fn=preset_mask_upload, inputs=[dropdown, radio], outputs=mask_upload, show_progress=False)
mask_btn.click(fn=get_mask, inputs=[radio, image_upload, mask_upload, source_prompt], outputs=mask_upload, show_progress=True)
# Preset interactions
dropdown.change(fn=preset_dropdown, inputs=dropdown, outputs=[radio, image_upload, mask_upload, source_prompt, edit_prompt], show_progress=False)
hparams_dropdown.change(
fn=preset_hparams_dropdown,
inputs=hparams_dropdown,
outputs=[
run_safety_checker,
run_inversion,
run_cross_attn_mask,
run_self_attn_mask,
guidance_scale,
cross_attn_sched,
noise_mixing
],
show_progress=False
)
# Global interactions
reset_btn.click(
fn=reset,
outputs=[
dropdown,
hparams_dropdown,
radio,
image_upload,
mask_upload,
source_prompt,
edit_prompt,
run_safety_checker,
run_inversion,
run_cross_attn_mask,
run_self_attn_mask,
guidance_scale,
cross_attn_sched,
noise_mixing
],
show_progress=True
)
run_btn.click(
fn=predict,
inputs=[
image_upload,
mask_upload,
source_prompt,
edit_prompt,
run_safety_checker,
run_inversion,
run_cross_attn_mask,
run_self_attn_mask,
guidance_scale,
cross_attn_sched,
noise_mixing
],
outputs=result,
show_progress=True
)
# Demo Launch, server_port=7067
demo.launch()