Tony Lian
Update: add attention guidance and refactor the code
89f6983
raw history blame
No virus
15.9 kB
import gradio as gr
import numpy as np
import os
import matplotlib.pyplot as plt
from utils.parse import filter_boxes, parse_input_with_negative, show_boxes
from generation import run as run_ours
from baseline import run as run_baseline
import torch
from shared import DEFAULT_SO_NEGATIVE_PROMPT, DEFAULT_OVERALL_NEGATIVE_PROMPT
from examples import stage1_examples, stage2_examples, default_template, simplified_prompt, prompt_placeholder, layout_placeholder
cuda_available = torch.cuda.is_available()
print(f"Is CUDA available: {torch.cuda.is_available()}")
if cuda_available:
gpu_memory = torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory
low_memory = gpu_memory <= 16 * 1024 ** 3
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}. With GPU memory: {gpu_memory}. Low memory: {low_memory}")
else:
low_memory = False
cache_examples = True
default_num_inference_steps = 20 if low_memory else 50
def get_lmd_prompt(prompt, template=default_template):
if prompt == "":
prompt = prompt_placeholder
if template == "":
template = default_template
return simplified_prompt.format(template=template, prompt=prompt)
def get_layout_image(response):
if response == "":
response = layout_placeholder
gen_boxes, bg_prompt, neg_prompt = parse_input_with_negative(response, no_input=True)
fig = plt.figure(figsize=(8, 8))
# https://stackoverflow.com/questions/7821518/save-plot-to-numpy-array
show_boxes(gen_boxes, bg_prompt, neg_prompt)
# If we haven't already shown or saved the plot, then we need to
# draw the figure first...
fig.canvas.draw()
# Now we can save it to a numpy array.
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.clf()
return data
def get_layout_image_gallery(response):
return [get_layout_image(response)]
def get_ours_image(response, overall_prompt_override="", seed=0, num_inference_steps=250, dpm_scheduler=True, use_autocast=False, fg_seed_start=20, fg_blending_ratio=0.1, frozen_step_ratio=0.5, attn_guidance_step_ratio=0.6, gligen_scheduled_sampling_beta=0.4, attn_guidance_scale=20, use_ref_ca=True, so_negative_prompt=DEFAULT_SO_NEGATIVE_PROMPT, overall_negative_prompt=DEFAULT_OVERALL_NEGATIVE_PROMPT, show_so_imgs=False, scale_boxes=False):
if response == "":
response = layout_placeholder
gen_boxes, bg_prompt, neg_prompt = parse_input_with_negative(response, no_input=True)
gen_boxes = filter_boxes(gen_boxes, scale_boxes=scale_boxes)
spec = {
# prompt is unused
'prompt': '',
'gen_boxes': gen_boxes,
'bg_prompt': bg_prompt,
'extra_neg_prompt': neg_prompt
}
if dpm_scheduler:
scheduler_key = "dpm_scheduler"
else:
scheduler_key = "scheduler"
overall_max_index_step = int(attn_guidance_step_ratio * num_inference_steps)
image_np, so_img_list = run_ours(
spec, bg_seed=seed, overall_prompt_override=overall_prompt_override, fg_seed_start=fg_seed_start,
fg_blending_ratio=fg_blending_ratio,frozen_step_ratio=frozen_step_ratio, use_autocast=use_autocast,
so_gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, overall_gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key,
use_ref_ca=use_ref_ca, so_negative_prompt=so_negative_prompt, overall_negative_prompt=overall_negative_prompt,
loss_scale=attn_guidance_scale, max_index_step=0, overall_loss_scale=attn_guidance_scale, overall_max_index_step=overall_max_index_step,
)
images = [image_np]
if show_so_imgs:
images.extend([np.asarray(so_img) for so_img in so_img_list])
if cuda_available:
print(f"Max GPU memory allocated: {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB")
torch.cuda.reset_max_memory_allocated()
return images
def get_baseline_image(prompt, seed=0):
if prompt == "":
prompt = prompt_placeholder
scheduler_key = "dpm_scheduler"
num_inference_steps = 20
image_np = run_baseline(prompt, bg_seed=seed, scheduler_key=scheduler_key, num_inference_steps=num_inference_steps)
return [image_np]
duplicate_html = '<a style="display:inline-block" href="https://huggingface.co/spaces/longlian/llm-grounded-diffusion?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>'
html = f"""<h1>LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models</h1>
<h2>LLM + Stable Diffusion => better prompt understanding in text2image generation 🀩</h2>
<h2><a href='https://llm-grounded-diffusion.github.io/'>Project Page</a> | <a href='https://bair.berkeley.edu/blog/2023/05/23/lmd/'>5-minute Blog Post</a> | <a href='https://arxiv.org/pdf/2305.13655.pdf'>ArXiv Paper</a> | <a href='https://github.com/TonyLianLong/LLM-groundedDiffusion'>Github</a> | <a href='https://llm-grounded-diffusion.github.io/#citation'>Cite our work</a> if our ideas inspire you.</h2>
<p><b>Tips:</b><p>
<p>1. If ChatGPT doesn't generate layout, add/remove the trailing space (added by default) and/or use GPT-4.</p>
<p>2. You can perform multi-round specification by giving ChatGPT follow-up requests (e.g., make the objects bigger or move the objects).</p>
<p>3. You can also try prompts in Simplified Chinese. You need to leave "prompt for overall image" empty in this case. If you want to try prompts in another language, translate the first line of last example to your language.</p>
<p>4. The diffusion model only runs 20 steps by default in this demo. You can make it run more steps to get higher quality images (or tweak frozen steps/guidance steps for better guidance and coherence).</p>
<p>5. Duplicate this space and add GPU or clone the space and run locally to skip the queue and run our model faster. (<b>Currently we are using a T4 GPU on this space, which is quite slow, and you can add a A10G to make it 5x faster</b>) {duplicate_html}</p>
<br/>
<p>Implementation note (updated): In this demo, we provide a few modes: faster generation by disabling attention/per-box guidance. The standard version describes what is implemented for the paper. You can set GLIGEN guidance steps ratio to 0 to disable GLIGEN and use only the original SD weights.</p>
<style>.btn {{flex-grow: unset !important;}} </p>
"""
def preset_change(preset):
# frozen_step_ratio, attn_guidance_step_ratio, attn_guidance_scale, use_ref_ca, so_negative_prompt
if preset == "Standard":
return gr.update(value=0.5, interactive=True), gr.update(value=0.6, interactive=True), gr.update(interactive=True), gr.update(value=True, interactive=True), gr.update(interactive=True)
elif preset == "Faster (disable attention guidance)":
return gr.update(value=0.5, interactive=True), gr.update(value=0, interactive=False), gr.update(interactive=False), gr.update(value=True, interactive=True), gr.update(interactive=True)
elif preset == "Faster (disable per-box guidance)":
return gr.update(value=0, interactive=False), gr.update(value=0.6, interactive=True), gr.update(interactive=True), gr.update(value=False, interactive=False), gr.update(interactive=False)
elif preset == "Fastest (disable both)":
return gr.update(value=0, interactive=False), gr.update(value=0, interactive=False), gr.update(interactive=False), gr.update(value=False, interactive=False), gr.update(interactive=True)
else:
raise gr.Error(f"Unknown preset {preset}")
with gr.Blocks(
title="LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models"
) as g:
gr.HTML(html)
with gr.Tab("Stage 1. Image Prompt to ChatGPT"):
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(lines=2, label="Prompt for Layout Generation", placeholder=prompt_placeholder)
generate_btn = gr.Button("Generate Prompt", variant='primary', elem_classes="btn")
with gr.Accordion("Advanced options", open=False):
template = gr.Textbox(lines=10, label="Custom Template", placeholder="Customized Template", value=default_template)
with gr.Column(scale=1):
output = gr.Textbox(label="Paste this into ChatGPT (GPT-4 preferred; on Mac, click text and press Command+A and Command+C to copy all)", show_copy_button=True)
gr.HTML("<a href='https://chat.openai.com' target='_blank'>Click here to open ChatGPT</a>")
generate_btn.click(fn=get_lmd_prompt, inputs=[prompt, template], outputs=output, api_name="get_lmd_prompt")
gr.Examples(
examples=stage1_examples,
inputs=[prompt],
outputs=[output],
fn=get_lmd_prompt,
cache_examples=cache_examples,
label="example_stage1"
)
with gr.Tab("Stage 2 (New). Layout to Image generation"):
with gr.Row():
with gr.Column(scale=1):
overall_prompt_override = gr.Textbox(lines=2, label="Prompt for the overall image (optional but recommended)", placeholder="You can put your input prompt for layout generation here, helpful if your scene cannot be represented by background prompt and boxes only, e.g., with object interactions. If left empty: background prompt with [objects].", value="")
response = gr.Textbox(lines=8, label="Paste ChatGPT response here (no original caption needed here)", placeholder=layout_placeholder)
num_inference_steps = gr.Slider(1, 100 if low_memory else 250, value=default_num_inference_steps, step=1, label="Number of denoising steps (set to >=50 for higher generation quality)")
# Using a environment variable allows setting default to faster/fastest on low-end GPUs.
preset = gr.Radio(label="Guidance: apply less control for faster generation", choices=["Standard", "Faster (disable attention guidance)", "Faster (disable per-box guidance)", "Fastest (disable both)"], value="Faster (disable attention guidance)" if low_memory else "Standard")
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
with gr.Accordion("Advanced options (play around for better generation)", open=False):
with gr.Tab("Guidance"):
frozen_step_ratio = gr.Slider(0, 1, value=0.5, step=0.1, label="Foreground frozen steps ratio (higher: stronger attribute binding; lower: higher coherence")
gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.4, step=0.1, label="GLIGEN guidance steps ratio (the beta value, higher: stronger GLIGEN guidance)")
attn_guidance_step_ratio = gr.Slider(0, 1, value=0.6, step=0.01, label="Attention guidance steps ratio (higher: stronger attention guidance; lower: faster and higher coherence")
attn_guidance_scale = gr.Slider(0, 50, value=20, step=0.5, label="Attention guidance scale: 0 means no attention guidance.")
use_ref_ca = gr.Checkbox(label="Using per-box attention to guide reference attention", show_label=False, value=True)
with gr.Tab("Generation"):
dpm_scheduler = gr.Checkbox(label="Use DPM scheduler (unchecked: DDIM scheduler, may have better coherence, recommend >=50 inference steps)", show_label=False, value=True)
use_autocast = gr.Checkbox(label="Use FP16 Mixed Precision (faster but with slightly lower quality)" + " [enabled due to low GPU memory]" if low_memory else "", show_label=False, value=True, interactive=not low_memory)
fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
fg_blending_ratio = gr.Slider(0, 1, value=0.1, step=0.01, label="Variations added to foreground for single object generation (0: no variation, 1: max variation)")
scale_boxes = gr.Checkbox(label="Scale bounding boxes to just fit the scene", show_label=False, value=False)
so_negative_prompt = gr.Textbox(lines=1, label="Negative prompt for single object generation", value=DEFAULT_SO_NEGATIVE_PROMPT)
overall_negative_prompt = gr.Textbox(lines=1, label="Negative prompt for overall generation", value=DEFAULT_OVERALL_NEGATIVE_PROMPT)
show_so_imgs = gr.Checkbox(label="Show annotated single object generations", show_label=False, value=False)
visualize_btn = gr.Button("Visualize Layout", elem_classes="btn")
generate_btn = gr.Button("Generate Image from Layout", variant='primary', elem_classes="btn")
with gr.Column(scale=1):
gallery = gr.Gallery(
label="Generated image", show_label=False, elem_id="gallery", columns=[1], rows=[1], object_fit="contain", preview=True
)
preset.change(preset_change, [preset], [frozen_step_ratio, attn_guidance_step_ratio, attn_guidance_scale, use_ref_ca, so_negative_prompt])
prompt.change(None, [prompt], overall_prompt_override, _js="(x) => x")
visualize_btn.click(fn=get_layout_image_gallery, inputs=response, outputs=gallery, api_name="visualize-layout")
generate_btn.click(fn=get_ours_image, inputs=[response, overall_prompt_override, seed, num_inference_steps, dpm_scheduler, use_autocast, fg_seed_start, fg_blending_ratio, frozen_step_ratio, attn_guidance_step_ratio, gligen_scheduled_sampling_beta, attn_guidance_scale, use_ref_ca, so_negative_prompt, overall_negative_prompt, show_so_imgs, scale_boxes], outputs=gallery, api_name="layout-to-image")
gr.Examples(
examples=stage2_examples,
inputs=[response, overall_prompt_override, seed],
outputs=[gallery],
fn=get_ours_image,
cache_examples=cache_examples,
label="example_ours"
)
with gr.Tab("Baseline: Stable Diffusion"):
with gr.Row():
with gr.Column(scale=1):
sd_prompt = gr.Textbox(lines=2, label="Prompt for baseline SD", placeholder=prompt_placeholder)
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
generate_btn = gr.Button("Generate", elem_classes="btn")
with gr.Column(scale=1):
gallery = gr.Gallery(
label="Generated image", show_label=False, elem_id="gallery2", columns=[1], rows=[1], object_fit="contain", preview=True
)
generate_btn.click(fn=get_baseline_image, inputs=[sd_prompt, seed], outputs=gallery, api_name="baseline")
gr.Examples(
examples=stage1_examples,
inputs=[sd_prompt],
outputs=[gallery],
fn=get_baseline_image,
cache_examples=cache_examples,
label="example_sd"
)
g.launch()