Spaces:
Build error
Build error
File size: 5,797 Bytes
5c4b5eb 6b566c6 5c4b5eb 26cfd4b 5c4b5eb 26cfd4b 5c4b5eb aeccac3 5c4b5eb 26cfd4b cedad44 26cfd4b cedad44 26cfd4b cedad44 26cfd4b cedad44 542ea3b 26cfd4b cedad44 26cfd4b cedad44 542ea3b cedad44 26cfd4b cedad44 aeccac3 5c4b5eb e07cbb0 8bfad22 5c4b5eb cedad44 26cfd4b 5c4b5eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import time
import spaces
import gradio as gr
import torch
import diffusers
from utils import patch_attention_proc, remove_patch
import math
import numpy as np
from PIL import Image
from threading import Semaphore
# Globals
css = """
h1 {
text-align: center;
display: block;
}
"""
# Pipeline
pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to("cuda", torch.float16)
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.safety_checker = None
semaphore = Semaphore() # for preventing collisions of two simultaneous button presses
@spaces.GPU
def generate_baseline(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
semaphore.acquire()
downsample_factor = 2
ratio = 0.38
merge_method = "downsample" if method == "todo" else "similarity"
merge_tokens = "keys/values" if method == "todo" else "all"
if height_width == 1024:
downsample_factor = 2
ratio = 0.75
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif height_width == 1536:
downsample_factor = 3
ratio = 0.89
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif height_width == 2048:
downsample_factor = 4
ratio = 0.9375
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
token_merge_args = {"ratio": ratio,
"merge_tokens": merge_tokens,
"merge_method": merge_method,
"downsample_method": "nearest",
"downsample_factor": downsample_factor,
"timestep_threshold_switch": 0.0,
"timestep_threshold_stop": 0.0,
"downsample_factor_level_2": downsample_factor_level_2,
"ratio_level_2": ratio_level_2
}
torch.manual_seed(seed)
start_time_base = time.time()
remove_patch(pipe)
base_img = pipe(prompt,
num_inference_steps=steps, height=height_width, width=height_width,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale).images[0]
end_time_base = time.time()
result = f"Baseline Runtime: {end_time_base-start_time_base:.2f} sec"
semaphore.release()
return base_img, result
@spaces.GPU
def generate_merged(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
semaphore.acquire()
downsample_factor = 2
ratio = 0.38
merge_method = "downsample" if method == "todo" else "similarity"
merge_tokens = "keys/values" if method == "todo" else "all"
if height_width == 1024:
downsample_factor = 2
ratio = 0.75
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif height_width == 1536:
downsample_factor = 3
ratio = 0.89
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif height_width == 2048:
downsample_factor = 4
ratio = 0.9375
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
token_merge_args = {"ratio": ratio,
"merge_tokens": merge_tokens,
"merge_method": merge_method,
"downsample_method": "nearest",
"downsample_factor": downsample_factor,
"timestep_threshold_switch": 0.0,
"timestep_threshold_stop": 0.0,
"downsample_factor_level_2": downsample_factor_level_2,
"ratio_level_2": ratio_level_2
}
patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
torch.manual_seed(seed)
start_time_merge = time.time()
merged_img = pipe(prompt,
num_inference_steps=steps, height=height_width, width=height_width,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale).images[0]
end_time_merge = time.time()
result = f"{'ToDo' if method == 'todo' else 'ToMe'} Runtime: {end_time_merge-start_time_merge:.2f} sec"
semaphore.release()
return merged_img, result
with gr.Blocks(css=css) as demo:
gr.Markdown("# ToDo: Token Downsampling for Efficient Generation of High-Resolution Images")
prompt = gr.Textbox(interactive=True, label="prompt")
negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
with gr.Row():
method = gr.Dropdown(["todo", "tome"], value="todo", label="method", info="Choose Your Desired Method (Default: todo)")
height_width = gr.Dropdown([1024, 1536, 2048], value=1024, label="height/width", info="Choose Your Desired Height/Width (Default: 1024)")
with gr.Row():
guidance_scale = gr.Number(label="guidance_scale", value=7.5, precision=1)
steps = gr.Number(label="steps", value=20, precision=0)
seed = gr.Number(label="seed", value=1, precision=0)
with gr.Row():
with gr.Column():
base_result = gr.Textbox(label="Baseline Runtime")
base_image = gr.Image(label=f"baseline_image", type="pil", interactive=False)
gen = gr.Button("Generate Baseline")
gen.click(generate_baseline, inputs=[prompt, seed, steps, height_width, negative_prompt,
guidance_scale, method], outputs=[base_image, base_result])
with gr.Column():
output_result = gr.Textbox(label="Runtime")
output_image = gr.Image(label=f"image", type="pil", interactive=False)
gen = gr.Button("Generate")
gen.click(generate_merged, inputs=[prompt, seed, steps, height_width, negative_prompt,
guidance_scale, method], outputs=[output_image, output_result])
demo.launch(share=True) |