layout-guidance / app.py
silentchen's picture
Upload app.py
194b558 verified
raw
history blame contribute delete
No virus
22.3 kB
import gradio as gr
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, LMSDiscreteScheduler
from my_model import unet_2d_condition
import json
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from functools import partial
import math
from utils import compute_ca_loss
from gradio import processing_utils
from typing import Optional
import spaces
import warnings
import sys
sys.tracebacklimit = 0
class Blocks(gr.Blocks):
def __init__(
self,
theme: str = "default",
analytics_enabled: Optional[bool] = None,
mode: str = "blocks",
title: str = "Gradio",
css: Optional[str] = None,
**kwargs,
):
self.extra_configs = {
'thumbnail': kwargs.pop('thumbnail', ''),
'url': kwargs.pop('url', 'https://gradio.app/'),
'creator': kwargs.pop('creator', '@teamGradio'),
}
super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
warnings.filterwarnings("ignore")
def get_config_file(self):
config = super(Blocks, self).get_config_file()
for k, v in self.extra_configs.items():
config[k] = v
return config
def draw_box(boxes=[], texts=[], img=None):
if len(boxes) == 0 and img is None:
return None
if img is None:
img = Image.new('RGB', (512, 512), (255, 255, 255))
colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
draw = ImageDraw.Draw(img)
font = ImageFont.truetype("DejaVuSansMono.ttf", size=18)
print(boxes)
for bid, box in enumerate(boxes):
draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
anno_text = texts[bid]
draw.rectangle(
[box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]],
outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4)
draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size * 1.2)], anno_text, font=font,
fill=(255, 255, 255))
return img
def get_concat(ims):
if len(ims) == 1:
n_col = 1
else:
n_col = 2
n_row = math.ceil(len(ims) / 2)
dst = Image.new('RGB', (ims[0].width * n_col, ims[0].height * n_row), color="white")
for i, im in enumerate(ims):
row_id = i // n_col
col_id = i % n_col
dst.paste(im, (im.width * col_id, im.height * row_id))
return dst
def binarize(x):
return (x != 0).astype('uint8') * 255
def sized_center_crop(img, cropx, cropy):
y, x = img.shape[:2]
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
return img[starty:starty + cropy, startx:startx + cropx]
def sized_center_fill(img, fill, cropx, cropy):
y, x = img.shape[:2]
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
img[starty:starty + cropy, startx:startx + cropx] = fill
return img
def sized_center_mask(img, cropx, cropy):
y, x = img.shape[:2]
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
center_region = img[starty:starty + cropy, startx:startx + cropx].copy()
img = (img * 0.2).astype('uint8')
img[starty:starty + cropy, startx:startx + cropx] = center_region
return img
def center_crop(img, HW=None, tgt_size=(512, 512)):
if HW is None:
H, W = img.shape[:2]
HW = min(H, W)
img = sized_center_crop(img, HW, HW)
img = Image.fromarray(img)
img = img.resize(tgt_size)
return np.array(img)
def draw(input, grounding_texts, new_image_trigger, state):
if type(input) == dict:
# import pdb; pdb.set_trace()
# image = input['composite']
mask = input['composite']
else:
mask = input
if mask.ndim == 3:
mask = 255 - mask[..., 0]
image_scale = 1.0
mask = binarize(mask)
if type(mask) != np.ndarray:
mask = np.array(mask)
if mask.sum() == 0:
state = {}
image = None
if 'boxes' not in state:
state['boxes'] = []
if 'masks' not in state or len(state['masks']) == 0:
state['masks'] = []
last_mask = np.zeros_like(mask)
else:
last_mask = state['masks'][-1]
if type(mask) == np.ndarray and mask.size > 1:
diff_mask = mask - last_mask
else:
diff_mask = np.zeros([])
if diff_mask.sum() > 0:
x1x2 = np.where(diff_mask.max(0) != 0)[0]
y1y2 = np.where(diff_mask.max(1) != 0)[0]
y1, y2 = y1y2.min(), y1y2.max()
x1, x2 = x1x2.min(), x1x2.max()
if (x2 - x1 > 5) and (y2 - y1 > 5):
state['masks'].append(mask.copy())
state['boxes'].append((x1, y1, x2, y2))
grounding_texts = [x.strip() for x in grounding_texts.split(';')]
grounding_texts = [x for x in grounding_texts if len(x) > 0]
if len(grounding_texts) < len(state['boxes']):
grounding_texts += [f'Obj. {bid + 1}' for bid in range(len(grounding_texts), len(state['boxes']))]
box_image = draw_box(state['boxes'], grounding_texts, image)
return [box_image, new_image_trigger, image_scale, state]
def clear(sketch_pad_trigger, batch_size, state, switch_task=False):
sketch_pad_trigger = sketch_pad_trigger + 1
blank_samples = batch_size % 2 if batch_size > 1 else 0
out_images = [None]
# state = {}
return [None, sketch_pad_trigger, None, 1.0] + out_images + [{}]
def main():
css = """
#img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
{
height: var(--height) !important;
max-height: var(--height) !important;
min-height: var(--height) !important;
}
#paper-info a {
color:#008AD7;
text-decoration: none;
}
#paper-info a:hover {
cursor: pointer;
text-decoration: none;
}
.tooltip {
color: #555;
position: relative;
display: inline-block;
cursor: pointer;
}
.tooltip .tooltiptext {
visibility: hidden;
width: 400px;
background-color: #555;
color: #fff;
text-align: center;
padding: 5px;
border-radius: 5px;
position: absolute;
z-index: 1; /* Set z-index to 1 */
left: 10px;
top: 100%;
opacity: 0;
transition: opacity 0.3s;
}
.tooltip:hover .tooltiptext {
visibility: visible;
opacity: 1;
z-index: 9999; /* Set a high z-index value when hovering */
}
"""
rescale_js = """
function(x) {
const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0;
const image_width = root.querySelector('#img2img_image').clientWidth;
const target_height = parseInt(image_width * image_scale);
document.body.style.setProperty('--height', `${target_height}px`);
root.querySelectorAll('button.justify-center.rounded')[0].style.display='none';
root.querySelectorAll('button.justify-center.rounded')[1].style.display='none';
return x;
}
"""
with open('./conf/unet/config.json') as f:
unet_config = json.load(f)
unet = unet_2d_condition.UNet2DConditionModel(**unet_config).from_pretrained('runwayml/stable-diffusion-v1-5',
subfolder="unet")
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet.to(device)
text_encoder.to(device)
vae.to(device)
def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
state):
if 'boxes' not in state:
state['boxes'] = []
boxes = state['boxes']
grounding_texts = [x.strip() for x in grounding_texts.split(';')]
# assert len(boxes) == len(grounding_texts)
if len(boxes) != len(grounding_texts):
if len(boxes) < len(grounding_texts):
raise ValueError("""The number of boxes should be equal to the number of grounding objects.
Number of boxes drawn: {}, number of grounding tokens: {}.
Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
boxes = (np.asarray(boxes) / 512).tolist()
boxes = [[box] for box in boxes]
grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
language_instruction_list = language_instruction.strip('.').split(' ')
object_positions = []
for obj in grounding_texts:
obj_position = []
for word in obj.split(' '):
obj_first_index = language_instruction_list.index(word) + 1
obj_position.append(obj_first_index)
object_positions.append(obj_position)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, boxes,
object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed,
guidance_scale)
blank_samples = batch_size % 2 if batch_size > 1 else 0
gen_images = [x for i, x in enumerate(gen_images)] \
+ [_ for _ in range(blank_samples)] \
+ [_ for _ in range(4 - batch_size - blank_samples)]
return gen_images + [state]
'''
inference model
'''
@spaces.GPU(duration=180)
def inference(device, unet, vae, tokenizer, text_encoder, prompt, bboxes, object_positions, batch_size, loss_scale,
loss_threshold, max_iter, max_index_step, rand_seed, guidance_scale):
uncond_input = tokenizer(
[""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
input_ids = tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids[0].unsqueeze(0).to(device)
# text_embeddings = text_encoder(input_ids)[0]
text_embeddings = torch.cat([uncond_embeddings, text_encoder(input_ids)[0]])
# text_embeddings[1, 1, :] = text_embeddings[1, 2, :]
generator = torch.manual_seed(rand_seed) # Seed generator to create the inital latent noise
latents = torch.randn(
(batch_size, 4, 64, 64),
generator=generator,
).to(device)
noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000)
# generator = torch.Generator("cuda").manual_seed(1024)
noise_scheduler.set_timesteps(51)
latents = latents * noise_scheduler.init_noise_sigma
loss = torch.tensor(10000)
for index, t in enumerate(noise_scheduler.timesteps):
iteration = 0
while loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step:
latents = latents.requires_grad_(True)
# latent_model_input = torch.cat([latents] * 2)
latent_model_input = latents
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
unet(latent_model_input, t, encoder_hidden_states=text_encoder(input_ids)[0])
# update latents with guidence from gaussian blob
loss = compute_ca_loss(attn_map_integrated_mid, attn_map_integrated_up, bboxes=bboxes,
object_positions=object_positions) * loss_scale
print(loss.item() / loss_scale)
grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
latents = latents - grad_cond * noise_scheduler.sigmas[index] ** 2
iteration += 1
torch.cuda.empty_cache()
torch.cuda.empty_cache()
with torch.no_grad():
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
noise_pred = noise_pred.sample
# perform classifier-free guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
torch.cuda.empty_cache()
# Decode image
with torch.no_grad():
# print("decode image")
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
with Blocks(
css=css,
analytics_enabled=False,
title="Layout-Guidance demo",
) as demo:
description = """<p style="text-align: center; font-weight: bold;">
<span style="font-size: 28px">Layout Guidance</span>
<br>
<span style="font-size: 18px" id="paper-info">
[<a href=" " target="_blank">Project Page</a>]
[<a href=" " target="_blank">Paper</a>]
[<a href=" " target="_blank">GitHub</a>]
</span>
</p>
"""
gr.HTML(description)
with gr.Column():
language_instruction = gr.Textbox(
label="Text Prompt",
)
grounding_instruction = gr.Textbox(
label="Grounding instruction (Separated by semicolon)",
)
sketch_pad_trigger = gr.Number(value=0, visible=False)
sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
init_white_trigger = gr.Number(value=0, visible=False)
image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
new_image_trigger = gr.Number(value=0, visible=False)
with gr.Row():
sketch_pad = gr.Paint(label="Sketch Pad", container=False, layers=False, scale=1, elem_id="img2img_image", canvas_size=(512,512))
out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
out_gen_1 = gr.Image(type="pil", visible=True, label="Generated Image")
with gr.Row():
clear_btn = gr.Button(value='Clear')
gen_btn = gr.Button(value='Generate')
with gr.Accordion("Advanced Options", open=False):
with gr.Column():
description = """<div class="tooltip">Loss Scale Factor &#9432
<span class="tooltiptext">The scale factor of the backward guidance loss. The larger it is, the better control we get while it sometimes losses fidelity. </span>
</div>
<div class="tooltip">Guidance Scale &#9432
<span class="tooltiptext">The scale factor of classifier-free guidance. </span>
</div>
<div class="tooltip" >Max Iteration per Step &#9432
<span class="tooltiptext">The max iterations of backward guidance in each diffusion inference process.</span>
</div>
<div class="tooltip" >Loss Threshold &#9432
<span class="tooltiptext">The threshold of loss. If the loss computed by cross-attention map is smaller then the threshold, the backward guidance is stopped. </span>
</div>
<div class="tooltip" >Max Step of Backward Guidance &#9432
<span class="tooltiptext">The max steps of backward guidance in diffusion inference process.</span>
</div>
"""
gr.HTML(description)
Loss_scale = gr.Slider(minimum=0, maximum=500, step=5, value=30,label="Loss Scale Factor")
guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
batch_size = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number of Samples", visible=False)
max_iter = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Max Iteration per Step")
loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss Threshold")
max_step = gr.Slider(minimum=0, maximum=50, step=1, value=10, label="Max Step of Backward Guidance")
rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed")
state = gr.State({})
class Controller:
def __init__(self):
self.calls = 0
self.tracks = 0
self.resizes = 0
self.scales = 0
def init_white(self, init_white_trigger):
self.calls += 1
return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger + 1
def change_n_samples(self, n_samples):
blank_samples = n_samples % 2 if n_samples > 1 else 0
return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \
+ [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)]
controller = Controller()
demo.load(
lambda x: x + 1,
inputs=sketch_pad_trigger,
outputs=sketch_pad_trigger,
queue=False)
sketch_pad.change(
draw,
inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
concurrency_limit=1,
queue=False,
)
grounding_instruction.change(
draw,
inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
concurrency_limit=1,
queue=False,
)
clear_btn.click(
clear,
inputs=[sketch_pad_trigger, sketch_pad_trigger, batch_size, state],
outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, state],
concurrency_limit=1,
queue=False)
sketch_pad_trigger.change(
controller.init_white,
inputs=[init_white_trigger],
outputs=[sketch_pad, image_scale, init_white_trigger],
concurrency_limit=1,
queue=False)
gen_btn.click(
fn=partial(generate, unet, vae, tokenizer, text_encoder),
inputs=[
language_instruction, grounding_instruction, sketch_pad,
loss_threshold, guidance_scale, batch_size, rand_seed,
max_step,
Loss_scale, max_iter,
state,
],
outputs=[out_gen_1, state],
concurrency_limit=1,
queue=True
)
sketch_pad_resize_trigger.change(
None,
None,
sketch_pad_resize_trigger,
js=rescale_js,
concurrency_limit=1,
queue=False)
init_white_trigger.change(
None,
None,
init_white_trigger,
js=rescale_js,
concurrency_limit=1,
queue=False)
with gr.Column():
gr.Examples(
examples=[
[
# "images/input.png",
"A hello kitty toy is playing with a purple ball.",
"hello kitty;ball",
"images/hello_kitty_results.png"
],
],
inputs=[language_instruction, grounding_instruction, out_gen_1],
outputs=None,
fn=None,
cache_examples=False,
)
description = """<p> The source codes of the demo are modified based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GlIGen</a>. Thanks! </p>"""
gr.HTML(description)
demo.queue(api_open=False)
demo.launch(share=False, show_api=False, show_error=True)
if __name__ == '__main__':
main()