AP123 commited on
Commit
e8f9bdd
1 Parent(s): 0850a5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +269 -49
app.py CHANGED
@@ -15,55 +15,275 @@ pipe.to("cuda")
15
 
16
  @torch.no_grad()
17
  def call(
18
- pipe, prompt, prompt2, height, width, num_inference_steps, denoising_end,
19
- guidance_scale, guidance_scale2, negative_prompt, negative_prompt2,
20
- num_images_per_prompt, eta, generator, latents, prompt_embeds,
21
- negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds,
22
- output_type, return_dict, callback, callback_steps, cross_attention_kwargs,
23
- guidance_rescale, original_size, crops_coords_top_left, target_size,
24
- negative_original_size, negative_crops_coords_top_left, negative_target_size):
25
- height = height or pipe.default_sample_size * pipe.vae_scale_factor
26
- width = width or pipe.default_sample_size * pipe.vae_scale_factor
27
- original_size = original_size or (height, width)
28
- target_size = target_size or (height, width)
29
- pipe.check_inputs(prompt, None, height, width, callback_steps, negative_prompt, None, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds)
30
- batch_size = 1 if isinstance(prompt, str) else len(prompt) if isinstance(prompt, list) else prompt_embeds.shape[0]
31
- device = pipe._execution_device
32
- do_classifier_free_guidance = guidance_scale > 1.0
33
- text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs else None
34
- prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt)
35
- prompt2_embeds, negative_prompt2_embeds, pooled_prompt2_embeds, negative_pooled_prompt2_embeds = pipe.encode_prompt(prompt2, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt2)
36
- pipe.scheduler.set_timesteps(num_inference_steps, device=device)
37
- timesteps = pipe.scheduler.timesteps
38
- num_channels_latents = pipe.unet.config.in_channels
39
- latents = pipe.prepare_latents(batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents)
40
- extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta)
41
- add_text_embeds, add_text2_embeds = pooled_prompt_embeds, pooled_prompt2_embeds
42
- add_time_ids = pipe._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype)
43
- add_time2_ids = pipe._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt2_embeds.dtype)
44
- negative_add_time_ids = pipe._get_add_time_ids(negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype) if negative_original_size and negative_target_size else add_time_ids
45
- if do_classifier_free_guidance:
46
- prompt_embeds, add_text_embeds, add_time_ids = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0), torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0), torch.cat([negative_add_time_ids, add_time_ids], dim=0)
47
- prompt2_embeds, add_text2_embeds, add_time2_ids = torch.cat([negative_prompt2_embeds, prompt2_embeds], dim=0), torch.cat([negative_pooled_prompt2_embeds, add_text2_embeds], dim=0), torch.cat([negative_add_time_ids, add_time2_ids], dim=0)
48
- prompt_embeds, add_text_embeds, add_time_ids = prompt_embeds.to(device), add_text_embeds.to(device), add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
49
- prompt2_embeds, add_text2_embeds, add_time2_ids = prompt2_embeds.to(device), add_text2_embeds.to(device), add_time2_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
50
- num_warmup_steps = max(len(timesteps) - num_inference_steps * pipe.scheduler.order, 0)
51
- if denoising_end and isinstance(denoising_end, float) and 0 < denoising_end < 1:
52
- discrete_timestep_cutoff = int(round(pipe.scheduler.config.num_train_timesteps - (denoising_end * pipe.scheduler.config.num_train_timesteps)))
53
- num_inference_steps = len([ts for ts in timesteps if ts >= discrete_timestep_cutoff])
54
- timesteps = timesteps[:num_inference_steps]
55
- with pipe.progress_bar(total=num_inference_steps) as progress_bar:
56
- for i, t in enumerate(timesteps):
57
- if i % 2 == 0:
58
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
59
- latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
60
- noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs={"text_embeds": add_text_embeds, "time_ids": add_time_ids})[0]
61
- if do_classifier_free_guidance:
62
- noise_pred = noise_pred.chunk(2)[0] + guidance_scale * (noise_pred.chunk(2)[1] - noise_pred.chunk(2)[0])
63
- else:
64
- latent_model_input2 = torch.cat([latents.flip(2)] * 2) if do_classifier_free_guidance else latents
65
- latent_model_input2 = pipe.scheduler.scale_model_input(latent_model_input2, t)
66
- noise_pred2 = pipe.unet(latent_model_input2, t)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def simple_call(prompt1, prompt2, guidance_scale1, guidance_scale2, negative_prompt1, negative_prompt2):
69
  generator = [torch.Generator(device="cuda").manual_seed(5)]
 
15
 
16
  @torch.no_grad()
