from diffusers import DDIMScheduler from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline from injection_utils import regiter_attention_editor_diffusers from bounded_attention import BoundedAttention from pytorch_lightning import seed_everything import spaces import gradio as gr import torch import numpy as np from PIL import Image, ImageDraw from functools import partial RESOLUTION = 256 MIN_SIZE = 0.01 WHITE = 255 COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"] def inference( device, model, boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, num_iterations, loss_threshold, num_guidance_steps, seed, ): seed_everything(seed) start_code = torch.randn([len(prompts), 4, 128, 128], device=device) editor = BoundedAttention( boxes, prompts, subject_token_indices, list(range(70, 82)), list(range(70, 82)), eos_token_index=num_tokens + 1, cross_loss_coef=cross_loss_scale, self_loss_coef=self_loss_scale, filter_token_indices=filter_token_indices, max_guidance_iter=num_guidance_steps, max_guidance_iter_per_step=num_iterations, start_step_size=init_step_size, end_step_size=final_step_size, loss_stopping_value=loss_threshold, num_clusters_per_box=num_clusters_per_subject, debug=False, ) regiter_attention_editor_diffusers(model, editor) return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images @spaces.GPU def generate( device, model, prompt, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps, seed, boxes ): subject_token_indices = convert_token_indices(subject_token_indices, nested=True) if len(boxes) != len(subject_token_indices): raise gr.Error(""" The number of boxes should be equal to the number of subject token indices. Number of boxes drawn: {}, number of grounding tokens: {}. """.format(len(boxes), len(subject_token_indices))) filter_token_indices = convert_token_indices(filter_token_indices) if len(filter_token_indices.strip()) > 0 else None num_tokens = int(num_tokens) if len(num_tokens.strip()) > 0 else None prompts = [prompt.strip('.').strip(',').strip()] * batch_size images = inference( device, model, boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale, num_iterations, loss_threshold, num_guidance_steps, seed) return images def convert_token_indices(token_indices, nested=False): if nested: return [convert_token_indices(indices, nested=False) for indices in token_indices.split(';')] return [int(index.strip()) for index in token_indices.split(',') if len(index.strip()) > 0] def draw(sketchpad): boxes = [] for i, layer in enumerate(sketchpad['layers']): mask = (layer != 0) if mask.sum() < 0: raise gr.Error(f'Box in layer {i} is too small') x1x2 = np.where(mask.max(0) != 0)[0] / RESOLUTION y1y2 = np.where(mask.max(1) != 0)[0] / RESOLUTION y1, y2 = y1y2.min(), y1y2.max() x1, x2 = x1x2.min(), x1x2.max() if (x2 - x1 < MIN_SIZE) or (y2 - y1 < MIN_SIZE): raise gr.Error(f'Box in layer {i} is too small') boxes.append((x1, y1, x2, y2)) layout_image = draw_boxes(boxes) return [boxes, layout_image] def draw_boxes(boxes): if len(boxes) == 0: return None boxes = np.array(boxes) * RESOLUTION image = Image.new('RGB', (RESOLUTION, RESOLUTION), (WHITE, WHITE, WHITE)) drawing = ImageDraw.Draw(image) for i, box in enumerate(boxes.astype(int).tolist()): drawing.rectangle(box, outline=COLORS[i % len(COLORS)], width=4) return image def clear(batch_size): return [[], None, None, None] def main(): css = """ #paper-info a { color:#008AD7; text-decoration: none; } #paper-info a:hover { cursor: pointer; text-decoration: none; } .tooltip { color: #555; position: relative; display: inline-block; cursor: pointer; } .tooltip .tooltiptext { visibility: hidden; width: 400px; background-color: #555; color: #fff; text-align: center; padding: 5px; border-radius: 5px; position: absolute; z-index: 1; /* Set z-index to 1 */ left: 10px; top: 100%; opacity: 0; transition: opacity 0.3s; } .tooltip:hover .tooltiptext { visibility: visible; opacity: 1; z-index: 9999; /* Set a high z-index value when hovering */ } """ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model_path = "stabilityai/stable-diffusion-xl-base-1.0" scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) model = StableDiffusionXLPipeline.from_pretrained(model_path, scheduler=scheduler, torch_dtype=torch.float16).to(device) model.unet.set_default_attn_processor() model.enable_xformers_memory_efficient_attention() model.enable_sequential_cpu_offload() with gr.Blocks( css=css, title="Bounded Attention demo", ) as demo: description = """
Bounded Attention
[Project Page]
[Paper]
[GitHub]
The source code of this demo is based on the GLIGEN demo.
""" gr.HTML(description) demo.queue(max_size=50) demo.launch(show_api=False, show_error=True) if __name__ == '__main__': main()