linjinpeng commited on
Commit
4049887
·
1 Parent(s): a19eadb

fix checkpoint, ready to merge to diffusers

Browse files
config.json CHANGED
@@ -4,13 +4,6 @@
4
  "_name_or_path": "./model_hub_tmp_0/.",
5
  "attention_head_dim": 64,
6
  "caption_projection_dim": 1536,
7
- "conditioning_channels": 3,
8
- "conditioning_embedding_out_channels": [
9
- 16,
10
- 32,
11
- 96,
12
- 256
13
- ],
14
  "in_channels": 16,
15
  "joint_attention_dim": 4096,
16
  "num_attention_heads": 24,
 
4
  "_name_or_path": "./model_hub_tmp_0/.",
5
  "attention_head_dim": 64,
6
  "caption_projection_dim": 1536,
 
 
 
 
 
 
 
7
  "in_channels": 16,
8
  "joint_attention_dim": 4096,
9
  "num_attention_heads": 24,
controlnet_sd3.py DELETED
@@ -1,552 +0,0 @@
1
- # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from dataclasses import dataclass
17
- from typing import Any, Dict, List, Optional, Tuple, Union
18
-
19
- import torch
20
- import torch.nn as nn
21
-
22
- import diffusers
23
- from diffusers.configuration_utils import ConfigMixin, register_to_config
24
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
- from diffusers.models.attention import JointTransformerBlock
26
- from diffusers.models.attention_processor import Attention, AttentionProcessor
27
- from diffusers.models.modeling_utils import ModelMixin
28
- from diffusers.utils import (
29
- USE_PEFT_BACKEND,
30
- is_torch_version,
31
- logging,
32
- scale_lora_layers,
33
- unscale_lora_layers,
34
- )
35
- from diffusers.models.controlnet import BaseOutput, zero_module
36
- from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
37
- from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
38
- from torch.nn import functional as F
39
-
40
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
- from packaging import version
42
-
43
- class ControlNetConditioningEmbedding(nn.Module):
44
- """
45
- Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
46
- [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
47
- training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
48
- convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
49
- (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
50
- model) to encode image-space conditions ... into feature maps ..."
51
- """
52
-
53
- def __init__(
54
- self,
55
- conditioning_embedding_channels: int,
56
- conditioning_channels: int = 3,
57
- block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
58
- ):
59
- super().__init__()
60
-
61
- self.conv_in = nn.Conv2d(
62
- conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
63
- )
64
-
65
- self.blocks = nn.ModuleList([])
66
-
67
- for i in range(len(block_out_channels) - 1):
68
- channel_in = block_out_channels[i]
69
- channel_out = block_out_channels[i + 1]
70
- self.blocks.append(
71
- nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)
72
- )
73
- self.blocks.append(
74
- nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)
75
- )
76
-
77
- self.conv_out = zero_module(
78
- nn.Conv2d(
79
- block_out_channels[-1],
80
- conditioning_embedding_channels,
81
- kernel_size=3,
82
- padding=1,
83
- )
84
- )
85
-
86
- def forward(self, conditioning):
87
- embedding = self.conv_in(conditioning)
88
- embedding = F.silu(embedding)
89
-
90
- for block in self.blocks:
91
- embedding = block(embedding)
92
- embedding = F.silu(embedding)
93
-
94
- embedding = self.conv_out(embedding)
95
-
96
- return embedding
97
-
98
-
99
- @dataclass
100
- class SD3ControlNetOutput(BaseOutput):
101
- controlnet_block_samples: Tuple[torch.Tensor]
102
-
103
-
104
- class SD3ControlNetModel(
105
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
106
- ):
107
- _supports_gradient_checkpointing = True
108
-
109
- @register_to_config
110
- def __init__(
111
- self,
112
- sample_size: int = 128,
113
- patch_size: int = 2,
114
- in_channels: int = 16,
115
- num_layers: int = 18,
116
- attention_head_dim: int = 64,
117
- num_attention_heads: int = 18,
118
- joint_attention_dim: int = 4096,
119
- caption_projection_dim: int = 1152,
120
- pooled_projection_dim: int = 2048,
121
- out_channels: int = 16,
122
- pos_embed_max_size: int = 96,
123
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
124
- 16,
125
- 32,
126
- 96,
127
- 256,
128
- ),
129
- conditioning_channels: int = 3,
130
- ):
131
- """
132
- conditioning_channels: condition image pixel space channels
133
- conditioning_embedding_out_channels: intermediate channels
134
-
135
- """
136
- super().__init__()
137
- default_out_channels = in_channels
138
- self.out_channels = (
139
- out_channels if out_channels is not None else default_out_channels
140
- )
141
- self.inner_dim = num_attention_heads * attention_head_dim
142
-
143
- self.pos_embed = PatchEmbed(
144
- height=sample_size,
145
- width=sample_size,
146
- patch_size=patch_size,
147
- in_channels=in_channels,
148
- embed_dim=self.inner_dim,
149
- pos_embed_max_size=pos_embed_max_size, # hard-code for now.
150
- )
151
- self.time_text_embed = CombinedTimestepTextProjEmbeddings(
152
- embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
153
- )
154
- self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
155
-
156
- # control net conditioning embedding
157
- # self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
158
- # conditioning_embedding_channels=default_out_channels,
159
- # block_out_channels=conditioning_embedding_out_channels,
160
- # conditioning_channels=conditioning_channels,
161
- # )
162
-
163
- # `attention_head_dim` is doubled to account for the mixing.
164
- # It needs to crafted when we get the actual checkpoints.
165
- self.transformer_blocks = nn.ModuleList(
166
- [
167
- JointTransformerBlock(
168
- dim=self.inner_dim,
169
- num_attention_heads=num_attention_heads,
170
- attention_head_dim=attention_head_dim if version.parse(diffusers.__version__) >= version.parse('0.30.0.dev0') else self.inner_dim,
171
- context_pre_only=False,
172
- )
173
- for _ in range(num_layers)
174
- ]
175
- )
176
-
177
- # controlnet_blocks
178
- self.controlnet_blocks = nn.ModuleList([])
179
- for _ in range(len(self.transformer_blocks)):
180
- controlnet_block = zero_module(nn.Linear(self.inner_dim, self.inner_dim))
181
- self.controlnet_blocks.append(controlnet_block)
182
-
183
- # control condition embedding
184
- pos_embed_cond = PatchEmbed(
185
- height=sample_size,
186
- width=sample_size,
187
- patch_size=patch_size,
188
- in_channels=in_channels + 1,
189
- embed_dim=self.inner_dim,
190
- pos_embed_type=None,
191
- )
192
- # pos_embed_cond = nn.Linear(in_channels + 1, self.inner_dim)
193
- self.pos_embed_cond = zero_module(pos_embed_cond)
194
-
195
- self.gradient_checkpointing = False
196
-
197
- # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
198
- def enable_forward_chunking(
199
- self, chunk_size: Optional[int] = None, dim: int = 0
200
- ) -> None:
201
- """
202
- Sets the attention processor to use [feed forward
203
- chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
204
-
205
- Parameters:
206
- chunk_size (`int`, *optional*):
207
- The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
208
- over each tensor of dim=`dim`.
209
- dim (`int`, *optional*, defaults to `0`):
210
- The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
211
- or dim=1 (sequence length).
212
- """
213
- if dim not in [0, 1]:
214
- raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
215
-
216
- # By default chunk size is 1
217
- chunk_size = chunk_size or 1
218
-
219
- def fn_recursive_feed_forward(
220
- module: torch.nn.Module, chunk_size: int, dim: int
221
- ):
222
- if hasattr(module, "set_chunk_feed_forward"):
223
- module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
224
-
225
- for child in module.children():
226
- fn_recursive_feed_forward(child, chunk_size, dim)
227
-
228
- for module in self.children():
229
- fn_recursive_feed_forward(module, chunk_size, dim)
230
-
231
- @property
232
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
233
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
234
- r"""
235
- Returns:
236
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
237
- indexed by its weight name.
238
- """
239
- # set recursively
240
- processors = {}
241
-
242
- def fn_recursive_add_processors(
243
- name: str,
244
- module: torch.nn.Module,
245
- processors: Dict[str, AttentionProcessor],
246
- ):
247
- if hasattr(module, "get_processor"):
248
- processors[f"{name}.processor"] = module.get_processor(
249
- return_deprecated_lora=True
250
- )
251
-
252
- for sub_name, child in module.named_children():
253
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
254
-
255
- return processors
256
-
257
- for name, module in self.named_children():
258
- fn_recursive_add_processors(name, module, processors)
259
-
260
- return processors
261
-
262
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
263
- def set_attn_processor(
264
- self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
265
- ):
266
- r"""
267
- Sets the attention processor to use to compute attention.
268
-
269
- Parameters:
270
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
271
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
272
- for **all** `Attention` layers.
273
-
274
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
275
- processor. This is strongly recommended when setting trainable attention processors.
276
-
277
- """
278
- count = len(self.attn_processors.keys())
279
-
280
- if isinstance(processor, dict) and len(processor) != count:
281
- raise ValueError(
282
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
283
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
284
- )
285
-
286
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
287
- if hasattr(module, "set_processor"):
288
- if not isinstance(processor, dict):
289
- module.set_processor(processor)
290
- else:
291
- module.set_processor(processor.pop(f"{name}.processor"))
292
-
293
- for sub_name, child in module.named_children():
294
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
295
-
296
- for name, module in self.named_children():
297
- fn_recursive_attn_processor(name, module, processor)
298
-
299
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
300
- def fuse_qkv_projections(self):
301
- """
302
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
303
- are fused. For cross-attention modules, key and value projection matrices are fused.
304
-
305
- <Tip warning={true}>
306
-
307
- This API is 🧪 experimental.
308
-
309
- </Tip>
310
- """
311
- self.original_attn_processors = None
312
-
313
- for _, attn_processor in self.attn_processors.items():
314
- if "Added" in str(attn_processor.__class__.__name__):
315
- raise ValueError(
316
- "`fuse_qkv_projections()` is not supported for models having added KV projections."
317
- )
318
-
319
- self.original_attn_processors = self.attn_processors
320
-
321
- for module in self.modules():
322
- if isinstance(module, Attention):
323
- module.fuse_projections(fuse=True)
324
-
325
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
326
- def unfuse_qkv_projections(self):
327
- """Disables the fused QKV projection if enabled.
328
-
329
- <Tip warning={true}>
330
-
331
- This API is 🧪 experimental.
332
-
333
- </Tip>
334
-
335
- """
336
- if self.original_attn_processors is not None:
337
- self.set_attn_processor(self.original_attn_processors)
338
-
339
- def _set_gradient_checkpointing(self, module, value=False):
340
- if hasattr(module, "gradient_checkpointing"):
341
- module.gradient_checkpointing = value
342
-
343
- @classmethod
344
- def from_transformer(
345
- cls, transformer, num_layers=None, load_weights_from_transformer=True
346
- ):
347
- config = transformer.config
348
- config["num_layers"] = num_layers or config.num_layers
349
- controlnet = cls(**config)
350
-
351
- if load_weights_from_transformer:
352
- controlnet.pos_embed.load_state_dict(
353
- transformer.pos_embed.state_dict(), strict=False
354
- )
355
- controlnet.time_text_embed.load_state_dict(
356
- transformer.time_text_embed.state_dict(), strict=False
357
- )
358
- controlnet.context_embedder.load_state_dict(
359
- transformer.context_embedder.state_dict(), strict=False
360
- )
361
- controlnet.transformer_blocks.load_state_dict(
362
- transformer.transformer_blocks.state_dict(), strict=False
363
- )
364
-
365
- return controlnet
366
-
367
- def forward(
368
- self,
369
- hidden_states: torch.FloatTensor,
370
- controlnet_cond: torch.Tensor,
371
- conditioning_scale: float = 1.0,
372
- encoder_hidden_states: torch.FloatTensor = None,
373
- pooled_projections: torch.FloatTensor = None,
374
- timestep: torch.LongTensor = None,
375
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
376
- return_dict: bool = True,
377
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
378
- """
379
- The [`SD3Transformer2DModel`] forward method.
380
-
381
- Args:
382
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
383
- Input `hidden_states`.
384
- controlnet_cond (`torch.Tensor`):
385
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
386
- conditioning_scale (`float`, defaults to `1.0`):
387
- The scale factor for ControlNet outputs.
388
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
389
- Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
390
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
391
- from the embeddings of input conditions.
392
- timestep ( `torch.LongTensor`):
393
- Used to indicate denoising step.
394
- joint_attention_kwargs (`dict`, *optional*):
395
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
396
- `self.processor` in
397
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
398
- return_dict (`bool`, *optional*, defaults to `True`):
399
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
400
- tuple.
401
-
402
- Returns:
403
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
404
- `tuple` where the first element is the sample tensor.
405
- """
406
- if joint_attention_kwargs is not None:
407
- joint_attention_kwargs = joint_attention_kwargs.copy()
408
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
409
- else:
410
- lora_scale = 1.0
411
-
412
- if USE_PEFT_BACKEND:
413
- # weight the lora layers by setting `lora_scale` for each PEFT layer
414
- scale_lora_layers(self, lora_scale)
415
- else:
416
- if (
417
- joint_attention_kwargs is not None
418
- and joint_attention_kwargs.get("scale", None) is not None
419
- ):
420
- logger.warning(
421
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
422
- )
423
-
424
- height, width = hidden_states.shape[-2:]
425
-
426
- hidden_states = self.pos_embed(
427
- hidden_states
428
- ) # takes care of adding positional embeddings too. b,c,H,W -> b, N, C
429
- temb = self.time_text_embed(timestep, pooled_projections)
430
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
431
-
432
- # add condition
433
- hidden_states = hidden_states + self.pos_embed_cond(controlnet_cond)
434
-
435
- block_res_samples = ()
436
-
437
- for block in self.transformer_blocks:
438
- if self.training and self.gradient_checkpointing:
439
-
440
- def create_custom_forward(module, return_dict=None):
441
- def custom_forward(*inputs):
442
- if return_dict is not None:
443
- return module(*inputs, return_dict=return_dict)
444
- else:
445
- return module(*inputs)
446
-
447
- return custom_forward
448
-
449
- ckpt_kwargs: Dict[str, Any] = (
450
- {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
451
- )
452
- hidden_states = torch.utils.checkpoint.checkpoint(
453
- create_custom_forward(block),
454
- hidden_states,
455
- encoder_hidden_states,
456
- temb,
457
- **ckpt_kwargs,
458
- )
459
-
460
- else:
461
- encoder_hidden_states, hidden_states = block(
462
- hidden_states=hidden_states,
463
- encoder_hidden_states=encoder_hidden_states,
464
- temb=temb,
465
- )
466
-
467
- block_res_samples = block_res_samples + (hidden_states,)
468
-
469
- controlnet_block_res_samples = ()
470
- for block_res_sample, controlnet_block in zip(
471
- block_res_samples, self.controlnet_blocks
472
- ):
473
- block_res_sample = controlnet_block(block_res_sample)
474
- controlnet_block_res_samples = controlnet_block_res_samples + (
475
- block_res_sample,
476
- )
477
-
478
- # 6. scaling
479
- controlnet_block_res_samples = [
480
- sample * conditioning_scale for sample in controlnet_block_res_samples
481
- ]
482
-
483
- if USE_PEFT_BACKEND:
484
- # remove `lora_scale` from each PEFT layer
485
- unscale_lora_layers(self, lora_scale)
486
-
487
- if not return_dict:
488
- return (controlnet_block_res_samples,)
489
-
490
- return SD3ControlNetOutput(
491
- controlnet_block_samples=controlnet_block_res_samples
492
- )
493
-
494
- def invert_copy_paste(self, controlnet_block_samples):
495
- controlnet_block_samples = controlnet_block_samples + controlnet_block_samples[::-1]
496
- return controlnet_block_samples
497
-
498
- class SD3MultiControlNetModel(ModelMixin):
499
- r"""
500
- `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
501
-
502
- This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
503
- compatible with `SD3ControlNetModel`.
504
-
505
- Args:
506
- controlnets (`List[SD3ControlNetModel]`):
507
- Provides additional conditioning to the unet during the denoising process. You must set multiple
508
- `SD3ControlNetModel` as a list.
509
- """
510
-
511
- def __init__(self, controlnets):
512
- super().__init__()
513
- self.nets = nn.ModuleList(controlnets)
514
-
515
- def forward(
516
- self,
517
- hidden_states: torch.FloatTensor,
518
- controlnet_cond: List[torch.tensor],
519
- conditioning_scale: List[float],
520
- pooled_projections: torch.FloatTensor,
521
- encoder_hidden_states: torch.FloatTensor = None,
522
- timestep: torch.LongTensor = None,
523
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
524
- return_dict: bool = True,
525
- ) -> Union[SD3ControlNetOutput, Tuple]:
526
- for i, (image, scale, controlnet) in enumerate(
527
- zip(controlnet_cond, conditioning_scale, self.nets)
528
- ):
529
- block_samples = controlnet(
530
- hidden_states=hidden_states,
531
- timestep=timestep,
532
- encoder_hidden_states=encoder_hidden_states,
533
- pooled_projections=pooled_projections,
534
- controlnet_cond=image,
535
- conditioning_scale=scale,
536
- joint_attention_kwargs=joint_attention_kwargs,
537
- return_dict=return_dict,
538
- )
539
-
540
- # merge samples
541
- if i == 0:
542
- control_block_samples = block_samples
543
- else:
544
- control_block_samples = [
545
- control_block_sample + block_sample
546
- for control_block_sample, block_sample in zip(
547
- control_block_samples[0], block_samples[0]
548
- )
549
- ]
550
- control_block_samples = (tuple(control_block_samples),)
551
-
552
- return control_block_samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo.py DELETED
@@ -1,53 +0,0 @@
1
- from diffusers.utils import load_image, check_min_version
2
- import torch
3
-
4
- # Local File
5
- from pipeline_sd3_controlnet_inpainting import StableDiffusion3ControlNetInpaintingPipeline, one_image_and_mask
6
- from controlnet_sd3 import SD3ControlNetModel
7
-
8
- check_min_version("0.29.2")
9
-
10
- # Build model
11
- controlnet = SD3ControlNetModel.from_pretrained(
12
- "alimama-creative/SD3-Controlnet-Inpainting",
13
- use_safetensors=True,
14
- )
15
- pipe = StableDiffusion3ControlNetInpaintingPipeline.from_pretrained(
16
- "stabilityai/stable-diffusion-3-medium-diffusers",
17
- controlnet=controlnet,
18
- torch_dtype=torch.float16,
19
- )
20
- pipe.text_encoder.to(torch.float16)
21
- pipe.controlnet.to(torch.float16)
22
- pipe.to("cuda")
23
-
24
- # Load image
25
- image = load_image(
26
- "https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting/blob/main/images/prod.png"
27
- )
28
- mask = load_image(
29
- "https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting/blob/main/images/mask.jpeg"
30
- )
31
-
32
- # Set args
33
- width = 1024
34
- height = 1024
35
- prompt="a woman wearing a white jacket, black hat and black pants is standing in a field, the hat writes SD3"
36
- generator = torch.Generator(device="cuda").manual_seed(24)
37
- input_dict = one_image_and_mask(image, mask, size=(width, height), latent_scale=pipe.vae_scale_factor, invert_mask = True)
38
-
39
- # Inference
40
- res_image = pipe(
41
- negative_prompt='deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW',
42
- prompt=prompt,
43
- height=height,
44
- width=width,
45
- control_image= input_dict['pil_masked_image'], # H, W, C,
46
- control_mask=input_dict["mask"] > 0.5, # B,1,H,W
47
- num_inference_steps=28,
48
- generator=generator,
49
- controlnet_conditioning_scale=0.95,
50
- guidance_scale=7,
51
- ).images[0]
52
-
53
- res_image.save(f'res.png')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1818c32db9a7541572a5ebb41b82f0fd4859643a1aba87c91d9f352d09e523c7
3
- size 4160564320
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7e2bcf98ed0989558dd05d857cc49c0aff14dfa3197050d65e51a9d37008dde
3
+ size 4160564288
pipeline_sd3_controlnet_inpainting.py DELETED
@@ -1,1333 +0,0 @@
1
- # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import inspect
16
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
-
18
- import torch
19
- from transformers import (
20
- CLIPTextModelWithProjection,
21
- CLIPTokenizer,
22
- T5EncoderModel,
23
- T5TokenizerFast,
24
- )
25
-
26
- from PIL import Image, ImageOps
27
- import numpy as np
28
- import os
29
- from torchvision.transforms import v2
30
-
31
- from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
32
- from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin
33
- from diffusers.models.autoencoders import AutoencoderKL
34
- from diffusers.models.transformers import SD3Transformer2DModel
35
- from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
36
- from diffusers.utils import (
37
- is_torch_xla_available,
38
- logging,
39
- replace_example_docstring,
40
- )
41
- from diffusers.utils.torch_utils import randn_tensor
42
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
43
- from diffusers.pipelines.stable_diffusion_3.pipeline_output import (
44
- StableDiffusion3PipelineOutput,
45
- )
46
- from torchvision.transforms.functional import resize, InterpolationMode
47
-
48
- from controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
49
-
50
- if is_torch_xla_available():
51
- import torch_xla.core.xla_model as xm
52
-
53
- XLA_AVAILABLE = True
54
- else:
55
- XLA_AVAILABLE = False
56
-
57
-
58
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
-
60
- EXAMPLE_DOC_STRING = """
61
- Examples:
62
- ```py
63
- >>> import torch
64
- >>> from diffusers import StableDiffusion3ControlNetPipeline
65
- >>> from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
66
- >>> from diffusers.utils import load_image
67
-
68
- >>> controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16)
69
-
70
- >>> pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
71
- ... "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
72
- ... )
73
- >>> pipe.to("cuda")
74
- >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
75
- >>> prompt = "A girl holding a sign that says InstantX"
76
- >>> image = pipe(prompt, control_image=control_image, controlnet_conditioning_scale=0.7).images[0]
77
- >>> image.save("sd3.png")
78
- ```
79
- """
80
-
81
- def one_image_and_mask(image, mask, size = None, latent_scale = 8 , invert_mask = False):
82
- '''
83
- Image : PIL Image, Torch Tensor [-1, 1], Path, B,C,H,W
84
- Mask : PIL Image , Torch Tensor [0, 1], Path, B,1,H,W
85
- '''
86
- # size = (W, H)
87
- if size is not None:
88
- if not ( type(size) == list or type(size) == tuple):
89
- size = (size, size)
90
-
91
- # Get image @ torch tensor
92
- if type(image) == str and os.path.exists(image):
93
- image = Image.open(image)
94
-
95
- if isinstance(image, Image.Image):
96
- image = image.convert("RGB")
97
- if size is not None:
98
- image = image.resize(size, Image.Resampling.LANCZOS)
99
- pil_image = image
100
- image_arr = np.array(image)
101
- assert image_arr.ndim == 3
102
- assert image_arr.shape[2] == 3
103
- th_image = torch.from_numpy(image_arr).float() / 127. - 1
104
- th_image = th_image.permute(2, 0, 1)
105
- else:
106
- th_image = image
107
- pil_image = None
108
-
109
- # Get BCHW
110
- assert isinstance(th_image, torch.Tensor)
111
- if len(th_image.shape) == 3:
112
- th_image = th_image.unsqueeze(0)
113
- H, W = th_image.shape[-2:]
114
- assert H % 8 == 0 and W % 8 == 0
115
-
116
- # Get mask @ torch tensor
117
- if type(mask) == str and os.path.exists(mask):
118
- mask = Image.open(mask)
119
-
120
- if isinstance(mask, Image.Image):
121
- mask = mask.convert("L")
122
- if invert_mask:
123
- mask = ImageOps.invert(mask)
124
- mask = mask.resize((W, H), Image.Resampling.LANCZOS)
125
- pil_mask = mask
126
- mask_arr = np.array(mask)
127
- if mask_arr.ndim == 3 and mask_arr.shape[2] == 3:
128
- mask_arr = mask_arr[:, :, 0] # H, W
129
- th_mask = torch.from_numpy(mask_arr).float() / 255.
130
- th_mask = th_mask.unsqueeze(0)
131
- else:
132
- th_mask = mask
133
- pil_mask = None
134
-
135
- assert isinstance(th_mask, torch.Tensor)
136
- if len(th_mask.shape) == 3:
137
- th_mask = th_mask.unsqueeze(0)
138
-
139
- # Get mask at latent space
140
- th_mask_latent = torch.nn.functional.interpolate(
141
- th_mask, size=(H // latent_scale, W // latent_scale), mode="bilinear", antialias=True
142
- )
143
-
144
- # Get masked image for vae-cond
145
- masked_image = th_image.clone()
146
- masked_image[(th_mask < 0.5).repeat(1,3,1,1)] = - 1. # set 0. like power paint @ https://github.com/open-mmlab/PowerPaint/blob/main/powerpaint/pipelines/pipeline_PowerPaint.py
147
-
148
- # Get pil masked image
149
- pil_masked_image = v2.ToPILImage()((masked_image/2 + 1/2).clip(0, 1).squeeze(0))
150
-
151
- # Get masked image
152
- return {
153
- 'image': th_image,
154
- 'mask': th_mask,
155
- 'mask_latent': th_mask_latent,
156
- 'masked_image': masked_image,
157
- 'pil_image': pil_image,
158
- 'pil_mask': pil_mask,
159
- 'pil_masked_image': pil_masked_image
160
- }
161
-
162
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
163
- def retrieve_timesteps(
164
- scheduler,
165
- num_inference_steps: Optional[int] = None,
166
- device: Optional[Union[str, torch.device]] = None,
167
- timesteps: Optional[List[int]] = None,
168
- sigmas: Optional[List[float]] = None,
169
- **kwargs,
170
- ):
171
- """
172
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
173
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
174
-
175
- Args:
176
- scheduler (`SchedulerMixin`):
177
- The scheduler to get timesteps from.
178
- num_inference_steps (`int`):
179
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
180
- must be `None`.
181
- device (`str` or `torch.device`, *optional*):
182
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
183
- timesteps (`List[int]`, *optional*):
184
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
185
- `num_inference_steps` and `sigmas` must be `None`.
186
- sigmas (`List[float]`, *optional*):
187
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
188
- `num_inference_steps` and `timesteps` must be `None`.
189
-
190
- Returns:
191
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
192
- second element is the number of inference steps.
193
- """
194
- if timesteps is not None and sigmas is not None:
195
- raise ValueError(
196
- "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
197
- )
198
- if timesteps is not None:
199
- accepts_timesteps = "timesteps" in set(
200
- inspect.signature(scheduler.set_timesteps).parameters.keys()
201
- )
202
- if not accepts_timesteps:
203
- raise ValueError(
204
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
205
- f" timestep schedules. Please check whether you are using the correct scheduler."
206
- )
207
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
208
- timesteps = scheduler.timesteps
209
- num_inference_steps = len(timesteps)
210
- elif sigmas is not None:
211
- accept_sigmas = "sigmas" in set(
212
- inspect.signature(scheduler.set_timesteps).parameters.keys()
213
- )
214
- if not accept_sigmas:
215
- raise ValueError(
216
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
217
- f" sigmas schedules. Please check whether you are using the correct scheduler."
218
- )
219
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
220
- timesteps = scheduler.timesteps
221
- num_inference_steps = len(timesteps)
222
- else:
223
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
224
- timesteps = scheduler.timesteps
225
- return timesteps, num_inference_steps
226
-
227
-
228
- class StableDiffusion3ControlNetInpaintingPipeline(
229
- DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin
230
- ):
231
- r"""
232
- Args:
233
- transformer ([`SD3Transformer2DModel`]):
234
- Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
235
- scheduler ([`FlowMatchEulerDiscreteScheduler`]):
236
- A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
237
- vae ([`AutoencoderKL`]):
238
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
239
- text_encoder ([`CLIPTextModelWithProjection`]):
240
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
241
- specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
242
- with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
243
- as its dimension.
244
- text_encoder_2 ([`CLIPTextModelWithProjection`]):
245
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
246
- specifically the
247
- [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
248
- variant.
249
- text_encoder_3 ([`T5EncoderModel`]):
250
- Frozen text-encoder. Stable Diffusion 3 uses
251
- [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
252
- [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
253
- tokenizer (`CLIPTokenizer`):
254
- Tokenizer of class
255
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
256
- tokenizer_2 (`CLIPTokenizer`):
257
- Second Tokenizer of class
258
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
259
- tokenizer_3 (`T5TokenizerFast`):
260
- Tokenizer of class
261
- [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
262
- controlnet ([`SD3ControlNetModel`] or `List[SD3ControlNetModel]` or [`SD3MultiControlNetModel`]):
263
- Provides additional conditioning to the `unet` during the denoising process. If you set multiple
264
- ControlNets as a list, the outputs from each ControlNet are added together to create one combined
265
- additional conditioning.
266
- """
267
-
268
- model_cpu_offload_seq = (
269
- "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
270
- )
271
- _optional_components = []
272
- _callback_tensor_inputs = [
273
- "latents",
274
- "prompt_embeds",
275
- "negative_prompt_embeds",
276
- "negative_pooled_prompt_embeds",
277
- ]
278
-
279
- def __init__(
280
- self,
281
- transformer: SD3Transformer2DModel,
282
- scheduler: FlowMatchEulerDiscreteScheduler,
283
- vae: AutoencoderKL,
284
- text_encoder: CLIPTextModelWithProjection,
285
- tokenizer: CLIPTokenizer,
286
- text_encoder_2: CLIPTextModelWithProjection,
287
- tokenizer_2: CLIPTokenizer,
288
- text_encoder_3: T5EncoderModel,
289
- tokenizer_3: T5TokenizerFast,
290
- controlnet: Union[
291
- SD3ControlNetModel,
292
- List[SD3ControlNetModel],
293
- Tuple[SD3ControlNetModel],
294
- SD3MultiControlNetModel,
295
- ],
296
- ):
297
- super().__init__()
298
-
299
- self.register_modules(
300
- vae=vae,
301
- text_encoder=text_encoder,
302
- text_encoder_2=text_encoder_2,
303
- text_encoder_3=text_encoder_3,
304
- tokenizer=tokenizer,
305
- tokenizer_2=tokenizer_2,
306
- tokenizer_3=tokenizer_3,
307
- transformer=transformer,
308
- scheduler=scheduler,
309
- controlnet=controlnet,
310
- )
311
- self.vae_scale_factor = (
312
- 2 ** (len(self.vae.config.block_out_channels) - 1)
313
- if hasattr(self, "vae") and self.vae is not None
314
- else 8
315
- )
316
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
317
- self.control_image_processor = VaeImageProcessor(
318
- vae_scale_factor=self.vae_scale_factor,
319
- do_convert_rgb=True,
320
- do_normalize=False,
321
- )
322
- self.tokenizer_max_length = (
323
- self.tokenizer.model_max_length
324
- if hasattr(self, "tokenizer") and self.tokenizer is not None
325
- else 77
326
- )
327
- self.default_sample_size = (
328
- self.transformer.config.sample_size
329
- if hasattr(self, "transformer") and self.transformer is not None
330
- else 128
331
- )
332
-
333
- # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
334
- def _get_t5_prompt_embeds(
335
- self,
336
- prompt: Union[str, List[str]] = None,
337
- num_images_per_prompt: int = 1,
338
- device: Optional[torch.device] = None,
339
- dtype: Optional[torch.dtype] = None,
340
- ):
341
- device = device or self._execution_device
342
- dtype = dtype or self.text_encoder.dtype
343
-
344
- prompt = [prompt] if isinstance(prompt, str) else prompt
345
- batch_size = len(prompt)
346
-
347
- if self.text_encoder_3 is None:
348
- return torch.zeros(
349
- (
350
- batch_size,
351
- self.tokenizer_max_length,
352
- self.transformer.config.joint_attention_dim,
353
- ),
354
- device=device,
355
- dtype=dtype,
356
- )
357
-
358
- text_inputs = self.tokenizer_3(
359
- prompt,
360
- padding="max_length",
361
- max_length=self.tokenizer_max_length,
362
- truncation=True,
363
- add_special_tokens=True,
364
- return_tensors="pt",
365
- )
366
- text_input_ids = text_inputs.input_ids
367
- untruncated_ids = self.tokenizer_3(
368
- prompt, padding="longest", return_tensors="pt"
369
- ).input_ids
370
-
371
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
372
- text_input_ids, untruncated_ids
373
- ):
374
- removed_text = self.tokenizer_3.batch_decode(
375
- untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
376
- )
377
- logger.warning(
378
- "The following part of your input was truncated because CLIP can only handle sequences up to"
379
- f" {self.tokenizer_max_length} tokens: {removed_text}"
380
- )
381
-
382
- prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
383
-
384
- dtype = self.text_encoder_3.dtype
385
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
386
-
387
- _, seq_len, _ = prompt_embeds.shape
388
-
389
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
390
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
391
- prompt_embeds = prompt_embeds.view(
392
- batch_size * num_images_per_prompt, seq_len, -1
393
- )
394
-
395
- return prompt_embeds
396
-
397
- # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
398
- def _get_clip_prompt_embeds(
399
- self,
400
- prompt: Union[str, List[str]],
401
- num_images_per_prompt: int = 1,
402
- device: Optional[torch.device] = None,
403
- clip_skip: Optional[int] = None,
404
- clip_model_index: int = 0,
405
- ):
406
- device = device or self._execution_device
407
-
408
- clip_tokenizers = [self.tokenizer, self.tokenizer_2]
409
- clip_text_encoders = [self.text_encoder, self.text_encoder_2]
410
-
411
- tokenizer = clip_tokenizers[clip_model_index]
412
- text_encoder = clip_text_encoders[clip_model_index]
413
-
414
- prompt = [prompt] if isinstance(prompt, str) else prompt
415
- batch_size = len(prompt)
416
-
417
- text_inputs = tokenizer(
418
- prompt,
419
- padding="max_length",
420
- max_length=self.tokenizer_max_length,
421
- truncation=True,
422
- return_tensors="pt",
423
- )
424
-
425
- text_input_ids = text_inputs.input_ids
426
- untruncated_ids = tokenizer(
427
- prompt, padding="longest", return_tensors="pt"
428
- ).input_ids
429
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
430
- text_input_ids, untruncated_ids
431
- ):
432
- removed_text = tokenizer.batch_decode(
433
- untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
434
- )
435
- logger.warning(
436
- "The following part of your input was truncated because CLIP can only handle sequences up to"
437
- f" {self.tokenizer_max_length} tokens: {removed_text}"
438
- )
439
- prompt_embeds = text_encoder(
440
- text_input_ids.to(device), output_hidden_states=True
441
- )
442
- pooled_prompt_embeds = prompt_embeds[0]
443
-
444
- if clip_skip is None:
445
- prompt_embeds = prompt_embeds.hidden_states[-2]
446
- else:
447
- prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
448
-
449
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
450
-
451
- _, seq_len, _ = prompt_embeds.shape
452
- # duplicate text embeddings for each generation per prompt, using mps friendly method
453
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
454
- prompt_embeds = prompt_embeds.view(
455
- batch_size * num_images_per_prompt, seq_len, -1
456
- )
457
-
458
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
459
- pooled_prompt_embeds = pooled_prompt_embeds.view(
460
- batch_size * num_images_per_prompt, -1
461
- )
462
-
463
- return prompt_embeds, pooled_prompt_embeds
464
-
465
- # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
466
- def encode_prompt(
467
- self,
468
- prompt: Union[str, List[str]],
469
- prompt_2: Union[str, List[str]],
470
- prompt_3: Union[str, List[str]],
471
- device: Optional[torch.device] = None,
472
- num_images_per_prompt: int = 1,
473
- do_classifier_free_guidance: bool = True,
474
- negative_prompt: Optional[Union[str, List[str]]] = None,
475
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
476
- negative_prompt_3: Optional[Union[str, List[str]]] = None,
477
- prompt_embeds: Optional[torch.FloatTensor] = None,
478
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
479
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
480
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
481
- clip_skip: Optional[int] = None,
482
- ):
483
- r"""
484
-
485
- Args:
486
- prompt (`str` or `List[str]`, *optional*):
487
- prompt to be encoded
488
- prompt_2 (`str` or `List[str]`, *optional*):
489
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
490
- used in all text-encoders
491
- prompt_3 (`str` or `List[str]`, *optional*):
492
- The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
493
- used in all text-encoders
494
- device: (`torch.device`):
495
- torch device
496
- num_images_per_prompt (`int`):
497
- number of images that should be generated per prompt
498
- do_classifier_free_guidance (`bool`):
499
- whether to use classifier free guidance or not
500
- negative_prompt (`str` or `List[str]`, *optional*):
501
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
502
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
503
- less than `1`).
504
- negative_prompt_2 (`str` or `List[str]`, *optional*):
505
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
506
- `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
507
- negative_prompt_2 (`str` or `List[str]`, *optional*):
508
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
509
- `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
510
- prompt_embeds (`torch.FloatTensor`, *optional*):
511
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
512
- provided, text embeddings will be generated from `prompt` input argument.
513
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
514
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
515
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
516
- argument.
517
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
518
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
519
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
520
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
521
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
522
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
523
- input argument.
524
- clip_skip (`int`, *optional*):
525
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
526
- the output of the pre-final layer will be used for computing the prompt embeddings.
527
- """
528
- device = device or self._execution_device
529
-
530
- prompt = [prompt] if isinstance(prompt, str) else prompt
531
- if prompt is not None:
532
- batch_size = len(prompt)
533
- else:
534
- batch_size = prompt_embeds.shape[0]
535
-
536
- if prompt_embeds is None:
537
- prompt_2 = prompt_2 or prompt
538
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
539
-
540
- prompt_3 = prompt_3 or prompt
541
- prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
542
-
543
- prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
544
- prompt=prompt,
545
- device=device,
546
- num_images_per_prompt=num_images_per_prompt,
547
- clip_skip=clip_skip,
548
- clip_model_index=0,
549
- )
550
- prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
551
- prompt=prompt_2,
552
- device=device,
553
- num_images_per_prompt=num_images_per_prompt,
554
- clip_skip=clip_skip,
555
- clip_model_index=1,
556
- )
557
- clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
558
-
559
- t5_prompt_embed = self._get_t5_prompt_embeds(
560
- prompt=prompt_3,
561
- num_images_per_prompt=num_images_per_prompt,
562
- device=device,
563
- )
564
-
565
- clip_prompt_embeds = torch.nn.functional.pad(
566
- clip_prompt_embeds,
567
- (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]),
568
- )
569
-
570
- prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
571
- pooled_prompt_embeds = torch.cat(
572
- [pooled_prompt_embed, pooled_prompt_2_embed], dim=-1
573
- )
574
-
575
- if do_classifier_free_guidance and negative_prompt_embeds is None:
576
- negative_prompt = negative_prompt or ""
577
- negative_prompt_2 = negative_prompt_2 or negative_prompt
578
- negative_prompt_3 = negative_prompt_3 or negative_prompt
579
-
580
- # normalize str to list
581
- negative_prompt = (
582
- batch_size * [negative_prompt]
583
- if isinstance(negative_prompt, str)
584
- else negative_prompt
585
- )
586
- negative_prompt_2 = (
587
- batch_size * [negative_prompt_2]
588
- if isinstance(negative_prompt_2, str)
589
- else negative_prompt_2
590
- )
591
- negative_prompt_3 = (
592
- batch_size * [negative_prompt_3]
593
- if isinstance(negative_prompt_3, str)
594
- else negative_prompt_3
595
- )
596
-
597
- if prompt is not None and type(prompt) is not type(negative_prompt):
598
- raise TypeError(
599
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
600
- f" {type(prompt)}."
601
- )
602
- elif batch_size != len(negative_prompt):
603
- raise ValueError(
604
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
605
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
606
- " the batch size of `prompt`."
607
- )
608
-
609
- negative_prompt_embed, negative_pooled_prompt_embed = (
610
- self._get_clip_prompt_embeds(
611
- negative_prompt,
612
- device=device,
613
- num_images_per_prompt=num_images_per_prompt,
614
- clip_skip=None,
615
- clip_model_index=0,
616
- )
617
- )
618
- negative_prompt_2_embed, negative_pooled_prompt_2_embed = (
619
- self._get_clip_prompt_embeds(
620
- negative_prompt_2,
621
- device=device,
622
- num_images_per_prompt=num_images_per_prompt,
623
- clip_skip=None,
624
- clip_model_index=1,
625
- )
626
- )
627
- negative_clip_prompt_embeds = torch.cat(
628
- [negative_prompt_embed, negative_prompt_2_embed], dim=-1
629
- )
630
-
631
- t5_negative_prompt_embed = self._get_t5_prompt_embeds(
632
- prompt=negative_prompt_3,
633
- num_images_per_prompt=num_images_per_prompt,
634
- device=device,
635
- )
636
-
637
- negative_clip_prompt_embeds = torch.nn.functional.pad(
638
- negative_clip_prompt_embeds,
639
- (
640
- 0,
641
- t5_negative_prompt_embed.shape[-1]
642
- - negative_clip_prompt_embeds.shape[-1],
643
- ),
644
- )
645
-
646
- negative_prompt_embeds = torch.cat(
647
- [negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2
648
- )
649
- negative_pooled_prompt_embeds = torch.cat(
650
- [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
651
- )
652
-
653
- return (
654
- prompt_embeds,
655
- negative_prompt_embeds,
656
- pooled_prompt_embeds,
657
- negative_pooled_prompt_embeds,
658
- )
659
-
660
- def check_inputs(
661
- self,
662
- prompt,
663
- prompt_2,
664
- prompt_3,
665
- height,
666
- width,
667
- negative_prompt=None,
668
- negative_prompt_2=None,
669
- negative_prompt_3=None,
670
- prompt_embeds=None,
671
- negative_prompt_embeds=None,
672
- pooled_prompt_embeds=None,
673
- negative_pooled_prompt_embeds=None,
674
- callback_on_step_end_tensor_inputs=None,
675
- ):
676
- if height % 8 != 0 or width % 8 != 0:
677
- raise ValueError(
678
- f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
679
- )
680
-
681
- if callback_on_step_end_tensor_inputs is not None and not all(
682
- k in self._callback_tensor_inputs
683
- for k in callback_on_step_end_tensor_inputs
684
- ):
685
- raise ValueError(
686
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
687
- )
688
-
689
- if prompt is not None and prompt_embeds is not None:
690
- raise ValueError(
691
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
692
- " only forward one of the two."
693
- )
694
- elif prompt_2 is not None and prompt_embeds is not None:
695
- raise ValueError(
696
- f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
697
- " only forward one of the two."
698
- )
699
- elif prompt_3 is not None and prompt_embeds is not None:
700
- raise ValueError(
701
- f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
702
- " only forward one of the two."
703
- )
704
- elif prompt is None and prompt_embeds is None:
705
- raise ValueError(
706
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
707
- )
708
- elif prompt is not None and (
709
- not isinstance(prompt, str) and not isinstance(prompt, list)
710
- ):
711
- raise ValueError(
712
- f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
713
- )
714
- elif prompt_2 is not None and (
715
- not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
716
- ):
717
- raise ValueError(
718
- f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
719
- )
720
- elif prompt_3 is not None and (
721
- not isinstance(prompt_3, str) and not isinstance(prompt_3, list)
722
- ):
723
- raise ValueError(
724
- f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}"
725
- )
726
-
727
- if negative_prompt is not None and negative_prompt_embeds is not None:
728
- raise ValueError(
729
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
730
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
731
- )
732
- elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
733
- raise ValueError(
734
- f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
735
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
736
- )
737
- elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
738
- raise ValueError(
739
- f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
740
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
741
- )
742
-
743
- if prompt_embeds is not None and negative_prompt_embeds is not None:
744
- if prompt_embeds.shape != negative_prompt_embeds.shape:
745
- raise ValueError(
746
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
747
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
748
- f" {negative_prompt_embeds.shape}."
749
- )
750
-
751
- if prompt_embeds is not None and pooled_prompt_embeds is None:
752
- raise ValueError(
753
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
754
- )
755
-
756
- if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
757
- raise ValueError(
758
- "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
759
- )
760
-
761
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
762
- def prepare_latents(
763
- self,
764
- batch_size,
765
- num_channels_latents,
766
- height,
767
- width,
768
- dtype,
769
- device,
770
- generator,
771
- latents=None,
772
- ):
773
- shape = (
774
- batch_size,
775
- num_channels_latents,
776
- int(height) // self.vae_scale_factor,
777
- int(width) // self.vae_scale_factor,
778
- )
779
-
780
- if isinstance(generator, list) and len(generator) != batch_size:
781
- raise ValueError(
782
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
783
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
784
- )
785
-
786
- if latents is None:
787
- latents = randn_tensor(
788
- shape, generator=generator, device=device, dtype=dtype
789
- )
790
- else:
791
- latents = latents.to(device=device, dtype=dtype)
792
-
793
- return latents
794
-
795
- # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
796
- def prepare_image(
797
- self,
798
- image,
799
- width,
800
- height,
801
- batch_size,
802
- num_images_per_prompt,
803
- device,
804
- dtype,
805
- do_classifier_free_guidance=False,
806
- guess_mode=False,
807
- ):
808
- image = self.control_image_processor.preprocess(
809
- image, height=height, width=width
810
- ).to(dtype=torch.float32)
811
- image_batch_size = image.shape[0]
812
-
813
- if image_batch_size == 1:
814
- repeat_by = batch_size
815
- else:
816
- # image batch size is the same as prompt batch size
817
- repeat_by = num_images_per_prompt
818
-
819
- image = image.repeat_interleave(repeat_by, dim=0)
820
-
821
- image = image.to(device=device, dtype=dtype)
822
-
823
- if do_classifier_free_guidance and not guess_mode:
824
- image = torch.cat([image] * 2)
825
-
826
- return image
827
-
828
- def prepare_image_with_mask(
829
- self,
830
- image,
831
- mask,
832
- width,
833
- height,
834
- batch_size,
835
- num_images_per_prompt,
836
- device,
837
- dtype,
838
- do_classifier_free_guidance=False,
839
- guess_mode=False,
840
- ):
841
-
842
- if isinstance(image, torch.Tensor):
843
- pass
844
- else:
845
- image = self.image_processor.preprocess(
846
- image, height=height, width=width
847
- ) # C,H,W
848
-
849
- if isinstance(mask, torch.Tensor):
850
- pass
851
- else:
852
- raise "Control Mask must be tensor"
853
-
854
- image_batch_size = image.shape[0]
855
-
856
- if image_batch_size == 1:
857
- repeat_by = batch_size
858
- else:
859
- # image batch size is the same as prompt batch size
860
- repeat_by = num_images_per_prompt
861
-
862
- image = image.repeat_interleave(repeat_by, dim=0)
863
- mask = mask.repeat_interleave(repeat_by, dim=0)
864
-
865
- image = image.to(device=device, dtype=self.vae.dtype)
866
- mask = mask.to(device=device, dtype=dtype)
867
-
868
- image_latents = self.vae.encode(image).latent_dist.sample()
869
- image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
870
- image_latents = image_latents.to(dtype)
871
-
872
- # cat image and mask
873
- mask = torch.nn.functional.interpolate(
874
- mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
875
- )
876
-
877
- control_image = torch.cat([image_latents, mask], dim=1)
878
-
879
- if do_classifier_free_guidance and not guess_mode:
880
- control_image = torch.cat([control_image] * 2)
881
- return control_image
882
-
883
- @property
884
- def guidance_scale(self):
885
- return self._guidance_scale
886
-
887
- @property
888
- def clip_skip(self):
889
- return self._clip_skip
890
-
891
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
892
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
893
- # corresponds to doing no classifier free guidance.
894
- @property
895
- def do_classifier_free_guidance(self):
896
- return self._guidance_scale > 1
897
-
898
- @property
899
- def joint_attention_kwargs(self):
900
- return self._joint_attention_kwargs
901
-
902
- @property
903
- def num_timesteps(self):
904
- return self._num_timesteps
905
-
906
- @property
907
- def interrupt(self):
908
- return self._interrupt
909
-
910
- @torch.no_grad()
911
- @replace_example_docstring(EXAMPLE_DOC_STRING)
912
- def __call__(
913
- self,
914
- prompt: Union[str, List[str]] = None,
915
- prompt_2: Optional[Union[str, List[str]]] = None,
916
- prompt_3: Optional[Union[str, List[str]]] = None,
917
- height: Optional[int] = None,
918
- width: Optional[int] = None,
919
- num_inference_steps: int = 28,
920
- timesteps: List[int] = None,
921
- guidance_scale: float = 7.0,
922
- control_guidance_start: Union[float, List[float]] = 0.0,
923
- control_guidance_end: Union[float, List[float]] = 1.0,
924
- control_image: Union[
925
- PipelineImageInput,
926
- List[PipelineImageInput],
927
- ] = None,
928
- control_mask=None,
929
- controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
930
- controlnet_pooled_projections: Optional[torch.FloatTensor] = None,
931
- negative_prompt: Optional[Union[str, List[str]]] = None,
932
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
933
- negative_prompt_3: Optional[Union[str, List[str]]] = None,
934
- num_images_per_prompt: Optional[int] = 1,
935
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
936
- latents: Optional[torch.FloatTensor] = None,
937
- prompt_embeds: Optional[torch.FloatTensor] = None,
938
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
939
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
940
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
941
- output_type: Optional[str] = "pil",
942
- return_dict: bool = True,
943
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
944
- clip_skip: Optional[int] = None,
945
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
946
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
947
- ):
948
- r"""
949
- Function invoked when calling the pipeline for generation.
950
-
951
- Args:
952
- prompt (`str` or `List[str]`, *optional*):
953
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
954
- instead.
955
- prompt_2 (`str` or `List[str]`, *optional*):
956
- The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
957
- will be used instead
958
- prompt_3 (`str` or `List[str]`, *optional*):
959
- The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
960
- will be used instead
961
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
962
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
963
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
964
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
965
- num_inference_steps (`int`, *optional*, defaults to 50):
966
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
967
- expense of slower inference.
968
- timesteps (`List[int]`, *optional*):
969
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
970
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
971
- passed will be used. Must be in descending order.
972
- guidance_scale (`float`, *optional*, defaults to 5.0):
973
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
974
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
975
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
976
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
977
- usually at the expense of lower image quality.
978
- control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
979
- The percentage of total steps at which the ControlNet starts applying.
980
- control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
981
- The percentage of total steps at which the ControlNet stops applying.
982
- control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
983
- `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
984
- The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
985
- specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
986
- as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
987
- width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
988
- images must be passed as a list such that each element of the list can be correctly batched for input
989
- to a single ControlNet.
990
- controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
991
- The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
992
- to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
993
- the corresponding scale as a list.
994
- controlnet_pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
995
- Embeddings projected from the embeddings of controlnet input conditions.
996
- negative_prompt (`str` or `List[str]`, *optional*):
997
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
998
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
999
- less than `1`).
1000
- negative_prompt_2 (`str` or `List[str]`, *optional*):
1001
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1002
- `text_encoder_2`. If not defined, `negative_prompt` is used instead
1003
- negative_prompt_3 (`str` or `List[str]`, *optional*):
1004
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
1005
- `text_encoder_3`. If not defined, `negative_prompt` is used instead
1006
- num_images_per_prompt (`int`, *optional*, defaults to 1):
1007
- The number of images to generate per prompt.
1008
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1009
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1010
- to make generation deterministic.
1011
- latents (`torch.FloatTensor`, *optional*):
1012
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1013
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1014
- tensor will ge generated by sampling using the supplied random `generator`.
1015
- prompt_embeds (`torch.FloatTensor`, *optional*):
1016
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1017
- provided, text embeddings will be generated from `prompt` input argument.
1018
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1019
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1020
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1021
- argument.
1022
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1023
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1024
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
1025
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1026
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1027
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1028
- input argument.
1029
- output_type (`str`, *optional*, defaults to `"pil"`):
1030
- The output format of the generate image. Choose between
1031
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1032
- return_dict (`bool`, *optional*, defaults to `True`):
1033
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
1034
- of a plain tuple.
1035
- joint_attention_kwargs (`dict`, *optional*):
1036
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1037
- `self.processor` in
1038
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1039
- callback_on_step_end (`Callable`, *optional*):
1040
- A function that calls at the end of each denoising steps during the inference. The function is called
1041
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1042
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1043
- `callback_on_step_end_tensor_inputs`.
1044
- callback_on_step_end_tensor_inputs (`List`, *optional*):
1045
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1046
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1047
- `._callback_tensor_inputs` attribute of your pipeline class.
1048
-
1049
- Examples:
1050
-
1051
- Returns:
1052
- [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
1053
- [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1054
- `tuple`. When returning a tuple, the first element is a list with the generated images.
1055
- """
1056
-
1057
- height = height or self.default_sample_size * self.vae_scale_factor
1058
- width = width or self.default_sample_size * self.vae_scale_factor
1059
-
1060
- # align format for control guidance
1061
- if not isinstance(control_guidance_start, list) and isinstance(
1062
- control_guidance_end, list
1063
- ):
1064
- control_guidance_start = len(control_guidance_end) * [
1065
- control_guidance_start
1066
- ]
1067
- elif not isinstance(control_guidance_end, list) and isinstance(
1068
- control_guidance_start, list
1069
- ):
1070
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1071
- elif not isinstance(control_guidance_start, list) and not isinstance(
1072
- control_guidance_end, list
1073
- ):
1074
- mult = (
1075
- len(self.controlnet.nets)
1076
- if isinstance(self.controlnet, SD3MultiControlNetModel)
1077
- else 1
1078
- )
1079
- control_guidance_start, control_guidance_end = (
1080
- mult * [control_guidance_start],
1081
- mult * [control_guidance_end],
1082
- )
1083
-
1084
- # 1. Check inputs. Raise error if not correct
1085
- self.check_inputs(
1086
- prompt,
1087
- prompt_2,
1088
- prompt_3,
1089
- height,
1090
- width,
1091
- negative_prompt=negative_prompt,
1092
- negative_prompt_2=negative_prompt_2,
1093
- negative_prompt_3=negative_prompt_3,
1094
- prompt_embeds=prompt_embeds,
1095
- negative_prompt_embeds=negative_prompt_embeds,
1096
- pooled_prompt_embeds=pooled_prompt_embeds,
1097
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1098
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1099
- )
1100
-
1101
- self._guidance_scale = guidance_scale
1102
- self._clip_skip = clip_skip
1103
- self._joint_attention_kwargs = joint_attention_kwargs
1104
- self._interrupt = False
1105
-
1106
- # 2. Define call parameters
1107
- if prompt is not None and isinstance(prompt, str):
1108
- batch_size = 1
1109
- elif prompt is not None and isinstance(prompt, list):
1110
- batch_size = len(prompt)
1111
- else:
1112
- batch_size = prompt_embeds.shape[0]
1113
-
1114
- device = self._execution_device
1115
- dtype = self.transformer.dtype
1116
-
1117
- (
1118
- prompt_embeds,
1119
- negative_prompt_embeds,
1120
- pooled_prompt_embeds,
1121
- negative_pooled_prompt_embeds,
1122
- ) = self.encode_prompt(
1123
- prompt=prompt,
1124
- prompt_2=prompt_2,
1125
- prompt_3=prompt_3,
1126
- negative_prompt=negative_prompt,
1127
- negative_prompt_2=negative_prompt_2,
1128
- negative_prompt_3=negative_prompt_3,
1129
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1130
- prompt_embeds=prompt_embeds,
1131
- negative_prompt_embeds=negative_prompt_embeds,
1132
- pooled_prompt_embeds=pooled_prompt_embeds,
1133
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1134
- device=device,
1135
- clip_skip=self.clip_skip,
1136
- num_images_per_prompt=num_images_per_prompt,
1137
- )
1138
-
1139
- if self.do_classifier_free_guidance:
1140
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1141
- pooled_prompt_embeds = torch.cat(
1142
- [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0
1143
- )
1144
-
1145
- # 3. Prepare control image
1146
- if isinstance(self.controlnet, SD3ControlNetModel):
1147
- control_image = self.prepare_image_with_mask(
1148
- image=control_image,
1149
- mask=control_mask,
1150
- width=width,
1151
- height=height,
1152
- batch_size=batch_size * num_images_per_prompt,
1153
- num_images_per_prompt=num_images_per_prompt,
1154
- device=device,
1155
- dtype=self.controlnet.dtype,
1156
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1157
- )
1158
- height, width = control_image.shape[-2:]
1159
- height = height * self.vae_scale_factor
1160
- width = width * self.vae_scale_factor
1161
- elif isinstance(self.controlnet, SD3MultiControlNetModel):
1162
- images = []
1163
- for image_ in control_image:
1164
- image_ = self.prepare_image_with_mask(
1165
- image=image_,
1166
- mask=control_mask,
1167
- width=width,
1168
- height=height,
1169
- batch_size=batch_size * num_images_per_prompt,
1170
- num_images_per_prompt=num_images_per_prompt,
1171
- device=device,
1172
- dtype=self.controlnet.dtype,
1173
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1174
- )
1175
- images.append(image_)
1176
-
1177
- control_image = images
1178
- height, width = control_image[0].shape[-2:]
1179
- height = height * self.vae_scale_factor
1180
- width = width * self.vae_scale_factor
1181
- else:
1182
- raise ValueError("ControlNet must be of type SD3ControlNetModel")
1183
-
1184
- if controlnet_pooled_projections is None:
1185
- controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
1186
- else:
1187
- controlnet_pooled_projections = (
1188
- controlnet_pooled_projections or pooled_prompt_embeds
1189
- )
1190
-
1191
- # 4. Prepare timesteps
1192
- timesteps, num_inference_steps = retrieve_timesteps(
1193
- self.scheduler, num_inference_steps, device, timesteps
1194
- )
1195
- num_warmup_steps = max(
1196
- len(timesteps) - num_inference_steps * self.scheduler.order, 0
1197
- )
1198
- self._num_timesteps = len(timesteps)
1199
-
1200
- # 5. Prepare latent variables
1201
- num_channels_latents = self.transformer.config.in_channels
1202
- latents = self.prepare_latents(
1203
- batch_size * num_images_per_prompt,
1204
- num_channels_latents,
1205
- height,
1206
- width,
1207
- prompt_embeds.dtype,
1208
- device,
1209
- generator,
1210
- latents,
1211
- )
1212
-
1213
- # 6. Create tensor stating which controlnets to keep
1214
- controlnet_keep = []
1215
- for i in range(len(timesteps)):
1216
- keeps = [
1217
- 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1218
- for s, e in zip(control_guidance_start, control_guidance_end)
1219
- ]
1220
- controlnet_keep.append(
1221
- keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps
1222
- )
1223
-
1224
- # 7. Denoising loop
1225
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1226
- for i, t in enumerate(timesteps):
1227
- if self.interrupt:
1228
- continue
1229
-
1230
- # expand the latents if we are doing classifier free guidance
1231
- latent_model_input = (
1232
- torch.cat([latents] * 2)
1233
- if self.do_classifier_free_guidance
1234
- else latents
1235
- )
1236
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1237
- timestep = t.expand(latent_model_input.shape[0])
1238
-
1239
- if isinstance(controlnet_keep[i], list):
1240
- cond_scale = [
1241
- c * s
1242
- for c, s in zip(
1243
- controlnet_conditioning_scale, controlnet_keep[i]
1244
- )
1245
- ]
1246
- else:
1247
- controlnet_cond_scale = controlnet_conditioning_scale
1248
- if isinstance(controlnet_cond_scale, list):
1249
- controlnet_cond_scale = controlnet_cond_scale[0]
1250
- cond_scale = controlnet_cond_scale * controlnet_keep[i]
1251
-
1252
- # controlnet(s) inference
1253
- control_block_samples = self.controlnet(
1254
- hidden_states=latent_model_input,
1255
- timestep=timestep,
1256
- encoder_hidden_states=prompt_embeds,
1257
- pooled_projections=controlnet_pooled_projections,
1258
- joint_attention_kwargs=self.joint_attention_kwargs,
1259
- controlnet_cond=control_image,
1260
- conditioning_scale=cond_scale,
1261
- return_dict=False,
1262
- )[0]
1263
-
1264
- noise_pred = self.transformer(
1265
- hidden_states=latent_model_input,
1266
- timestep=timestep,
1267
- encoder_hidden_states=prompt_embeds,
1268
- pooled_projections=pooled_prompt_embeds,
1269
- block_controlnet_hidden_states=control_block_samples,
1270
- joint_attention_kwargs=self.joint_attention_kwargs,
1271
- return_dict=False,
1272
- )[0]
1273
-
1274
- # perform guidance
1275
- if self.do_classifier_free_guidance:
1276
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1277
- noise_pred = noise_pred_uncond + self.guidance_scale * (
1278
- noise_pred_text - noise_pred_uncond
1279
- )
1280
-
1281
- # compute the previous noisy sample x_t -> x_t-1
1282
- latents_dtype = latents.dtype
1283
- latents = self.scheduler.step(
1284
- noise_pred, t, latents, return_dict=False
1285
- )[0]
1286
-
1287
- if latents.dtype != latents_dtype:
1288
- if torch.backends.mps.is_available():
1289
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1290
- latents = latents.to(latents_dtype)
1291
-
1292
- if callback_on_step_end is not None:
1293
- callback_kwargs = {}
1294
- for k in callback_on_step_end_tensor_inputs:
1295
- callback_kwargs[k] = locals()[k]
1296
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1297
-
1298
- latents = callback_outputs.pop("latents", latents)
1299
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1300
- negative_prompt_embeds = callback_outputs.pop(
1301
- "negative_prompt_embeds", negative_prompt_embeds
1302
- )
1303
- negative_pooled_prompt_embeds = callback_outputs.pop(
1304
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1305
- )
1306
-
1307
- # call the callback, if provided
1308
- if i == len(timesteps) - 1 or (
1309
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1310
- ):
1311
- progress_bar.update()
1312
-
1313
- if XLA_AVAILABLE:
1314
- xm.mark_step()
1315
-
1316
- if output_type == "latent":
1317
- image = latents
1318
-
1319
- else:
1320
- latents = (
1321
- latents / self.vae.config.scaling_factor
1322
- ) + self.vae.config.shift_factor
1323
- latents = latents.to(dtype=self.vae.dtype)
1324
- image = self.vae.decode(latents, return_dict=False)[0]
1325
- image = self.image_processor.postprocess(image, output_type=output_type)
1326
-
1327
- # Offload all models
1328
- self.maybe_free_model_hooks()
1329
-
1330
- if not return_dict:
1331
- return (image,)
1332
-
1333
- return StableDiffusion3PipelineOutput(images=image)