17
  def call(
18
+ pipe,
19
+ prompt: Union[str, List[str]] = None,
20
+ prompt2: Union[str, List[str]] = None,
21
+ height: Optional[int] = None,
22
+ width: Optional[int] = None,
23
+ num_inference_steps: int = 50,
24
+ denoising_end: Optional[float] = None,
25
+ guidance_scale: float = 5.0,
26
+ guidance_scale2: float = 5.0,
27
+ negative_prompt: Optional[Union[str, List[str]]] = None,
28
+ negative_prompt2: Optional[Union[str, List[str]]] = None,
29
+ num_images_per_prompt: Optional[int] = 1,
30
+ eta: float = 0.0,
31
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
32
+ latents: Optional[torch.FloatTensor] = None,
33
+ prompt_embeds: Optional[torch.FloatTensor] = None,
34
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
35
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
36
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
37
+ output_type: Optional[str] = "pil",
38
+ return_dict: bool = True,
39
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
40
+ callback_steps: int = 1,
41
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
42
+ guidance_rescale: float = 0.0,
43
+ original_size: Optional[Tuple[int, int]] = None,
44
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
45
+ target_size: Optional[Tuple[int, int]] = None,
46
+ negative_original_size: Optional[Tuple[int, int]] = None,
47
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
48
+ negative_target_size: Optional[Tuple[int, int]] = None,
49
+ ):
50
+ # 0. Default height and width to unet
51
+ height = height or pipe.default_sample_size * pipe.vae_scale_factor
52
+ width = width or pipe.default_sample_size * pipe.vae_scale_factor
53
+
54
+ original_size = original_size or (height, width)
55
+ target_size = target_size or (height, width)
56
+
57
+ # 1. Check inputs. Raise error if not correct
58
+ pipe.check_inputs(
59
+ prompt,
60
+ None,
61
+ height,
62
+ width,
63
+ callback_steps,
64
+ negative_prompt,
65
+ None,
66
+ prompt_embeds,
67
+ negative_prompt_embeds,
68
+ pooled_prompt_embeds,
69
+ negative_pooled_prompt_embeds,
70
+ )
71
+
72
+ # 2. Define call parameters
73
+ if prompt is not None and isinstance(prompt, str):
74
+ batch_size = 1
75
+ elif prompt is not None and isinstance(prompt, list):
76
+ batch_size = len(prompt)
77
+ else:
78
+ batch_size = prompt_embeds.shape[0]
79
+
80
+ device = pipe._execution_device
81
+
82
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
83
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
84
+ # corresponds to doing no classifier free guidance.
85
+ do_classifier_free_guidance = guidance_scale > 1.0
86
+
87
+ # 3. Encode input prompt
88
+ text_encoder_lora_scale = (
89
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
90
+ )
91
+
92
+ (
93
+ prompt_embeds,
94
+ negative_prompt_embeds,
95
+ pooled_prompt_embeds,
96
+ negative_pooled_prompt_embeds,
97
+ ) = pipe.encode_prompt(
98
+ prompt=prompt,
99
+ device=device,
100
+ num_images_per_prompt=num_images_per_prompt,
101
+ do_classifier_free_guidance=do_classifier_free_guidance,
102
+ negative_prompt=negative_prompt,
103
+ prompt_embeds=None,
104
+ negative_prompt_embeds=None,
105
+ pooled_prompt_embeds=None,
106
+ negative_pooled_prompt_embeds=None,
107
+ lora_scale=text_encoder_lora_scale,
108
+ )
109
+
110
+ (
111
+ prompt2_embeds,
112
+ negative_prompt2_embeds,
113
+ pooled_prompt2_embeds,
114
+ negative_pooled_prompt2_embeds,
115
+ ) = pipe.encode_prompt(
116
+ prompt=prompt2,
117
+ device=device,
118
+ num_images_per_prompt=num_images_per_prompt,
119
+ do_classifier_free_guidance=do_classifier_free_guidance,
120
+ negative_prompt=negative_prompt2,
121
+ prompt_embeds=None,
122
+ negative_prompt_embeds=None,
123
+ pooled_prompt_embeds=None,
124
+ negative_pooled_prompt_embeds=None,
125
+ lora_scale=text_encoder_lora_scale,
126
+ )
127
+
128
+ # 4. Prepare timesteps
129
+ pipe.scheduler.set_timesteps(num_inference_steps, device=device)
130
+
131
+ timesteps = pipe.scheduler.timesteps
132
+
133
+ # 5. Prepare latent variables
134
+ num_channels_latents = pipe.unet.config.in_channels
135
+ latents = pipe.prepare_latents(
136
+ batch_size * num_images_per_prompt,
137
+ num_channels_latents,
138
+ height,
139
+ width,
140
+ prompt_embeds.dtype,
141
+ device,
142
+ generator,
143
+ latents,
144
+ )
145
+
146
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
147
+ extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta)
148
+
149
+ # 7. Prepare added time ids & embeddings
150
+ add_text_embeds = pooled_prompt_embeds
151
+ add_text2_embeds = pooled_prompt2_embeds
152
+
153
+ add_time_ids = pipe._get_add_time_ids(
154
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
155
+ )
156
+ add_time2_ids = pipe._get_add_time_ids(
157
+ original_size, crops_coords_top_left, target_size, dtype=prompt2_embeds.dtype
158
+ )
159
+
160
+ if negative_original_size is not None and negative_target_size is not None:
161
+ negative_add_time_ids = pipe._get_add_time_ids(
162
+ negative_original_size,
163
+ negative_crops_coords_top_left,
164
+ negative_target_size,
165
+ dtype=prompt_embeds.dtype,
166
+ )
167
+ else:
168
+ negative_add_time_ids = add_time_ids
169
+ negative_add_time2_ids = add_time2_ids
170
+
171
+ if do_classifier_free_guidance:
172
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
173
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
174
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
175
+
176
+ prompt2_embeds = torch.cat([negative_prompt2_embeds, prompt2_embeds], dim=0)
177
+ add_text2_embeds = torch.cat([negative_pooled_prompt2_embeds, add_text2_embeds], dim=0)
178
+ add_time2_ids = torch.cat([negative_add_time2_ids, add_time2_ids], dim=0)
179
+
180
+ prompt_embeds = prompt_embeds.to(device)
181
+ add_text_embeds = add_text_embeds.to(device)
182
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
183
+
184
+ prompt2_embeds = prompt2_embeds.to(device)
185
+ add_text2_embeds = add_text2_embeds.to(device)
186
+ add_time2_ids = add_time2_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
187
+
188
+ # 8. Denoising loop
189
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * pipe.scheduler.order, 0)
190
+
191
+ # 7.1 Apply denoising_end
192
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
193
+ discrete_timestep_cutoff = int(
194
+ round(
195
+ pipe.scheduler.config.num_train_timesteps
196
+ - (denoising_end * pipe.scheduler.config.num_train_timesteps)
197
+ )
198
+ )
199
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
200
+ timesteps = timesteps[:num_inference_steps]
201
+
202
+ with pipe.progress_bar(total=num_inference_steps) as progress_bar:
203
+ for i, t in enumerate(timesteps):
204
+ if i % 2 == 0:
205
+ # expand the latents if we are doing classifier free guidance
206
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
207
+
208
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
209
+
210
+ # predict the noise residual
211
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
212
+ noise_pred = pipe.unet(
213
+ latent_model_input,
214
+ t,
215
+ encoder_hidden_states=prompt_embeds,
216
+ cross_attention_kwargs=cross_attention_kwargs,
217
+ added_cond_kwargs=added_cond_kwargs,
218
+ return_dict=False,
219
+ )[0]
220
+
221
+ # perform guidance
222
+ if do_classifier_free_guidance:
223
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
224
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
225
+ else:
226
+ # expand the latents if we are doing classifier free guidance
227
+ latent_model_input2 = torch.cat([latents.flip(2)] * 2) if do_classifier_free_guidance else latents
228
+ latent_model_input2 = pipe.scheduler.scale_model_input(latent_model_input2, t)
229
+
230
+ # predict the noise residual
231
+ added_cond2_kwargs = {"text_embeds": add_text2_embeds, "time_ids": add_time2_ids}
232
+ noise_pred2 = pipe.unet(
233
+ latent_model_input2,
234
+ t,
235
+ encoder_hidden_states=prompt2_embeds,
236
+ cross_attention_kwargs=cross_attention_kwargs,
237
+ added_cond_kwargs=added_cond2_kwargs,
238
+ return_dict=False,
239
+ )[0]
240
+
241
+ # perform guidance
242
+ if do_classifier_free_guidance:
243
+ noise_pred2_uncond, noise_pred2_text = noise_pred2.chunk(2)
244
+ noise_pred2 = noise_pred2_uncond + guidance_scale2 * (noise_pred2_text - noise_pred2_uncond)
245
+
246
+ noise_pred = noise_pred if i % 2 == 0 else noise_pred2.flip(2)
247
+
248
+ # compute the previous noisy sample x_t -> x_t-1
249
+ latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
250
+
251
+ # call the callback, if provided
252
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):
253
+ progress_bar.update()
254
+ if callback is not None and i % callback_steps == 0:
255
+ callback(i, t, latents)
256
+
257
+ if not output_type == "latent":
258
+ # make sure the VAE is in float32 mode, as it overflows in float16
259
+ needs_upcasting = pipe.vae.dtype == torch.float16 and pipe.vae.config.force_upcast
260
+
261
+ if needs_upcasting:
262
+ pipe.upcast_vae()
263
+ latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype)
264
+
265
+ image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
266
+
267
+ # cast back to fp16 if needed
268
+ if needs_upcasting:
269
+ pipe.vae.to(dtype=torch.float16)
270
+ else:
271
+ image = latents
272
+
273
+ if not output_type == "latent":
274
+ # apply watermark if available
275
+ if pipe.watermark is not None:
276
+ image = pipe.watermark.apply_watermark(image)
277
+
278
+ image = pipe.image_processor.postprocess(image, output_type=output_type)
279
+
280
+ # Offload all models
281
+ pipe.maybe_free_model_hooks()
282
+
283
+ if not return_dict:
284
+ return (image,)
285
+
286
+ return StableDiffusionXLPipelineOutput(images=image)
287
 
288
  def simple_call(prompt1, prompt2, guidance_scale1, guidance_scale2, negative_prompt1, negative_prompt2):
289
  generator = [torch.Generator(device="cuda").manual_seed(5)]