Yinhong Liu commited on
Commit
c3c658f
·
1 Parent(s): eab8699

sd3 pipeline

Browse files
requirements.txt CHANGED
@@ -3,4 +3,5 @@ diffusers
3
  invisible_watermark
4
  torch
5
  transformers
6
- xformers
 
 
3
  invisible_watermark
4
  torch
5
  transformers
6
+ xformers
7
+ sentencepiece
sid/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_flax_available,
9
+ is_torch_available,
10
+ is_transformers_available,
11
+ )
12
+
13
+
14
+ _dummy_objects = {}
15
+ _additional_imports = {}
16
+ _import_structure = {"pipeline_output": ["SiDPipelineOutput"]}
17
+
18
+ try:
19
+ if not (is_transformers_available() and is_torch_available()):
20
+ raise OptionalDependencyNotAvailable()
21
+ except OptionalDependencyNotAvailable:
22
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
23
+
24
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
25
+ else:
26
+ _import_structure["pipeline_sid_sd3"] = ["SiDSD3Pipeline"]
27
+ _import_structure["pipeline_sid_flux"] = ["SiDFluxPipeline"]
28
+ _import_structure["pipeline_sid_sana"] = ["SiDSanaPipeline"]
29
+
30
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
31
+ try:
32
+ if not (is_transformers_available() and is_torch_available()):
33
+ raise OptionalDependencyNotAvailable()
34
+ except OptionalDependencyNotAvailable:
35
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
36
+ else:
37
+ from .pipeline_sid_sd3 import SiDSD3Pipeline
38
+ from .pipeline_sid_flux import SiDFluxPipeline
39
+ from .pipeline_sid_sana import SiDSanaPipeline
40
+ else:
41
+ import sys
42
+
43
+ sys.modules[__name__] = _LazyModule(
44
+ __name__,
45
+ globals()["__file__"],
46
+ _import_structure,
47
+ module_spec=__spec__,
48
+ )
49
+
50
+ for name, value in _dummy_objects.items():
51
+ setattr(sys.modules[__name__], name, value)
52
+ for name, value in _additional_imports.items():
53
+ setattr(sys.modules[__name__], name, value)
sid/pipeline_output.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from ...utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class SiDPipelineOutput(BaseOutput):
12
+ """
13
+ Output class for Stable Diffusion pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]
sid/pipeline_sid_flux.py ADDED
File without changes
sid/pipeline_sid_sana.py ADDED
File without changes
sid/pipeline_sid_sd3.py ADDED
@@ -0,0 +1,806 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from transformers import (
20
+ CLIPTextModelWithProjection,
21
+ CLIPTokenizer,
22
+ SiglipImageProcessor,
23
+ SiglipVisionModel,
24
+ T5EncoderModel,
25
+ T5TokenizerFast,
26
+ )
27
+
28
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
29
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
31
+ from diffusers.models.autoencoders import AutoencoderKL
32
+ from diffusers.models.transformers import SD3Transformer2DModel
33
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ is_torch_xla_available,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+ from .pipeline_output import SiDPipelineOutput
45
+
46
+
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+ EXAMPLE_DOC_STRING = """
58
+ Examples:
59
+ ```py
60
+ >>> import torch
61
+ >>> from diffusers import StableDiffusion3Pipeline
62
+
63
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
64
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
65
+ ... )
66
+ >>> pipe.to("cuda")
67
+ >>> prompt = "A cat holding a sign that says hello world"
68
+ >>> image = pipe(prompt).images[0]
69
+ >>> image.save("sd3.png")
70
+ ```
71
+ """
72
+
73
+
74
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
75
+ def calculate_shift(
76
+ image_seq_len,
77
+ base_seq_len: int = 256,
78
+ max_seq_len: int = 4096,
79
+ base_shift: float = 0.5,
80
+ max_shift: float = 1.15,
81
+ ):
82
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
83
+ b = base_shift - m * base_seq_len
84
+ mu = image_seq_len * m + b
85
+ return mu
86
+
87
+
88
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
89
+ def retrieve_timesteps(
90
+ scheduler,
91
+ num_inference_steps: Optional[int] = None,
92
+ device: Optional[Union[str, torch.device]] = None,
93
+ timesteps: Optional[List[int]] = None,
94
+ sigmas: Optional[List[float]] = None,
95
+ **kwargs,
96
+ ):
97
+ r"""
98
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
99
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
100
+
101
+ Args:
102
+ scheduler (`SchedulerMixin`):
103
+ The scheduler to get timesteps from.
104
+ num_inference_steps (`int`):
105
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
106
+ must be `None`.
107
+ device (`str` or `torch.device`, *optional*):
108
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
109
+ timesteps (`List[int]`, *optional*):
110
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
111
+ `num_inference_steps` and `sigmas` must be `None`.
112
+ sigmas (`List[float]`, *optional*):
113
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
114
+ `num_inference_steps` and `timesteps` must be `None`.
115
+
116
+ Returns:
117
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
118
+ second element is the number of inference steps.
119
+ """
120
+ if timesteps is not None and sigmas is not None:
121
+ raise ValueError(
122
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
123
+ )
124
+ if timesteps is not None:
125
+ accepts_timesteps = "timesteps" in set(
126
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
127
+ )
128
+ if not accepts_timesteps:
129
+ raise ValueError(
130
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
131
+ f" timestep schedules. Please check whether you are using the correct scheduler."
132
+ )
133
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
134
+ timesteps = scheduler.timesteps
135
+ num_inference_steps = len(timesteps)
136
+ elif sigmas is not None:
137
+ accept_sigmas = "sigmas" in set(
138
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
139
+ )
140
+ if not accept_sigmas:
141
+ raise ValueError(
142
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
143
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
144
+ )
145
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
146
+ timesteps = scheduler.timesteps
147
+ num_inference_steps = len(timesteps)
148
+ else:
149
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
150
+ timesteps = scheduler.timesteps
151
+ return timesteps, num_inference_steps
152
+
153
+
154
+ class SiDSD3Pipeline(
155
+ DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin
156
+ ):
157
+ r"""
158
+ Args:
159
+ transformer ([`SD3Transformer2DModel`]):
160
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
161
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
162
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
163
+ vae ([`AutoencoderKL`]):
164
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
165
+ text_encoder ([`CLIPTextModelWithProjection`]):
166
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
167
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
168
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
169
+ as its dimension.
170
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
171
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
172
+ specifically the
173
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
174
+ variant.
175
+ text_encoder_3 ([`T5EncoderModel`]):
176
+ Frozen text-encoder. Stable Diffusion 3 uses
177
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
178
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
179
+ tokenizer (`CLIPTokenizer`):
180
+ Tokenizer of class
181
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
182
+ tokenizer_2 (`CLIPTokenizer`):
183
+ Second Tokenizer of class
184
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
185
+ tokenizer_3 (`T5TokenizerFast`):
186
+ Tokenizer of class
187
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
188
+ image_encoder (`SiglipVisionModel`, *optional*):
189
+ Pre-trained Vision Model for IP Adapter.
190
+ feature_extractor (`SiglipImageProcessor`, *optional*):
191
+ Image processor for IP Adapter.
192
+ """
193
+
194
+ model_cpu_offload_seq = (
195
+ "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
196
+ )
197
+ _optional_components = ["image_encoder", "feature_extractor"]
198
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"]
199
+
200
+ def __init__(
201
+ self,
202
+ transformer: SD3Transformer2DModel,
203
+ scheduler: FlowMatchEulerDiscreteScheduler,
204
+ vae: AutoencoderKL,
205
+ text_encoder: CLIPTextModelWithProjection,
206
+ tokenizer: CLIPTokenizer,
207
+ text_encoder_2: CLIPTextModelWithProjection,
208
+ tokenizer_2: CLIPTokenizer,
209
+ text_encoder_3: T5EncoderModel,
210
+ tokenizer_3: T5TokenizerFast,
211
+ image_encoder: SiglipVisionModel = None,
212
+ feature_extractor: SiglipImageProcessor = None,
213
+ ):
214
+ super().__init__()
215
+
216
+ self.register_modules(
217
+ vae=vae,
218
+ text_encoder=text_encoder,
219
+ text_encoder_2=text_encoder_2,
220
+ text_encoder_3=text_encoder_3,
221
+ tokenizer=tokenizer,
222
+ tokenizer_2=tokenizer_2,
223
+ tokenizer_3=tokenizer_3,
224
+ transformer=transformer,
225
+ scheduler=scheduler,
226
+ image_encoder=image_encoder,
227
+ feature_extractor=feature_extractor,
228
+ )
229
+ self.vae_scale_factor = (
230
+ 2 ** (len(self.vae.config.block_out_channels) - 1)
231
+ if getattr(self, "vae", None)
232
+ else 8
233
+ )
234
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
235
+ self.tokenizer_max_length = (
236
+ self.tokenizer.model_max_length
237
+ if hasattr(self, "tokenizer") and self.tokenizer is not None
238
+ else 77
239
+ )
240
+ self.default_sample_size = (
241
+ self.transformer.config.sample_size
242
+ if hasattr(self, "transformer") and self.transformer is not None
243
+ else 128
244
+ )
245
+ self.patch_size = (
246
+ self.transformer.config.patch_size
247
+ if hasattr(self, "transformer") and self.transformer is not None
248
+ else 2
249
+ )
250
+
251
+ def _get_t5_prompt_embeds(
252
+ self,
253
+ prompt: Union[str, List[str]] = None,
254
+ num_images_per_prompt: int = 1,
255
+ max_sequence_length: int = 256,
256
+ device: Optional[torch.device] = None,
257
+ dtype: Optional[torch.dtype] = None,
258
+ ):
259
+ device = device or self._execution_device
260
+ dtype = dtype or self.text_encoder.dtype
261
+
262
+ prompt = [prompt] if isinstance(prompt, str) else prompt
263
+ batch_size = len(prompt)
264
+
265
+ if self.text_encoder_3 is None:
266
+ return torch.zeros(
267
+ (
268
+ batch_size * num_images_per_prompt,
269
+ self.tokenizer_max_length,
270
+ self.transformer.config.joint_attention_dim,
271
+ ),
272
+ device=device,
273
+ dtype=dtype,
274
+ )
275
+
276
+ text_inputs = self.tokenizer_3(
277
+ prompt,
278
+ padding="max_length",
279
+ max_length=max_sequence_length,
280
+ truncation=True,
281
+ add_special_tokens=True,
282
+ return_tensors="pt",
283
+ )
284
+ text_input_ids = text_inputs.input_ids
285
+ untruncated_ids = self.tokenizer_3(
286
+ prompt, padding="longest", return_tensors="pt"
287
+ ).input_ids
288
+
289
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
290
+ text_input_ids, untruncated_ids
291
+ ):
292
+ removed_text = self.tokenizer_3.batch_decode(
293
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
294
+ )
295
+ logger.warning(
296
+ "The following part of your input was truncated because `max_sequence_length` is set to "
297
+ f" {max_sequence_length} tokens: {removed_text}"
298
+ )
299
+
300
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
301
+
302
+ dtype = self.text_encoder_3.dtype
303
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
304
+
305
+ _, seq_len, _ = prompt_embeds.shape
306
+
307
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
308
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
309
+ prompt_embeds = prompt_embeds.view(
310
+ batch_size * num_images_per_prompt, seq_len, -1
311
+ )
312
+
313
+ return prompt_embeds
314
+
315
+ def _get_clip_prompt_embeds(
316
+ self,
317
+ prompt: Union[str, List[str]],
318
+ num_images_per_prompt: int = 1,
319
+ device: Optional[torch.device] = None,
320
+ clip_skip: Optional[int] = None,
321
+ clip_model_index: int = 0,
322
+ ):
323
+ device = device or self._execution_device
324
+
325
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
326
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
327
+
328
+ tokenizer = clip_tokenizers[clip_model_index]
329
+ text_encoder = clip_text_encoders[clip_model_index]
330
+
331
+ prompt = [prompt] if isinstance(prompt, str) else prompt
332
+ batch_size = len(prompt)
333
+
334
+ text_inputs = tokenizer(
335
+ prompt,
336
+ padding="max_length",
337
+ max_length=self.tokenizer_max_length,
338
+ truncation=True,
339
+ return_tensors="pt",
340
+ )
341
+
342
+ text_input_ids = text_inputs.input_ids
343
+ untruncated_ids = tokenizer(
344
+ prompt, padding="longest", return_tensors="pt"
345
+ ).input_ids
346
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
347
+ text_input_ids, untruncated_ids
348
+ ):
349
+ removed_text = tokenizer.batch_decode(
350
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
351
+ )
352
+ logger.warning(
353
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
354
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
355
+ )
356
+ prompt_embeds = text_encoder(
357
+ text_input_ids.to(device), output_hidden_states=True
358
+ )
359
+ pooled_prompt_embeds = prompt_embeds[0]
360
+
361
+ if clip_skip is None:
362
+ prompt_embeds = prompt_embeds.hidden_states[-2]
363
+ else:
364
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
365
+
366
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
367
+
368
+ _, seq_len, _ = prompt_embeds.shape
369
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
370
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
371
+ prompt_embeds = prompt_embeds.view(
372
+ batch_size * num_images_per_prompt, seq_len, -1
373
+ )
374
+
375
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
376
+ pooled_prompt_embeds = pooled_prompt_embeds.view(
377
+ batch_size * num_images_per_prompt, -1
378
+ )
379
+
380
+ return prompt_embeds, pooled_prompt_embeds
381
+
382
+ def encode_prompt(
383
+ self,
384
+ prompt: Union[str, List[str]],
385
+ prompt_2: Union[str, List[str]],
386
+ prompt_3: Union[str, List[str]],
387
+ device: Optional[torch.device] = None,
388
+ num_images_per_prompt: int = 1,
389
+ prompt_embeds: Optional[torch.FloatTensor] = None,
390
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
391
+ clip_skip: Optional[int] = None,
392
+ max_sequence_length: int = 256,
393
+ ):
394
+ r"""
395
+
396
+ Args:
397
+ prompt (`str` or `List[str]`, *optional*):
398
+ prompt to be encoded
399
+ prompt_2 (`str` or `List[str]`, *optional*):
400
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
401
+ used in all text-encoders
402
+ prompt_3 (`str` or `List[str]`, *optional*):
403
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
404
+ used in all text-encoders
405
+ device: (`torch.device`):
406
+ torch device
407
+ num_images_per_prompt (`int`):
408
+ number of images that should be generated per prompt
409
+ do_classifier_free_guidance (`bool`):
410
+ whether to use classifier free guidance or not
411
+ negative_prompt (`str` or `List[str]`, *optional*):
412
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
413
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
414
+ less than `1`).
415
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
416
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
417
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
418
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
419
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
420
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
421
+ prompt_embeds (`torch.FloatTensor`, *optional*):
422
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
423
+ provided, text embeddings will be generated from `prompt` input argument.
424
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
425
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
426
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
427
+ argument.
428
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
429
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
430
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
431
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
432
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
433
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
434
+ input argument.
435
+ clip_skip (`int`, *optional*):
436
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
437
+ the output of the pre-final layer will be used for computing the prompt embeddings.
438
+ lora_scale (`float`, *optional*):
439
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
440
+ """
441
+ device = device or self._execution_device
442
+
443
+ prompt = [prompt] if isinstance(prompt, str) else prompt
444
+ if prompt is not None:
445
+ batch_size = len(prompt)
446
+ else:
447
+ batch_size = prompt_embeds.shape[0]
448
+
449
+ if prompt_embeds is None:
450
+ prompt_2 = prompt_2 or prompt
451
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
452
+
453
+ prompt_3 = prompt_3 or prompt
454
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
455
+
456
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
457
+ prompt=prompt,
458
+ device=device,
459
+ num_images_per_prompt=num_images_per_prompt,
460
+ clip_skip=clip_skip,
461
+ clip_model_index=0,
462
+ )
463
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
464
+ prompt=prompt_2,
465
+ device=device,
466
+ num_images_per_prompt=num_images_per_prompt,
467
+ clip_skip=clip_skip,
468
+ clip_model_index=1,
469
+ )
470
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
471
+
472
+ t5_prompt_embed = self._get_t5_prompt_embeds(
473
+ prompt=prompt_3,
474
+ num_images_per_prompt=num_images_per_prompt,
475
+ max_sequence_length=max_sequence_length,
476
+ device=device,
477
+ )
478
+
479
+ clip_prompt_embeds = torch.nn.functional.pad(
480
+ clip_prompt_embeds,
481
+ (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]),
482
+ )
483
+
484
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
485
+ pooled_prompt_embeds = torch.cat(
486
+ [pooled_prompt_embed, pooled_prompt_2_embed], dim=-1
487
+ )
488
+
489
+ return (
490
+ prompt_embeds,
491
+ pooled_prompt_embeds,
492
+ )
493
+
494
+ def check_inputs(
495
+ self,
496
+ prompt,
497
+ prompt_2,
498
+ prompt_3,
499
+ height,
500
+ width,
501
+ negative_prompt=None,
502
+ negative_prompt_2=None,
503
+ negative_prompt_3=None,
504
+ prompt_embeds=None,
505
+ negative_prompt_embeds=None,
506
+ pooled_prompt_embeds=None,
507
+ negative_pooled_prompt_embeds=None,
508
+ callback_on_step_end_tensor_inputs=None,
509
+ max_sequence_length=None,
510
+ ):
511
+ if (
512
+ height % (self.vae_scale_factor * self.patch_size) != 0
513
+ or width % (self.vae_scale_factor * self.patch_size) != 0
514
+ ):
515
+ raise ValueError(
516
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
517
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
518
+ )
519
+
520
+ if callback_on_step_end_tensor_inputs is not None and not all(
521
+ k in self._callback_tensor_inputs
522
+ for k in callback_on_step_end_tensor_inputs
523
+ ):
524
+ raise ValueError(
525
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
526
+ )
527
+
528
+ if prompt is not None and prompt_embeds is not None:
529
+ raise ValueError(
530
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
531
+ " only forward one of the two."
532
+ )
533
+ elif prompt_2 is not None and prompt_embeds is not None:
534
+ raise ValueError(
535
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
536
+ " only forward one of the two."
537
+ )
538
+ elif prompt_3 is not None and prompt_embeds is not None:
539
+ raise ValueError(
540
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
541
+ " only forward one of the two."
542
+ )
543
+ elif prompt is None and prompt_embeds is None:
544
+ raise ValueError(
545
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
546
+ )
547
+ elif prompt is not None and (
548
+ not isinstance(prompt, str) and not isinstance(prompt, list)
549
+ ):
550
+ raise ValueError(
551
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
552
+ )
553
+ elif prompt_2 is not None and (
554
+ not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
555
+ ):
556
+ raise ValueError(
557
+ f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
558
+ )
559
+ elif prompt_3 is not None and (
560
+ not isinstance(prompt_3, str) and not isinstance(prompt_3, list)
561
+ ):
562
+ raise ValueError(
563
+ f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}"
564
+ )
565
+
566
+ if negative_prompt is not None and negative_prompt_embeds is not None:
567
+ raise ValueError(
568
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
569
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
570
+ )
571
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
572
+ raise ValueError(
573
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
574
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
575
+ )
576
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
577
+ raise ValueError(
578
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
579
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
580
+ )
581
+
582
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
583
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
584
+ raise ValueError(
585
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
586
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
587
+ f" {negative_prompt_embeds.shape}."
588
+ )
589
+
590
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
591
+ raise ValueError(
592
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
593
+ )
594
+
595
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
596
+ raise ValueError(
597
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
598
+ )
599
+
600
+ if max_sequence_length is not None and max_sequence_length > 512:
601
+ raise ValueError(
602
+ f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
603
+ )
604
+
605
+ def prepare_latents(
606
+ self,
607
+ batch_size,
608
+ num_channels_latents,
609
+ height,
610
+ width,
611
+ dtype,
612
+ device,
613
+ generator,
614
+ latents=None,
615
+ ):
616
+ if latents is not None:
617
+ return latents.to(device=device, dtype=dtype)
618
+
619
+ shape = (
620
+ batch_size,
621
+ num_channels_latents,
622
+ int(height) // self.vae_scale_factor,
623
+ int(width) // self.vae_scale_factor,
624
+ )
625
+
626
+ if isinstance(generator, list) and len(generator) != batch_size:
627
+ raise ValueError(
628
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
629
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
630
+ )
631
+
632
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
633
+
634
+ return latents
635
+
636
+ @property
637
+ def guidance_scale(self):
638
+ return self._guidance_scale
639
+
640
+ @property
641
+ def skip_guidance_layers(self):
642
+ return self._skip_guidance_layers
643
+
644
+ @property
645
+ def clip_skip(self):
646
+ return self._clip_skip
647
+
648
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
649
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
650
+ # corresponds to doing no classifier free guidance.
651
+ @property
652
+ def do_classifier_free_guidance(self):
653
+ return self._guidance_scale > 1
654
+
655
+ @property
656
+ def joint_attention_kwargs(self):
657
+ return self._joint_attention_kwargs
658
+
659
+ @property
660
+ def num_timesteps(self):
661
+ return self._num_timesteps
662
+
663
+ @property
664
+ def interrupt(self):
665
+ return self._interrupt
666
+
667
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
668
+
669
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
670
+ if (
671
+ self.image_encoder is not None
672
+ and "image_encoder" not in self._exclude_from_cpu_offload
673
+ ):
674
+ logger.warning(
675
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
676
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
677
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
678
+ )
679
+
680
+ super().enable_sequential_cpu_offload(*args, **kwargs)
681
+
682
+ @torch.no_grad()
683
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
684
+ def __call__(
685
+ self,
686
+ prompt: Union[str, List[str]] = None,
687
+ prompt_2: Optional[Union[str, List[str]]] = None,
688
+ prompt_3: Optional[Union[str, List[str]]] = None,
689
+ height: Optional[int] = None,
690
+ width: Optional[int] = None,
691
+ num_inference_steps: int = 28,
692
+ guidance_scale: float = 1.0,
693
+ num_images_per_prompt: Optional[int] = 1,
694
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
695
+ latents: Optional[torch.FloatTensor] = None,
696
+ output_type: Optional[str] = "pil",
697
+ return_dict: bool = True,
698
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
699
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
700
+ max_sequence_length: int = 256,
701
+ use_sd3_shift: bool = False,
702
+ noise_type: str = 'fresh', # 'fresh', 'ddim', 'fixed'
703
+ ):
704
+ height = height or self.default_sample_size * self.vae_scale_factor
705
+ width = width or self.default_sample_size * self.vae_scale_factor
706
+
707
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
708
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
709
+
710
+ # 1. Check inputs. Raise error if not correct
711
+ self.check_inputs(
712
+ prompt,
713
+ prompt_2,
714
+ prompt_3,
715
+ height,
716
+ width,
717
+ prompt_embeds=prompt_embeds,
718
+ pooled_prompt_embeds=pooled_prompt_embeds,
719
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
720
+ max_sequence_length=max_sequence_length,
721
+ )
722
+
723
+ self._guidance_scale = guidance_scale
724
+ self._interrupt = False
725
+
726
+ # 2. Define call parameters
727
+ if prompt is not None and isinstance(prompt, str):
728
+ batch_size = 1
729
+ elif prompt is not None and isinstance(prompt, list):
730
+ batch_size = len(prompt)
731
+ else:
732
+ batch_size = prompt_embeds.shape[0]
733
+
734
+ device = self._execution_device
735
+
736
+ (
737
+ prompt_embeds,
738
+ pooled_prompt_embeds,
739
+ ) = self.encode_prompt(
740
+ prompt,
741
+ prompt_2,
742
+ prompt_3,
743
+ prompt_embeds=prompt_embeds,
744
+ pooled_prompt_embeds=pooled_prompt_embeds,
745
+ device=device,
746
+ num_images_per_prompt=num_images_per_prompt,
747
+ max_sequence_length=max_sequence_length,
748
+ )
749
+ # 3. Prepare latents
750
+ num_channels_latents = self.transformer.config.in_channels
751
+ latents = self.prepare_latents(
752
+ batch_size * num_images_per_prompt,
753
+ num_channels_latents,
754
+ height,
755
+ width,
756
+ prompt_embeds.dtype,
757
+ device,
758
+ generator,
759
+ latents,
760
+ )
761
+
762
+ # 4. SiD sampling loop
763
+ # Initialize D_x
764
+ D_x = torch.zeros_like(latents).to(latents.device)
765
+ # Use fixed noise for now (can be extended as needed)
766
+ initial_latents = latents.clone()
767
+ for i in range(num_inference_steps):
768
+ if noise_type == 'fresh':
769
+ noise = latents if i == 0 else torch.randn_like(latents).to(latents.device)
770
+ elif noise_type=='ddim':
771
+ noise = latents if i == 0 else ((latents - (1.0 - t) * D_x) / t).detach()
772
+ elif noise_type == 'fixed':
773
+ noise = initial_latents # Use the initial, unmodified latents
774
+ else:
775
+ raise ValueError(f"Unknown noise_type: {noise_type}")
776
+
777
+ # Compute t value, normalized to [0, 1]
778
+ t_val = 1.0 - float(i) / float(num_inference_steps)
779
+ if use_sd3_shift:
780
+ shift = 3.0
781
+ t_val = shift * t_val / (1 + (shift - 1) * t_val)
782
+ t = torch.full((latents.shape[0],), t_val, device=latents.device, dtype=latents.dtype)
783
+ t_flatten = t.flatten()
784
+ if t.numel() > 1:
785
+ t_view = t.view(-1, 1, 1, 1)
786
+ else:
787
+ t_view = t
788
+ # SiD update
789
+ latents = (1.0 - t_view) * D_x + t_view * noise
790
+ flow_pred = self.transformer(
791
+ hidden_states=latents,
792
+ encoder_hidden_states=prompt_embeds,
793
+ pooled_projections=pooled_prompt_embeds,
794
+ timestep=t_flatten,
795
+ return_dict=False,
796
+ )[0]
797
+ D_x = latents - t_view * flow_pred
798
+
799
+ # 5. Decode latent to image
800
+ image = self.vae.decode((D_x / self.vae.config.scaling_factor) + self.vae.config.shift_factor, return_dict=False)[0]
801
+
802
+ # 6. Return output
803
+ if not return_dict:
804
+ return (image,)
805
+
806
+ return SiDPipelineOutput(images=image)