AlekseyCalvin commited on
Commit
8d1dcb4
1 Parent(s): d2c75c4

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +878 -0
pipeline.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import html
4
+ import inspect
5
+ import re
6
+ import urllib.parse as ul
7
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextModelWithProjection
8
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoPipelineForImage2Image, FluxPipeline, FluxTransformer2DModel
9
+ from diffusers import StableDiffusion3Pipeline, AutoencoderKL, DiffusionPipeline, ImagePipelineOutput
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, SD3LoraLoaderMixin
12
+ from diffusers.utils import (
13
+ USE_PEFT_BACKEND,
14
+ is_torch_xla_available,
15
+ logging,
16
+ BACKENDS_MAPPING,
17
+ is_bs4_available,
18
+ is_ftfy_available,
19
+ deprecate,
20
+ replace_example_docstring,
21
+ scale_lora_layers,
22
+ unscale_lora_layers,
23
+ )
24
+ from diffusers.utils.torch_utils import randn_tensor
25
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
26
+ from typing import Any, Callable, Dict, List, Optional, Union
27
+ from PIL import Image
28
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxTransformer2DModel
29
+
30
+ from diffusers.utils import is_torch_xla_available
31
+
32
+ if is_bs4_available():
33
+ from bs4 import BeautifulSoup
34
+
35
+ if is_ftfy_available():
36
+ import ftfy
37
+
38
+ if is_torch_xla_available():
39
+ import torch_xla.core.xla_model as xm
40
+
41
+ XLA_AVAILABLE = True
42
+ else:
43
+ XLA_AVAILABLE = False
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+ # Constants for shift calculation
48
+ BASE_SEQ_LEN = 256
49
+ MAX_SEQ_LEN = 4096
50
+ BASE_SHIFT = 0.5
51
+ MAX_SHIFT = 1.2
52
+
53
+ # Helper functions
54
+ def calculate_timestep_shift(image_seq_len: int) -> float:
55
+ """Calculates the timestep shift (mu) based on the image sequence length."""
56
+ m = (MAX_SHIFT - BASE_SHIFT) / (MAX_SEQ_LEN - BASE_SEQ_LEN)
57
+ b = BASE_SHIFT - m * BASE_SEQ_LEN
58
+ mu = image_seq_len * m + b
59
+ return mu
60
+
61
+ def prepare_timesteps(
62
+ scheduler: FlowMatchEulerDiscreteScheduler,
63
+ num_inference_steps: Optional[int] = None,
64
+ device: Optional[Union[str, torch.device]] = None,
65
+ timesteps: Optional[List[int]] = None,
66
+ sigmas: Optional[List[float]] = None,
67
+ mu: Optional[float] = None,
68
+ ) -> (torch.Tensor, int):
69
+ """Prepares the timesteps for the diffusion process."""
70
+ if timesteps is not None and sigmas is not None:
71
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
72
+
73
+ if timesteps is not None:
74
+ scheduler.set_timesteps(timesteps=timesteps, device=device)
75
+ elif sigmas is not None:
76
+ scheduler.set_timesteps(sigmas=sigmas, device=device)
77
+ else:
78
+ scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
79
+
80
+ timesteps = scheduler.timesteps
81
+ num_inference_steps = len(timesteps)
82
+ return timesteps, num_inference_steps
83
+
84
+ # FLUX pipeline function
85
+ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
86
+ def __init__(
87
+ self,
88
+ scheduler: FlowMatchEulerDiscreteScheduler,
89
+ vae: AutoencoderKL,
90
+ text_encoder: CLIPTextModel,
91
+ tokenizer: CLIPTokenizer,
92
+ text_encoder_2: T5EncoderModel,
93
+ tokenizer_2: T5TokenizerFast,
94
+ transformer: FluxTransformer2DModel,
95
+ ):
96
+ super().__init__()
97
+
98
+ self.register_modules(
99
+ vae=vae,
100
+ text_encoder=text_encoder,
101
+ text_encoder_2=text_encoder_2,
102
+ tokenizer=tokenizer,
103
+ tokenizer_2=tokenizer_2,
104
+ transformer=transformer,
105
+ scheduler=scheduler,
106
+ )
107
+ self.vae_scale_factor = (
108
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
109
+ )
110
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
111
+ self.tokenizer_max_length = (
112
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
113
+ )
114
+ self.default_sample_size = 64
115
+
116
+ def _get_t5_prompt_embeds(
117
+ self,
118
+ prompt: Union[str, List[str]] = None,
119
+ num_images_per_prompt: int = 1,
120
+ max_sequence_length: int = 512,
121
+ device: Optional[torch.device] = None,
122
+ dtype: Optional[torch.dtype] = None,
123
+ ):
124
+ device = device or self._execution_device
125
+ dtype = dtype or self.text_encoder.dtype
126
+
127
+ prompt = [prompt] if isinstance(prompt, str) else prompt
128
+ batch_size = len(prompt)
129
+
130
+ text_inputs = self.tokenizer_2(
131
+ prompt,
132
+ padding="max_length",
133
+ max_length=max_sequence_length,
134
+ truncation=True,
135
+ return_length=True,
136
+ return_overflowing_tokens=True,
137
+ return_tensors="pt",
138
+ )
139
+ text_input_ids = text_inputs.input_ids
140
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
141
+
142
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
143
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
144
+ logger.warning(
145
+ "The following part of your input was truncated because `max_sequence_length` is set to "
146
+ f" {max_sequence_length} tokens: {removed_text}"
147
+ )
148
+
149
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
150
+
151
+ dtype = self.text_encoder_2.dtype
152
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
153
+
154
+ _, seq_len, _ = prompt_embeds.shape
155
+
156
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
157
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
158
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
159
+
160
+ return prompt_embeds
161
+
162
+ def _get_clip_prompt_embeds(
163
+ self,
164
+ prompt: Union[str, List[str]],
165
+ num_images_per_prompt: int = 1,
166
+ device: Optional[torch.device] = None,
167
+ ):
168
+ device = device or self._execution_device
169
+
170
+ prompt = [prompt] if isinstance(prompt, str) else prompt
171
+ batch_size = len(prompt)
172
+
173
+ text_inputs = self.tokenizer(
174
+ prompt,
175
+ padding="max_length",
176
+ max_length=self.tokenizer_max_length,
177
+ truncation=True,
178
+ return_overflowing_tokens=False,
179
+ return_length=False,
180
+ return_tensors="pt",
181
+ )
182
+
183
+ text_input_ids = text_inputs.input_ids
184
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
185
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
186
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
187
+ logger.warning(
188
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
189
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
190
+ )
191
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
192
+
193
+ # Use pooled output of CLIPTextModel
194
+ prompt_embeds = prompt_embeds.pooler_output
195
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
196
+
197
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
198
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
199
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
200
+
201
+ return prompt_embeds
202
+
203
+ def encode_prompt(
204
+ self,
205
+ prompt: Union[str, List[str]],
206
+ prompt_2: Union[str, List[str]],
207
+ device: Optional[torch.device] = None,
208
+ num_images_per_prompt: int = 1,
209
+ do_classifier_free_guidance: bool = True,
210
+ negative_prompt: Optional[Union[str, List[str]]] = None,
211
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
212
+ prompt_embeds: Optional[torch.FloatTensor] = None,
213
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
214
+ negative_prompt_2_embed: Optional[torch.Tensor] = None,
215
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
216
+ negative_pooled_prompt_2_embed: Optional[torch.FloatTensor] = None,
217
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
218
+ max_sequence_length: int = 512,
219
+ lora_scale: Optional[float] = None,
220
+ ):
221
+ device = device or self._execution_device
222
+
223
+ if device is None:
224
+ device = self._execution_device
225
+
226
+ # set lora scale so that monkey patched LoRA
227
+ # function of text encoder can correctly access it
228
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
229
+ self._lora_scale = lora_scale
230
+
231
+ # dynamically adjust the LoRA scale
232
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
233
+ scale_lora_layers(self.text_encoder, lora_scale)
234
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
235
+ scale_lora_layers(self.text_encoder_2, lora_scale)
236
+
237
+ prompt = [prompt] if isinstance(prompt, str) else prompt
238
+
239
+ if prompt_embeds is None:
240
+ prompt_2 = prompt_2 or prompt
241
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
242
+
243
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
244
+ prompt=prompt,
245
+ device=device,
246
+ num_images_per_prompt=num_images_per_prompt,
247
+ )
248
+ prompt_embeds = self._get_t5_prompt_embeds(
249
+ prompt=prompt_2,
250
+ num_images_per_prompt=num_images_per_prompt,
251
+ max_sequence_length=max_sequence_length,
252
+ device=device,
253
+ )
254
+
255
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
256
+ negative_prompt = negative_prompt or ""
257
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
258
+
259
+ # normalize str to list
260
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
261
+ negative_prompt_2 = (
262
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
263
+ )
264
+
265
+ if prompt is not None and type(prompt) is not type(negative_prompt):
266
+ raise TypeError(
267
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
268
+ f" {type(prompt)}."
269
+ )
270
+ elif batch_size != len(negative_prompt):
271
+ raise ValueError(
272
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
273
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
274
+ " the batch size of `prompt`."
275
+ )
276
+
277
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
278
+ negative_prompt,
279
+ device=device,
280
+ num_images_per_prompt=num_images_per_prompt,
281
+ )
282
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
283
+
284
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
285
+ prompt=negative_prompt_2,
286
+ num_images_per_prompt=num_images_per_prompt,
287
+ max_sequence_length=max_sequence_length,
288
+ device=device,
289
+ )
290
+
291
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
292
+ negative_clip_prompt_embeds,
293
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
294
+ )
295
+
296
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
297
+ negative_pooled_prompt_embeds = torch.cat(
298
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
299
+ )
300
+
301
+ if self.text_encoder is not None:
302
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
303
+ # Retrieve the original scale by scaling back the LoRA layers
304
+ unscale_lora_layers(self.text_encoder, lora_scale)
305
+
306
+ if self.text_encoder_2 is not None:
307
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
308
+ # Retrieve the original scale by scaling back the LoRA layers
309
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
310
+
311
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
312
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
313
+
314
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
315
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
316
+
317
+ return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
318
+
319
+ def check_inputs(
320
+ self,
321
+ prompt,
322
+ prompt_2,
323
+ height,
324
+ width,
325
+ negative_prompt=None,
326
+ negative_prompt_2=None,
327
+ prompt_embeds=None,
328
+ negative_prompt_embeds=None,
329
+ pooled_prompt_embeds=None,
330
+ negative_pooled_prompt_embeds=None,
331
+ max_sequence_length=None,
332
+ ):
333
+ if height % 8 != 0 or width % 8 != 0:
334
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
335
+
336
+ if prompt is not None and prompt_embeds is not None:
337
+ raise ValueError(
338
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
339
+ " only forward one of the two."
340
+ )
341
+ elif prompt_2 is not None and prompt_embeds is not None:
342
+ raise ValueError(
343
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
344
+ " only forward one of the two."
345
+ )
346
+ elif prompt is None and prompt_embeds is None:
347
+ raise ValueError(
348
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
349
+ )
350
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
351
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
352
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
353
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
354
+
355
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
356
+ raise ValueError(
357
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
358
+ )
359
+ if negative_prompt is not None and negative_prompt_embeds is not None:
360
+ raise ValueError(
361
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
362
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
363
+ )
364
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
365
+ raise ValueError(
366
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
367
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
368
+ )
369
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
370
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
371
+
372
+ if max_sequence_length is not None and max_sequence_length > 512:
373
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
374
+
375
+ @staticmethod
376
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
377
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
378
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
379
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
380
+
381
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
382
+
383
+ latent_image_ids = latent_image_ids.reshape(
384
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
385
+ )
386
+
387
+ return latent_image_ids.to(device=device, dtype=dtype)
388
+
389
+ @staticmethod
390
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
391
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
392
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
393
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
394
+
395
+ return latents
396
+
397
+ @staticmethod
398
+ def _unpack_latents(latents, height, width, vae_scale_factor):
399
+ batch_size, num_patches, channels = latents.shape
400
+
401
+ height = height // vae_scale_factor
402
+ width = width // vae_scale_factor
403
+
404
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
405
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
406
+
407
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
408
+
409
+ return latents
410
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
411
+ def prepare_extra_step_kwargs(self, generator, eta):
412
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
413
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
414
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
415
+ # and should be between [0, 1]
416
+
417
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
418
+ extra_step_kwargs = {}
419
+ if accepts_eta:
420
+ extra_step_kwargs["eta"] = eta
421
+
422
+ # check if the scheduler accepts generator
423
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
424
+ if accepts_generator:
425
+ extra_step_kwargs["generator"] = generator
426
+ return extra_step_kwargs
427
+
428
+ def enable_vae_slicing(self):
429
+ r"""
430
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
431
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
432
+ """
433
+ self.vae.enable_slicing()
434
+
435
+ def disable_vae_slicing(self):
436
+ r"""
437
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
438
+ computing decoding in one step.
439
+ """
440
+ self.vae.disable_slicing()
441
+
442
+ def enable_vae_tiling(self):
443
+ r"""
444
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
445
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
446
+ processing larger images.
447
+ """
448
+ self.vae.enable_tiling()
449
+
450
+ def disable_vae_tiling(self):
451
+ r"""
452
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
453
+ computing decoding in one step.
454
+ """
455
+ self.vae.disable_tiling()
456
+
457
+ def prepare_latents(
458
+ self,
459
+ batch_size,
460
+ num_channels_latents,
461
+ height,
462
+ width,
463
+ dtype,
464
+ device,
465
+ generator,
466
+ latents=None,
467
+ ):
468
+ height = 2 * (int(height) // self.vae_scale_factor)
469
+ width = 2 * (int(width) // self.vae_scale_factor)
470
+
471
+ shape = (batch_size, num_channels_latents, height, width)
472
+
473
+ if latents is not None:
474
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
475
+ return latents.to(device=device, dtype=dtype), latent_image_ids
476
+
477
+ if isinstance(generator, list) and len(generator) != batch_size:
478
+ raise ValueError(
479
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
480
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
481
+ )
482
+
483
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
484
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
485
+
486
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
487
+
488
+ return latents, latent_image_ids
489
+
490
+ @property
491
+ def guidance_scale(self):
492
+ return self._guidance_scale
493
+
494
+ @property
495
+ def do_classifier_free_guidance(self):
496
+ return self._guidance_scale > 1
497
+
498
+ @property
499
+ def joint_attention_kwargs(self):
500
+ return self._joint_attention_kwargs
501
+
502
+ @property
503
+ def num_timesteps(self):
504
+ return self._num_timesteps
505
+
506
+ @property
507
+ def interrupt(self):
508
+ return self._interrupt
509
+
510
+ @torch.no_grad()
511
+ @torch.inference_mode()
512
+ def generate_images(
513
+ self,
514
+ prompt: Union[str, List[str]] = None,
515
+ prompt_2: Optional[Union[str, List[str]]] = None,
516
+ height: Optional[int] = None,
517
+ width: Optional[int] = None,
518
+ negative_prompt: Optional[Union[str, List[str]]] = None,
519
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
520
+ num_inference_steps: int = 8,
521
+ timesteps: List[int] = None,
522
+ eta: Optional[float] = 0.0,
523
+ guidance_scale: float = 3.5,
524
+ num_images_per_prompt: Optional[int] = 1,
525
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
526
+ latents: Optional[torch.FloatTensor] = None,
527
+ prompt_embeds: Optional[torch.FloatTensor] = None,
528
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
529
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
530
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
531
+ output_type: Optional[str] = "pil",
532
+ return_dict: bool = True,
533
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
534
+ max_sequence_length: int = 300,
535
+ **kwargs,
536
+ ):
537
+ height = height or self.default_sample_size * self.vae_scale_factor
538
+ width = width or self.default_sample_size * self.vae_scale_factor
539
+
540
+ # 1. Check inputs
541
+ self.check_inputs(
542
+ prompt,
543
+ prompt_2,
544
+ height,
545
+ width,
546
+ negative_prompt=negative_prompt,
547
+ negative_prompt_2=negative_prompt_2,
548
+ prompt_embeds=prompt_embeds,
549
+ negative_prompt_embeds=negative_prompt_embeds,
550
+ pooled_prompt_embeds=pooled_prompt_embeds,
551
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
552
+ max_sequence_length=max_sequence_length,
553
+ )
554
+
555
+ self._guidance_scale = guidance_scale
556
+ self._joint_attention_kwargs = joint_attention_kwargs
557
+ self._interrupt = False
558
+
559
+ # 2. Define call parameters
560
+ if prompt is not None and isinstance(prompt, str):
561
+ batch_size = 1
562
+ elif prompt is not None and isinstance(prompt, list):
563
+ batch_size = len(prompt)
564
+ else:
565
+ batch_size = prompt_embeds.shape[0]
566
+
567
+ device = self._execution_device
568
+
569
+ do_classifier_free_guidance = guidance_scale > 1.0
570
+
571
+ lora_scale = (
572
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
573
+ )
574
+ (
575
+ prompt_embeds,
576
+ negative_prompt_embeds,
577
+ pooled_prompt_embeds,
578
+ negative_pooled_prompt_embeds,
579
+ ) = self.encode_prompt(
580
+ prompt=prompt,
581
+ prompt_2=prompt_2,
582
+ negative_prompt=negative_prompt,
583
+ negative_prompt_2=negative_prompt_2,
584
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
585
+ prompt_embeds=prompt_embeds,
586
+ negative_prompt_embeds=negative_prompt_embeds,
587
+ pooled_prompt_embeds=pooled_prompt_embeds,
588
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
589
+ device=device,
590
+ num_images_per_prompt=num_images_per_prompt,
591
+ max_sequence_length=max_sequence_length,
592
+ lora_scale=lora_scale,
593
+ )
594
+
595
+ if self.do_classifier_free_guidance:
596
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
597
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
598
+
599
+ # 4. Prepare latent variables
600
+ num_channels_latents = self.transformer.config.in_channels // 4
601
+ latents, latent_image_ids = self.prepare_latents(
602
+ batch_size * num_images_per_prompt,
603
+ num_channels_latents,
604
+ height,
605
+ width,
606
+ prompt_embeds.dtype,
607
+ negative_prompt_embeds.dtype,
608
+ device,
609
+ generator,
610
+ latents,
611
+ )
612
+
613
+ # 5. Prepare timesteps
614
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
615
+ image_seq_len = latents.shape[1]
616
+ mu = calculate_timestep_shift(image_seq_len)
617
+ timesteps, num_inference_steps = prepare_timesteps(
618
+ self.scheduler,
619
+ num_inference_steps,
620
+ device,
621
+ timesteps,
622
+ sigmas,
623
+ mu=mu,
624
+ )
625
+ self._num_timesteps = len(timesteps)
626
+
627
+ # Handle guidance
628
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
629
+
630
+ # 6. Denoising loop
631
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
632
+ for i, t in enumerate(timesteps):
633
+ if self.interrupt:
634
+ continue
635
+
636
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
637
+
638
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
639
+
640
+ if self.transformer.config.guidance_embeds:
641
+ guidance = torch.tensor([guidance_scale], device=device)
642
+ guidance = guidance.expand(latents.shape[0])
643
+ else:
644
+ guidance = None
645
+
646
+ noise_pred = self.transformer(
647
+ hidden_states=latent_model_input,
648
+ timestep=timestep / 1000,
649
+ guidance=guidance,
650
+ pooled_projections=pooled_prompt_embeds,
651
+ encoder_hidden_states=prompt_embeds,
652
+ txt_ids=text_ids,
653
+ img_ids=latent_image_ids,
654
+ joint_attention_kwargs=self.joint_attention_kwargs,
655
+ return_dict=False,
656
+ )[0]
657
+
658
+ noise_pred_uncond = self.transformer(
659
+ hidden_states=latents,
660
+ timestep=timestep / 1000,
661
+ guidance=guidance,
662
+ pooled_projections=negative_pooled_prompt_embeds,
663
+ encoder_hidden_states=negative_prompt_embeds,
664
+ img_ids=latent_image_ids,
665
+ joint_attention_kwargs=self.joint_attention_kwargs,
666
+ return_dict=False,
667
+ )[0]
668
+
669
+ if self.do_classifier_free_guidance:
670
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
671
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
672
+
673
+ # compute the previous noisy sample x_t -> x_t-1
674
+ latents_dtype = latents.dtype
675
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
676
+
677
+ if latents.dtype != latents_dtype:
678
+ if torch.backends.mps.is_available():
679
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
680
+ latents = latents.to(latents_dtype)
681
+
682
+ # call the callback, if provided
683
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
684
+ progress_bar.update()
685
+
686
+ # Final image
687
+ return self._decode_latents_to_image(latents, height, width, output_type)
688
+ self.maybe_free_model_hooks()
689
+ torch.cuda.empty_cache()
690
+
691
+ @torch.no_grad()
692
+ @torch.inference_mode()
693
+ def __call__(
694
+ self,
695
+ prompt: Union[str, List[str]] = None,
696
+ prompt_2: Optional[Union[str, List[str]]] = None,
697
+ height: Optional[int] = None,
698
+ width: Optional[int] = None,
699
+ negative_prompt: Optional[Union[str, List[str]]] = None,
700
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
701
+ num_inference_steps: int = 8,
702
+ timesteps: List[int] = None,
703
+ eta: Optional[float] = 0.0,
704
+ guidance_scale: float = 3.5,
705
+ num_images_per_prompt: Optional[int] = 1,
706
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
707
+ latents: Optional[torch.FloatTensor] = None,
708
+ prompt_embeds: Optional[torch.FloatTensor] = None,
709
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
710
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
711
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
712
+ output_type: Optional[str] = "pil",
713
+ return_dict: bool = True,
714
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
715
+ max_sequence_length: int = 300,
716
+ **kwargs,
717
+ ):
718
+ height = height or self.default_sample_size * self.vae_scale_factor
719
+ width = width or self.default_sample_size * self.vae_scale_factor
720
+
721
+ # 1. Check inputs
722
+ self.check_inputs(
723
+ prompt,
724
+ prompt_2,
725
+ height,
726
+ width,
727
+ negative_prompt=negative_prompt,
728
+ negative_prompt_2=negative_prompt_2,
729
+ prompt_embeds=prompt_embeds,
730
+ negative_prompt_embeds=negative_prompt_embeds,
731
+ pooled_prompt_embeds=pooled_prompt_embeds,
732
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
733
+ max_sequence_length=max_sequence_length,
734
+ )
735
+
736
+ self._guidance_scale = guidance_scale
737
+ self._joint_attention_kwargs = joint_attention_kwargs
738
+ self._interrupt = False
739
+
740
+ # 2. Define call parameters
741
+ if prompt is not None and isinstance(prompt, str):
742
+ batch_size = 1
743
+ elif prompt is not None and isinstance(prompt, list):
744
+ batch_size = len(prompt)
745
+ else:
746
+ batch_size = prompt_embeds.shape[0]
747
+
748
+ device = self._execution_device
749
+
750
+ do_classifier_free_guidance = guidance_scale > 1.0
751
+
752
+ lora_scale = (
753
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
754
+ )
755
+ (
756
+ prompt_embeds,
757
+ negative_prompt_embeds,
758
+ pooled_prompt_embeds,
759
+ negative_pooled_prompt_embeds,
760
+ ) = self.encode_prompt(
761
+ prompt=prompt,
762
+ prompt_2=prompt_2,
763
+ negative_prompt=negative_prompt,
764
+ negative_prompt_2=negative_prompt_2,
765
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
766
+ prompt_embeds=prompt_embeds,
767
+ negative_prompt_embeds=negative_prompt_embeds,
768
+ pooled_prompt_embeds=pooled_prompt_embeds,
769
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
770
+ device=device,
771
+ num_images_per_prompt=num_images_per_prompt,
772
+ max_sequence_length=max_sequence_length,
773
+ lora_scale=lora_scale,
774
+ )
775
+
776
+ if self.do_classifier_free_guidance:
777
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
778
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
779
+
780
+ # 4. Prepare latent variables
781
+ num_channels_latents = self.transformer.config.in_channels // 4
782
+ latents, latent_image_ids = self.prepare_latents(
783
+ batch_size * num_images_per_prompt,
784
+ num_channels_latents,
785
+ height,
786
+ width,
787
+ prompt_embeds.dtype,
788
+ negative_prompt_embeds.dtype,
789
+ device,
790
+ generator,
791
+ latents,
792
+ )
793
+
794
+ # 5. Prepare timesteps
795
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
796
+ image_seq_len = latents.shape[1]
797
+ mu = calculate_timestep_shift(image_seq_len)
798
+ timesteps, num_inference_steps = prepare_timesteps(
799
+ self.scheduler,
800
+ num_inference_steps,
801
+ device,
802
+ timesteps,
803
+ sigmas,
804
+ mu=mu,
805
+ )
806
+ self._num_timesteps = len(timesteps)
807
+
808
+ # Handle guidance
809
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
810
+
811
+ # 6. Denoising loop
812
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
813
+ for i, t in enumerate(timesteps):
814
+ if self.interrupt:
815
+ continue
816
+
817
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
818
+
819
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
820
+
821
+ if self.transformer.config.guidance_embeds:
822
+ guidance = torch.tensor([guidance_scale], device=device)
823
+ guidance = guidance.expand(latents.shape[0])
824
+ else:
825
+ guidance = None
826
+
827
+ noise_pred = self.transformer(
828
+ hidden_states=latent_model_input,
829
+ timestep=timestep / 1000,
830
+ guidance=guidance,
831
+ pooled_projections=pooled_prompt_embeds,
832
+ encoder_hidden_states=prompt_embeds,
833
+ txt_ids=text_ids,
834
+ img_ids=latent_image_ids,
835
+ joint_attention_kwargs=self.joint_attention_kwargs,
836
+ return_dict=False,
837
+ )[0]
838
+
839
+ noise_pred_uncond = self.transformer(
840
+ hidden_states=latents,
841
+ timestep=timestep / 1000,
842
+ guidance=guidance,
843
+ pooled_projections=negative_pooled_prompt_embeds,
844
+ encoder_hidden_states=negative_prompt_embeds,
845
+ img_ids=latent_image_ids,
846
+ joint_attention_kwargs=self.joint_attention_kwargs,
847
+ return_dict=False,
848
+ )[0]
849
+
850
+ if self.do_classifier_free_guidance:
851
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
852
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
853
+
854
+ # compute the previous noisy sample x_t -> x_t-1
855
+ latents_dtype = latents.dtype
856
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
857
+
858
+ if latents.dtype != latents_dtype:
859
+ if torch.backends.mps.is_available():
860
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
861
+ latents = latents.to(latents_dtype)
862
+
863
+ # call the callback, if provided
864
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
865
+ progress_bar.update()
866
+
867
+ # Final image
868
+ return self._decode_latents_to_image(latents, height, width, output_type)
869
+ self.maybe_free_model_hooks()
870
+ torch.cuda.empty_cache()
871
+
872
+ def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
873
+ """Decodes the given latents into an image."""
874
+ vae = vae or self.vae
875
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
876
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
877
+ image = vae.decode(latents, return_dict=False)[0]
878
+ return self.image_processor.postprocess(image, output_type=output_type)[0]