Spaces:
Running
on
Zero
Running
on
Zero
Move model to GPU only inside generate function
Browse files- app.py +11 -8
- pipeline_stable_diffusion_xl_opt.py +2 -0
app.py
CHANGED
@@ -21,7 +21,6 @@ COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
|
|
21 |
|
22 |
|
23 |
def inference(
|
24 |
-
device,
|
25 |
model,
|
26 |
boxes,
|
27 |
prompts,
|
@@ -39,6 +38,12 @@ def inference(
|
|
39 |
num_guidance_steps,
|
40 |
seed,
|
41 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
seed_everything(seed)
|
43 |
start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
|
44 |
eos_token_index = num_tokens + 1
|
@@ -85,7 +90,6 @@ def inference(
|
|
85 |
|
86 |
@spaces.GPU
|
87 |
def generate(
|
88 |
-
device,
|
89 |
model,
|
90 |
prompt,
|
91 |
subject_token_indices,
|
@@ -107,8 +111,8 @@ def generate(
|
|
107 |
subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
|
108 |
if len(boxes) != len(subject_token_indices):
|
109 |
raise gr.Error("""
|
110 |
-
The number of boxes should be equal to the number of
|
111 |
-
Number of boxes drawn: {}, number of
|
112 |
""".format(len(boxes), len(subject_token_indices)))
|
113 |
|
114 |
filter_token_indices = convert_token_indices(filter_token_indices) if len(filter_token_indices.strip()) > 0 else None
|
@@ -116,7 +120,7 @@ def generate(
|
|
116 |
prompts = [prompt.strip('.').strip(',').strip()] * batch_size
|
117 |
|
118 |
images = inference(
|
119 |
-
|
120 |
final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
|
121 |
num_iterations, loss_threshold, num_guidance_steps, seed)
|
122 |
|
@@ -210,10 +214,9 @@ def main():
|
|
210 |
}
|
211 |
"""
|
212 |
|
213 |
-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
214 |
model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
215 |
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
216 |
-
model = StableDiffusionXLPipeline.from_pretrained(model_path, scheduler=scheduler
|
217 |
|
218 |
nltk.download('averaged_perceptron_tagger')
|
219 |
|
@@ -325,7 +328,7 @@ def main():
|
|
325 |
)
|
326 |
|
327 |
generate_image_button.click(
|
328 |
-
fn=partial(generate,
|
329 |
inputs=[
|
330 |
prompt, subject_token_indices, filter_token_indices, num_tokens,
|
331 |
init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
|
|
|
21 |
|
22 |
|
23 |
def inference(
|
|
|
24 |
model,
|
25 |
boxes,
|
26 |
prompts,
|
|
|
38 |
num_guidance_steps,
|
39 |
seed,
|
40 |
):
|
41 |
+
if not torch.cuda.is_available():
|
42 |
+
raise gr.Error("cuda is not available")
|
43 |
+
|
44 |
+
device = torch.device("cuda")
|
45 |
+
model = model.to(device=device, dtype=torch.float16)
|
46 |
+
|
47 |
seed_everything(seed)
|
48 |
start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
|
49 |
eos_token_index = num_tokens + 1
|
|
|
90 |
|
91 |
@spaces.GPU
|
92 |
def generate(
|
|
|
93 |
model,
|
94 |
prompt,
|
95 |
subject_token_indices,
|
|
|
111 |
subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
|
112 |
if len(boxes) != len(subject_token_indices):
|
113 |
raise gr.Error("""
|
114 |
+
The number of boxes should be equal to the number of subjects.
|
115 |
+
Number of boxes drawn: {}, number of subjects: {}.
|
116 |
""".format(len(boxes), len(subject_token_indices)))
|
117 |
|
118 |
filter_token_indices = convert_token_indices(filter_token_indices) if len(filter_token_indices.strip()) > 0 else None
|
|
|
120 |
prompts = [prompt.strip('.').strip(',').strip()] * batch_size
|
121 |
|
122 |
images = inference(
|
123 |
+
model, boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
|
124 |
final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
|
125 |
num_iterations, loss_threshold, num_guidance_steps, seed)
|
126 |
|
|
|
214 |
}
|
215 |
"""
|
216 |
|
|
|
217 |
model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
218 |
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
219 |
+
model = StableDiffusionXLPipeline.from_pretrained(model_path, scheduler=scheduler)
|
220 |
|
221 |
nltk.download('averaged_perceptron_tagger')
|
222 |
|
|
|
328 |
)
|
329 |
|
330 |
generate_image_button.click(
|
331 |
+
fn=partial(generate, model),
|
332 |
inputs=[
|
333 |
prompt, subject_token_indices, filter_token_indices, num_tokens,
|
334 |
init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
|
pipeline_stable_diffusion_xl_opt.py
CHANGED
@@ -831,6 +831,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
|
831 |
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
832 |
timesteps = timesteps[:num_inference_steps]
|
833 |
|
|
|
|
|
834 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
835 |
for i, t in enumerate(timesteps):
|
836 |
latents = self.update_loss(latents, i, t, prompt_embeds, cross_attention_kwargs, add_text_embeds, add_time_ids)
|
|
|
831 |
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
832 |
timesteps = timesteps[:num_inference_steps]
|
833 |
|
834 |
+
latents = latents.half()
|
835 |
+
prompt_embeds = prompt_embeds.half()
|
836 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
837 |
for i, t in enumerate(timesteps):
|
838 |
latents = self.update_loss(latents, i, t, prompt_embeds, cross_attention_kwargs, add_text_embeds, add_time_ids)
|