omer11a commited on
Commit
4b19f84
1 Parent(s): 137c79d

Move model to GPU only inside generate function

Browse files
Files changed (2) hide show
  1. app.py +11 -8
  2. pipeline_stable_diffusion_xl_opt.py +2 -0
app.py CHANGED
@@ -21,7 +21,6 @@ COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
21
 
22
 
23
  def inference(
24
- device,
25
  model,
26
  boxes,
27
  prompts,
@@ -39,6 +38,12 @@ def inference(
39
  num_guidance_steps,
40
  seed,
41
  ):
 
 
 
 
 
 
42
  seed_everything(seed)
43
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
44
  eos_token_index = num_tokens + 1
@@ -85,7 +90,6 @@ def inference(
85
 
86
  @spaces.GPU
87
  def generate(
88
- device,
89
  model,
90
  prompt,
91
  subject_token_indices,
@@ -107,8 +111,8 @@ def generate(
107
  subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
108
  if len(boxes) != len(subject_token_indices):
109
  raise gr.Error("""
110
- The number of boxes should be equal to the number of subject token indices.
111
- Number of boxes drawn: {}, number of grounding tokens: {}.
112
  """.format(len(boxes), len(subject_token_indices)))
113
 
114
  filter_token_indices = convert_token_indices(filter_token_indices) if len(filter_token_indices.strip()) > 0 else None
@@ -116,7 +120,7 @@ def generate(
116
  prompts = [prompt.strip('.').strip(',').strip()] * batch_size
117
 
118
  images = inference(
119
- device, model, boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
120
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
121
  num_iterations, loss_threshold, num_guidance_steps, seed)
122
 
@@ -210,10 +214,9 @@ def main():
210
  }
211
  """
212
 
213
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
214
  model_path = "stabilityai/stable-diffusion-xl-base-1.0"
215
  scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
216
- model = StableDiffusionXLPipeline.from_pretrained(model_path, scheduler=scheduler, device=device)
217
 
218
  nltk.download('averaged_perceptron_tagger')
219
 
@@ -325,7 +328,7 @@ def main():
325
  )
326
 
327
  generate_image_button.click(
328
- fn=partial(generate, device, model),
329
  inputs=[
330
  prompt, subject_token_indices, filter_token_indices, num_tokens,
331
  init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
 
21
 
22
 
23
  def inference(
 
24
  model,
25
  boxes,
26
  prompts,
 
38
  num_guidance_steps,
39
  seed,
40
  ):
41
+ if not torch.cuda.is_available():
42
+ raise gr.Error("cuda is not available")
43
+
44
+ device = torch.device("cuda")
45
+ model = model.to(device=device, dtype=torch.float16)
46
+
47
  seed_everything(seed)
48
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
49
  eos_token_index = num_tokens + 1
 
90
 
91
  @spaces.GPU
92
  def generate(
 
93
  model,
94
  prompt,
95
  subject_token_indices,
 
111
  subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
112
  if len(boxes) != len(subject_token_indices):
113
  raise gr.Error("""
114
+ The number of boxes should be equal to the number of subjects.
115
+ Number of boxes drawn: {}, number of subjects: {}.
116
  """.format(len(boxes), len(subject_token_indices)))
117
 
118
  filter_token_indices = convert_token_indices(filter_token_indices) if len(filter_token_indices.strip()) > 0 else None
 
120
  prompts = [prompt.strip('.').strip(',').strip()] * batch_size
121
 
122
  images = inference(
123
+ model, boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
124
  final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
125
  num_iterations, loss_threshold, num_guidance_steps, seed)
126
 
 
214
  }
215
  """
216
 
 
217
  model_path = "stabilityai/stable-diffusion-xl-base-1.0"
218
  scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
219
+ model = StableDiffusionXLPipeline.from_pretrained(model_path, scheduler=scheduler)
220
 
221
  nltk.download('averaged_perceptron_tagger')
222
 
 
328
  )
329
 
330
  generate_image_button.click(
331
+ fn=partial(generate, model),
332
  inputs=[
333
  prompt, subject_token_indices, filter_token_indices, num_tokens,
334
  init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
pipeline_stable_diffusion_xl_opt.py CHANGED
@@ -831,6 +831,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
831
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
832
  timesteps = timesteps[:num_inference_steps]
833
 
 
 
834
  with self.progress_bar(total=num_inference_steps) as progress_bar:
835
  for i, t in enumerate(timesteps):
836
  latents = self.update_loss(latents, i, t, prompt_embeds, cross_attention_kwargs, add_text_embeds, add_time_ids)
 
831
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
832
  timesteps = timesteps[:num_inference_steps]
833
 
834
+ latents = latents.half()
835
+ prompt_embeds = prompt_embeds.half()
836
  with self.progress_bar(total=num_inference_steps) as progress_bar:
837
  for i, t in enumerate(timesteps):
838
  latents = self.update_loss(latents, i, t, prompt_embeds, cross_attention_kwargs, add_text_embeds, add_time_ids)