omer11a's picture
Fixed some stuff
97abb9e
import spaces
import gradio as gr
import torch
import nltk
import numpy as np
from PIL import Image, ImageDraw
from diffusers import DDIMScheduler
from diffusers.models.attention_processor import AttnProcessor2_0
from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline
from injection_utils import register_attention_editor_diffusers
from bounded_attention import BoundedAttention
from pytorch_lightning import seed_everything
REMOTE_MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
LOCAL_MODEL_PATH = "./model"
RESOLUTION = 256
MIN_SIZE = 0.01
WHITE = 255
COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
PROMPT1 = "a ginger kitten and a gray puppy in a yard"
SUBJECT_SUB_PROMPTS1 = "ginger kitten;gray puppy"
SUBJECT_TOKEN_INDICES1 = "2,3;6,7"
FILTER_TOKEN_INDICES1 = "1,4,5,8,9"
NUM_TOKENS1 = "10"
PROMPT2 = "3 D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest"
PROMPT3 = "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship"
PROMPT4 = "a realistic photo of a highway with a semi trailer and a concrete mixer and a helicopter"
PROMPT5 = "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool"
EXAMPLE_BOXES = {
PROMPT1: [
[0.15, 0.2, 0.45, 0.9],
[0.55, 0.25, 0.85, 0.95],
],
PROMPT2 : [
[0.35, 0.4, 0.65, 0.9],
[0, 0.6, 0.3, 0.9],
[0.7, 0.55, 1, 0.85]
],
PROMPT3: [
[0.4, 0.45, 0.6, 0.95],
[0.2, 0.3, 0.4, 0.85],
[0.6, 0.3, 0.8, 0.85],
[0.1, 0, 0.9, 0.3]
],
PROMPT4: [
[0.05, 0.5, 0.45, 0.85],
[0.55, 0.6, 0.95, 0.85],
[0.3, 0.2, 0.7, 0.45],
],
PROMPT5: [
[0, 0.5, 0.2, 0.8],
[0.2, 0.2, 0.4, 0.5],
[0.4, 0.5, 0.6, 0.8],
[0.6, 0.2, 0.8, 0.5],
[0.8, 0.5, 1, 0.8]
],
}
CSS = """
#paper-info a {
color:#008AD7;
text-decoration: none;
}
#paper-info a:hover {
cursor: pointer;
text-decoration: none;
}
.tooltip {
color: #555;
position: relative;
display: inline-block;
cursor: pointer;
}
.tooltip .tooltiptext {
visibility: hidden;
width: 400px;
background-color: #555;
color: #fff;
text-align: center;
padding: 5px;
border-radius: 5px;
position: absolute;
z-index: 1; /* Set z-index to 1 */
left: 10px;
top: 100%;
opacity: 0;
transition: opacity 0.3s;
}
.tooltip:hover .tooltiptext {
visibility: visible;
opacity: 1;
z-index: 9999; /* Set a high z-index value when hovering */
}
"""
DESCRIPTION = """
<p style="text-align: center; font-weight: bold;">
<span style="font-size: 28px">Bounded Attention</span>
<br>
<span style="font-size: 18px" id="paper-info">
[<a href="https://omer11a.github.io/bounded-attention/" target="_blank">Project Page</a>]
[<a href="https://arxiv.org/abs/2403.16990" target="_blank">Paper</a>]
[<a href="https://github.com/omer11a/bounded-attention" target="_blank">GitHub</a>]
</span>
</p>
"""
COPY_LINK = """
<a href="https://huggingface.co/spaces/omer11a/bounded-attention?duplicate=true">
<img src="https://bit.ly/3gLdBN6" alt="Duplicate Space">
</a>
Duplicate this space to generate more samples without waiting in queue.
<br>
To get better results, increase the number of guidance steps to 15.
"""
ADVANCED_OPTION_DESCRIPTION = """
<div class="tooltip" >Number of guidance steps &#9432
<span class="tooltiptext">The number of timesteps in which to perform guidance. Recommended value is 15, but increasing this will also increases the runtime.</span>
</div>
<div class="tooltip">Batch size &#9432
<span class="tooltiptext">The number of images to generate.</span>
</div>
<div class="tooltip">Initial step size &#9432
<span class="tooltiptext">The initial step size of the linear step size scheduler when performing guidance.</span>
</div>
<div class="tooltip">Final step size &#9432
<span class="tooltiptext">The final step size of the linear step size scheduler when performing guidance.</span>
</div>
<div class="tooltip">First refinement step &#9432
<span class="tooltiptext">The timestep from which subject mask refinement is performed.</span>
</div>
<div class="tooltip">Number of self-attention clusters per subject &#9432
<span class="tooltiptext">The number of clusters computed when clustering the self-attention maps (#clusters = #subject x #clusters_per_subject). Changing this value might improve semantics (adherence to the prompt), especially when the subjects exceed their bounding boxes.</span>
</div>
<div class="tooltip">Cross-attention loss scale factor &#9432
<span class="tooltiptext">The scale factor of the cross-attention loss term. Increasing it will improve semantic control (adherence to the prompt), but may reduce image quality.</span>
</div>
<div class="tooltip">Self-attention loss scale factor &#9432
<span class="tooltiptext">The scale factor of the self-attention loss term. Increasing it will improve layout control (adherence to the bounding boxes), but may reduce image quality.</span>
</div>
<div class="tooltip" >Number of Gradient Descent iterations per timestep &#9432
<span class="tooltiptext">The number of Gradient Descent iterations for each timestep when performing guidance.</span>
</div>
<div class="tooltip" >Loss Threshold &#9432
<span class="tooltiptext">If the loss is below the threshold, Gradient Descent stops for that timestep. </span>
</div>
<div class="tooltip">Classifier-free guidance scale &#9432
<span class="tooltiptext">The scale factor of classifier-free guidance.</span>
</div>
"""
FOOTNOTE = """
<p>The source code of this demo is based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GLIGEN demo</a>.</p>
"""
def inference(
boxes,
prompts,
subject_sub_prompts,
subject_token_indices,
filter_token_indices,
num_tokens,
init_step_size,
final_step_size,
first_refinement_step,
num_clusters_per_subject,
cross_loss_scale,
self_loss_scale,
classifier_free_guidance_scale,
num_iterations,
loss_threshold,
num_guidance_steps,
seed,
):
if not torch.cuda.is_available():
raise gr.Error("cuda is not available")
device = torch.device("cuda")
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
model = StableDiffusionXLPipeline.from_pretrained(LOCAL_MODEL_PATH, scheduler=scheduler, torch_dtype=torch.float16, device_map="auto")
model.to(device)
model.unet.set_attn_processor(AttnProcessor2_0())
model.enable_sequential_cpu_offload()
seed_everything(seed)
start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
eos_token_index = None if num_tokens is None else num_tokens + 1
editor = BoundedAttention(
boxes,
prompts,
list(range(70, 82)),
list(range(70, 82)),
subject_sub_prompts=subject_sub_prompts,
subject_token_indices=subject_token_indices,
filter_token_indices=filter_token_indices,
eos_token_index=eos_token_index,
cross_loss_coef=cross_loss_scale,
self_loss_coef=self_loss_scale,
max_guidance_iter=num_guidance_steps,
max_guidance_iter_per_step=num_iterations,
start_step_size=init_step_size,
end_step_size=final_step_size,
loss_stopping_value=loss_threshold,
min_clustering_step=first_refinement_step,
num_clusters_per_box=num_clusters_per_subject,
max_resolution=32,
)
register_attention_editor_diffusers(model, editor)
return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
@spaces.GPU(duration=340)
def generate(
prompt,
subject_sub_prompts,
subject_token_indices,
filter_token_indices,
num_tokens,
init_step_size,
final_step_size,
first_refinement_step,
num_clusters_per_subject,
cross_loss_scale,
self_loss_scale,
classifier_free_guidance_scale,
batch_size,
num_iterations,
loss_threshold,
num_guidance_steps,
seed,
boxes,
):
num_subjects = 0
subject_sub_prompts = convert_sub_prompts(subject_sub_prompts)
subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
if subject_sub_prompts is not None:
num_subjects = len(subject_sub_prompts)
if subject_token_indices is not None:
num_subjects = len(subject_token_indices)
if len(boxes) != num_subjects:
raise gr.Error("""
The number of boxes should be equal to the number of subjects.
Number of boxes drawn: {}, number of subjects: {}.
""".format(len(boxes), num_subjects))
filter_token_indices = convert_token_indices(filter_token_indices) if len(filter_token_indices.strip()) > 0 else None
num_tokens = int(num_tokens) if len(num_tokens.strip()) > 0 else None
prompts = [prompt.strip(".").strip(",").strip()] * batch_size
images = inference(
boxes, prompts, subject_sub_prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
classifier_free_guidance_scale, num_iterations, loss_threshold, num_guidance_steps, seed)
return images
def convert_sub_prompts(sub_prompts):
sub_prompts = sub_prompts.strip()
if len(sub_prompts) == 0:
return None
return [sub_prompt.strip() for sub_prompt in sub_prompts.split(";")]
def convert_token_indices(token_indices, nested=False):
token_indices = token_indices.strip()
if len(token_indices) == 0:
return None
if nested:
return [convert_token_indices(indices, nested=False) for indices in token_indices.split(";")]
return [int(index.strip()) for index in token_indices.split(",") if len(index.strip()) > 0]
def draw(sketchpad):
boxes = []
for i, layer in enumerate(sketchpad["layers"]):
non_zeros = layer.nonzero()
x1 = x2 = y1 = y2 = 0
if len(non_zeros[0]) > 0:
x1x2 = non_zeros[1] / layer.shape[1]
y1y2 = non_zeros[0] / layer.shape[0]
x1 = x1x2.min()
x2 = x1x2.max()
y1 = y1y2.min()
y2 = y1y2.max()
if (x2 - x1 < MIN_SIZE) or (y2 - y1 < MIN_SIZE):
raise gr.Error(f"Box in layer {i} is too small")
boxes.append((x1, y1, x2, y2))
print(f"Drawn boxes: {boxes}")
layout_image = draw_boxes(boxes)
return [boxes, layout_image]
def draw_boxes(boxes, is_sketch=False):
if len(boxes) == 0:
return None
boxes = np.array(boxes) * RESOLUTION
image = Image.new("RGB", (RESOLUTION, RESOLUTION), (WHITE, WHITE, WHITE))
drawing = ImageDraw.Draw(image)
for i, box in enumerate(boxes.astype(int).tolist()):
color = "black" if is_sketch else COLORS[i % len(COLORS)]
drawing.rectangle(box, outline=color, width=4)
return image
def clear(batch_size):
return [[], None, None, None]
def build_example_layout(prompt, *args):
boxes = EXAMPLE_BOXES[prompt]
print(f"Loaded boxes: {boxes}")
composite = draw_boxes(boxes, is_sketch=True)
sketchpad = {"background": None, "layers": [], "composite": composite}
layout_image = draw_boxes(boxes)
return boxes, sketchpad, layout_image
def main():
nltk.download("averaged_perceptron_tagger")
model = StableDiffusionXLPipeline.from_pretrained(REMOTE_MODEL_PATH)
model.save_pretrained(LOCAL_MODEL_PATH)
del model
with gr.Blocks(
css=CSS,
title="Bounded Attention demo",
) as demo:
gr.HTML(DESCRIPTION)
gr.HTML(COPY_LINK)
with gr.Column():
gr.HTML("Scroll down to see examples of the required input format.")
prompt = gr.Textbox(
label="Text prompt",
placeholder=PROMPT1,
)
subject_sub_prompts = gr.Textbox(
label="Sub-prompts for each subject (separate with semicolons)",
placeholder=SUBJECT_SUB_PROMPTS1,
)
with gr.Accordion("Precise inputs", open=False):
subject_token_indices = gr.Textbox(
label="Optional: The token indices of each subject (separate indices for the same subject with commas, and for different subjects with semicolons)",
placeholder=SUBJECT_TOKEN_INDICES1,
)
filter_token_indices = gr.Textbox(
label="Optional: The token indices to filter, i.e. conjunctions, numbers, postional relations, etc. (if left empty, this will be automatically inferred)",
placeholder=FILTER_TOKEN_INDICES1,
)
num_tokens = gr.Textbox(
label="Optional: The number of tokens in the prompt (We use this to verify your input, as sometimes rare words are split into more than one token)",
placeholder=NUM_TOKENS1,
)
with gr.Row():
sketchpad = gr.Sketchpad(label="Sketch Pad (draw each bounding box in a different layer)")
layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False)
with gr.Row():
generate_layout_button = gr.Button(value="Generate layout")
generate_image_button = gr.Button(value="Generate image")
clear_button = gr.Button(value="Clear")
with gr.Row():
out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False)
with gr.Accordion("Advanced Options", open=False):
with gr.Column():
gr.HTML(ADVANCED_OPTION_DESCRIPTION)
batch_size = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of samples (limited to one sample on current space)")
num_guidance_steps = gr.Slider(minimum=5, maximum=20, step=1, value=8, label="Number of timesteps to perform guidance")
init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=30, label="Initial step size")
final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=15, label="Final step size")
first_refinement_step = gr.Slider(minimum=0, maximum=50, step=1, value=15, label="The timestep from which to start refining the subject masks")
num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject")
cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor")
self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor")
num_iterations = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Number of Gradient Descent iterations")
loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss threshold")
classifier_free_guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Classifier-free guidance Scale")
seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed")
boxes = gr.State([])
clear_button.click(
clear,
inputs=[batch_size],
outputs=[boxes, sketchpad, layout_image, out_images],
queue=False,
)
generate_layout_button.click(
draw,
inputs=[sketchpad],
outputs=[boxes, layout_image],
queue=False,
)
generate_image_button.click(
fn=generate,
inputs=[
prompt, subject_sub_prompts, subject_token_indices, filter_token_indices, num_tokens,
init_step_size, final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
seed,
boxes,
],
outputs=[out_images],
queue=True,
)
with gr.Column():
gr.Examples(
examples=[
[
PROMPT1, SUBJECT_SUB_PROMPTS1, SUBJECT_TOKEN_INDICES1, FILTER_TOKEN_INDICES1, NUM_TOKENS1,
15, 10, 15, 3, 1, 1,
7.5, 1, 5, 0.2, 8,
12,
],
[
PROMPT2, "cute unicorn;pink hedgehog;nerdy owl", "7,8,17;11,12,17;15,16,17", "5,6,9,10,13,14,18,19", "21",
25, 18, 15, 3, 1, 1,
7.5, 1, 5, 0.2, 8,
286,
],
[
PROMPT3, "astronaut;robot;green alien;spaceship", "7;10;13,14;17", "5,6,8,9,11,12,15,16", "17",
18, 12, 15, 3, 1, 1,
7.5, 1, 5, 0.2, 8,
216,
],
[
PROMPT4, "semi trailer;concrete mixer;helicopter", "9,10;13,14;17", "1,4,5,7,8,11,12,15,16", "17",
25, 18, 15, 3, 1, 1,
7.5, 1, 5, 0.2, 8,
82,
],
[
PROMPT5, "golden retriever;german shepherd;boston terrier;english bulldog;border collie", "2,3;6,7;10,11;14,15;18,19", "1,4,5,8,9,12,13,16,17,20,21", "22",
18, 12, 15, 3, 1, 1,
7.5, 1, 5, 0.2, 8,
152,
],
],
fn=build_example_layout,
inputs=[
prompt, subject_sub_prompts, subject_token_indices, filter_token_indices, num_tokens,
init_step_size, final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
seed,
],
outputs=[boxes, sketchpad, layout_image],
run_on_click=True,
)
gr.HTML(FOOTNOTE)
demo.launch(show_api=False, show_error=True)
if __name__ == "__main__":
main()