AlekseyCalvin commited on
Commit
b5ea7bd
1 Parent(s): c3f6e82

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +673 -0
pipeline.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import html
4
+ import inspect
5
+ import re
6
+ import urllib.parse as ul
7
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextModelWithProjection
8
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoPipelineForImage2Image, FluxPipeline, FluxTransformer2DModel
9
+ from diffusers import StableDiffusion3Pipeline, AutoencoderKL, DiffusionPipeline, ImagePipelineOutput
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, SD3LoraLoaderMixin
12
+ from diffusers.utils import (
13
+ USE_PEFT_BACKEND,
14
+ is_torch_xla_available,
15
+ logging,
16
+ BACKENDS_MAPPING,
17
+ deprecate,
18
+ replace_example_docstring,
19
+ scale_lora_layers,
20
+ unscale_lora_layers,
21
+ )
22
+ from diffusers.utils.torch_utils import randn_tensor
23
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
24
+ from typing import Any, Callable, Dict, List, Optional, Union
25
+ from PIL import Image
26
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxTransformer2DModel
27
+ from diffusers.utils import is_torch_xla_available
28
+
29
+ if is_torch_xla_available():
30
+ import torch_xla.core.xla_model as xm
31
+
32
+ XLA_AVAILABLE = True
33
+ else:
34
+ XLA_AVAILABLE = False
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+ # Constants for shift calculation
39
+ BASE_SEQ_LEN = 256
40
+ MAX_SEQ_LEN = 4096
41
+ BASE_SHIFT = 0.5
42
+ MAX_SHIFT = 1.16
43
+
44
+ # Helper functions
45
+ def calculate_timestep_shift(image_seq_len: int) -> float:
46
+ """Calculates the timestep shift (mu) based on the image sequence length."""
47
+ m = (MAX_SHIFT - BASE_SHIFT) / (MAX_SEQ_LEN - BASE_SEQ_LEN)
48
+ b = BASE_SHIFT - m * BASE_SEQ_LEN
49
+ mu = image_seq_len * m + b
50
+ return mu
51
+
52
+ def prepare_timesteps(
53
+ scheduler: FlowMatchEulerDiscreteScheduler,
54
+ num_inference_steps: Optional[int] = None,
55
+ device: Optional[Union[str, torch.device]] = None,
56
+ timesteps: Optional[List[int]] = None,
57
+ sigmas: Optional[List[float]] = None,
58
+ mu: Optional[float] = None,
59
+ ) -> (torch.Tensor, int):
60
+ """Prepares the timesteps for the diffusion process."""
61
+ if timesteps is not None and sigmas is not None:
62
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
63
+
64
+ if timesteps is not None:
65
+ scheduler.set_timesteps(timesteps=timesteps, device=device)
66
+ elif sigmas is not None:
67
+ scheduler.set_timesteps(sigmas=sigmas, device=device)
68
+ else:
69
+ scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
70
+
71
+ timesteps = scheduler.timesteps
72
+ num_inference_steps = len(timesteps)
73
+ return timesteps, num_inference_steps
74
+
75
+ # FLUX pipeline function
76
+ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
77
+ def __init__(
78
+ self,
79
+ scheduler: FlowMatchEulerDiscreteScheduler,
80
+ vae: AutoencoderKL,
81
+ text_encoder: CLIPTextModel,
82
+ tokenizer: CLIPTokenizer,
83
+ text_encoder_2: T5EncoderModel,
84
+ tokenizer_2: T5TokenizerFast,
85
+ transformer: FluxTransformer2DModel,
86
+ ):
87
+ super().__init__()
88
+
89
+ self.register_modules(
90
+ vae=vae,
91
+ text_encoder=text_encoder,
92
+ text_encoder_2=text_encoder_2,
93
+ tokenizer=tokenizer,
94
+ tokenizer_2=tokenizer_2,
95
+ transformer=transformer,
96
+ scheduler=scheduler,
97
+ )
98
+ self.vae_scale_factor = (
99
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
100
+ )
101
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
102
+ self.tokenizer_max_length = (
103
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
104
+ )
105
+ self.default_sample_size = 64
106
+
107
+ def _get_t5_prompt_embeds(
108
+ self,
109
+ prompt: Union[str, List[str]] = None,
110
+ num_images_per_prompt: int = 1,
111
+ max_sequence_length: int = 512,
112
+ device: Optional[torch.device] = None,
113
+ dtype: Optional[torch.dtype] = None,
114
+ ):
115
+ device = device or self._execution_device
116
+ dtype = dtype or self.text_encoder.dtype
117
+
118
+ prompt = [prompt] if isinstance(prompt, str) else prompt
119
+ batch_size = len(prompt)
120
+
121
+ text_inputs = self.tokenizer_2(
122
+ prompt,
123
+ padding="max_length",
124
+ max_length=max_sequence_length,
125
+ truncation=True,
126
+ return_length=True,
127
+ return_overflowing_tokens=True,
128
+ return_tensors="pt",
129
+ )
130
+ text_input_ids = text_inputs.input_ids
131
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
132
+
133
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
134
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
135
+ logger.warning(
136
+ "The following part of your input was truncated because `max_sequence_length` is set to "
137
+ f" {max_sequence_length} tokens: {removed_text}"
138
+ )
139
+
140
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
141
+
142
+ dtype = self.text_encoder_2.dtype
143
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
144
+
145
+ _, seq_len, _ = prompt_embeds.shape
146
+
147
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
148
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
149
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
150
+
151
+ return prompt_embeds
152
+
153
+ def _get_clip_prompt_embeds(
154
+ self,
155
+ prompt: Union[str, List[str]],
156
+ num_images_per_prompt: int = 1,
157
+ device: Optional[torch.device] = None,
158
+ ):
159
+ device = device or self._execution_device
160
+
161
+ prompt = [prompt] if isinstance(prompt, str) else prompt
162
+ batch_size = len(prompt)
163
+
164
+ text_inputs = self.tokenizer(
165
+ prompt,
166
+ padding="max_length",
167
+ max_length=self.tokenizer_max_length,
168
+ truncation=True,
169
+ return_overflowing_tokens=False,
170
+ return_length=False,
171
+ return_tensors="pt",
172
+ )
173
+
174
+ text_input_ids = text_inputs.input_ids
175
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
176
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
177
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
178
+ logger.warning(
179
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
180
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
181
+ )
182
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
183
+
184
+ # Use pooled output of CLIPTextModel
185
+ prompt_embeds = prompt_embeds.pooler_output
186
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
187
+
188
+ _, seq_len, _ = prompt_embeds.shape
189
+
190
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
191
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
192
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
193
+
194
+ return prompt_embeds
195
+
196
+ def encode_prompt(
197
+ self,
198
+ prompt: Union[str, List[str]],
199
+ prompt_2: Union[str, List[str]],
200
+ num_images_per_prompt: int = 1,
201
+ max_sequence_length: int = 512,
202
+ do_classifier_free_guidance: bool = True,
203
+ device: Optional[torch.device] = None,
204
+ negative_prompt: Optional[Union[str, List[str]]] = None,
205
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
206
+ prompt_embeds: Optional[torch.FloatTensor] = None,
207
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
208
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
209
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
210
+ lora_scale: Optional[float] = None,
211
+ ):
212
+ device = device or self._execution_device
213
+ if device is None:
214
+ device = self._execution_device
215
+
216
+ # set lora scale so that monkey patched LoRA
217
+ # function of text encoder can correctly access it
218
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
219
+ self._lora_scale = lora_scale
220
+
221
+ # dynamically adjust the LoRA scale
222
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
223
+ scale_lora_layers(self.text_encoder, lora_scale)
224
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
225
+ scale_lora_layers(self.text_encoder_2, lora_scale)
226
+
227
+ prompt = [prompt] if isinstance(prompt, str) else prompt
228
+ if prompt is not None:
229
+ batch_size = len(prompt)
230
+ else:
231
+ batch_size = prompt_embeds.shape[0]
232
+
233
+ if prompt_embeds is None:
234
+ prompt_2 = prompt_2 or prompt
235
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
236
+
237
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
238
+ prompt=prompt,
239
+ device=device,
240
+ num_images_per_prompt=num_images_per_prompt,
241
+ )
242
+ prompt_embeds = self._get_t5_prompt_embeds(
243
+ prompt=prompt_2,
244
+ num_images_per_prompt=num_images_per_prompt,
245
+ max_sequence_length=max_sequence_length,
246
+ device=device,
247
+ )
248
+
249
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
250
+ negative_prompt = negative_prompt or ""
251
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
252
+
253
+ # normalize str to list
254
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
255
+ negative_prompt_2 = (
256
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
257
+ )
258
+
259
+ if prompt is not None and type(prompt) is not type(negative_prompt):
260
+ raise TypeError(
261
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
262
+ f" {type(prompt)}."
263
+ )
264
+ elif batch_size != len(negative_prompt):
265
+ raise ValueError(
266
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
267
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
268
+ " the batch size of `prompt`."
269
+ )
270
+
271
+ negative_pooled_prompt_embeds = self._get_clip_prompt_embeds(
272
+ prompt=negative_prompt,
273
+ device=device,
274
+ num_images_per_prompt=num_images_per_prompt,
275
+ )
276
+
277
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
278
+ prompt=negative_prompt_2,
279
+ device=device,
280
+ num_images_per_prompt=num_images_per_prompt,
281
+ max_sequence_length=max_sequence_length,
282
+ )
283
+
284
+ if self.text_encoder is not None:
285
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
286
+ # Retrieve the original scale by scaling back the LoRA layers
287
+ unscale_lora_layers(self.text_encoder, lora_scale)
288
+
289
+ if self.text_encoder_2 is not None:
290
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
291
+ # Retrieve the original scale by scaling back the LoRA layers
292
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
293
+
294
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
295
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
296
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
297
+ negative_text_ids = torch.zeros(batch_size, negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
298
+
299
+ return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids
300
+
301
+ def check_inputs(
302
+ self,
303
+ prompt,
304
+ prompt_2,
305
+ height,
306
+ width,
307
+ prompt_embeds=None,
308
+ negative_prompt_embeds=None,
309
+ pooled_prompt_embeds=None,
310
+ negative_pooled_prompt_embeds=None,
311
+ max_sequence_length=None,
312
+ ):
313
+ if height % 8 != 0 or width % 8 != 0:
314
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
315
+
316
+ if prompt is not None and prompt_embeds is not None:
317
+ raise ValueError(
318
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
319
+ " only forward one of the two."
320
+ )
321
+ elif prompt_2 is not None and prompt_embeds is not None:
322
+ raise ValueError(
323
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
324
+ " only forward one of the two."
325
+ )
326
+ elif prompt is None and prompt_embeds is None:
327
+ raise ValueError(
328
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
329
+ )
330
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
331
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
332
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
333
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
334
+
335
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
336
+ raise ValueError(
337
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
338
+ )
339
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
340
+ raise ValueError("Must provide `negative_pooled_prompt_embeds` when specifying `negative_prompt_embeds`.")
341
+
342
+ if max_sequence_length is not None and max_sequence_length > 512:
343
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
344
+
345
+ @staticmethod
346
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
347
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
348
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
349
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
350
+
351
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
352
+
353
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
354
+ latent_image_ids = latent_image_ids.reshape(
355
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
356
+ )
357
+
358
+ return latent_image_ids.to(device=device, dtype=dtype)
359
+
360
+ @staticmethod
361
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
362
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
363
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
364
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
365
+
366
+ return latents
367
+
368
+ @staticmethod
369
+ def _unpack_latents(latents, height, width, vae_scale_factor):
370
+ batch_size, num_patches, channels = latents.shape
371
+
372
+ height = height // vae_scale_factor
373
+ width = width // vae_scale_factor
374
+
375
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
376
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
377
+
378
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
379
+
380
+ return latents
381
+
382
+ def prepare_latents(
383
+ self,
384
+ batch_size,
385
+ num_channels_latents,
386
+ height,
387
+ width,
388
+ dtype,
389
+ device,
390
+ generator,
391
+ latents=None,
392
+ ):
393
+ height = 2 * (int(height) // self.vae_scale_factor)
394
+ width = 2 * (int(width) // self.vae_scale_factor)
395
+
396
+ shape = (batch_size, num_channels_latents, height, width)
397
+
398
+ if latents is not None:
399
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
400
+ return latents.to(device=device, dtype=dtype), latent_image_ids
401
+
402
+ if isinstance(generator, list) and len(generator) != batch_size:
403
+ raise ValueError(
404
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
405
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
406
+ )
407
+
408
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
409
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
410
+
411
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
412
+
413
+ return latents, latent_image_ids
414
+
415
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
416
+ def prepare_extra_step_kwargs(self, generator, eta):
417
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
418
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
419
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
420
+ # and should be between [0, 1]
421
+
422
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
423
+ extra_step_kwargs = {}
424
+ if accepts_eta:
425
+ extra_step_kwargs["eta"] = eta
426
+
427
+ # check if the scheduler accepts generator
428
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
429
+ if accepts_generator:
430
+ extra_step_kwargs["generator"] = generator
431
+ return extra_step_kwargs
432
+
433
+ def enable_vae_slicing(self):
434
+ r"""
435
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
436
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
437
+ """
438
+ self.vae.enable_slicing()
439
+
440
+ def disable_vae_slicing(self):
441
+ r"""
442
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
443
+ computing decoding in one step.
444
+ """
445
+ self.vae.disable_slicing()
446
+
447
+ def enable_vae_tiling(self):
448
+ r"""
449
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
450
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
451
+ processing larger images.
452
+ """
453
+ self.vae.enable_tiling()
454
+
455
+ def disable_vae_tiling(self):
456
+ r"""
457
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
458
+ computing decoding in one step.
459
+ """
460
+ self.vae.disable_tiling()
461
+
462
+ @property
463
+ def guidance_scale(self):
464
+ return self._guidance_scale
465
+
466
+ @property
467
+ def do_classifier_free_guidance(self):
468
+ return self._guidance_scale > 1
469
+
470
+ @property
471
+ def joint_attention_kwargs(self):
472
+ return self._joint_attention_kwargs
473
+
474
+ @property
475
+ def num_timesteps(self):
476
+ return self._num_timesteps
477
+
478
+ @property
479
+ def interrupt(self):
480
+ return self._interrupt
481
+
482
+ @torch.no_grad()
483
+ @torch.inference_mode()
484
+ def __call__(
485
+ self,
486
+ prompt: Union[str, List[str]] = None,
487
+ prompt_2: Optional[Union[str, List[str]]] = None,
488
+ height: Optional[int] = None,
489
+ width: Optional[int] = None,
490
+ negative_prompt: Optional[Union[str, List[str]]] = None,
491
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
492
+ num_inference_steps: int = 8,
493
+ timesteps: List[int] = None,
494
+ eta: Optional[float] = 0.0,
495
+ guidance_scale: float = 3.5,
496
+ device: Optional[int] = None,
497
+ num_images_per_prompt: Optional[int] = 1,
498
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
499
+ latents: Optional[torch.FloatTensor] = None,
500
+ prompt_embeds: Optional[torch.FloatTensor] = None,
501
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
502
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
503
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
504
+ output_type: Optional[str] = "pil",
505
+ cfg: Optional[bool] = True,
506
+ return_dict: bool = True,
507
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
508
+ max_sequence_length: int = 512,
509
+ **kwargs,
510
+ ):
511
+ height = height or self.default_sample_size * self.vae_scale_factor
512
+ width = width or self.default_sample_size * self.vae_scale_factor
513
+
514
+ # 1. Check inputs
515
+ self.check_inputs(
516
+ prompt,
517
+ prompt_2,
518
+ height,
519
+ width,
520
+ prompt_embeds=prompt_embeds,
521
+ negative_prompt_embeds=negative_prompt_embeds,
522
+ pooled_prompt_embeds=pooled_prompt_embeds,
523
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
524
+ max_sequence_length=max_sequence_length,
525
+ )
526
+
527
+ self._guidance_scale = guidance_scale
528
+ self._joint_attention_kwargs = joint_attention_kwargs
529
+ self._interrupt = False
530
+
531
+ # 2. Define call parameters
532
+ if prompt is not None and isinstance(prompt, str):
533
+ batch_size = 1
534
+ elif prompt is not None and isinstance(prompt, list):
535
+ batch_size = len(prompt)
536
+ else:
537
+ batch_size = prompt_embeds.shape[0]
538
+
539
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
540
+
541
+ lora_scale = (
542
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
543
+ )
544
+ (
545
+ prompt_embeds,
546
+ pooled_prompt_embeds,
547
+ text_ids,
548
+ negative_prompt_embeds,
549
+ negative_pooled_prompt_embeds,
550
+ negative_text_ids,
551
+ ) = self.encode_prompt(
552
+ prompt=prompt,
553
+ prompt_2=prompt_2,
554
+ num_images_per_prompt=num_images_per_prompt,
555
+ max_sequence_length=max_sequence_length,
556
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
557
+ device=device,
558
+ negative_prompt=negative_prompt,
559
+ negative_prompt_2=negative_prompt_2,
560
+ prompt_embeds=prompt_embeds,
561
+ negative_prompt_embeds=negative_prompt_embeds,
562
+ pooled_prompt_embeds=pooled_prompt_embeds,
563
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
564
+ lora_scale=lora_scale,
565
+ )
566
+
567
+ if self.do_classifier_free_guidance:
568
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
569
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
570
+
571
+ # 4. Prepare latent variables
572
+ num_channels_latents = self.transformer.config.in_channels // 4
573
+ latents, latent_image_ids = self.prepare_latents(
574
+ batch_size * num_images_per_prompt,
575
+ num_channels_latents,
576
+ height,
577
+ width,
578
+ prompt_embeds.dtype,
579
+ device,
580
+ generator,
581
+ latents,
582
+ )
583
+
584
+ # 5. Prepare timesteps
585
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
586
+ image_seq_len = latents.shape[1]
587
+ mu = calculate_shift(
588
+ image_seq_len,
589
+ self.scheduler.config.base_image_seq_len,
590
+ self.scheduler.config.max_image_seq_len,
591
+ self.scheduler.config.base_shift,
592
+ self.scheduler.config.max_shift,
593
+ )
594
+ timesteps, num_inference_steps = prepare_timesteps(
595
+ self.scheduler,
596
+ num_inference_steps,
597
+ device,
598
+ timesteps,
599
+ sigmas,
600
+ mu=mu,
601
+ )
602
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
603
+ self._num_timesteps = len(timesteps)
604
+
605
+ # 6. Denoising loop
606
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
607
+ for i, t in enumerate(timesteps):
608
+ if self.interrupt:
609
+ continue
610
+
611
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
612
+
613
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
614
+
615
+ if self.transformer.config.guidance_embeds:
616
+ guidance = torch.tensor([guidance_scale], device=device)
617
+ guidance = guidance.expand(latents.shape[0])
618
+ else:
619
+ guidance = None
620
+
621
+ noise_pred_text = self.transformer(
622
+ hidden_states=latent_model_input,
623
+ timestep=timestep / 1000,
624
+ guidance=guidance,
625
+ pooled_projections=pooled_prompt_embeds,
626
+ encoder_hidden_states=prompt_embeds,
627
+ txt_ids=text_ids,
628
+ img_ids=latent_image_ids,
629
+ joint_attention_kwargs=self.joint_attention_kwargs,
630
+ return_dict=False,
631
+ )[0]
632
+ noise_pred_uncond = self.transformer(
633
+ hidden_states=latents,
634
+ timestep=timestep / 1000,
635
+ guidance=guidance,
636
+ pooled_projections=negative_pooled_prompt_embeds,
637
+ encoder_hidden_states=negative_prompt_embeds,
638
+ txt_ids=negative_text_ids,
639
+ img_ids=latent_image_ids,
640
+ joint_attention_kwargs=self.joint_attention_kwargs,
641
+ return_dict=False,
642
+ )[0]
643
+
644
+ if self.do_classifier_free_guidance:
645
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
646
+ noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
647
+ else: noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
648
+
649
+ # compute the previous noisy sample x_t -> x_t-1
650
+ latents_dtype = latents.dtype
651
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
652
+
653
+ if latents.dtype != latents_dtype:
654
+ if torch.backends.mps.is_available():
655
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
656
+ latents = latents.to(latents_dtype)
657
+
658
+ # call the callback, if provided
659
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
660
+ progress_bar.update()
661
+
662
+ # Final image
663
+ return self._decode_latents_to_image(latents, height, width, output_type)
664
+ self.maybe_free_model_hooks()
665
+ torch.cuda.empty_cache()
666
+
667
+ def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
668
+ """Decodes the given latents into an image."""
669
+ vae = vae or self.vae
670
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
671
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
672
+ image = vae.decode(latents, return_dict=False)[0]
673
+ return self.image_processor.postprocess(image, output_type=output_type)[0]