Spaces:
Running
on
Zero
Running
on
Zero
Load model directly to GPU
Browse files- app.py +7 -15
- injection_utils.py +25 -3
app.py
CHANGED
@@ -11,8 +11,6 @@ from injection_utils import regiter_attention_editor_diffusers
|
|
11 |
from bounded_attention import BoundedAttention
|
12 |
from pytorch_lightning import seed_everything
|
13 |
|
14 |
-
from functools import partial
|
15 |
-
|
16 |
MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
|
17 |
RESOLUTION = 256
|
18 |
MIN_SIZE = 0.01
|
@@ -113,7 +111,6 @@ FOOTNOTE = """
|
|
113 |
|
114 |
|
115 |
def inference(
|
116 |
-
model,
|
117 |
boxes,
|
118 |
prompts,
|
119 |
subject_token_indices,
|
@@ -134,7 +131,10 @@ def inference(
|
|
134 |
raise gr.Error("cuda is not available")
|
135 |
|
136 |
device = torch.device("cuda")
|
137 |
-
|
|
|
|
|
|
|
138 |
|
139 |
seed_everything(seed)
|
140 |
start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
|
@@ -159,15 +159,11 @@ def inference(
|
|
159 |
)
|
160 |
|
161 |
register_attention_editor_diffusers(model, editor)
|
162 |
-
|
163 |
-
unregister_attention_editor_diffusers(model)
|
164 |
-
model.double().to(torch.device("cpu"))
|
165 |
-
return images
|
166 |
|
167 |
|
168 |
@spaces.GPU(duration=300)
|
169 |
def generate(
|
170 |
-
model,
|
171 |
prompt,
|
172 |
subject_token_indices,
|
173 |
filter_token_indices,
|
@@ -197,7 +193,7 @@ def generate(
|
|
197 |
prompts = [prompt.strip(".").strip(",").strip()] * batch_size
|
198 |
|
199 |
images = inference(
|
200 |
-
|
201 |
final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
|
202 |
num_iterations, loss_threshold, num_guidance_steps, seed)
|
203 |
|
@@ -253,10 +249,6 @@ def clear(batch_size):
|
|
253 |
|
254 |
def main():
|
255 |
nltk.download("averaged_perceptron_tagger")
|
256 |
-
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
257 |
-
model = StableDiffusionXLPipeline.from_pretrained(MODEL_PATH, scheduler=scheduler)
|
258 |
-
model.unet.set_default_attn_processor()
|
259 |
-
model.enable_sequential_cpu_offload()
|
260 |
|
261 |
with gr.Blocks(
|
262 |
css=CSS,
|
@@ -328,7 +320,7 @@ def main():
|
|
328 |
)
|
329 |
|
330 |
generate_image_button.click(
|
331 |
-
fn=
|
332 |
inputs=[
|
333 |
prompt, subject_token_indices, filter_token_indices, num_tokens,
|
334 |
init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
|
|
|
11 |
from bounded_attention import BoundedAttention
|
12 |
from pytorch_lightning import seed_everything
|
13 |
|
|
|
|
|
14 |
MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
|
15 |
RESOLUTION = 256
|
16 |
MIN_SIZE = 0.01
|
|
|
111 |
|
112 |
|
113 |
def inference(
|
|
|
114 |
boxes,
|
115 |
prompts,
|
116 |
subject_token_indices,
|
|
|
131 |
raise gr.Error("cuda is not available")
|
132 |
|
133 |
device = torch.device("cuda")
|
134 |
+
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
135 |
+
model = StableDiffusionXLPipeline.from_pretrained(MODEL_PATH, scheduler=scheduler, torch_dtype=torch.float16).to(device)
|
136 |
+
model.unet.set_default_attn_processor()
|
137 |
+
model.enable_sequential_cpu_offload()
|
138 |
|
139 |
seed_everything(seed)
|
140 |
start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
|
|
|
159 |
)
|
160 |
|
161 |
register_attention_editor_diffusers(model, editor)
|
162 |
+
return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
|
|
|
|
|
|
|
163 |
|
164 |
|
165 |
@spaces.GPU(duration=300)
|
166 |
def generate(
|
|
|
167 |
prompt,
|
168 |
subject_token_indices,
|
169 |
filter_token_indices,
|
|
|
193 |
prompts = [prompt.strip(".").strip(",").strip()] * batch_size
|
194 |
|
195 |
images = inference(
|
196 |
+
boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
|
197 |
final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
|
198 |
num_iterations, loss_threshold, num_guidance_steps, seed)
|
199 |
|
|
|
249 |
|
250 |
def main():
|
251 |
nltk.download("averaged_perceptron_tagger")
|
|
|
|
|
|
|
|
|
252 |
|
253 |
with gr.Blocks(
|
254 |
css=CSS,
|
|
|
320 |
)
|
321 |
|
322 |
generate_image_button.click(
|
323 |
+
fn=generate,
|
324 |
inputs=[
|
325 |
prompt, subject_token_indices, filter_token_indices, num_tokens,
|
326 |
init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
|
injection_utils.py
CHANGED
@@ -53,7 +53,7 @@ class AttentionBase:
|
|
53 |
self.cur_att_layer = 0
|
54 |
|
55 |
|
56 |
-
def
|
57 |
"""
|
58 |
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
|
59 |
"""
|
@@ -89,13 +89,14 @@ def regiter_attention_editor_diffusers(model, editor: AttentionBase):
|
|
89 |
|
90 |
return forward
|
91 |
|
92 |
-
def register_editor(net, count, place_in_unet
|
93 |
for name, subnet in net.named_children():
|
94 |
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
|
|
|
95 |
net.forward = ca_forward(net, place_in_unet)
|
96 |
return count + 1
|
97 |
elif hasattr(net, 'children'):
|
98 |
-
count = register_editor(subnet, count, place_in_unet
|
99 |
return count
|
100 |
|
101 |
cross_att_count = 0
|
@@ -110,3 +111,24 @@ def regiter_attention_editor_diffusers(model, editor: AttentionBase):
|
|
110 |
editor.num_att_layers = cross_att_count
|
111 |
editor.model = model
|
112 |
model.editor = editor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
self.cur_att_layer = 0
|
54 |
|
55 |
|
56 |
+
def register_attention_editor_diffusers(model, editor: AttentionBase):
|
57 |
"""
|
58 |
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
|
59 |
"""
|
|
|
89 |
|
90 |
return forward
|
91 |
|
92 |
+
def register_editor(net, count, place_in_unet):
|
93 |
for name, subnet in net.named_children():
|
94 |
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
|
95 |
+
net.original_forward = net.forward
|
96 |
net.forward = ca_forward(net, place_in_unet)
|
97 |
return count + 1
|
98 |
elif hasattr(net, 'children'):
|
99 |
+
count = register_editor(subnet, count, place_in_unet)
|
100 |
return count
|
101 |
|
102 |
cross_att_count = 0
|
|
|
111 |
editor.num_att_layers = cross_att_count
|
112 |
editor.model = model
|
113 |
model.editor = editor
|
114 |
+
|
115 |
+
|
116 |
+
def unregister_attention_editor_diffusers(model):
|
117 |
+
def unregister_editor(net):
|
118 |
+
for name, subnet in net.named_children():
|
119 |
+
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
|
120 |
+
net.forward = net.original_forward
|
121 |
+
net.original_forward = None
|
122 |
+
elif hasattr(net, 'children'):
|
123 |
+
unregister_editor(subnet)
|
124 |
+
|
125 |
+
for net_name, net in model.unet.named_children():
|
126 |
+
if "down" in net_name:
|
127 |
+
unregister_editor(net)
|
128 |
+
elif "mid" in net_name:
|
129 |
+
unregister_editor(net)
|
130 |
+
elif "up" in net_name:
|
131 |
+
unregister_editor(net)
|
132 |
+
|
133 |
+
editor.model = None
|
134 |
+
model.editor = None
|