# This demo needs to be run from the repo folder. # python demo/fake_gan/run.py import os import random import gradio as gr import itertools from PIL import Image, ImageFont, ImageDraw import DirectedDiffusion # prompt # boundingbox # prompt indices for region # number of trailing attention # number of DD steps # gaussian coefficient # seed EXAMPLES = [ [ "A painting of a tiger, on the wall in the living room", "0.2,0.6,0.0,0.5", "1,5", 5, 15, 1.0, 2094889, ], [ "a dog diving into a pool in sunny day", "0.0,0.5,0.0,0.5", "1,2", 10, 20, 5.0, 2483964026826, ], [ "A red cube above a blue sphere", "0.4,0.7,0.0,0.5 0.4,0.7,0.5,1.0", "2,3 6,7", 10, 20, 1.0, 1213698, ], [ "The sun shining on a house", "0.0,0.5,0.0,0.5", "1,2", 10, 20, 1.0, 2483964026826, ], [ "a diver swimming through a school of fish", "0.5,1.0,0.0,0.5", "1,2", 10, 10, 5.0, 2483964026826, ], [ "A stone castle surrounded by lakes and trees", "0.3,0.7,0.0,1.0", "1,2,3", 10, 5, 1.0, 2483964026826, ], [ "A dog hiding behind the chair", "0.5,0.9,0.0,1.0", "1,2", 10, 5, 2.5, 248396402123, ], [ "A dog sitting next to a mirror", "0.0,0.5,0.0,1.0 0.5,1.0,0.0,1.0", "1,2 6,7", 20, 5, 1.0, 24839640268232521, ], ] model_bundle = DirectedDiffusion.AttnEditorUtils.load_all_models( model_path_diffusion="CompVis/stable-diffusion-v1-4" ) # model_bundle = DirectedDiffusion.AttnEditorUtils.load_all_models( # model_path_diffusion="../DirectedDiffusion/assets/models/stable-diffusion-v1-4" # ) ALL_OUTPUT = {} def directed_diffusion( in_prompt, in_bb, in_token_ids, in_slider_trailings, in_slider_ddsteps, in_slider_gcoef, in_seed, is_draw_bbox, ): str_arg_to_val = lambda arg, f: [ [f(b) for b in a.split(",")] for a in arg.split(" ") ] roi = str_arg_to_val(in_bb, float) attn_editor_bundle = { "edit_index": str_arg_to_val(in_token_ids, int), "roi": roi, "num_trailing_attn": [in_slider_trailings] * len(roi), "num_affected_steps": in_slider_ddsteps, "noise_scale": [in_slider_gcoef] * len(roi), } img = DirectedDiffusion.Diffusion.stablediffusion( model_bundle, attn_editor_bundle=attn_editor_bundle, guidance_scale=7.5, prompt=in_prompt, steps=50, seed=in_seed, is_save_attn=False, is_save_recons=False, ) if is_draw_bbox and in_slider_ddsteps > 0: for r in roi: x0, y0, x1, y1 = [int(r_ * 512) for r_ in r] image_editable = ImageDraw.Draw(img) image_editable.rectangle( xy=[x0, x1, y0, y1], outline=(255, 0, 0, 255), width=5 ) return img def run_it( in_prompt, in_bb, in_token_ids, in_slider_trailings, in_slider_ddsteps, in_slider_gcoef, in_seed, is_draw_bbox, is_grid_search, progress=gr.Progress(), ): global ALL_OUTPUT num_affected_steps = [in_slider_ddsteps] noise_scale = [in_slider_gcoef] num_trailing_attn = [in_slider_trailings] if is_grid_search: num_affected_steps = [5, 10] noise_scale = [1.0, 1.5, 2.5] num_trailing_attn = [10, 20, 30, 40] param_list = [num_affected_steps, noise_scale, num_trailing_attn] param_list = list(itertools.product(*param_list)) results = [] progress(0, desc="Starting...") for i, element in enumerate(progress.tqdm(param_list)): print("=========== Arguments ============") print("Prompt:", in_prompt) print("BoundingBox:", in_bb) print("Token indices:", in_token_ids) print("Num Trialings:", element[2]) print("Num DD steps:", element[0]) print("Gaussian coef:", element[1]) print("Seed:", in_seed) print("===================================") img = directed_diffusion( in_prompt=in_prompt, in_bb=in_bb, in_token_ids=in_token_ids, in_slider_trailings=element[2], in_slider_ddsteps=element[0], in_slider_gcoef=element[1], in_seed=in_seed, is_draw_bbox=is_draw_bbox, ) results.append( ( img, "#Trailing:{},#DDSteps:{},GaussianCoef:{}".format( element[2], element[0], element[1] ), ) ) return results with gr.Blocks() as demo: gr.Markdown( """ ### Directed Diffusion: Direct Control of Object Placement through Attention Guidance **\*Wan-Duo Kurt Ma, \^J. P. Lewis, \^\*W. Bastiaan Kleijn, \^Thomas Leung** *\*Victoria University of Wellington, \^Google Research* Let's pin the object in the prompt as you wish! For more information, please checkout our project page ([link](https://hohonu-vicml.github.io/DirectedDiffusion.Page/)), repository ([link](https://github.com/hohonu-vicml/DirectedDiffusion)), and the paper ([link](https://arxiv.org/abs/2302.13153)) """ ) with gr.Row(variant="panel"): with gr.Column(variant="compact"): in_prompt = gr.Textbox( label="Enter your prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", ).style( container=False, ) with gr.Row(variant="compact"): in_bb = gr.Textbox( label="Bounding box", show_label=True, max_lines=1, placeholder="e.g., 0.1,0.5,0.3,0.6", ) in_token_ids = gr.Textbox( label="Token indices", show_label=True, max_lines=1, placeholder="e.g., 1,2,3", ) in_seed = gr.Number( value=2483964026821236, label="Random seed", interactive=True ) with gr.Row(variant="compact"): is_grid_search = gr.Checkbox( value=False, label="Grid search? (If checked then sliders are ignored)", ) is_draw_bbox = gr.Checkbox( value=True, label="To draw the bounding box?", ) with gr.Row(variant="compact"): in_slider_trailings = gr.Slider( minimum=0, maximum=30, value=10, step=1, label="#trailings" ) in_slider_ddsteps = gr.Slider( minimum=0, maximum=30, value=10, step=1, label="#DDSteps" ) in_slider_gcoef = gr.Slider( minimum=0, maximum=10, value=1.0, step=0.1, label="GaussianCoef" ) with gr.Row(variant="compact"): btn_run = gr.Button("Generate image").style(full_width=True) #btn_clean = gr.Button("Clean Gallery").style(full_width=True) gr.Markdown( """ Note: 1) Please click one of the examples below for quick setup. 2) if #DDsteps==0, it means the SD process runs without DD. 3) The bounding box is the tuple of four scalars representing the fractional boundary of an image: left,right,top,bottom 4) The token indices are the word positions in the prompt associated with the edited region, 1-indexed. """ ) with gr.Column(variant="compact"): gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(grid=[2], height="auto") args = [ in_prompt, in_bb, in_token_ids, in_slider_trailings, in_slider_ddsteps, in_slider_gcoef, in_seed, is_draw_bbox, is_grid_search, ] btn_run.click(run_it, inputs=args, outputs=gallery) #btn_clean.click(clean_gallery, outputs=gallery) examples = gr.Examples( examples=EXAMPLES, inputs=args, ) if __name__ == "__main__": demo.queue().launch()