omer11a commited on
Commit
78b6f81
1 Parent(s): 1e9f321

Load model directly to GPU

Browse files
Files changed (2) hide show
  1. app.py +7 -15
  2. injection_utils.py +25 -3
app.py CHANGED
@@ -11,8 +11,6 @@ from injection_utils import regiter_attention_editor_diffusers
11
  from bounded_attention import BoundedAttention
12
  from pytorch_lightning import seed_everything
13
 
14
- from functools import partial
15
-
16
  MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
17
  RESOLUTION = 256
18
  MIN_SIZE = 0.01
@@ -113,7 +111,6 @@ FOOTNOTE = """
113
 
114
 
115
  def inference(
116
- model,
117
  boxes,
118
  prompts,
119
  subject_token_indices,
@@ -134,7 +131,10 @@ def inference(
134
  raise gr.Error("cuda is not available")
135
 
136
  device = torch.device("cuda")
137
- model.to(device).half()
 
 
 
138
 
139
  seed_everything(seed)
140
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
@@ -159,15 +159,11 @@ def inference(
159
  )
160
 
161
  register_attention_editor_diffusers(model, editor)
162
- images = model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
163
- unregister_attention_editor_diffusers(model)
164
- model.double().to(torch.device("cpu"))
165
- return images
166
 
167
 
168
  @spaces.GPU(duration=300)
169
  def generate(
170
- model,
171
  prompt,
172
  subject_token_indices,
173
  filter_token_indices,
@@ -197,7 +193,7 @@ def generate(
197
  prompts = [prompt.strip(".").strip(",").strip()] * batch_size
198
 
199
  images = inference(
200
- model, boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
201
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
202
  num_iterations, loss_threshold, num_guidance_steps, seed)
203
 
@@ -253,10 +249,6 @@ def clear(batch_size):
253
 
254
  def main():
255
  nltk.download("averaged_perceptron_tagger")
256
- scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
257
- model = StableDiffusionXLPipeline.from_pretrained(MODEL_PATH, scheduler=scheduler)
258
- model.unet.set_default_attn_processor()
259
- model.enable_sequential_cpu_offload()
260
 
261
  with gr.Blocks(
262
  css=CSS,
@@ -328,7 +320,7 @@ def main():
328
  )
329
 
330
  generate_image_button.click(
331
- fn=partial(generate, model),
332
  inputs=[
333
  prompt, subject_token_indices, filter_token_indices, num_tokens,
334
  init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
 
11
  from bounded_attention import BoundedAttention
12
  from pytorch_lightning import seed_everything
13
 
 
 
14
  MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
15
  RESOLUTION = 256
16
  MIN_SIZE = 0.01
 
111
 
112
 
113
  def inference(
 
114
  boxes,
115
  prompts,
116
  subject_token_indices,
 
131
  raise gr.Error("cuda is not available")
132
 
133
  device = torch.device("cuda")
134
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
135
+ model = StableDiffusionXLPipeline.from_pretrained(MODEL_PATH, scheduler=scheduler, torch_dtype=torch.float16).to(device)
136
+ model.unet.set_default_attn_processor()
137
+ model.enable_sequential_cpu_offload()
138
 
139
  seed_everything(seed)
140
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
 
159
  )
160
 
161
  register_attention_editor_diffusers(model, editor)
162
+ return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
 
 
 
163
 
164
 
165
  @spaces.GPU(duration=300)
166
  def generate(
 
167
  prompt,
168
  subject_token_indices,
169
  filter_token_indices,
 
193
  prompts = [prompt.strip(".").strip(",").strip()] * batch_size
194
 
195
  images = inference(
196
+ boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
197
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
198
  num_iterations, loss_threshold, num_guidance_steps, seed)
199
 
 
249
 
250
  def main():
251
  nltk.download("averaged_perceptron_tagger")
 
 
 
 
252
 
253
  with gr.Blocks(
254
  css=CSS,
 
320
  )
321
 
322
  generate_image_button.click(
323
+ fn=generate,
324
  inputs=[
325
  prompt, subject_token_indices, filter_token_indices, num_tokens,
326
  init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
injection_utils.py CHANGED
@@ -53,7 +53,7 @@ class AttentionBase:
53
  self.cur_att_layer = 0
54
 
55
 
56
- def regiter_attention_editor_diffusers(model, editor: AttentionBase):
57
  """
58
  Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
59
  """
@@ -89,13 +89,14 @@ def regiter_attention_editor_diffusers(model, editor: AttentionBase):
89
 
90
  return forward
91
 
92
- def register_editor(net, count, place_in_unet, prefix=''):
93
  for name, subnet in net.named_children():
94
  if net.__class__.__name__ == 'Attention': # spatial Transformer layer
 
95
  net.forward = ca_forward(net, place_in_unet)
96
  return count + 1
97
  elif hasattr(net, 'children'):
98
- count = register_editor(subnet, count, place_in_unet, prefix=prefix + '\t')
99
  return count
100
 
101
  cross_att_count = 0
@@ -110,3 +111,24 @@ def regiter_attention_editor_diffusers(model, editor: AttentionBase):
110
  editor.num_att_layers = cross_att_count
111
  editor.model = model
112
  model.editor = editor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  self.cur_att_layer = 0
54
 
55
 
56
+ def register_attention_editor_diffusers(model, editor: AttentionBase):
57
  """
58
  Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
59
  """
 
89
 
90
  return forward
91
 
92
+ def register_editor(net, count, place_in_unet):
93
  for name, subnet in net.named_children():
94
  if net.__class__.__name__ == 'Attention': # spatial Transformer layer
95
+ net.original_forward = net.forward
96
  net.forward = ca_forward(net, place_in_unet)
97
  return count + 1
98
  elif hasattr(net, 'children'):
99
+ count = register_editor(subnet, count, place_in_unet)
100
  return count
101
 
102
  cross_att_count = 0
 
111
  editor.num_att_layers = cross_att_count
112
  editor.model = model
113
  model.editor = editor
114
+
115
+
116
+ def unregister_attention_editor_diffusers(model):
117
+ def unregister_editor(net):
118
+ for name, subnet in net.named_children():
119
+ if net.__class__.__name__ == 'Attention': # spatial Transformer layer
120
+ net.forward = net.original_forward
121
+ net.original_forward = None
122
+ elif hasattr(net, 'children'):
123
+ unregister_editor(subnet)
124
+
125
+ for net_name, net in model.unet.named_children():
126
+ if "down" in net_name:
127
+ unregister_editor(net)
128
+ elif "mid" in net_name:
129
+ unregister_editor(net)
130
+ elif "up" in net_name:
131
+ unregister_editor(net)
132
+
133
+ editor.model = None
134
+ model.editor = None