yizhangliu commited on
Commit
c861f8d
·
1 Parent(s): 9ba66a8

update app.py && remove kolors/

Browse files
app.py CHANGED
@@ -78,12 +78,6 @@ from io import BytesIO
78
  from diffusers import StableDiffusionInpaintPipeline
79
  from huggingface_hub import hf_hub_download
80
 
81
- # from huggingface_hub import snapshot_download
82
- # from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_inpainting import StableDiffusionXLInpaintPipeline
83
- # from kolors.models.modeling_chatglm import ChatGLMModel
84
- # from kolors.models.tokenization_chatglm import ChatGLMTokenizer
85
- # from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
86
-
87
  from util_computer import computer_info
88
 
89
  # relate anything
@@ -334,24 +328,6 @@ def load_sd_model(device):
334
  # torch_dtype=torch.float16,
335
  # )
336
  # sd_model = sd_model.to(device)
337
-
338
- ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors-Inpainting")
339
- text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder',torch_dtype=torch.float16).half().to(device)
340
- tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
341
- vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
342
- scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
343
- unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
344
-
345
- sd_model = StableDiffusionXLInpaintPipeline(
346
- vae=vae,
347
- text_encoder=text_encoder,
348
- tokenizer=tokenizer,
349
- unet=unet,
350
- scheduler=scheduler
351
- )
352
-
353
- sd_model.to(device)
354
- sd_model.enable_attention_slicing()
355
  '''
356
 
357
  def load_lama_cleaner_model(device):
 
78
  from diffusers import StableDiffusionInpaintPipeline
79
  from huggingface_hub import hf_hub_download
80
 
 
 
 
 
 
 
81
  from util_computer import computer_info
82
 
83
  # relate anything
 
328
  # torch_dtype=torch.float16,
329
  # )
330
  # sd_model = sd_model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  '''
332
 
333
  def load_lama_cleaner_model(device):
kolors/__init__.py DELETED
File without changes
kolors/models/__init__.py DELETED
File without changes
kolors/models/configuration_chatglm.py DELETED
@@ -1,61 +0,0 @@
1
- from transformers import PretrainedConfig
2
-
3
-
4
- class ChatGLMConfig(PretrainedConfig):
5
- model_type = "chatglm"
6
- def __init__(
7
- self,
8
- num_layers=28,
9
- padded_vocab_size=65024,
10
- hidden_size=4096,
11
- ffn_hidden_size=13696,
12
- kv_channels=128,
13
- num_attention_heads=32,
14
- seq_length=2048,
15
- hidden_dropout=0.0,
16
- classifier_dropout=None,
17
- attention_dropout=0.0,
18
- layernorm_epsilon=1e-5,
19
- rmsnorm=True,
20
- apply_residual_connection_post_layernorm=False,
21
- post_layer_norm=True,
22
- add_bias_linear=False,
23
- add_qkv_bias=False,
24
- bias_dropout_fusion=True,
25
- multi_query_attention=False,
26
- multi_query_group_num=1,
27
- apply_query_key_layer_scaling=True,
28
- attention_softmax_in_fp32=True,
29
- fp32_residual_connection=False,
30
- quantization_bit=0,
31
- pre_seq_len=None,
32
- prefix_projection=False,
33
- **kwargs
34
- ):
35
- self.num_layers = num_layers
36
- self.vocab_size = padded_vocab_size
37
- self.padded_vocab_size = padded_vocab_size
38
- self.hidden_size = hidden_size
39
- self.ffn_hidden_size = ffn_hidden_size
40
- self.kv_channels = kv_channels
41
- self.num_attention_heads = num_attention_heads
42
- self.seq_length = seq_length
43
- self.hidden_dropout = hidden_dropout
44
- self.classifier_dropout = classifier_dropout
45
- self.attention_dropout = attention_dropout
46
- self.layernorm_epsilon = layernorm_epsilon
47
- self.rmsnorm = rmsnorm
48
- self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
49
- self.post_layer_norm = post_layer_norm
50
- self.add_bias_linear = add_bias_linear
51
- self.add_qkv_bias = add_qkv_bias
52
- self.bias_dropout_fusion = bias_dropout_fusion
53
- self.multi_query_attention = multi_query_attention
54
- self.multi_query_group_num = multi_query_group_num
55
- self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
56
- self.attention_softmax_in_fp32 = attention_softmax_in_fp32
57
- self.fp32_residual_connection = fp32_residual_connection
58
- self.quantization_bit = quantization_bit
59
- self.pre_seq_len = pre_seq_len
60
- self.prefix_projection = prefix_projection
61
- super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kolors/models/controlnet.py DELETED
@@ -1,887 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, List, Optional, Tuple, Union
16
-
17
- import torch
18
- from torch import nn
19
- from torch.nn import functional as F
20
-
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
23
- from diffusers.utils import BaseOutput, logging
24
- from diffusers.models.attention_processor import (
25
- ADDED_KV_ATTENTION_PROCESSORS,
26
- CROSS_ATTENTION_PROCESSORS,
27
- AttentionProcessor,
28
- AttnAddedKVProcessor,
29
- AttnProcessor,
30
- )
31
- from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
- from diffusers.models.modeling_utils import ModelMixin
33
-
34
- try:
35
- from diffusers.unets.unet_2d_blocks import (
36
- CrossAttnDownBlock2D,
37
- DownBlock2D,
38
- UNetMidBlock2D,
39
- UNetMidBlock2DCrossAttn,
40
- get_down_block,
41
- )
42
- from diffusers.unets.unet_2d_condition import UNet2DConditionModel
43
- except:
44
- from diffusers.models.unets.unet_2d_blocks import (
45
- CrossAttnDownBlock2D,
46
- DownBlock2D,
47
- UNetMidBlock2D,
48
- UNetMidBlock2DCrossAttn,
49
- get_down_block,
50
- )
51
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
52
-
53
-
54
-
55
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
-
57
-
58
- @dataclass
59
- class ControlNetOutput(BaseOutput):
60
- """
61
- The output of [`ControlNetModel`].
62
-
63
- Args:
64
- down_block_res_samples (`tuple[torch.Tensor]`):
65
- A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
66
- be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
67
- used to condition the original UNet's downsampling activations.
68
- mid_down_block_re_sample (`torch.Tensor`):
69
- The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
70
- `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
71
- Output can be used to condition the original UNet's middle block activation.
72
- """
73
-
74
- down_block_res_samples: Tuple[torch.Tensor]
75
- mid_block_res_sample: torch.Tensor
76
-
77
-
78
- class ControlNetConditioningEmbedding(nn.Module):
79
- """
80
- Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
81
- [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
82
- training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
83
- convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
84
- (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
85
- model) to encode image-space conditions ... into feature maps ..."
86
- """
87
-
88
- def __init__(
89
- self,
90
- conditioning_embedding_channels: int,
91
- conditioning_channels: int = 3,
92
- block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
93
- ):
94
- super().__init__()
95
-
96
- self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
97
-
98
- self.blocks = nn.ModuleList([])
99
-
100
- for i in range(len(block_out_channels) - 1):
101
- channel_in = block_out_channels[i]
102
- channel_out = block_out_channels[i + 1]
103
- self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
104
- self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
105
-
106
- self.conv_out = zero_module(
107
- nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
108
- )
109
-
110
- def forward(self, conditioning):
111
- embedding = self.conv_in(conditioning)
112
- embedding = F.silu(embedding)
113
-
114
- for block in self.blocks:
115
- embedding = block(embedding)
116
- embedding = F.silu(embedding)
117
-
118
- embedding = self.conv_out(embedding)
119
-
120
- return embedding
121
-
122
-
123
- class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
124
- """
125
- A ControlNet model.
126
-
127
- Args:
128
- in_channels (`int`, defaults to 4):
129
- The number of channels in the input sample.
130
- flip_sin_to_cos (`bool`, defaults to `True`):
131
- Whether to flip the sin to cos in the time embedding.
132
- freq_shift (`int`, defaults to 0):
133
- The frequency shift to apply to the time embedding.
134
- down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
135
- The tuple of downsample blocks to use.
136
- only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
137
- block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
138
- The tuple of output channels for each block.
139
- layers_per_block (`int`, defaults to 2):
140
- The number of layers per block.
141
- downsample_padding (`int`, defaults to 1):
142
- The padding to use for the downsampling convolution.
143
- mid_block_scale_factor (`float`, defaults to 1):
144
- The scale factor to use for the mid block.
145
- act_fn (`str`, defaults to "silu"):
146
- The activation function to use.
147
- norm_num_groups (`int`, *optional*, defaults to 32):
148
- The number of groups to use for the normalization. If None, normalization and activation layers is skipped
149
- in post-processing.
150
- norm_eps (`float`, defaults to 1e-5):
151
- The epsilon to use for the normalization.
152
- cross_attention_dim (`int`, defaults to 1280):
153
- The dimension of the cross attention features.
154
- transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
155
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
156
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
157
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
158
- encoder_hid_dim (`int`, *optional*, defaults to None):
159
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
160
- dimension to `cross_attention_dim`.
161
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
162
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
163
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
164
- attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
165
- The dimension of the attention heads.
166
- use_linear_projection (`bool`, defaults to `False`):
167
- class_embed_type (`str`, *optional*, defaults to `None`):
168
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
169
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
170
- addition_embed_type (`str`, *optional*, defaults to `None`):
171
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
172
- "text". "text" will use the `TextTimeEmbedding` layer.
173
- num_class_embeds (`int`, *optional*, defaults to 0):
174
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
175
- class conditioning with `class_embed_type` equal to `None`.
176
- upcast_attention (`bool`, defaults to `False`):
177
- resnet_time_scale_shift (`str`, defaults to `"default"`):
178
- Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
179
- projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
180
- The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
181
- `class_embed_type="projection"`.
182
- controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
183
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
184
- conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
185
- The tuple of output channel for each block in the `conditioning_embedding` layer.
186
- global_pool_conditions (`bool`, defaults to `False`):
187
- TODO(Patrick) - unused parameter.
188
- addition_embed_type_num_heads (`int`, defaults to 64):
189
- The number of heads to use for the `TextTimeEmbedding` layer.
190
- """
191
-
192
- _supports_gradient_checkpointing = True
193
-
194
- @register_to_config
195
- def __init__(
196
- self,
197
- in_channels: int = 4,
198
- conditioning_channels: int = 3,
199
- flip_sin_to_cos: bool = True,
200
- freq_shift: int = 0,
201
- down_block_types: Tuple[str, ...] = (
202
- "CrossAttnDownBlock2D",
203
- "CrossAttnDownBlock2D",
204
- "CrossAttnDownBlock2D",
205
- "DownBlock2D",
206
- ),
207
- mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
208
- only_cross_attention: Union[bool, Tuple[bool]] = False,
209
- block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
210
- layers_per_block: int = 2,
211
- downsample_padding: int = 1,
212
- mid_block_scale_factor: float = 1,
213
- act_fn: str = "silu",
214
- norm_num_groups: Optional[int] = 32,
215
- norm_eps: float = 1e-5,
216
- cross_attention_dim: int = 1280,
217
- transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
218
- encoder_hid_dim: Optional[int] = None,
219
- encoder_hid_dim_type: Optional[str] = None,
220
- attention_head_dim: Union[int, Tuple[int, ...]] = 8,
221
- num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
222
- use_linear_projection: bool = False,
223
- class_embed_type: Optional[str] = None,
224
- addition_embed_type: Optional[str] = None,
225
- addition_time_embed_dim: Optional[int] = None,
226
- num_class_embeds: Optional[int] = None,
227
- upcast_attention: bool = False,
228
- resnet_time_scale_shift: str = "default",
229
- projection_class_embeddings_input_dim: Optional[int] = None,
230
- controlnet_conditioning_channel_order: str = "rgb",
231
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
232
- global_pool_conditions: bool = False,
233
- addition_embed_type_num_heads: int = 64,
234
- ):
235
- super().__init__()
236
-
237
- # If `num_attention_heads` is not defined (which is the case for most models)
238
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
239
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
240
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
241
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
242
- # which is why we correct for the naming here.
243
- num_attention_heads = num_attention_heads or attention_head_dim
244
-
245
- # Check inputs
246
- if len(block_out_channels) != len(down_block_types):
247
- raise ValueError(
248
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
249
- )
250
-
251
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
252
- raise ValueError(
253
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
254
- )
255
-
256
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
257
- raise ValueError(
258
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
259
- )
260
-
261
- if isinstance(transformer_layers_per_block, int):
262
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
263
-
264
- # input
265
- conv_in_kernel = 3
266
- conv_in_padding = (conv_in_kernel - 1) // 2
267
- self.conv_in = nn.Conv2d(
268
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
269
- )
270
-
271
- # time
272
- time_embed_dim = block_out_channels[0] * 4
273
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
274
- timestep_input_dim = block_out_channels[0]
275
- self.time_embedding = TimestepEmbedding(
276
- timestep_input_dim,
277
- time_embed_dim,
278
- act_fn=act_fn,
279
- )
280
-
281
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
282
- encoder_hid_dim_type = "text_proj"
283
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
284
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
285
-
286
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
287
- raise ValueError(
288
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
289
- )
290
-
291
- if encoder_hid_dim_type == "text_proj":
292
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
293
- elif encoder_hid_dim_type == "text_image_proj":
294
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
295
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
296
- # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
297
- self.encoder_hid_proj = TextImageProjection(
298
- text_embed_dim=encoder_hid_dim,
299
- image_embed_dim=cross_attention_dim,
300
- cross_attention_dim=cross_attention_dim,
301
- )
302
-
303
- elif encoder_hid_dim_type is not None:
304
- raise ValueError(
305
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
306
- )
307
- else:
308
- self.encoder_hid_proj = None
309
-
310
- # class embedding
311
- if class_embed_type is None and num_class_embeds is not None:
312
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
313
- elif class_embed_type == "timestep":
314
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
315
- elif class_embed_type == "identity":
316
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
317
- elif class_embed_type == "projection":
318
- if projection_class_embeddings_input_dim is None:
319
- raise ValueError(
320
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
321
- )
322
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
323
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
324
- # 2. it projects from an arbitrary input dimension.
325
- #
326
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
327
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
328
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
329
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
330
- else:
331
- self.class_embedding = None
332
-
333
- if addition_embed_type == "text":
334
- if encoder_hid_dim is not None:
335
- text_time_embedding_from_dim = encoder_hid_dim
336
- else:
337
- text_time_embedding_from_dim = cross_attention_dim
338
-
339
- self.add_embedding = TextTimeEmbedding(
340
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
341
- )
342
- elif addition_embed_type == "text_image":
343
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
344
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
345
- # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
346
- self.add_embedding = TextImageTimeEmbedding(
347
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
348
- )
349
- elif addition_embed_type == "text_time":
350
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
351
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
352
-
353
- elif addition_embed_type is not None:
354
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
355
-
356
- # control net conditioning embedding
357
- self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
358
- conditioning_embedding_channels=block_out_channels[0],
359
- block_out_channels=conditioning_embedding_out_channels,
360
- conditioning_channels=conditioning_channels,
361
- )
362
-
363
- self.down_blocks = nn.ModuleList([])
364
- self.controlnet_down_blocks = nn.ModuleList([])
365
-
366
- if isinstance(only_cross_attention, bool):
367
- only_cross_attention = [only_cross_attention] * len(down_block_types)
368
-
369
- if isinstance(attention_head_dim, int):
370
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
371
-
372
- if isinstance(num_attention_heads, int):
373
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
374
-
375
- # down
376
- output_channel = block_out_channels[0]
377
-
378
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
379
- controlnet_block = zero_module(controlnet_block)
380
- self.controlnet_down_blocks.append(controlnet_block)
381
-
382
- for i, down_block_type in enumerate(down_block_types):
383
- input_channel = output_channel
384
- output_channel = block_out_channels[i]
385
- is_final_block = i == len(block_out_channels) - 1
386
-
387
- down_block = get_down_block(
388
- down_block_type,
389
- num_layers=layers_per_block,
390
- transformer_layers_per_block=transformer_layers_per_block[i],
391
- in_channels=input_channel,
392
- out_channels=output_channel,
393
- temb_channels=time_embed_dim,
394
- add_downsample=not is_final_block,
395
- resnet_eps=norm_eps,
396
- resnet_act_fn=act_fn,
397
- resnet_groups=norm_num_groups,
398
- cross_attention_dim=cross_attention_dim,
399
- num_attention_heads=num_attention_heads[i],
400
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
401
- downsample_padding=downsample_padding,
402
- use_linear_projection=use_linear_projection,
403
- only_cross_attention=only_cross_attention[i],
404
- upcast_attention=upcast_attention,
405
- resnet_time_scale_shift=resnet_time_scale_shift,
406
- )
407
- self.down_blocks.append(down_block)
408
-
409
- for _ in range(layers_per_block):
410
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
411
- controlnet_block = zero_module(controlnet_block)
412
- self.controlnet_down_blocks.append(controlnet_block)
413
-
414
- if not is_final_block:
415
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
416
- controlnet_block = zero_module(controlnet_block)
417
- self.controlnet_down_blocks.append(controlnet_block)
418
-
419
- # mid
420
- mid_block_channel = block_out_channels[-1]
421
-
422
- controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
423
- controlnet_block = zero_module(controlnet_block)
424
- self.controlnet_mid_block = controlnet_block
425
-
426
- if mid_block_type == "UNetMidBlock2DCrossAttn":
427
- self.mid_block = UNetMidBlock2DCrossAttn(
428
- transformer_layers_per_block=transformer_layers_per_block[-1],
429
- in_channels=mid_block_channel,
430
- temb_channels=time_embed_dim,
431
- resnet_eps=norm_eps,
432
- resnet_act_fn=act_fn,
433
- output_scale_factor=mid_block_scale_factor,
434
- resnet_time_scale_shift=resnet_time_scale_shift,
435
- cross_attention_dim=cross_attention_dim,
436
- num_attention_heads=num_attention_heads[-1],
437
- resnet_groups=norm_num_groups,
438
- use_linear_projection=use_linear_projection,
439
- upcast_attention=upcast_attention,
440
- )
441
- elif mid_block_type == "UNetMidBlock2D":
442
- self.mid_block = UNetMidBlock2D(
443
- in_channels=block_out_channels[-1],
444
- temb_channels=time_embed_dim,
445
- num_layers=0,
446
- resnet_eps=norm_eps,
447
- resnet_act_fn=act_fn,
448
- output_scale_factor=mid_block_scale_factor,
449
- resnet_groups=norm_num_groups,
450
- resnet_time_scale_shift=resnet_time_scale_shift,
451
- add_attention=False,
452
- )
453
- else:
454
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
455
-
456
- @classmethod
457
- def from_unet(
458
- cls,
459
- unet: UNet2DConditionModel,
460
- controlnet_conditioning_channel_order: str = "rgb",
461
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
462
- load_weights_from_unet: bool = True,
463
- conditioning_channels: int = 3,
464
- ):
465
- r"""
466
- Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
467
-
468
- Parameters:
469
- unet (`UNet2DConditionModel`):
470
- The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
471
- where applicable.
472
- """
473
- transformer_layers_per_block = (
474
- unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
475
- )
476
- encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
477
- encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
478
- addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
479
- addition_time_embed_dim = (
480
- unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
481
- )
482
-
483
- controlnet = cls(
484
- encoder_hid_dim=encoder_hid_dim,
485
- encoder_hid_dim_type=encoder_hid_dim_type,
486
- addition_embed_type=addition_embed_type,
487
- addition_time_embed_dim=addition_time_embed_dim,
488
- transformer_layers_per_block=transformer_layers_per_block,
489
- in_channels=unet.config.in_channels,
490
- flip_sin_to_cos=unet.config.flip_sin_to_cos,
491
- freq_shift=unet.config.freq_shift,
492
- down_block_types=unet.config.down_block_types,
493
- only_cross_attention=unet.config.only_cross_attention,
494
- block_out_channels=unet.config.block_out_channels,
495
- layers_per_block=unet.config.layers_per_block,
496
- downsample_padding=unet.config.downsample_padding,
497
- mid_block_scale_factor=unet.config.mid_block_scale_factor,
498
- act_fn=unet.config.act_fn,
499
- norm_num_groups=unet.config.norm_num_groups,
500
- norm_eps=unet.config.norm_eps,
501
- cross_attention_dim=unet.config.cross_attention_dim,
502
- attention_head_dim=unet.config.attention_head_dim,
503
- num_attention_heads=unet.config.num_attention_heads,
504
- use_linear_projection=unet.config.use_linear_projection,
505
- class_embed_type=unet.config.class_embed_type,
506
- num_class_embeds=unet.config.num_class_embeds,
507
- upcast_attention=unet.config.upcast_attention,
508
- resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
509
- projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
510
- mid_block_type=unet.config.mid_block_type,
511
- controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
512
- conditioning_embedding_out_channels=conditioning_embedding_out_channels,
513
- conditioning_channels=conditioning_channels,
514
- )
515
-
516
- if load_weights_from_unet:
517
- controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
518
- controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
519
- controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
520
-
521
- if controlnet.class_embedding:
522
- controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
523
-
524
- if hasattr(controlnet, "add_embedding"):
525
- controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
526
-
527
- controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
528
- controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
529
-
530
- return controlnet
531
-
532
- @property
533
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
534
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
535
- r"""
536
- Returns:
537
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
538
- indexed by its weight name.
539
- """
540
- # set recursively
541
- processors = {}
542
-
543
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
544
- if hasattr(module, "get_processor"):
545
- processors[f"{name}.processor"] = module.get_processor()
546
-
547
- for sub_name, child in module.named_children():
548
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
549
-
550
- return processors
551
-
552
- for name, module in self.named_children():
553
- fn_recursive_add_processors(name, module, processors)
554
-
555
- return processors
556
-
557
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
558
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
559
- r"""
560
- Sets the attention processor to use to compute attention.
561
-
562
- Parameters:
563
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
564
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
565
- for **all** `Attention` layers.
566
-
567
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
568
- processor. This is strongly recommended when setting trainable attention processors.
569
-
570
- """
571
- count = len(self.attn_processors.keys())
572
-
573
- if isinstance(processor, dict) and len(processor) != count:
574
- raise ValueError(
575
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
576
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
577
- )
578
-
579
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
580
- if hasattr(module, "set_processor"):
581
- if not isinstance(processor, dict):
582
- module.set_processor(processor)
583
- else:
584
- module.set_processor(processor.pop(f"{name}.processor"))
585
-
586
- for sub_name, child in module.named_children():
587
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
588
-
589
- for name, module in self.named_children():
590
- fn_recursive_attn_processor(name, module, processor)
591
-
592
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
593
- def set_default_attn_processor(self):
594
- """
595
- Disables custom attention processors and sets the default attention implementation.
596
- """
597
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
598
- processor = AttnAddedKVProcessor()
599
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
600
- processor = AttnProcessor()
601
- else:
602
- raise ValueError(
603
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
604
- )
605
-
606
- self.set_attn_processor(processor)
607
-
608
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
609
- def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
610
- r"""
611
- Enable sliced attention computation.
612
-
613
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
614
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
615
-
616
- Args:
617
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
618
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
619
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
620
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
621
- must be a multiple of `slice_size`.
622
- """
623
- sliceable_head_dims = []
624
-
625
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
626
- if hasattr(module, "set_attention_slice"):
627
- sliceable_head_dims.append(module.sliceable_head_dim)
628
-
629
- for child in module.children():
630
- fn_recursive_retrieve_sliceable_dims(child)
631
-
632
- # retrieve number of attention layers
633
- for module in self.children():
634
- fn_recursive_retrieve_sliceable_dims(module)
635
-
636
- num_sliceable_layers = len(sliceable_head_dims)
637
-
638
- if slice_size == "auto":
639
- # half the attention head size is usually a good trade-off between
640
- # speed and memory
641
- slice_size = [dim // 2 for dim in sliceable_head_dims]
642
- elif slice_size == "max":
643
- # make smallest slice possible
644
- slice_size = num_sliceable_layers * [1]
645
-
646
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
647
-
648
- if len(slice_size) != len(sliceable_head_dims):
649
- raise ValueError(
650
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
651
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
652
- )
653
-
654
- for i in range(len(slice_size)):
655
- size = slice_size[i]
656
- dim = sliceable_head_dims[i]
657
- if size is not None and size > dim:
658
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
659
-
660
- # Recursively walk through all the children.
661
- # Any children which exposes the set_attention_slice method
662
- # gets the message
663
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
664
- if hasattr(module, "set_attention_slice"):
665
- module.set_attention_slice(slice_size.pop())
666
-
667
- for child in module.children():
668
- fn_recursive_set_attention_slice(child, slice_size)
669
-
670
- reversed_slice_size = list(reversed(slice_size))
671
- for module in self.children():
672
- fn_recursive_set_attention_slice(module, reversed_slice_size)
673
-
674
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
675
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
676
- module.gradient_checkpointing = value
677
-
678
- def forward(
679
- self,
680
- sample: torch.Tensor,
681
- timestep: Union[torch.Tensor, float, int],
682
- encoder_hidden_states: torch.Tensor,
683
- controlnet_cond: torch.Tensor,
684
- conditioning_scale: float = 1.0,
685
- class_labels: Optional[torch.Tensor] = None,
686
- timestep_cond: Optional[torch.Tensor] = None,
687
- attention_mask: Optional[torch.Tensor] = None,
688
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
689
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
690
- guess_mode: bool = False,
691
- return_dict: bool = True,
692
- ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
693
- """
694
- The [`ControlNetModel`] forward method.
695
-
696
- Args:
697
- sample (`torch.Tensor`):
698
- The noisy input tensor.
699
- timestep (`Union[torch.Tensor, float, int]`):
700
- The number of timesteps to denoise an input.
701
- encoder_hidden_states (`torch.Tensor`):
702
- The encoder hidden states.
703
- controlnet_cond (`torch.Tensor`):
704
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
705
- conditioning_scale (`float`, defaults to `1.0`):
706
- The scale factor for ControlNet outputs.
707
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
708
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
709
- timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
710
- Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
711
- timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
712
- embeddings.
713
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
714
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
715
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
716
- negative values to the attention scores corresponding to "discard" tokens.
717
- added_cond_kwargs (`dict`):
718
- Additional conditions for the Stable Diffusion XL UNet.
719
- cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
720
- A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
721
- guess_mode (`bool`, defaults to `False`):
722
- In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
723
- you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
724
- return_dict (`bool`, defaults to `True`):
725
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
726
-
727
- Returns:
728
- [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
729
- If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
730
- returned where the first element is the sample tensor.
731
- """
732
- # check channel order
733
- channel_order = self.config.controlnet_conditioning_channel_order
734
-
735
- if channel_order == "rgb":
736
- # in rgb order by default
737
- ...
738
- elif channel_order == "bgr":
739
- controlnet_cond = torch.flip(controlnet_cond, dims=[1])
740
- else:
741
- raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
742
-
743
- # prepare attention_mask
744
- if attention_mask is not None:
745
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
746
- attention_mask = attention_mask.unsqueeze(1)
747
-
748
- #Todo
749
- if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
750
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
751
-
752
- # 1. time
753
- timesteps = timestep
754
- if not torch.is_tensor(timesteps):
755
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
756
- # This would be a good case for the `match` statement (Python 3.10+)
757
- is_mps = sample.device.type == "mps"
758
- if isinstance(timestep, float):
759
- dtype = torch.float32 if is_mps else torch.float64
760
- else:
761
- dtype = torch.int32 if is_mps else torch.int64
762
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
763
- elif len(timesteps.shape) == 0:
764
- timesteps = timesteps[None].to(sample.device)
765
-
766
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
767
- timesteps = timesteps.expand(sample.shape[0])
768
-
769
- t_emb = self.time_proj(timesteps)
770
-
771
- # timesteps does not contain any weights and will always return f32 tensors
772
- # but time_embedding might actually be running in fp16. so we need to cast here.
773
- # there might be better ways to encapsulate this.
774
- t_emb = t_emb.to(dtype=sample.dtype)
775
-
776
- emb = self.time_embedding(t_emb, timestep_cond)
777
- aug_emb = None
778
-
779
- if self.class_embedding is not None:
780
- if class_labels is None:
781
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
782
-
783
- if self.config.class_embed_type == "timestep":
784
- class_labels = self.time_proj(class_labels)
785
-
786
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
787
- emb = emb + class_emb
788
-
789
- if self.config.addition_embed_type is not None:
790
- if self.config.addition_embed_type == "text":
791
- aug_emb = self.add_embedding(encoder_hidden_states)
792
-
793
- elif self.config.addition_embed_type == "text_time":
794
- if "text_embeds" not in added_cond_kwargs:
795
- raise ValueError(
796
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
797
- )
798
- text_embeds = added_cond_kwargs.get("text_embeds")
799
- if "time_ids" not in added_cond_kwargs:
800
- raise ValueError(
801
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
802
- )
803
- time_ids = added_cond_kwargs.get("time_ids")
804
- time_embeds = self.add_time_proj(time_ids.flatten())
805
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
806
-
807
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
808
- add_embeds = add_embeds.to(emb.dtype)
809
- aug_emb = self.add_embedding(add_embeds)
810
-
811
- emb = emb + aug_emb if aug_emb is not None else emb
812
-
813
- # 2. pre-process
814
- sample = self.conv_in(sample)
815
-
816
- controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
817
- sample = sample + controlnet_cond
818
-
819
- # 3. down
820
- down_block_res_samples = (sample,)
821
- for downsample_block in self.down_blocks:
822
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
823
- sample, res_samples = downsample_block(
824
- hidden_states=sample,
825
- temb=emb,
826
- encoder_hidden_states=encoder_hidden_states,
827
- attention_mask=attention_mask,
828
- cross_attention_kwargs=cross_attention_kwargs,
829
- )
830
- else:
831
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
832
-
833
- down_block_res_samples += res_samples
834
-
835
- # 4. mid
836
- if self.mid_block is not None:
837
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
838
- sample = self.mid_block(
839
- sample,
840
- emb,
841
- encoder_hidden_states=encoder_hidden_states,
842
- attention_mask=attention_mask,
843
- cross_attention_kwargs=cross_attention_kwargs,
844
- )
845
- else:
846
- sample = self.mid_block(sample, emb)
847
-
848
- # 5. Control net blocks
849
-
850
- controlnet_down_block_res_samples = ()
851
-
852
- for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
853
- down_block_res_sample = controlnet_block(down_block_res_sample)
854
- controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
855
-
856
- down_block_res_samples = controlnet_down_block_res_samples
857
-
858
- mid_block_res_sample = self.controlnet_mid_block(sample)
859
-
860
- # 6. scaling
861
- if guess_mode and not self.config.global_pool_conditions:
862
- scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
863
- scales = scales * conditioning_scale
864
- down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
865
- mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
866
- else:
867
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
868
- mid_block_res_sample = mid_block_res_sample * conditioning_scale
869
-
870
- if self.config.global_pool_conditions:
871
- down_block_res_samples = [
872
- torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
873
- ]
874
- mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
875
-
876
- if not return_dict:
877
- return (down_block_res_samples, mid_block_res_sample)
878
-
879
- return ControlNetOutput(
880
- down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
881
- )
882
-
883
-
884
- def zero_module(module):
885
- for p in module.parameters():
886
- nn.init.zeros_(p)
887
- return module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kolors/models/modeling_chatglm.py DELETED
@@ -1,1298 +0,0 @@
1
- """ PyTorch ChatGLM model. """
2
-
3
- import math
4
- import copy
5
- import warnings
6
- import re
7
- import sys
8
-
9
- import torch
10
- import torch.utils.checkpoint
11
- import torch.nn.functional as F
12
- from torch import nn
13
- from torch.nn import CrossEntropyLoss, LayerNorm
14
- from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
15
- from torch.nn.utils import skip_init
16
- from typing import Optional, Tuple, Union, List, Callable, Dict, Any
17
- from copy import deepcopy
18
-
19
- from transformers.modeling_outputs import (
20
- BaseModelOutputWithPast,
21
- CausalLMOutputWithPast,
22
- SequenceClassifierOutputWithPast,
23
- )
24
- from transformers.modeling_utils import PreTrainedModel
25
- from transformers.utils import logging
26
- from transformers.generation.logits_process import LogitsProcessor
27
- from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
28
-
29
- try:
30
- from .configuration_chatglm import ChatGLMConfig
31
- except:
32
- from configuration_chatglm import ChatGLMConfig
33
-
34
-
35
- # flags required to enable jit fusion kernels
36
-
37
- if sys.platform != 'darwin':
38
- torch._C._jit_set_profiling_mode(False)
39
- torch._C._jit_set_profiling_executor(False)
40
- torch._C._jit_override_can_fuse_on_cpu(True)
41
- torch._C._jit_override_can_fuse_on_gpu(True)
42
-
43
- logger = logging.get_logger(__name__)
44
-
45
- _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
46
- _CONFIG_FOR_DOC = "ChatGLM6BConfig"
47
-
48
- CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
49
- "THUDM/chatglm3-6b-base",
50
- # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
51
- ]
52
-
53
-
54
- def default_init(cls, *args, **kwargs):
55
- return cls(*args, **kwargs)
56
-
57
-
58
- class InvalidScoreLogitsProcessor(LogitsProcessor):
59
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
60
- if torch.isnan(scores).any() or torch.isinf(scores).any():
61
- scores.zero_()
62
- scores[..., 5] = 5e4
63
- return scores
64
-
65
-
66
- class PrefixEncoder(torch.nn.Module):
67
- """
68
- The torch.nn model to encode the prefix
69
- Input shape: (batch-size, prefix-length)
70
- Output shape: (batch-size, prefix-length, 2*layers*hidden)
71
- """
72
-
73
- def __init__(self, config: ChatGLMConfig):
74
- super().__init__()
75
- self.prefix_projection = config.prefix_projection
76
- if self.prefix_projection:
77
- # Use a two-layer MLP to encode the prefix
78
- kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
79
- self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
80
- self.trans = torch.nn.Sequential(
81
- torch.nn.Linear(kv_size, config.hidden_size),
82
- torch.nn.Tanh(),
83
- torch.nn.Linear(config.hidden_size, kv_size)
84
- )
85
- else:
86
- self.embedding = torch.nn.Embedding(config.pre_seq_len,
87
- config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
88
-
89
- def forward(self, prefix: torch.Tensor):
90
- if self.prefix_projection:
91
- prefix_tokens = self.embedding(prefix)
92
- past_key_values = self.trans(prefix_tokens)
93
- else:
94
- past_key_values = self.embedding(prefix)
95
- return past_key_values
96
-
97
-
98
- def split_tensor_along_last_dim(
99
- tensor: torch.Tensor,
100
- num_partitions: int,
101
- contiguous_split_chunks: bool = False,
102
- ) -> List[torch.Tensor]:
103
- """Split a tensor along its last dimension.
104
-
105
- Arguments:
106
- tensor: input tensor.
107
- num_partitions: number of partitions to split the tensor
108
- contiguous_split_chunks: If True, make each chunk contiguous
109
- in memory.
110
-
111
- Returns:
112
- A list of Tensors
113
- """
114
- # Get the size and dimension.
115
- last_dim = tensor.dim() - 1
116
- last_dim_size = tensor.size()[last_dim] // num_partitions
117
- # Split.
118
- tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
119
- # Note: torch.split does not create contiguous tensors by default.
120
- if contiguous_split_chunks:
121
- return tuple(chunk.contiguous() for chunk in tensor_list)
122
-
123
- return tensor_list
124
-
125
-
126
- class RotaryEmbedding(nn.Module):
127
- def __init__(self, dim, original_impl=False, device=None, dtype=None):
128
- super().__init__()
129
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
130
- self.register_buffer("inv_freq", inv_freq)
131
- self.dim = dim
132
- self.original_impl = original_impl
133
-
134
- def forward_impl(
135
- self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
136
- ):
137
- """Enhanced Transformer with Rotary Position Embedding.
138
-
139
- Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
140
- transformers/rope/__init__.py. MIT License:
141
- https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
142
- """
143
- # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
144
- theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
145
-
146
- # Create position indexes `[0, 1, ..., seq_len - 1]`
147
- seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
148
-
149
- # Calculate the product of position index and $\theta_i$
150
- idx_theta = torch.outer(seq_idx, theta).float()
151
-
152
- cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
153
-
154
- # this is to mimic the behaviour of complex32, else we will get different results
155
- if dtype in (torch.float16, torch.bfloat16, torch.int8):
156
- cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
157
- return cache
158
-
159
- def forward(self, max_seq_len, offset=0):
160
- return self.forward_impl(
161
- max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
162
- )
163
-
164
-
165
- @torch.jit.script
166
- def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
167
- # x: [sq, b, np, hn]
168
- sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
169
- rot_dim = rope_cache.shape[-2] * 2
170
- x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
171
- # truncate to support variable sizes
172
- rope_cache = rope_cache[:sq]
173
- xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
174
- rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
175
- x_out2 = torch.stack(
176
- [
177
- xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
178
- xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
179
- ],
180
- -1,
181
- )
182
- x_out2 = x_out2.flatten(3)
183
- return torch.cat((x_out2, x_pass), dim=-1)
184
-
185
-
186
- class RMSNorm(torch.nn.Module):
187
- def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
188
- super().__init__()
189
- self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
190
- self.eps = eps
191
-
192
- def forward(self, hidden_states: torch.Tensor):
193
- input_dtype = hidden_states.dtype
194
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
195
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
196
-
197
- return (self.weight * hidden_states).to(input_dtype)
198
-
199
-
200
- class CoreAttention(torch.nn.Module):
201
- def __init__(self, config: ChatGLMConfig, layer_number):
202
- super(CoreAttention, self).__init__()
203
-
204
- self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
205
- self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
206
- if self.apply_query_key_layer_scaling:
207
- self.attention_softmax_in_fp32 = True
208
- self.layer_number = max(1, layer_number)
209
-
210
- projection_size = config.kv_channels * config.num_attention_heads
211
-
212
- # Per attention head and per partition values.
213
- self.hidden_size_per_partition = projection_size
214
- self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
215
- self.num_attention_heads_per_partition = config.num_attention_heads
216
-
217
- coeff = None
218
- self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
219
- if self.apply_query_key_layer_scaling:
220
- coeff = self.layer_number
221
- self.norm_factor *= coeff
222
- self.coeff = coeff
223
-
224
- self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
225
-
226
- def forward(self, query_layer, key_layer, value_layer, attention_mask):
227
- pytorch_major_version = int(torch.__version__.split('.')[0])
228
- if pytorch_major_version >= 2:
229
- query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
230
- if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
231
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
232
- is_causal=True)
233
- else:
234
- if attention_mask is not None:
235
- attention_mask = ~attention_mask
236
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
237
- attention_mask)
238
- context_layer = context_layer.permute(2, 0, 1, 3)
239
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
240
- context_layer = context_layer.reshape(*new_context_layer_shape)
241
- else:
242
- # Raw attention scores
243
-
244
- # [b, np, sq, sk]
245
- output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
246
-
247
- # [sq, b, np, hn] -> [sq, b * np, hn]
248
- query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
249
- # [sk, b, np, hn] -> [sk, b * np, hn]
250
- key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
251
-
252
- # preallocting input tensor: [b * np, sq, sk]
253
- matmul_input_buffer = torch.empty(
254
- output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
255
- device=query_layer.device
256
- )
257
-
258
- # Raw attention scores. [b * np, sq, sk]
259
- matmul_result = torch.baddbmm(
260
- matmul_input_buffer,
261
- query_layer.transpose(0, 1), # [b * np, sq, hn]
262
- key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
263
- beta=0.0,
264
- alpha=(1.0 / self.norm_factor),
265
- )
266
-
267
- # change view to [b, np, sq, sk]
268
- attention_scores = matmul_result.view(*output_size)
269
-
270
- # ===========================
271
- # Attention probs and dropout
272
- # ===========================
273
-
274
- # attention scores and attention mask [b, np, sq, sk]
275
- if self.attention_softmax_in_fp32:
276
- attention_scores = attention_scores.float()
277
- if self.coeff is not None:
278
- attention_scores = attention_scores * self.coeff
279
- if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
280
- attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
281
- device=attention_scores.device, dtype=torch.bool)
282
- attention_mask.tril_()
283
- attention_mask = ~attention_mask
284
- if attention_mask is not None:
285
- attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
286
- attention_probs = F.softmax(attention_scores, dim=-1)
287
- attention_probs = attention_probs.type_as(value_layer)
288
-
289
- # This is actually dropping out entire tokens to attend to, which might
290
- # seem a bit unusual, but is taken from the original Transformer paper.
291
- attention_probs = self.attention_dropout(attention_probs)
292
- # =========================
293
- # Context layer. [sq, b, hp]
294
- # =========================
295
-
296
- # value_layer -> context layer.
297
- # [sk, b, np, hn] --> [b, np, sq, hn]
298
-
299
- # context layer shape: [b, np, sq, hn]
300
- output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
301
- # change view [sk, b * np, hn]
302
- value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
303
- # change view [b * np, sq, sk]
304
- attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
305
- # matmul: [b * np, sq, hn]
306
- context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
307
- # change view [b, np, sq, hn]
308
- context_layer = context_layer.view(*output_size)
309
- # [b, np, sq, hn] --> [sq, b, np, hn]
310
- context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
311
- # [sq, b, np, hn] --> [sq, b, hp]
312
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
313
- context_layer = context_layer.view(*new_context_layer_shape)
314
-
315
- return context_layer
316
-
317
-
318
- class SelfAttention(torch.nn.Module):
319
- """Parallel self-attention layer abstract class.
320
-
321
- Self-attention layer takes input with size [s, b, h]
322
- and returns output of the same size.
323
- """
324
-
325
- def __init__(self, config: ChatGLMConfig, layer_number, device=None):
326
- super(SelfAttention, self).__init__()
327
- self.layer_number = max(1, layer_number)
328
-
329
- self.projection_size = config.kv_channels * config.num_attention_heads
330
-
331
- # Per attention head and per partition values.
332
- self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
333
- self.num_attention_heads_per_partition = config.num_attention_heads
334
-
335
- self.multi_query_attention = config.multi_query_attention
336
- self.qkv_hidden_size = 3 * self.projection_size
337
- if self.multi_query_attention:
338
- self.num_multi_query_groups_per_partition = config.multi_query_group_num
339
- self.qkv_hidden_size = (
340
- self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
341
- )
342
- self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
343
- bias=config.add_bias_linear or config.add_qkv_bias,
344
- device=device, **_config_to_kwargs(config)
345
- )
346
-
347
- self.core_attention = CoreAttention(config, self.layer_number)
348
-
349
- # Output.
350
- self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
351
- device=device, **_config_to_kwargs(config)
352
- )
353
-
354
- def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
355
- if self.multi_query_attention:
356
- num_attention_heads = self.num_multi_query_groups_per_partition
357
- else:
358
- num_attention_heads = self.num_attention_heads_per_partition
359
- return torch.empty(
360
- inference_max_sequence_len,
361
- batch_size,
362
- num_attention_heads,
363
- self.hidden_size_per_attention_head,
364
- dtype=dtype,
365
- device=device,
366
- )
367
-
368
- def forward(
369
- self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
370
- ):
371
- # hidden_states: [sq, b, h]
372
-
373
- # =================================================
374
- # Pre-allocate memory for key-values for inference.
375
- # =================================================
376
- # =====================
377
- # Query, Key, and Value
378
- # =====================
379
-
380
- # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
381
- mixed_x_layer = self.query_key_value(hidden_states)
382
-
383
- if self.multi_query_attention:
384
- (query_layer, key_layer, value_layer) = mixed_x_layer.split(
385
- [
386
- self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
387
- self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
388
- self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
389
- ],
390
- dim=-1,
391
- )
392
- query_layer = query_layer.view(
393
- query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
394
- )
395
- key_layer = key_layer.view(
396
- key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
397
- )
398
- value_layer = value_layer.view(
399
- value_layer.size()[:-1]
400
- + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
401
- )
402
- else:
403
- new_tensor_shape = mixed_x_layer.size()[:-1] + \
404
- (self.num_attention_heads_per_partition,
405
- 3 * self.hidden_size_per_attention_head)
406
- mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
407
-
408
- # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
409
- (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
410
-
411
- # apply relative positional encoding (rotary embedding)
412
- if rotary_pos_emb is not None:
413
- query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
414
- key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
415
-
416
- # adjust key and value for inference
417
- if kv_cache is not None:
418
- cache_k, cache_v = kv_cache
419
- key_layer = torch.cat((cache_k, key_layer), dim=0)
420
- value_layer = torch.cat((cache_v, value_layer), dim=0)
421
- if use_cache:
422
- kv_cache = (key_layer, value_layer)
423
- else:
424
- kv_cache = None
425
-
426
- if self.multi_query_attention:
427
- key_layer = key_layer.unsqueeze(-2)
428
- key_layer = key_layer.expand(
429
- -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
430
- )
431
- key_layer = key_layer.contiguous().view(
432
- key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
433
- )
434
- value_layer = value_layer.unsqueeze(-2)
435
- value_layer = value_layer.expand(
436
- -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
437
- )
438
- value_layer = value_layer.contiguous().view(
439
- value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
440
- )
441
-
442
- # ==================================
443
- # core attention computation
444
- # ==================================
445
-
446
- context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
447
-
448
- # =================
449
- # Output. [sq, b, h]
450
- # =================
451
-
452
- output = self.dense(context_layer)
453
-
454
- return output, kv_cache
455
-
456
-
457
- def _config_to_kwargs(args):
458
- common_kwargs = {
459
- "dtype": args.torch_dtype,
460
- }
461
- return common_kwargs
462
-
463
-
464
- class MLP(torch.nn.Module):
465
- """MLP.
466
-
467
- MLP will take the input with h hidden state, project it to 4*h
468
- hidden dimension, perform nonlinear transformation, and project the
469
- state back into h hidden dimension.
470
- """
471
-
472
- def __init__(self, config: ChatGLMConfig, device=None):
473
- super(MLP, self).__init__()
474
-
475
- self.add_bias = config.add_bias_linear
476
-
477
- # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
478
- self.dense_h_to_4h = nn.Linear(
479
- config.hidden_size,
480
- config.ffn_hidden_size * 2,
481
- bias=self.add_bias,
482
- device=device,
483
- **_config_to_kwargs(config)
484
- )
485
-
486
- def swiglu(x):
487
- x = torch.chunk(x, 2, dim=-1)
488
- return F.silu(x[0]) * x[1]
489
-
490
- self.activation_func = swiglu
491
-
492
- # Project back to h.
493
- self.dense_4h_to_h = nn.Linear(
494
- config.ffn_hidden_size,
495
- config.hidden_size,
496
- bias=self.add_bias,
497
- device=device,
498
- **_config_to_kwargs(config)
499
- )
500
-
501
- def forward(self, hidden_states):
502
- # [s, b, 4hp]
503
- intermediate_parallel = self.dense_h_to_4h(hidden_states)
504
- intermediate_parallel = self.activation_func(intermediate_parallel)
505
- # [s, b, h]
506
- output = self.dense_4h_to_h(intermediate_parallel)
507
- return output
508
-
509
-
510
- class GLMBlock(torch.nn.Module):
511
- """A single transformer layer.
512
-
513
- Transformer layer takes input with size [s, b, h] and returns an
514
- output of the same size.
515
- """
516
-
517
- def __init__(self, config: ChatGLMConfig, layer_number, device=None):
518
- super(GLMBlock, self).__init__()
519
- self.layer_number = layer_number
520
-
521
- self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
522
-
523
- self.fp32_residual_connection = config.fp32_residual_connection
524
-
525
- LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
526
- # Layernorm on the input data.
527
- self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
528
- dtype=config.torch_dtype)
529
-
530
- # Self attention.
531
- self.self_attention = SelfAttention(config, layer_number, device=device)
532
- self.hidden_dropout = config.hidden_dropout
533
-
534
- # Layernorm on the attention output
535
- self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
536
- dtype=config.torch_dtype)
537
-
538
- # MLP
539
- self.mlp = MLP(config, device=device)
540
-
541
- def forward(
542
- self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
543
- ):
544
- # hidden_states: [s, b, h]
545
-
546
- # Layer norm at the beginning of the transformer layer.
547
- layernorm_output = self.input_layernorm(hidden_states)
548
- # Self attention.
549
- attention_output, kv_cache = self.self_attention(
550
- layernorm_output,
551
- attention_mask,
552
- rotary_pos_emb,
553
- kv_cache=kv_cache,
554
- use_cache=use_cache
555
- )
556
-
557
- # Residual connection.
558
- if self.apply_residual_connection_post_layernorm:
559
- residual = layernorm_output
560
- else:
561
- residual = hidden_states
562
-
563
- layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
564
- layernorm_input = residual + layernorm_input
565
-
566
- # Layer norm post the self attention.
567
- layernorm_output = self.post_attention_layernorm(layernorm_input)
568
-
569
- # MLP.
570
- mlp_output = self.mlp(layernorm_output)
571
-
572
- # Second residual connection.
573
- if self.apply_residual_connection_post_layernorm:
574
- residual = layernorm_output
575
- else:
576
- residual = layernorm_input
577
-
578
- output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
579
- output = residual + output
580
-
581
- return output, kv_cache
582
-
583
-
584
- class GLMTransformer(torch.nn.Module):
585
- """Transformer class."""
586
-
587
- def __init__(self, config: ChatGLMConfig, device=None):
588
- super(GLMTransformer, self).__init__()
589
-
590
- self.fp32_residual_connection = config.fp32_residual_connection
591
- self.post_layer_norm = config.post_layer_norm
592
-
593
- # Number of layers.
594
- self.num_layers = config.num_layers
595
-
596
- # Transformer layers.
597
- def build_layer(layer_number):
598
- return GLMBlock(config, layer_number, device=device)
599
-
600
- self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
601
-
602
- if self.post_layer_norm:
603
- LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
604
- # Final layer norm before output.
605
- self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
606
- dtype=config.torch_dtype)
607
-
608
- self.gradient_checkpointing = False
609
-
610
- def _get_layer(self, layer_number):
611
- return self.layers[layer_number]
612
-
613
- def forward(
614
- self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
615
- use_cache: Optional[bool] = True,
616
- output_hidden_states: Optional[bool] = False,
617
- ):
618
- if not kv_caches:
619
- kv_caches = [None for _ in range(self.num_layers)]
620
- presents = () if use_cache else None
621
- if self.gradient_checkpointing and self.training:
622
- if use_cache:
623
- logger.warning_once(
624
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
625
- )
626
- use_cache = False
627
-
628
- all_self_attentions = None
629
- all_hidden_states = () if output_hidden_states else None
630
- for index in range(self.num_layers):
631
- if output_hidden_states:
632
- all_hidden_states = all_hidden_states + (hidden_states,)
633
-
634
- layer = self._get_layer(index)
635
- if self.gradient_checkpointing and self.training:
636
- layer_ret = torch.utils.checkpoint.checkpoint(
637
- layer,
638
- hidden_states,
639
- attention_mask,
640
- rotary_pos_emb,
641
- kv_caches[index],
642
- use_cache
643
- )
644
- else:
645
- layer_ret = layer(
646
- hidden_states,
647
- attention_mask,
648
- rotary_pos_emb,
649
- kv_cache=kv_caches[index],
650
- use_cache=use_cache
651
- )
652
- hidden_states, kv_cache = layer_ret
653
- if use_cache:
654
- presents = presents + (kv_cache,)
655
-
656
- if output_hidden_states:
657
- all_hidden_states = all_hidden_states + (hidden_states,)
658
-
659
- # Final layer norm.
660
- if self.post_layer_norm:
661
- hidden_states = self.final_layernorm(hidden_states)
662
-
663
- return hidden_states, presents, all_hidden_states, all_self_attentions
664
-
665
-
666
- class ChatGLMPreTrainedModel(PreTrainedModel):
667
- """
668
- An abstract class to handle weights initialization and
669
- a simple interface for downloading and loading pretrained models.
670
- """
671
-
672
- is_parallelizable = False
673
- supports_gradient_checkpointing = True
674
- config_class = ChatGLMConfig
675
- base_model_prefix = "transformer"
676
- _no_split_modules = ["GLMBlock"]
677
-
678
- def _init_weights(self, module: nn.Module):
679
- """Initialize the weights."""
680
- return
681
-
682
- def get_masks(self, input_ids, past_key_values, padding_mask=None):
683
- batch_size, seq_length = input_ids.shape
684
- full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
685
- full_attention_mask.tril_()
686
- past_length = 0
687
- if past_key_values:
688
- past_length = past_key_values[0][0].shape[0]
689
- if past_length:
690
- full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
691
- device=input_ids.device), full_attention_mask), dim=-1)
692
- if padding_mask is not None:
693
- full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
694
- if not past_length and padding_mask is not None:
695
- full_attention_mask -= padding_mask.unsqueeze(-1) - 1
696
- full_attention_mask = (full_attention_mask < 0.5).bool()
697
- full_attention_mask.unsqueeze_(1)
698
- return full_attention_mask
699
-
700
- def get_position_ids(self, input_ids, device):
701
- batch_size, seq_length = input_ids.shape
702
- position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
703
- return position_ids
704
-
705
- def _set_gradient_checkpointing(self, module, value=False):
706
- if isinstance(module, GLMTransformer):
707
- module.gradient_checkpointing = value
708
-
709
-
710
- class Embedding(torch.nn.Module):
711
- """Language model embeddings."""
712
-
713
- def __init__(self, config: ChatGLMConfig, device=None):
714
- super(Embedding, self).__init__()
715
-
716
- self.hidden_size = config.hidden_size
717
- # Word embeddings (parallel).
718
- self.word_embeddings = nn.Embedding(
719
- config.padded_vocab_size,
720
- self.hidden_size,
721
- dtype=config.torch_dtype,
722
- device=device
723
- )
724
- self.fp32_residual_connection = config.fp32_residual_connection
725
-
726
- def forward(self, input_ids):
727
- # Embeddings.
728
- words_embeddings = self.word_embeddings(input_ids)
729
- embeddings = words_embeddings
730
- # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
731
- embeddings = embeddings.transpose(0, 1).contiguous()
732
- # If the input flag for fp32 residual connection is set, convert for float.
733
- if self.fp32_residual_connection:
734
- embeddings = embeddings.float()
735
- return embeddings
736
-
737
-
738
- class ChatGLMModel(ChatGLMPreTrainedModel):
739
- def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
740
- super().__init__(config)
741
- if empty_init:
742
- init_method = skip_init
743
- else:
744
- init_method = default_init
745
- init_kwargs = {}
746
- if device is not None:
747
- init_kwargs["device"] = device
748
- self.embedding = init_method(Embedding, config, **init_kwargs)
749
- self.num_layers = config.num_layers
750
- self.multi_query_group_num = config.multi_query_group_num
751
- self.kv_channels = config.kv_channels
752
-
753
- # Rotary positional embeddings
754
- self.seq_length = config.seq_length
755
- rotary_dim = (
756
- config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
757
- )
758
-
759
- self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
760
- dtype=config.torch_dtype)
761
- self.encoder = init_method(GLMTransformer, config, **init_kwargs)
762
- self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
763
- dtype=config.torch_dtype, **init_kwargs)
764
- self.pre_seq_len = config.pre_seq_len
765
- self.prefix_projection = config.prefix_projection
766
- if self.pre_seq_len is not None:
767
- for param in self.parameters():
768
- param.requires_grad = False
769
- self.prefix_tokens = torch.arange(self.pre_seq_len).long()
770
- self.prefix_encoder = PrefixEncoder(config)
771
- self.dropout = torch.nn.Dropout(0.1)
772
-
773
- def get_input_embeddings(self):
774
- return self.embedding.word_embeddings
775
-
776
- def get_prompt(self, batch_size, device, dtype=torch.half):
777
- prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
778
- past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
779
- past_key_values = past_key_values.view(
780
- batch_size,
781
- self.pre_seq_len,
782
- self.num_layers * 2,
783
- self.multi_query_group_num,
784
- self.kv_channels
785
- )
786
- # seq_len, b, nh, hidden_size
787
- past_key_values = self.dropout(past_key_values)
788
- past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
789
- return past_key_values
790
-
791
- def forward(
792
- self,
793
- input_ids,
794
- position_ids: Optional[torch.Tensor] = None,
795
- attention_mask: Optional[torch.BoolTensor] = None,
796
- full_attention_mask: Optional[torch.BoolTensor] = None,
797
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
798
- inputs_embeds: Optional[torch.Tensor] = None,
799
- use_cache: Optional[bool] = None,
800
- output_hidden_states: Optional[bool] = None,
801
- return_dict: Optional[bool] = None,
802
- ):
803
- output_hidden_states = (
804
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
805
- )
806
- use_cache = use_cache if use_cache is not None else self.config.use_cache
807
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
808
-
809
- batch_size, seq_length = input_ids.shape
810
-
811
- if inputs_embeds is None:
812
- inputs_embeds = self.embedding(input_ids)
813
-
814
- if self.pre_seq_len is not None:
815
- if past_key_values is None:
816
- past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
817
- dtype=inputs_embeds.dtype)
818
- if attention_mask is not None:
819
- attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
820
- attention_mask], dim=-1)
821
-
822
- if full_attention_mask is None:
823
- if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
824
- full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
825
-
826
- # Rotary positional embeddings
827
- rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
828
- if position_ids is not None:
829
- rotary_pos_emb = rotary_pos_emb[position_ids]
830
- else:
831
- rotary_pos_emb = rotary_pos_emb[None, :seq_length]
832
- rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
833
-
834
- # Run encoder.
835
- hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
836
- inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
837
- kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
838
- )
839
-
840
- if not return_dict:
841
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
842
-
843
- return BaseModelOutputWithPast(
844
- last_hidden_state=hidden_states,
845
- past_key_values=presents,
846
- hidden_states=all_hidden_states,
847
- attentions=all_self_attentions,
848
- )
849
-
850
- def quantize(self, weight_bit_width: int):
851
- from .quantization import quantize
852
- quantize(self.encoder, weight_bit_width)
853
- return self
854
-
855
-
856
- class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
857
- def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
858
- super().__init__(config)
859
-
860
- self.max_sequence_length = config.max_length
861
- self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
862
- self.config = config
863
- self.quantized = False
864
-
865
- if self.config.quantization_bit:
866
- self.quantize(self.config.quantization_bit, empty_init=True)
867
-
868
- def _update_model_kwargs_for_generation(
869
- self,
870
- outputs: ModelOutput,
871
- model_kwargs: Dict[str, Any],
872
- is_encoder_decoder: bool = False,
873
- standardize_cache_format: bool = False,
874
- ) -> Dict[str, Any]:
875
- # update past_key_values
876
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
877
- outputs, standardize_cache_format=standardize_cache_format
878
- )
879
-
880
- # update attention mask
881
- if "attention_mask" in model_kwargs:
882
- attention_mask = model_kwargs["attention_mask"]
883
- model_kwargs["attention_mask"] = torch.cat(
884
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
885
- )
886
-
887
- # update position ids
888
- if "position_ids" in model_kwargs:
889
- position_ids = model_kwargs["position_ids"]
890
- new_position_id = position_ids[..., -1:].clone()
891
- new_position_id += 1
892
- model_kwargs["position_ids"] = torch.cat(
893
- [position_ids, new_position_id], dim=-1
894
- )
895
-
896
- model_kwargs["is_first_forward"] = False
897
- return model_kwargs
898
-
899
- def prepare_inputs_for_generation(
900
- self,
901
- input_ids: torch.LongTensor,
902
- past_key_values: Optional[torch.Tensor] = None,
903
- attention_mask: Optional[torch.Tensor] = None,
904
- position_ids: Optional[torch.Tensor] = None,
905
- use_cache: Optional[bool] = None,
906
- is_first_forward: bool = True,
907
- **kwargs
908
- ) -> dict:
909
- # only last token for input_ids if past is not None
910
- if position_ids is None:
911
- position_ids = self.get_position_ids(input_ids, device=input_ids.device)
912
- if not is_first_forward:
913
- if past_key_values is not None:
914
- position_ids = position_ids[..., -1:]
915
- input_ids = input_ids[:, -1:]
916
- return {
917
- "input_ids": input_ids,
918
- "past_key_values": past_key_values,
919
- "position_ids": position_ids,
920
- "attention_mask": attention_mask,
921
- "return_last_logit": True,
922
- "use_cache": use_cache
923
- }
924
-
925
- def forward(
926
- self,
927
- input_ids: Optional[torch.Tensor] = None,
928
- position_ids: Optional[torch.Tensor] = None,
929
- attention_mask: Optional[torch.Tensor] = None,
930
- past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
931
- inputs_embeds: Optional[torch.Tensor] = None,
932
- labels: Optional[torch.Tensor] = None,
933
- use_cache: Optional[bool] = None,
934
- output_attentions: Optional[bool] = None,
935
- output_hidden_states: Optional[bool] = None,
936
- return_dict: Optional[bool] = None,
937
- return_last_logit: Optional[bool] = False,
938
- ):
939
- use_cache = use_cache if use_cache is not None else self.config.use_cache
940
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
941
-
942
- transformer_outputs = self.transformer(
943
- input_ids=input_ids,
944
- position_ids=position_ids,
945
- attention_mask=attention_mask,
946
- past_key_values=past_key_values,
947
- inputs_embeds=inputs_embeds,
948
- use_cache=use_cache,
949
- output_hidden_states=output_hidden_states,
950
- return_dict=return_dict,
951
- )
952
-
953
- hidden_states = transformer_outputs[0]
954
- if return_last_logit:
955
- hidden_states = hidden_states[-1:]
956
- lm_logits = self.transformer.output_layer(hidden_states)
957
- lm_logits = lm_logits.transpose(0, 1).contiguous()
958
-
959
- loss = None
960
- if labels is not None:
961
- lm_logits = lm_logits.to(torch.float32)
962
-
963
- # Shift so that tokens < n predict n
964
- shift_logits = lm_logits[..., :-1, :].contiguous()
965
- shift_labels = labels[..., 1:].contiguous()
966
- # Flatten the tokens
967
- loss_fct = CrossEntropyLoss(ignore_index=-100)
968
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
969
-
970
- lm_logits = lm_logits.to(hidden_states.dtype)
971
- loss = loss.to(hidden_states.dtype)
972
-
973
- if not return_dict:
974
- output = (lm_logits,) + transformer_outputs[1:]
975
- return ((loss,) + output) if loss is not None else output
976
-
977
- return CausalLMOutputWithPast(
978
- loss=loss,
979
- logits=lm_logits,
980
- past_key_values=transformer_outputs.past_key_values,
981
- hidden_states=transformer_outputs.hidden_states,
982
- attentions=transformer_outputs.attentions,
983
- )
984
-
985
- @staticmethod
986
- def _reorder_cache(
987
- past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
988
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
989
- """
990
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
991
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
992
- beam_idx at every generation step.
993
-
994
- Output shares the same memory storage as `past`.
995
- """
996
- return tuple(
997
- (
998
- layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
999
- layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
1000
- )
1001
- for layer_past in past
1002
- )
1003
-
1004
- def process_response(self, output, history):
1005
- content = ""
1006
- history = deepcopy(history)
1007
- for response in output.split("<|assistant|>"):
1008
- metadata, content = response.split("\n", maxsplit=1)
1009
- if not metadata.strip():
1010
- content = content.strip()
1011
- history.append({"role": "assistant", "metadata": metadata, "content": content})
1012
- content = content.replace("[[训练时间]]", "2023年")
1013
- else:
1014
- history.append({"role": "assistant", "metadata": metadata, "content": content})
1015
- if history[0]["role"] == "system" and "tools" in history[0]:
1016
- content = "\n".join(content.split("\n")[1:-1])
1017
- def tool_call(**kwargs):
1018
- return kwargs
1019
- parameters = eval(content)
1020
- content = {"name": metadata.strip(), "parameters": parameters}
1021
- else:
1022
- content = {"name": metadata.strip(), "content": content}
1023
- return content, history
1024
-
1025
- @torch.inference_mode()
1026
- def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1027
- max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1028
- **kwargs):
1029
- if history is None:
1030
- history = []
1031
- if logits_processor is None:
1032
- logits_processor = LogitsProcessorList()
1033
- logits_processor.append(InvalidScoreLogitsProcessor())
1034
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1035
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1036
- inputs = tokenizer.build_chat_input(query, history=history, role=role)
1037
- inputs = inputs.to(self.device)
1038
- eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1039
- tokenizer.get_command("<|observation|>")]
1040
- outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1041
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1042
- response = tokenizer.decode(outputs)
1043
- history.append({"role": role, "content": query})
1044
- response, history = self.process_response(response, history)
1045
- return response, history
1046
-
1047
- @torch.inference_mode()
1048
- def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1049
- past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1050
- logits_processor=None, return_past_key_values=False, **kwargs):
1051
- if history is None:
1052
- history = []
1053
- if logits_processor is None:
1054
- logits_processor = LogitsProcessorList()
1055
- logits_processor.append(InvalidScoreLogitsProcessor())
1056
- eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1057
- tokenizer.get_command("<|observation|>")]
1058
- gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1059
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1060
- if past_key_values is None:
1061
- inputs = tokenizer.build_chat_input(query, history=history, role=role)
1062
- else:
1063
- inputs = tokenizer.build_chat_input(query, role=role)
1064
- inputs = inputs.to(self.device)
1065
- if past_key_values is not None:
1066
- past_length = past_key_values[0][0].shape[0]
1067
- if self.transformer.pre_seq_len is not None:
1068
- past_length -= self.transformer.pre_seq_len
1069
- inputs.position_ids += past_length
1070
- attention_mask = inputs.attention_mask
1071
- attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1072
- inputs['attention_mask'] = attention_mask
1073
- history.append({"role": role, "content": query})
1074
- for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1075
- eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
1076
- **gen_kwargs):
1077
- if return_past_key_values:
1078
- outputs, past_key_values = outputs
1079
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1080
- response = tokenizer.decode(outputs)
1081
- if response and response[-1] != "�":
1082
- response, new_history = self.process_response(response, history)
1083
- if return_past_key_values:
1084
- yield response, new_history, past_key_values
1085
- else:
1086
- yield response, new_history
1087
-
1088
- @torch.inference_mode()
1089
- def stream_generate(
1090
- self,
1091
- input_ids,
1092
- generation_config: Optional[GenerationConfig] = None,
1093
- logits_processor: Optional[LogitsProcessorList] = None,
1094
- stopping_criteria: Optional[StoppingCriteriaList] = None,
1095
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1096
- return_past_key_values=False,
1097
- **kwargs,
1098
- ):
1099
- batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1100
-
1101
- if generation_config is None:
1102
- generation_config = self.generation_config
1103
- generation_config = copy.deepcopy(generation_config)
1104
- model_kwargs = generation_config.update(**kwargs)
1105
- model_kwargs["use_cache"] = generation_config.use_cache
1106
- bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1107
-
1108
- if isinstance(eos_token_id, int):
1109
- eos_token_id = [eos_token_id]
1110
- eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
1111
-
1112
- has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1113
- if has_default_max_length and generation_config.max_new_tokens is None:
1114
- warnings.warn(
1115
- f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1116
- "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1117
- " recommend using `max_new_tokens` to control the maximum length of the generation.",
1118
- UserWarning,
1119
- )
1120
- elif generation_config.max_new_tokens is not None:
1121
- generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1122
- if not has_default_max_length:
1123
- logger.warn(
1124
- f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1125
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1126
- "Please refer to the documentation for more information. "
1127
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1128
- UserWarning,
1129
- )
1130
-
1131
- if input_ids_seq_length >= generation_config.max_length:
1132
- input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1133
- logger.warning(
1134
- f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1135
- f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1136
- " increasing `max_new_tokens`."
1137
- )
1138
-
1139
- # 2. Set generation parameters if not already defined
1140
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1141
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1142
-
1143
- logits_processor = self._get_logits_processor(
1144
- generation_config=generation_config,
1145
- input_ids_seq_length=input_ids_seq_length,
1146
- encoder_input_ids=input_ids,
1147
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1148
- logits_processor=logits_processor,
1149
- )
1150
-
1151
- stopping_criteria = self._get_stopping_criteria(
1152
- generation_config=generation_config, stopping_criteria=stopping_criteria
1153
- )
1154
- logits_warper = self._get_logits_warper(generation_config)
1155
-
1156
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1157
- scores = None
1158
- while True:
1159
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1160
- # forward pass to get next token
1161
- outputs = self(
1162
- **model_inputs,
1163
- return_dict=True,
1164
- output_attentions=False,
1165
- output_hidden_states=False,
1166
- )
1167
-
1168
- next_token_logits = outputs.logits[:, -1, :]
1169
-
1170
- # pre-process distribution
1171
- next_token_scores = logits_processor(input_ids, next_token_logits)
1172
- next_token_scores = logits_warper(input_ids, next_token_scores)
1173
-
1174
- # sample
1175
- probs = nn.functional.softmax(next_token_scores, dim=-1)
1176
- if generation_config.do_sample:
1177
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1178
- else:
1179
- next_tokens = torch.argmax(probs, dim=-1)
1180
- # update generated ids, model inputs, and length for next step
1181
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1182
- model_kwargs = self._update_model_kwargs_for_generation(
1183
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1184
- )
1185
- unfinished_sequences = unfinished_sequences.mul(
1186
- next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1187
- )
1188
- if return_past_key_values:
1189
- yield input_ids, outputs.past_key_values
1190
- else:
1191
- yield input_ids
1192
- # stop when each sentence is finished, or if we exceed the maximum length
1193
- if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1194
- break
1195
-
1196
- def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
1197
- if bits == 0:
1198
- return
1199
-
1200
- from .quantization import quantize
1201
-
1202
- if self.quantized:
1203
- logger.info("Already quantized.")
1204
- return self
1205
-
1206
- self.quantized = True
1207
-
1208
- self.config.quantization_bit = bits
1209
-
1210
- self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
1211
- **kwargs)
1212
- return self
1213
-
1214
-
1215
- class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1216
- def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1217
- super().__init__(config)
1218
-
1219
- self.num_labels = config.num_labels
1220
- self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1221
-
1222
- self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1223
- if config.classifier_dropout is not None:
1224
- self.dropout = nn.Dropout(config.classifier_dropout)
1225
- else:
1226
- self.dropout = None
1227
- self.config = config
1228
-
1229
- if self.config.quantization_bit:
1230
- self.quantize(self.config.quantization_bit, empty_init=True)
1231
-
1232
- def forward(
1233
- self,
1234
- input_ids: Optional[torch.LongTensor] = None,
1235
- position_ids: Optional[torch.LongTensor] = None,
1236
- attention_mask: Optional[torch.Tensor] = None,
1237
- full_attention_mask: Optional[torch.Tensor] = None,
1238
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1239
- inputs_embeds: Optional[torch.LongTensor] = None,
1240
- labels: Optional[torch.LongTensor] = None,
1241
- use_cache: Optional[bool] = None,
1242
- output_hidden_states: Optional[bool] = None,
1243
- return_dict: Optional[bool] = None,
1244
- ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1245
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1246
-
1247
- transformer_outputs = self.transformer(
1248
- input_ids=input_ids,
1249
- position_ids=position_ids,
1250
- attention_mask=attention_mask,
1251
- full_attention_mask=full_attention_mask,
1252
- past_key_values=past_key_values,
1253
- inputs_embeds=inputs_embeds,
1254
- use_cache=use_cache,
1255
- output_hidden_states=output_hidden_states,
1256
- return_dict=return_dict,
1257
- )
1258
-
1259
- hidden_states = transformer_outputs[0]
1260
- pooled_hidden_states = hidden_states[-1]
1261
- if self.dropout is not None:
1262
- pooled_hidden_states = self.dropout(pooled_hidden_states)
1263
- logits = self.classifier_head(pooled_hidden_states)
1264
-
1265
- loss = None
1266
- if labels is not None:
1267
- if self.config.problem_type is None:
1268
- if self.num_labels == 1:
1269
- self.config.problem_type = "regression"
1270
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1271
- self.config.problem_type = "single_label_classification"
1272
- else:
1273
- self.config.problem_type = "multi_label_classification"
1274
-
1275
- if self.config.problem_type == "regression":
1276
- loss_fct = MSELoss()
1277
- if self.num_labels == 1:
1278
- loss = loss_fct(logits.squeeze().float(), labels.squeeze())
1279
- else:
1280
- loss = loss_fct(logits.float(), labels)
1281
- elif self.config.problem_type == "single_label_classification":
1282
- loss_fct = CrossEntropyLoss()
1283
- loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
1284
- elif self.config.problem_type == "multi_label_classification":
1285
- loss_fct = BCEWithLogitsLoss()
1286
- loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
1287
-
1288
- if not return_dict:
1289
- output = (logits,) + transformer_outputs[1:]
1290
- return ((loss,) + output) if loss is not None else output
1291
-
1292
- return SequenceClassifierOutputWithPast(
1293
- loss=loss,
1294
- logits=logits,
1295
- past_key_values=transformer_outputs.past_key_values,
1296
- hidden_states=transformer_outputs.hidden_states,
1297
- attentions=transformer_outputs.attentions,
1298
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kolors/models/tokenization_chatglm.py DELETED
@@ -1,300 +0,0 @@
1
- import json
2
- import os
3
- import re
4
- from typing import List, Optional, Union, Dict
5
- from sentencepiece import SentencePieceProcessor
6
- from transformers import PreTrainedTokenizer
7
- from transformers.utils import logging, PaddingStrategy
8
- from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
9
-
10
-
11
- class SPTokenizer:
12
- def __init__(self, model_path: str):
13
- # reload tokenizer
14
- assert os.path.isfile(model_path), model_path
15
- self.sp_model = SentencePieceProcessor(model_file=model_path)
16
-
17
- # BOS / EOS token IDs
18
- self.n_words: int = self.sp_model.vocab_size()
19
- self.bos_id: int = self.sp_model.bos_id()
20
- self.eos_id: int = self.sp_model.eos_id()
21
- self.pad_id: int = self.sp_model.unk_id()
22
- assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
23
-
24
- role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
25
- special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
26
- self.special_tokens = {}
27
- self.index_special_tokens = {}
28
- for token in special_tokens:
29
- self.special_tokens[token] = self.n_words
30
- self.index_special_tokens[self.n_words] = token
31
- self.n_words += 1
32
- self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
33
-
34
- def tokenize(self, s: str, encode_special_tokens=False):
35
- if encode_special_tokens:
36
- last_index = 0
37
- t = []
38
- for match in re.finditer(self.role_special_token_expression, s):
39
- if last_index < match.start():
40
- t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
41
- t.append(s[match.start():match.end()])
42
- last_index = match.end()
43
- if last_index < len(s):
44
- t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
45
- return t
46
- else:
47
- return self.sp_model.EncodeAsPieces(s)
48
-
49
- def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
50
- assert type(s) is str
51
- t = self.sp_model.encode(s)
52
- if bos:
53
- t = [self.bos_id] + t
54
- if eos:
55
- t = t + [self.eos_id]
56
- return t
57
-
58
- def decode(self, t: List[int]) -> str:
59
- text, buffer = "", []
60
- for token in t:
61
- if token in self.index_special_tokens:
62
- if buffer:
63
- text += self.sp_model.decode(buffer)
64
- buffer = []
65
- text += self.index_special_tokens[token]
66
- else:
67
- buffer.append(token)
68
- if buffer:
69
- text += self.sp_model.decode(buffer)
70
- return text
71
-
72
- def decode_tokens(self, tokens: List[str]) -> str:
73
- text = self.sp_model.DecodePieces(tokens)
74
- return text
75
-
76
- def convert_token_to_id(self, token):
77
- """ Converts a token (str) in an id using the vocab. """
78
- if token in self.special_tokens:
79
- return self.special_tokens[token]
80
- return self.sp_model.PieceToId(token)
81
-
82
- def convert_id_to_token(self, index):
83
- """Converts an index (integer) in a token (str) using the vocab."""
84
- if index in self.index_special_tokens:
85
- return self.index_special_tokens[index]
86
- if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
87
- return ""
88
- return self.sp_model.IdToPiece(index)
89
-
90
-
91
- class ChatGLMTokenizer(PreTrainedTokenizer):
92
- vocab_files_names = {"vocab_file": "tokenizer.model"}
93
-
94
- model_input_names = ["input_ids", "attention_mask", "position_ids"]
95
-
96
- def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
97
- **kwargs):
98
- self.name = "GLMTokenizer"
99
-
100
- self.vocab_file = vocab_file
101
- self.tokenizer = SPTokenizer(vocab_file)
102
- self.special_tokens = {
103
- "<bos>": self.tokenizer.bos_id,
104
- "<eos>": self.tokenizer.eos_id,
105
- "<pad>": self.tokenizer.pad_id
106
- }
107
- self.encode_special_tokens = encode_special_tokens
108
- super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
109
- encode_special_tokens=encode_special_tokens,
110
- **kwargs)
111
-
112
- def get_command(self, token):
113
- if token in self.special_tokens:
114
- return self.special_tokens[token]
115
- assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
116
- return self.tokenizer.special_tokens[token]
117
-
118
- @property
119
- def unk_token(self) -> str:
120
- return "<unk>"
121
-
122
- @property
123
- def pad_token(self) -> str:
124
- return "<unk>"
125
-
126
- @property
127
- def pad_token_id(self):
128
- return self.get_command("<pad>")
129
-
130
- @property
131
- def eos_token(self) -> str:
132
- return "</s>"
133
-
134
- @property
135
- def eos_token_id(self):
136
- return self.get_command("<eos>")
137
-
138
- @property
139
- def vocab_size(self):
140
- return self.tokenizer.n_words
141
-
142
- def get_vocab(self):
143
- """ Returns vocab as a dict """
144
- vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
145
- vocab.update(self.added_tokens_encoder)
146
- return vocab
147
-
148
- def _tokenize(self, text, **kwargs):
149
- return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
150
-
151
- def _convert_token_to_id(self, token):
152
- """ Converts a token (str) in an id using the vocab. """
153
- return self.tokenizer.convert_token_to_id(token)
154
-
155
- def _convert_id_to_token(self, index):
156
- """Converts an index (integer) in a token (str) using the vocab."""
157
- return self.tokenizer.convert_id_to_token(index)
158
-
159
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
160
- return self.tokenizer.decode_tokens(tokens)
161
-
162
- def save_vocabulary(self, save_directory, filename_prefix=None):
163
- """
164
- Save the vocabulary and special tokens file to a directory.
165
-
166
- Args:
167
- save_directory (`str`):
168
- The directory in which to save the vocabulary.
169
- filename_prefix (`str`, *optional*):
170
- An optional prefix to add to the named of the saved files.
171
-
172
- Returns:
173
- `Tuple(str)`: Paths to the files saved.
174
- """
175
- if os.path.isdir(save_directory):
176
- vocab_file = os.path.join(
177
- save_directory, self.vocab_files_names["vocab_file"]
178
- )
179
- else:
180
- vocab_file = save_directory
181
-
182
- with open(self.vocab_file, 'rb') as fin:
183
- proto_str = fin.read()
184
-
185
- with open(vocab_file, "wb") as writer:
186
- writer.write(proto_str)
187
-
188
- return (vocab_file,)
189
-
190
- def get_prefix_tokens(self):
191
- prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
192
- return prefix_tokens
193
-
194
- def build_single_message(self, role, metadata, message):
195
- assert role in ["system", "user", "assistant", "observation"], role
196
- role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
197
- message_tokens = self.tokenizer.encode(message)
198
- tokens = role_tokens + message_tokens
199
- return tokens
200
-
201
- def build_chat_input(self, query, history=None, role="user"):
202
- if history is None:
203
- history = []
204
- input_ids = []
205
- for item in history:
206
- content = item["content"]
207
- if item["role"] == "system" and "tools" in item:
208
- content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
209
- input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
210
- input_ids.extend(self.build_single_message(role, "", query))
211
- input_ids.extend([self.get_command("<|assistant|>")])
212
- return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
213
-
214
- def build_inputs_with_special_tokens(
215
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
216
- ) -> List[int]:
217
- """
218
- Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
219
- adding special tokens. A BERT sequence has the following format:
220
-
221
- - single sequence: `[CLS] X [SEP]`
222
- - pair of sequences: `[CLS] A [SEP] B [SEP]`
223
-
224
- Args:
225
- token_ids_0 (`List[int]`):
226
- List of IDs to which the special tokens will be added.
227
- token_ids_1 (`List[int]`, *optional*):
228
- Optional second list of IDs for sequence pairs.
229
-
230
- Returns:
231
- `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
232
- """
233
- prefix_tokens = self.get_prefix_tokens()
234
- token_ids_0 = prefix_tokens + token_ids_0
235
- if token_ids_1 is not None:
236
- token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
237
- return token_ids_0
238
-
239
- def _pad(
240
- self,
241
- encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
242
- max_length: Optional[int] = None,
243
- padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
244
- pad_to_multiple_of: Optional[int] = None,
245
- return_attention_mask: Optional[bool] = None,
246
- ) -> dict:
247
- """
248
- Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
249
-
250
- Args:
251
- encoded_inputs:
252
- Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
253
- max_length: maximum length of the returned list and optionally padding length (see below).
254
- Will truncate by taking into account the special tokens.
255
- padding_strategy: PaddingStrategy to use for padding.
256
-
257
- - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
258
- - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
259
- - PaddingStrategy.DO_NOT_PAD: Do not pad
260
- The tokenizer padding sides are defined in self.padding_side:
261
-
262
- - 'left': pads on the left of the sequences
263
- - 'right': pads on the right of the sequences
264
- pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
265
- This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
266
- `>= 7.5` (Volta).
267
- return_attention_mask:
268
- (optional) Set to False to avoid returning attention mask (default: set to model specifics)
269
- """
270
- # Load from model defaults
271
- assert self.padding_side == "left"
272
-
273
- required_input = encoded_inputs[self.model_input_names[0]]
274
- seq_length = len(required_input)
275
-
276
- if padding_strategy == PaddingStrategy.LONGEST:
277
- max_length = len(required_input)
278
-
279
- if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
280
- max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
281
-
282
- needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
283
-
284
- # Initialize attention mask if not present.
285
- if "attention_mask" not in encoded_inputs:
286
- encoded_inputs["attention_mask"] = [1] * seq_length
287
-
288
- if "position_ids" not in encoded_inputs:
289
- encoded_inputs["position_ids"] = list(range(seq_length))
290
-
291
- if needs_to_be_padded:
292
- difference = max_length - len(required_input)
293
-
294
- if "attention_mask" in encoded_inputs:
295
- encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
296
- if "position_ids" in encoded_inputs:
297
- encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
298
- encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
299
-
300
- return encoded_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kolors/models/unet_2d_condition.py DELETED
@@ -1,1318 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, List, Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.nn as nn
19
- import torch.utils.checkpoint
20
-
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
23
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
24
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
- from diffusers.models.activations import get_activation
26
- from diffusers.models.attention_processor import (
27
- ADDED_KV_ATTENTION_PROCESSORS,
28
- CROSS_ATTENTION_PROCESSORS,
29
- Attention,
30
- AttentionProcessor,
31
- AttnAddedKVProcessor,
32
- AttnProcessor,
33
- )
34
- from diffusers.models.embeddings import (
35
- GaussianFourierProjection,
36
- GLIGENTextBoundingboxProjection,
37
- ImageHintTimeEmbedding,
38
- ImageProjection,
39
- ImageTimeEmbedding,
40
- TextImageProjection,
41
- TextImageTimeEmbedding,
42
- TextTimeEmbedding,
43
- TimestepEmbedding,
44
- Timesteps,
45
- )
46
- from diffusers.models.modeling_utils import ModelMixin
47
-
48
- try:
49
- from diffusers.models.unet_2d_blocks import (
50
- get_down_block,
51
- get_mid_block,
52
- get_up_block,
53
- )
54
- except:
55
- from diffusers.models.unets.unet_2d_blocks import (
56
- get_down_block,
57
- get_mid_block,
58
- get_up_block,
59
- )
60
-
61
-
62
-
63
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
64
-
65
-
66
- @dataclass
67
- class UNet2DConditionOutput(BaseOutput):
68
- """
69
- The output of [`UNet2DConditionModel`].
70
-
71
- Args:
72
- sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
73
- The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
74
- """
75
-
76
- sample: torch.Tensor = None
77
-
78
-
79
- class UNet2DConditionModel(
80
- ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
81
- ):
82
- r"""
83
- A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
84
- shaped output.
85
-
86
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
87
- for all models (such as downloading or saving).
88
-
89
- Parameters:
90
- sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
91
- Height and width of input/output sample.
92
- in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
93
- out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
94
- center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
95
- flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
96
- Whether to flip the sin to cos in the time embedding.
97
- freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
98
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
99
- The tuple of downsample blocks to use.
100
- mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
101
- Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
102
- `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
103
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
104
- The tuple of upsample blocks to use.
105
- only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
106
- Whether to include self-attention in the basic transformer blocks, see
107
- [`~models.attention.BasicTransformerBlock`].
108
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
109
- The tuple of output channels for each block.
110
- layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
111
- downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
112
- mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
113
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
114
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
115
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
116
- If `None`, normalization and activation layers is skipped in post-processing.
117
- norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
118
- cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
119
- The dimension of the cross attention features.
120
- transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
121
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
122
- [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
123
- [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
124
- reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
125
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
126
- blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
127
- [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
128
- [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
129
- encoder_hid_dim (`int`, *optional*, defaults to None):
130
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
131
- dimension to `cross_attention_dim`.
132
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
133
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
134
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
135
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
136
- num_attention_heads (`int`, *optional*):
137
- The number of attention heads. If not defined, defaults to `attention_head_dim`
138
- resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
139
- for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
140
- class_embed_type (`str`, *optional*, defaults to `None`):
141
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
142
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
143
- addition_embed_type (`str`, *optional*, defaults to `None`):
144
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
145
- "text". "text" will use the `TextTimeEmbedding` layer.
146
- addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
147
- Dimension for the timestep embeddings.
148
- num_class_embeds (`int`, *optional*, defaults to `None`):
149
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
150
- class conditioning with `class_embed_type` equal to `None`.
151
- time_embedding_type (`str`, *optional*, defaults to `positional`):
152
- The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
153
- time_embedding_dim (`int`, *optional*, defaults to `None`):
154
- An optional override for the dimension of the projected time embedding.
155
- time_embedding_act_fn (`str`, *optional*, defaults to `None`):
156
- Optional activation function to use only once on the time embeddings before they are passed to the rest of
157
- the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
158
- timestep_post_act (`str`, *optional*, defaults to `None`):
159
- The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
160
- time_cond_proj_dim (`int`, *optional*, defaults to `None`):
161
- The dimension of `cond_proj` layer in the timestep embedding.
162
- conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
163
- conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
164
- projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
165
- `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
166
- class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
167
- embeddings with the class embeddings.
168
- mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
169
- Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
170
- `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
171
- `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
172
- otherwise.
173
- """
174
-
175
- _supports_gradient_checkpointing = True
176
- _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
177
-
178
- @register_to_config
179
- def __init__(
180
- self,
181
- sample_size: Optional[int] = None,
182
- in_channels: int = 4,
183
- out_channels: int = 4,
184
- center_input_sample: bool = False,
185
- flip_sin_to_cos: bool = True,
186
- freq_shift: int = 0,
187
- down_block_types: Tuple[str] = (
188
- "CrossAttnDownBlock2D",
189
- "CrossAttnDownBlock2D",
190
- "CrossAttnDownBlock2D",
191
- "DownBlock2D",
192
- ),
193
- mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
194
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
195
- only_cross_attention: Union[bool, Tuple[bool]] = False,
196
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
197
- layers_per_block: Union[int, Tuple[int]] = 2,
198
- downsample_padding: int = 1,
199
- mid_block_scale_factor: float = 1,
200
- dropout: float = 0.0,
201
- act_fn: str = "silu",
202
- norm_num_groups: Optional[int] = 32,
203
- norm_eps: float = 1e-5,
204
- cross_attention_dim: Union[int, Tuple[int]] = 1280,
205
- transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
206
- reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
207
- encoder_hid_dim: Optional[int] = None,
208
- encoder_hid_dim_type: Optional[str] = None,
209
- attention_head_dim: Union[int, Tuple[int]] = 8,
210
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
211
- dual_cross_attention: bool = False,
212
- use_linear_projection: bool = False,
213
- class_embed_type: Optional[str] = None,
214
- addition_embed_type: Optional[str] = None,
215
- addition_time_embed_dim: Optional[int] = None,
216
- num_class_embeds: Optional[int] = None,
217
- upcast_attention: bool = False,
218
- resnet_time_scale_shift: str = "default",
219
- resnet_skip_time_act: bool = False,
220
- resnet_out_scale_factor: float = 1.0,
221
- time_embedding_type: str = "positional",
222
- time_embedding_dim: Optional[int] = None,
223
- time_embedding_act_fn: Optional[str] = None,
224
- timestep_post_act: Optional[str] = None,
225
- time_cond_proj_dim: Optional[int] = None,
226
- conv_in_kernel: int = 3,
227
- conv_out_kernel: int = 3,
228
- projection_class_embeddings_input_dim: Optional[int] = None,
229
- attention_type: str = "default",
230
- class_embeddings_concat: bool = False,
231
- mid_block_only_cross_attention: Optional[bool] = None,
232
- cross_attention_norm: Optional[str] = None,
233
- addition_embed_type_num_heads: int = 64,
234
- ):
235
- super().__init__()
236
-
237
- self.sample_size = sample_size
238
-
239
- if num_attention_heads is not None:
240
- raise ValueError(
241
- "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
242
- )
243
-
244
- # If `num_attention_heads` is not defined (which is the case for most models)
245
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
246
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
247
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
248
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
249
- # which is why we correct for the naming here.
250
- num_attention_heads = num_attention_heads or attention_head_dim
251
-
252
- # Check inputs
253
- self._check_config(
254
- down_block_types=down_block_types,
255
- up_block_types=up_block_types,
256
- only_cross_attention=only_cross_attention,
257
- block_out_channels=block_out_channels,
258
- layers_per_block=layers_per_block,
259
- cross_attention_dim=cross_attention_dim,
260
- transformer_layers_per_block=transformer_layers_per_block,
261
- reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
262
- attention_head_dim=attention_head_dim,
263
- num_attention_heads=num_attention_heads,
264
- )
265
-
266
- # input
267
- conv_in_padding = (conv_in_kernel - 1) // 2
268
- self.conv_in = nn.Conv2d(
269
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
270
- )
271
-
272
- # time
273
- time_embed_dim, timestep_input_dim = self._set_time_proj(
274
- time_embedding_type,
275
- block_out_channels=block_out_channels,
276
- flip_sin_to_cos=flip_sin_to_cos,
277
- freq_shift=freq_shift,
278
- time_embedding_dim=time_embedding_dim,
279
- )
280
-
281
- self.time_embedding = TimestepEmbedding(
282
- timestep_input_dim,
283
- time_embed_dim,
284
- act_fn=act_fn,
285
- post_act_fn=timestep_post_act,
286
- cond_proj_dim=time_cond_proj_dim,
287
- )
288
-
289
- self._set_encoder_hid_proj(
290
- encoder_hid_dim_type,
291
- cross_attention_dim=cross_attention_dim,
292
- encoder_hid_dim=encoder_hid_dim,
293
- )
294
-
295
- # class embedding
296
- self._set_class_embedding(
297
- class_embed_type,
298
- act_fn=act_fn,
299
- num_class_embeds=num_class_embeds,
300
- projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
301
- time_embed_dim=time_embed_dim,
302
- timestep_input_dim=timestep_input_dim,
303
- )
304
-
305
- self._set_add_embedding(
306
- addition_embed_type,
307
- addition_embed_type_num_heads=addition_embed_type_num_heads,
308
- addition_time_embed_dim=addition_time_embed_dim,
309
- cross_attention_dim=cross_attention_dim,
310
- encoder_hid_dim=encoder_hid_dim,
311
- flip_sin_to_cos=flip_sin_to_cos,
312
- freq_shift=freq_shift,
313
- projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
314
- time_embed_dim=time_embed_dim,
315
- )
316
-
317
- if time_embedding_act_fn is None:
318
- self.time_embed_act = None
319
- else:
320
- self.time_embed_act = get_activation(time_embedding_act_fn)
321
-
322
- self.down_blocks = nn.ModuleList([])
323
- self.up_blocks = nn.ModuleList([])
324
-
325
- if isinstance(only_cross_attention, bool):
326
- if mid_block_only_cross_attention is None:
327
- mid_block_only_cross_attention = only_cross_attention
328
-
329
- only_cross_attention = [only_cross_attention] * len(down_block_types)
330
-
331
- if mid_block_only_cross_attention is None:
332
- mid_block_only_cross_attention = False
333
-
334
- if isinstance(num_attention_heads, int):
335
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
336
-
337
- if isinstance(attention_head_dim, int):
338
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
339
-
340
- if isinstance(cross_attention_dim, int):
341
- cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
342
-
343
- if isinstance(layers_per_block, int):
344
- layers_per_block = [layers_per_block] * len(down_block_types)
345
-
346
- if isinstance(transformer_layers_per_block, int):
347
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
348
-
349
- if class_embeddings_concat:
350
- # The time embeddings are concatenated with the class embeddings. The dimension of the
351
- # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
352
- # regular time embeddings
353
- blocks_time_embed_dim = time_embed_dim * 2
354
- else:
355
- blocks_time_embed_dim = time_embed_dim
356
-
357
- # down
358
- output_channel = block_out_channels[0]
359
- for i, down_block_type in enumerate(down_block_types):
360
- input_channel = output_channel
361
- output_channel = block_out_channels[i]
362
- is_final_block = i == len(block_out_channels) - 1
363
-
364
- down_block = get_down_block(
365
- down_block_type,
366
- num_layers=layers_per_block[i],
367
- transformer_layers_per_block=transformer_layers_per_block[i],
368
- in_channels=input_channel,
369
- out_channels=output_channel,
370
- temb_channels=blocks_time_embed_dim,
371
- add_downsample=not is_final_block,
372
- resnet_eps=norm_eps,
373
- resnet_act_fn=act_fn,
374
- resnet_groups=norm_num_groups,
375
- cross_attention_dim=cross_attention_dim[i],
376
- num_attention_heads=num_attention_heads[i],
377
- downsample_padding=downsample_padding,
378
- dual_cross_attention=dual_cross_attention,
379
- use_linear_projection=use_linear_projection,
380
- only_cross_attention=only_cross_attention[i],
381
- upcast_attention=upcast_attention,
382
- resnet_time_scale_shift=resnet_time_scale_shift,
383
- attention_type=attention_type,
384
- resnet_skip_time_act=resnet_skip_time_act,
385
- resnet_out_scale_factor=resnet_out_scale_factor,
386
- cross_attention_norm=cross_attention_norm,
387
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
388
- dropout=dropout,
389
- )
390
- self.down_blocks.append(down_block)
391
-
392
- # mid
393
- self.mid_block = get_mid_block(
394
- mid_block_type,
395
- temb_channels=blocks_time_embed_dim,
396
- in_channels=block_out_channels[-1],
397
- resnet_eps=norm_eps,
398
- resnet_act_fn=act_fn,
399
- resnet_groups=norm_num_groups,
400
- output_scale_factor=mid_block_scale_factor,
401
- transformer_layers_per_block=transformer_layers_per_block[-1],
402
- num_attention_heads=num_attention_heads[-1],
403
- cross_attention_dim=cross_attention_dim[-1],
404
- dual_cross_attention=dual_cross_attention,
405
- use_linear_projection=use_linear_projection,
406
- mid_block_only_cross_attention=mid_block_only_cross_attention,
407
- upcast_attention=upcast_attention,
408
- resnet_time_scale_shift=resnet_time_scale_shift,
409
- attention_type=attention_type,
410
- resnet_skip_time_act=resnet_skip_time_act,
411
- cross_attention_norm=cross_attention_norm,
412
- attention_head_dim=attention_head_dim[-1],
413
- dropout=dropout,
414
- )
415
-
416
- # count how many layers upsample the images
417
- self.num_upsamplers = 0
418
-
419
- # up
420
- reversed_block_out_channels = list(reversed(block_out_channels))
421
- reversed_num_attention_heads = list(reversed(num_attention_heads))
422
- reversed_layers_per_block = list(reversed(layers_per_block))
423
- reversed_cross_attention_dim = list(reversed(cross_attention_dim))
424
- reversed_transformer_layers_per_block = (
425
- list(reversed(transformer_layers_per_block))
426
- if reverse_transformer_layers_per_block is None
427
- else reverse_transformer_layers_per_block
428
- )
429
- only_cross_attention = list(reversed(only_cross_attention))
430
-
431
- output_channel = reversed_block_out_channels[0]
432
- for i, up_block_type in enumerate(up_block_types):
433
- is_final_block = i == len(block_out_channels) - 1
434
-
435
- prev_output_channel = output_channel
436
- output_channel = reversed_block_out_channels[i]
437
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
438
-
439
- # add upsample block for all BUT final layer
440
- if not is_final_block:
441
- add_upsample = True
442
- self.num_upsamplers += 1
443
- else:
444
- add_upsample = False
445
-
446
- up_block = get_up_block(
447
- up_block_type,
448
- num_layers=reversed_layers_per_block[i] + 1,
449
- transformer_layers_per_block=reversed_transformer_layers_per_block[i],
450
- in_channels=input_channel,
451
- out_channels=output_channel,
452
- prev_output_channel=prev_output_channel,
453
- temb_channels=blocks_time_embed_dim,
454
- add_upsample=add_upsample,
455
- resnet_eps=norm_eps,
456
- resnet_act_fn=act_fn,
457
- resolution_idx=i,
458
- resnet_groups=norm_num_groups,
459
- cross_attention_dim=reversed_cross_attention_dim[i],
460
- num_attention_heads=reversed_num_attention_heads[i],
461
- dual_cross_attention=dual_cross_attention,
462
- use_linear_projection=use_linear_projection,
463
- only_cross_attention=only_cross_attention[i],
464
- upcast_attention=upcast_attention,
465
- resnet_time_scale_shift=resnet_time_scale_shift,
466
- attention_type=attention_type,
467
- resnet_skip_time_act=resnet_skip_time_act,
468
- resnet_out_scale_factor=resnet_out_scale_factor,
469
- cross_attention_norm=cross_attention_norm,
470
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
471
- dropout=dropout,
472
- )
473
- self.up_blocks.append(up_block)
474
- prev_output_channel = output_channel
475
-
476
- # out
477
- if norm_num_groups is not None:
478
- self.conv_norm_out = nn.GroupNorm(
479
- num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
480
- )
481
-
482
- self.conv_act = get_activation(act_fn)
483
-
484
- else:
485
- self.conv_norm_out = None
486
- self.conv_act = None
487
-
488
- conv_out_padding = (conv_out_kernel - 1) // 2
489
- self.conv_out = nn.Conv2d(
490
- block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
491
- )
492
-
493
- self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
494
-
495
- def _check_config(
496
- self,
497
- down_block_types: Tuple[str],
498
- up_block_types: Tuple[str],
499
- only_cross_attention: Union[bool, Tuple[bool]],
500
- block_out_channels: Tuple[int],
501
- layers_per_block: Union[int, Tuple[int]],
502
- cross_attention_dim: Union[int, Tuple[int]],
503
- transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
504
- reverse_transformer_layers_per_block: bool,
505
- attention_head_dim: int,
506
- num_attention_heads: Optional[Union[int, Tuple[int]]],
507
- ):
508
- if len(down_block_types) != len(up_block_types):
509
- raise ValueError(
510
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
511
- )
512
-
513
- if len(block_out_channels) != len(down_block_types):
514
- raise ValueError(
515
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
516
- )
517
-
518
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
519
- raise ValueError(
520
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
521
- )
522
-
523
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
524
- raise ValueError(
525
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
526
- )
527
-
528
- if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
529
- raise ValueError(
530
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
531
- )
532
-
533
- if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
534
- raise ValueError(
535
- f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
536
- )
537
-
538
- if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
539
- raise ValueError(
540
- f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
541
- )
542
- if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
543
- for layer_number_per_block in transformer_layers_per_block:
544
- if isinstance(layer_number_per_block, list):
545
- raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
546
-
547
- def _set_time_proj(
548
- self,
549
- time_embedding_type: str,
550
- block_out_channels: int,
551
- flip_sin_to_cos: bool,
552
- freq_shift: float,
553
- time_embedding_dim: int,
554
- ) -> Tuple[int, int]:
555
- if time_embedding_type == "fourier":
556
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
557
- if time_embed_dim % 2 != 0:
558
- raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
559
- self.time_proj = GaussianFourierProjection(
560
- time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
561
- )
562
- timestep_input_dim = time_embed_dim
563
- elif time_embedding_type == "positional":
564
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
565
-
566
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
567
- timestep_input_dim = block_out_channels[0]
568
- else:
569
- raise ValueError(
570
- f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
571
- )
572
-
573
- return time_embed_dim, timestep_input_dim
574
-
575
- def _set_encoder_hid_proj(
576
- self,
577
- encoder_hid_dim_type: Optional[str],
578
- cross_attention_dim: Union[int, Tuple[int]],
579
- encoder_hid_dim: Optional[int],
580
- ):
581
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
582
- encoder_hid_dim_type = "text_proj"
583
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
584
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
585
-
586
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
587
- raise ValueError(
588
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
589
- )
590
-
591
- if encoder_hid_dim_type == "text_proj":
592
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
593
- elif encoder_hid_dim_type == "text_image_proj":
594
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
595
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
596
- # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
597
- self.encoder_hid_proj = TextImageProjection(
598
- text_embed_dim=encoder_hid_dim,
599
- image_embed_dim=cross_attention_dim,
600
- cross_attention_dim=cross_attention_dim,
601
- )
602
- elif encoder_hid_dim_type == "image_proj":
603
- # Kandinsky 2.2
604
- self.encoder_hid_proj = ImageProjection(
605
- image_embed_dim=encoder_hid_dim,
606
- cross_attention_dim=cross_attention_dim,
607
- )
608
- elif encoder_hid_dim_type is not None:
609
- raise ValueError(
610
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
611
- )
612
- else:
613
- self.encoder_hid_proj = None
614
-
615
- def _set_class_embedding(
616
- self,
617
- class_embed_type: Optional[str],
618
- act_fn: str,
619
- num_class_embeds: Optional[int],
620
- projection_class_embeddings_input_dim: Optional[int],
621
- time_embed_dim: int,
622
- timestep_input_dim: int,
623
- ):
624
- if class_embed_type is None and num_class_embeds is not None:
625
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
626
- elif class_embed_type == "timestep":
627
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
628
- elif class_embed_type == "identity":
629
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
630
- elif class_embed_type == "projection":
631
- if projection_class_embeddings_input_dim is None:
632
- raise ValueError(
633
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
634
- )
635
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
636
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
637
- # 2. it projects from an arbitrary input dimension.
638
- #
639
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
640
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
641
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
642
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
643
- elif class_embed_type == "simple_projection":
644
- if projection_class_embeddings_input_dim is None:
645
- raise ValueError(
646
- "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
647
- )
648
- self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
649
- else:
650
- self.class_embedding = None
651
-
652
- def _set_add_embedding(
653
- self,
654
- addition_embed_type: str,
655
- addition_embed_type_num_heads: int,
656
- addition_time_embed_dim: Optional[int],
657
- flip_sin_to_cos: bool,
658
- freq_shift: float,
659
- cross_attention_dim: Optional[int],
660
- encoder_hid_dim: Optional[int],
661
- projection_class_embeddings_input_dim: Optional[int],
662
- time_embed_dim: int,
663
- ):
664
- if addition_embed_type == "text":
665
- if encoder_hid_dim is not None:
666
- text_time_embedding_from_dim = encoder_hid_dim
667
- else:
668
- text_time_embedding_from_dim = cross_attention_dim
669
-
670
- self.add_embedding = TextTimeEmbedding(
671
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
672
- )
673
- elif addition_embed_type == "text_image":
674
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
675
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
676
- # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
677
- self.add_embedding = TextImageTimeEmbedding(
678
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
679
- )
680
- elif addition_embed_type == "text_time":
681
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
682
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
683
- elif addition_embed_type == "image":
684
- # Kandinsky 2.2
685
- self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
686
- elif addition_embed_type == "image_hint":
687
- # Kandinsky 2.2 ControlNet
688
- self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
689
- elif addition_embed_type is not None:
690
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
691
-
692
- def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
693
- if attention_type in ["gated", "gated-text-image"]:
694
- positive_len = 768
695
- if isinstance(cross_attention_dim, int):
696
- positive_len = cross_attention_dim
697
- elif isinstance(cross_attention_dim, (list, tuple)):
698
- positive_len = cross_attention_dim[0]
699
-
700
- feature_type = "text-only" if attention_type == "gated" else "text-image"
701
- self.position_net = GLIGENTextBoundingboxProjection(
702
- positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
703
- )
704
-
705
- @property
706
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
707
- r"""
708
- Returns:
709
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
710
- indexed by its weight name.
711
- """
712
- # set recursively
713
- processors = {}
714
-
715
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
716
- if hasattr(module, "get_processor"):
717
- processors[f"{name}.processor"] = module.get_processor()
718
-
719
- for sub_name, child in module.named_children():
720
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
721
-
722
- return processors
723
-
724
- for name, module in self.named_children():
725
- fn_recursive_add_processors(name, module, processors)
726
-
727
- return processors
728
-
729
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
730
- r"""
731
- Sets the attention processor to use to compute attention.
732
-
733
- Parameters:
734
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
735
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
736
- for **all** `Attention` layers.
737
-
738
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
739
- processor. This is strongly recommended when setting trainable attention processors.
740
-
741
- """
742
- count = len(self.attn_processors.keys())
743
-
744
- if isinstance(processor, dict) and len(processor) != count:
745
- raise ValueError(
746
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
747
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
748
- )
749
-
750
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
751
- if hasattr(module, "set_processor"):
752
- if not isinstance(processor, dict):
753
- module.set_processor(processor)
754
- else:
755
- module.set_processor(processor.pop(f"{name}.processor"))
756
-
757
- for sub_name, child in module.named_children():
758
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
759
-
760
- for name, module in self.named_children():
761
- fn_recursive_attn_processor(name, module, processor)
762
-
763
- def set_default_attn_processor(self):
764
- """
765
- Disables custom attention processors and sets the default attention implementation.
766
- """
767
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
768
- processor = AttnAddedKVProcessor()
769
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
770
- processor = AttnProcessor()
771
- else:
772
- raise ValueError(
773
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
774
- )
775
-
776
- self.set_attn_processor(processor)
777
-
778
- def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
779
- r"""
780
- Enable sliced attention computation.
781
-
782
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
783
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
784
-
785
- Args:
786
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
787
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
788
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
789
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
790
- must be a multiple of `slice_size`.
791
- """
792
- sliceable_head_dims = []
793
-
794
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
795
- if hasattr(module, "set_attention_slice"):
796
- sliceable_head_dims.append(module.sliceable_head_dim)
797
-
798
- for child in module.children():
799
- fn_recursive_retrieve_sliceable_dims(child)
800
-
801
- # retrieve number of attention layers
802
- for module in self.children():
803
- fn_recursive_retrieve_sliceable_dims(module)
804
-
805
- num_sliceable_layers = len(sliceable_head_dims)
806
-
807
- if slice_size == "auto":
808
- # half the attention head size is usually a good trade-off between
809
- # speed and memory
810
- slice_size = [dim // 2 for dim in sliceable_head_dims]
811
- elif slice_size == "max":
812
- # make smallest slice possible
813
- slice_size = num_sliceable_layers * [1]
814
-
815
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
816
-
817
- if len(slice_size) != len(sliceable_head_dims):
818
- raise ValueError(
819
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
820
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
821
- )
822
-
823
- for i in range(len(slice_size)):
824
- size = slice_size[i]
825
- dim = sliceable_head_dims[i]
826
- if size is not None and size > dim:
827
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
828
-
829
- # Recursively walk through all the children.
830
- # Any children which exposes the set_attention_slice method
831
- # gets the message
832
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
833
- if hasattr(module, "set_attention_slice"):
834
- module.set_attention_slice(slice_size.pop())
835
-
836
- for child in module.children():
837
- fn_recursive_set_attention_slice(child, slice_size)
838
-
839
- reversed_slice_size = list(reversed(slice_size))
840
- for module in self.children():
841
- fn_recursive_set_attention_slice(module, reversed_slice_size)
842
-
843
- def _set_gradient_checkpointing(self, module, value=False):
844
- if hasattr(module, "gradient_checkpointing"):
845
- module.gradient_checkpointing = value
846
-
847
- def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
848
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
849
-
850
- The suffixes after the scaling factors represent the stage blocks where they are being applied.
851
-
852
- Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
853
- are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
854
-
855
- Args:
856
- s1 (`float`):
857
- Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
858
- mitigate the "oversmoothing effect" in the enhanced denoising process.
859
- s2 (`float`):
860
- Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
861
- mitigate the "oversmoothing effect" in the enhanced denoising process.
862
- b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
863
- b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
864
- """
865
- for i, upsample_block in enumerate(self.up_blocks):
866
- setattr(upsample_block, "s1", s1)
867
- setattr(upsample_block, "s2", s2)
868
- setattr(upsample_block, "b1", b1)
869
- setattr(upsample_block, "b2", b2)
870
-
871
- def disable_freeu(self):
872
- """Disables the FreeU mechanism."""
873
- freeu_keys = {"s1", "s2", "b1", "b2"}
874
- for i, upsample_block in enumerate(self.up_blocks):
875
- for k in freeu_keys:
876
- if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
877
- setattr(upsample_block, k, None)
878
-
879
- def fuse_qkv_projections(self):
880
- """
881
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
882
- are fused. For cross-attention modules, key and value projection matrices are fused.
883
-
884
- <Tip warning={true}>
885
-
886
- This API is 🧪 experimental.
887
-
888
- </Tip>
889
- """
890
- self.original_attn_processors = None
891
-
892
- for _, attn_processor in self.attn_processors.items():
893
- if "Added" in str(attn_processor.__class__.__name__):
894
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
895
-
896
- self.original_attn_processors = self.attn_processors
897
-
898
- for module in self.modules():
899
- if isinstance(module, Attention):
900
- module.fuse_projections(fuse=True)
901
-
902
- def unfuse_qkv_projections(self):
903
- """Disables the fused QKV projection if enabled.
904
-
905
- <Tip warning={true}>
906
-
907
- This API is 🧪 experimental.
908
-
909
- </Tip>
910
-
911
- """
912
- if self.original_attn_processors is not None:
913
- self.set_attn_processor(self.original_attn_processors)
914
-
915
- def get_time_embed(
916
- self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
917
- ) -> Optional[torch.Tensor]:
918
- timesteps = timestep
919
- if not torch.is_tensor(timesteps):
920
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
921
- # This would be a good case for the `match` statement (Python 3.10+)
922
- is_mps = sample.device.type == "mps"
923
- if isinstance(timestep, float):
924
- dtype = torch.float32 if is_mps else torch.float64
925
- else:
926
- dtype = torch.int32 if is_mps else torch.int64
927
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
928
- elif len(timesteps.shape) == 0:
929
- timesteps = timesteps[None].to(sample.device)
930
-
931
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
932
- timesteps = timesteps.expand(sample.shape[0])
933
-
934
- t_emb = self.time_proj(timesteps)
935
- # `Timesteps` does not contain any weights and will always return f32 tensors
936
- # but time_embedding might actually be running in fp16. so we need to cast here.
937
- # there might be better ways to encapsulate this.
938
- t_emb = t_emb.to(dtype=sample.dtype)
939
- return t_emb
940
-
941
- def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
942
- class_emb = None
943
- if self.class_embedding is not None:
944
- if class_labels is None:
945
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
946
-
947
- if self.config.class_embed_type == "timestep":
948
- class_labels = self.time_proj(class_labels)
949
-
950
- # `Timesteps` does not contain any weights and will always return f32 tensors
951
- # there might be better ways to encapsulate this.
952
- class_labels = class_labels.to(dtype=sample.dtype)
953
-
954
- class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
955
- return class_emb
956
-
957
- def get_aug_embed(
958
- self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
959
- ) -> Optional[torch.Tensor]:
960
- aug_emb = None
961
- if self.config.addition_embed_type == "text":
962
- aug_emb = self.add_embedding(encoder_hidden_states)
963
- elif self.config.addition_embed_type == "text_image":
964
- # Kandinsky 2.1 - style
965
- if "image_embeds" not in added_cond_kwargs:
966
- raise ValueError(
967
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
968
- )
969
-
970
- image_embs = added_cond_kwargs.get("image_embeds")
971
- text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
972
- aug_emb = self.add_embedding(text_embs, image_embs)
973
- elif self.config.addition_embed_type == "text_time":
974
- # SDXL - style
975
- if "text_embeds" not in added_cond_kwargs:
976
- raise ValueError(
977
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
978
- )
979
- text_embeds = added_cond_kwargs.get("text_embeds")
980
- if "time_ids" not in added_cond_kwargs:
981
- raise ValueError(
982
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
983
- )
984
- time_ids = added_cond_kwargs.get("time_ids")
985
- time_embeds = self.add_time_proj(time_ids.flatten())
986
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
987
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
988
- add_embeds = add_embeds.to(emb.dtype)
989
- aug_emb = self.add_embedding(add_embeds)
990
- elif self.config.addition_embed_type == "image":
991
- # Kandinsky 2.2 - style
992
- if "image_embeds" not in added_cond_kwargs:
993
- raise ValueError(
994
- f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
995
- )
996
- image_embs = added_cond_kwargs.get("image_embeds")
997
- aug_emb = self.add_embedding(image_embs)
998
- elif self.config.addition_embed_type == "image_hint":
999
- # Kandinsky 2.2 - style
1000
- if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1001
- raise ValueError(
1002
- f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1003
- )
1004
- image_embs = added_cond_kwargs.get("image_embeds")
1005
- hint = added_cond_kwargs.get("hint")
1006
- aug_emb = self.add_embedding(image_embs, hint)
1007
- return aug_emb
1008
-
1009
- def process_encoder_hidden_states(
1010
- self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1011
- ) -> torch.Tensor:
1012
- if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1013
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1014
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1015
- # Kandinsky 2.1 - style
1016
- if "image_embeds" not in added_cond_kwargs:
1017
- raise ValueError(
1018
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1019
- )
1020
-
1021
- image_embeds = added_cond_kwargs.get("image_embeds")
1022
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1023
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1024
- # Kandinsky 2.2 - style
1025
- if "image_embeds" not in added_cond_kwargs:
1026
- raise ValueError(
1027
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1028
- )
1029
- image_embeds = added_cond_kwargs.get("image_embeds")
1030
- encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1031
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1032
- if "image_embeds" not in added_cond_kwargs:
1033
- raise ValueError(
1034
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1035
- )
1036
-
1037
- if hasattr(self, 'text_encoder_hid_proj') and not self.text_encoder_hid_proj is None:
1038
- encoder_hidden_states = self.text_encoder_hid_proj( encoder_hidden_states )
1039
-
1040
- image_embeds = added_cond_kwargs.get("image_embeds")
1041
- image_embeds = self.encoder_hid_proj(image_embeds)
1042
- encoder_hidden_states = (encoder_hidden_states, image_embeds)
1043
- return encoder_hidden_states
1044
-
1045
- def forward(
1046
- self,
1047
- sample: torch.Tensor,
1048
- timestep: Union[torch.Tensor, float, int],
1049
- encoder_hidden_states: torch.Tensor,
1050
- class_labels: Optional[torch.Tensor] = None,
1051
- timestep_cond: Optional[torch.Tensor] = None,
1052
- attention_mask: Optional[torch.Tensor] = None,
1053
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1054
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1055
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1056
- mid_block_additional_residual: Optional[torch.Tensor] = None,
1057
- down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1058
- encoder_attention_mask: Optional[torch.Tensor] = None,
1059
- return_dict: bool = True,
1060
- ) -> Union[UNet2DConditionOutput, Tuple]:
1061
- r"""
1062
- The [`UNet2DConditionModel`] forward method.
1063
-
1064
- Args:
1065
- sample (`torch.Tensor`):
1066
- The noisy input tensor with the following shape `(batch, channel, height, width)`.
1067
- timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
1068
- encoder_hidden_states (`torch.Tensor`):
1069
- The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1070
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1071
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1072
- timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1073
- Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1074
- through the `self.time_embedding` layer to obtain the timestep embeddings.
1075
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1076
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1077
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1078
- negative values to the attention scores corresponding to "discard" tokens.
1079
- cross_attention_kwargs (`dict`, *optional*):
1080
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1081
- `self.processor` in
1082
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1083
- added_cond_kwargs: (`dict`, *optional*):
1084
- A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1085
- are passed along to the UNet blocks.
1086
- down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1087
- A tuple of tensors that if specified are added to the residuals of down unet blocks.
1088
- mid_block_additional_residual: (`torch.Tensor`, *optional*):
1089
- A tensor that if specified is added to the residual of the middle unet block.
1090
- down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1091
- additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1092
- encoder_attention_mask (`torch.Tensor`):
1093
- A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1094
- `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1095
- which adds large negative values to the attention scores corresponding to "discard" tokens.
1096
- return_dict (`bool`, *optional*, defaults to `True`):
1097
- Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1098
- tuple.
1099
-
1100
- Returns:
1101
- [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1102
- If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1103
- otherwise a `tuple` is returned where the first element is the sample tensor.
1104
- """
1105
- # By default samples have to be AT least a multiple of the overall upsampling factor.
1106
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1107
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
1108
- # on the fly if necessary.
1109
- default_overall_up_factor = 2**self.num_upsamplers
1110
-
1111
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1112
- forward_upsample_size = False
1113
- upsample_size = None
1114
-
1115
- for dim in sample.shape[-2:]:
1116
- if dim % default_overall_up_factor != 0:
1117
- # Forward upsample size to force interpolation output size.
1118
- forward_upsample_size = True
1119
- break
1120
-
1121
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1122
- # expects mask of shape:
1123
- # [batch, key_tokens]
1124
- # adds singleton query_tokens dimension:
1125
- # [batch, 1, key_tokens]
1126
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1127
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1128
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1129
- if attention_mask is not None:
1130
- # assume that mask is expressed as:
1131
- # (1 = keep, 0 = discard)
1132
- # convert mask into a bias that can be added to attention scores:
1133
- # (keep = +0, discard = -10000.0)
1134
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1135
- attention_mask = attention_mask.unsqueeze(1)
1136
-
1137
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
1138
- if encoder_attention_mask is not None:
1139
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1140
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1141
-
1142
- # 0. center input if necessary
1143
- if self.config.center_input_sample:
1144
- sample = 2 * sample - 1.0
1145
-
1146
- # 1. time
1147
- t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1148
- emb = self.time_embedding(t_emb, timestep_cond)
1149
- aug_emb = None
1150
-
1151
- class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1152
- if class_emb is not None:
1153
- if self.config.class_embeddings_concat:
1154
- emb = torch.cat([emb, class_emb], dim=-1)
1155
- else:
1156
- emb = emb + class_emb
1157
-
1158
- aug_emb = self.get_aug_embed(
1159
- emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1160
- )
1161
- if self.config.addition_embed_type == "image_hint":
1162
- aug_emb, hint = aug_emb
1163
- sample = torch.cat([sample, hint], dim=1)
1164
-
1165
- emb = emb + aug_emb if aug_emb is not None else emb
1166
-
1167
- if self.time_embed_act is not None:
1168
- emb = self.time_embed_act(emb)
1169
-
1170
- encoder_hidden_states = self.process_encoder_hidden_states(
1171
- encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1172
- )
1173
-
1174
- # 2. pre-process
1175
- sample = self.conv_in(sample)
1176
-
1177
- # 2.5 GLIGEN position net
1178
- if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1179
- cross_attention_kwargs = cross_attention_kwargs.copy()
1180
- gligen_args = cross_attention_kwargs.pop("gligen")
1181
- cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1182
-
1183
- # 3. down
1184
- # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1185
- # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1186
- if cross_attention_kwargs is not None:
1187
- cross_attention_kwargs = cross_attention_kwargs.copy()
1188
- lora_scale = cross_attention_kwargs.pop("scale", 1.0)
1189
- else:
1190
- lora_scale = 1.0
1191
-
1192
- if USE_PEFT_BACKEND:
1193
- # weight the lora layers by setting `lora_scale` for each PEFT layer
1194
- scale_lora_layers(self, lora_scale)
1195
-
1196
- is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1197
- # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1198
- is_adapter = down_intrablock_additional_residuals is not None
1199
- # maintain backward compatibility for legacy usage, where
1200
- # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1201
- # but can only use one or the other
1202
- if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1203
- deprecate(
1204
- "T2I should not use down_block_additional_residuals",
1205
- "1.3.0",
1206
- "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1207
- and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1208
- for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1209
- standard_warn=False,
1210
- )
1211
- down_intrablock_additional_residuals = down_block_additional_residuals
1212
- is_adapter = True
1213
-
1214
- down_block_res_samples = (sample,)
1215
- for downsample_block in self.down_blocks:
1216
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1217
- # For t2i-adapter CrossAttnDownBlock2D
1218
- additional_residuals = {}
1219
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
1220
- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1221
-
1222
- sample, res_samples = downsample_block(
1223
- hidden_states=sample,
1224
- temb=emb,
1225
- encoder_hidden_states=encoder_hidden_states,
1226
- attention_mask=attention_mask,
1227
- cross_attention_kwargs=cross_attention_kwargs,
1228
- encoder_attention_mask=encoder_attention_mask,
1229
- **additional_residuals,
1230
- )
1231
- else:
1232
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1233
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
1234
- sample += down_intrablock_additional_residuals.pop(0)
1235
-
1236
- down_block_res_samples += res_samples
1237
-
1238
- if is_controlnet:
1239
- new_down_block_res_samples = ()
1240
-
1241
- for down_block_res_sample, down_block_additional_residual in zip(
1242
- down_block_res_samples, down_block_additional_residuals
1243
- ):
1244
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
1245
- new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1246
-
1247
- down_block_res_samples = new_down_block_res_samples
1248
-
1249
- # 4. mid
1250
- if self.mid_block is not None:
1251
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1252
- sample = self.mid_block(
1253
- sample,
1254
- emb,
1255
- encoder_hidden_states=encoder_hidden_states,
1256
- attention_mask=attention_mask,
1257
- cross_attention_kwargs=cross_attention_kwargs,
1258
- encoder_attention_mask=encoder_attention_mask,
1259
- )
1260
- else:
1261
- sample = self.mid_block(sample, emb)
1262
-
1263
- # To support T2I-Adapter-XL
1264
- if (
1265
- is_adapter
1266
- and len(down_intrablock_additional_residuals) > 0
1267
- and sample.shape == down_intrablock_additional_residuals[0].shape
1268
- ):
1269
- sample += down_intrablock_additional_residuals.pop(0)
1270
-
1271
- if is_controlnet:
1272
- sample = sample + mid_block_additional_residual
1273
-
1274
- # 5. up
1275
- for i, upsample_block in enumerate(self.up_blocks):
1276
- is_final_block = i == len(self.up_blocks) - 1
1277
-
1278
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1279
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1280
-
1281
- # if we have not reached the final block and need to forward the
1282
- # upsample size, we do it here
1283
- if not is_final_block and forward_upsample_size:
1284
- upsample_size = down_block_res_samples[-1].shape[2:]
1285
-
1286
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1287
- sample = upsample_block(
1288
- hidden_states=sample,
1289
- temb=emb,
1290
- res_hidden_states_tuple=res_samples,
1291
- encoder_hidden_states=encoder_hidden_states,
1292
- cross_attention_kwargs=cross_attention_kwargs,
1293
- upsample_size=upsample_size,
1294
- attention_mask=attention_mask,
1295
- encoder_attention_mask=encoder_attention_mask,
1296
- )
1297
- else:
1298
- sample = upsample_block(
1299
- hidden_states=sample,
1300
- temb=emb,
1301
- res_hidden_states_tuple=res_samples,
1302
- upsample_size=upsample_size,
1303
- )
1304
-
1305
- # 6. post-process
1306
- if self.conv_norm_out:
1307
- sample = self.conv_norm_out(sample)
1308
- sample = self.conv_act(sample)
1309
- sample = self.conv_out(sample)
1310
-
1311
- if USE_PEFT_BACKEND:
1312
- # remove `lora_scale` from each PEFT layer
1313
- unscale_lora_layers(self, lora_scale)
1314
-
1315
- if not return_dict:
1316
- return (sample,)
1317
-
1318
- return UNet2DConditionOutput(sample=sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kolors/pipelines/__init__.py DELETED
File without changes
kolors/pipelines/pipeline_controlnet_xl_kolors_img2img.py DELETED
@@ -1,1365 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import inspect
17
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
-
19
- import numpy as np
20
- import PIL.Image
21
- import torch
22
- import torch.nn.functional as F
23
- from transformers import (
24
- CLIPImageProcessor,
25
- CLIPTextModel,
26
- CLIPTextModelWithProjection,
27
- CLIPTokenizer,
28
- CLIPVisionModelWithProjection,
29
- )
30
-
31
- from diffusers.utils.import_utils import is_invisible_watermark_available
32
-
33
- from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
34
- from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
35
- from diffusers.loaders import (
36
- FromSingleFileMixin,
37
- IPAdapterMixin,
38
- StableDiffusionXLLoraLoaderMixin,
39
- TextualInversionLoaderMixin,
40
- )
41
- from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
42
- from diffusers.models.attention_processor import (
43
- AttnProcessor2_0,
44
- XFormersAttnProcessor,
45
- )
46
- from diffusers.models.lora import adjust_lora_scale_text_encoder
47
- from diffusers.schedulers import KarrasDiffusionSchedulers
48
- from diffusers.utils import (
49
- USE_PEFT_BACKEND,
50
- deprecate,
51
- logging,
52
- replace_example_docstring,
53
- scale_lora_layers,
54
- unscale_lora_layers,
55
- )
56
- from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
57
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
58
- from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
59
- from diffusers.pipelines.controlnet import MultiControlNetModel
60
-
61
- from ..models.controlnet import ControlNetModel
62
-
63
-
64
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
-
66
-
67
-
68
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
69
- def retrieve_latents(
70
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
71
- ):
72
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
73
- return encoder_output.latent_dist.sample(generator)
74
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
75
- return encoder_output.latent_dist.mode()
76
- elif hasattr(encoder_output, "latents"):
77
- return encoder_output.latents
78
- else:
79
- raise AttributeError("Could not access latents of provided encoder_output")
80
-
81
-
82
- class StableDiffusionXLControlNetImg2ImgPipeline(
83
- DiffusionPipeline,
84
- StableDiffusionMixin,
85
- TextualInversionLoaderMixin,
86
- StableDiffusionXLLoraLoaderMixin,
87
- FromSingleFileMixin,
88
- IPAdapterMixin,
89
- ):
90
- r"""
91
- Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.
92
-
93
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
94
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
95
-
96
- The pipeline also inherits the following loading methods:
97
- - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
98
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
99
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
100
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
101
-
102
- Args:
103
- vae ([`AutoencoderKL`]):
104
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
105
- text_encoder ([`CLIPTextModel`]):
106
- Frozen text-encoder. Stable Diffusion uses the text portion of
107
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
108
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
109
- tokenizer (`CLIPTokenizer`):
110
- Tokenizer of class
111
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
112
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
113
- controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
114
- Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
115
- as a list, the outputs from each ControlNet are added together to create one combined additional
116
- conditioning.
117
- scheduler ([`SchedulerMixin`]):
118
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
119
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
120
- requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
121
- Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the
122
- config of `stabilityai/stable-diffusion-xl-refiner-1-0`.
123
- force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
124
- Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
125
- `stabilityai/stable-diffusion-xl-base-1-0`.
126
- add_watermarker (`bool`, *optional*):
127
- Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
128
- watermark output images. If not defined, it will default to True if the package is installed, otherwise no
129
- watermarker will be used.
130
- feature_extractor ([`~transformers.CLIPImageProcessor`]):
131
- A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
132
- """
133
-
134
- model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
135
- _optional_components = [
136
- "tokenizer",
137
- "text_encoder",
138
- "feature_extractor",
139
- "image_encoder",
140
- ]
141
- _callback_tensor_inputs = [
142
- "latents",
143
- "prompt_embeds",
144
- "negative_prompt_embeds",
145
- "add_text_embeds",
146
- "add_time_ids",
147
- "negative_pooled_prompt_embeds",
148
- "add_neg_time_ids",
149
- ]
150
-
151
- def __init__(
152
- self,
153
- vae: AutoencoderKL,
154
- text_encoder: CLIPTextModel,
155
- tokenizer: CLIPTokenizer,
156
- unet: UNet2DConditionModel,
157
- controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
158
- scheduler: KarrasDiffusionSchedulers,
159
- requires_aesthetics_score: bool = False,
160
- force_zeros_for_empty_prompt: bool = True,
161
- feature_extractor: CLIPImageProcessor = None,
162
- image_encoder: CLIPVisionModelWithProjection = None,
163
- ):
164
- super().__init__()
165
-
166
- if isinstance(controlnet, (list, tuple)):
167
- controlnet = MultiControlNetModel(controlnet)
168
-
169
- self.register_modules(
170
- vae=vae,
171
- text_encoder=text_encoder,
172
- tokenizer=tokenizer,
173
- unet=unet,
174
- controlnet=controlnet,
175
- scheduler=scheduler,
176
- feature_extractor=feature_extractor,
177
- image_encoder=image_encoder,
178
- )
179
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
180
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
181
- self.control_image_processor = VaeImageProcessor(
182
- vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
183
- )
184
-
185
- self.watermark = None
186
-
187
- self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
188
- self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
189
-
190
-
191
- def encode_prompt(
192
- self,
193
- prompt,
194
- device: Optional[torch.device] = None,
195
- num_images_per_prompt: int = 1,
196
- do_classifier_free_guidance: bool = True,
197
- negative_prompt=None,
198
- prompt_embeds: Optional[torch.FloatTensor] = None,
199
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
200
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
201
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
202
- lora_scale: Optional[float] = None,
203
- ):
204
- r"""
205
- Encodes the prompt into text encoder hidden states.
206
-
207
- Args:
208
- prompt (`str` or `List[str]`, *optional*):
209
- prompt to be encoded
210
- device: (`torch.device`):
211
- torch device
212
- num_images_per_prompt (`int`):
213
- number of images that should be generated per prompt
214
- do_classifier_free_guidance (`bool`):
215
- whether to use classifier free guidance or not
216
- negative_prompt (`str` or `List[str]`, *optional*):
217
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
218
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
219
- less than `1`).
220
- prompt_embeds (`torch.FloatTensor`, *optional*):
221
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
222
- provided, text embeddings will be generated from `prompt` input argument.
223
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
224
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
225
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
226
- argument.
227
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
228
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
229
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
230
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
231
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
232
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
233
- input argument.
234
- lora_scale (`float`, *optional*):
235
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
236
- """
237
- # from IPython import embed; embed(); exit()
238
- device = device or self._execution_device
239
-
240
- # set lora scale so that monkey patched LoRA
241
- # function of text encoder can correctly access it
242
- if lora_scale is not None and isinstance(self, LoraLoaderMixin):
243
- self._lora_scale = lora_scale
244
-
245
- if prompt is not None and isinstance(prompt, str):
246
- batch_size = 1
247
- elif prompt is not None and isinstance(prompt, list):
248
- batch_size = len(prompt)
249
- else:
250
- batch_size = prompt_embeds.shape[0]
251
-
252
- # Define tokenizers and text encoders
253
- tokenizers = [self.tokenizer]
254
- text_encoders = [self.text_encoder]
255
-
256
- if prompt_embeds is None:
257
- # textual inversion: procecss multi-vector tokens if necessary
258
- prompt_embeds_list = []
259
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
260
- if isinstance(self, TextualInversionLoaderMixin):
261
- prompt = self.maybe_convert_prompt(prompt, tokenizer)
262
-
263
- text_inputs = tokenizer(
264
- prompt,
265
- padding="max_length",
266
- max_length=256,
267
- truncation=True,
268
- return_tensors="pt",
269
- ).to('cuda')
270
- output = text_encoder(
271
- input_ids=text_inputs['input_ids'] ,
272
- attention_mask=text_inputs['attention_mask'],
273
- position_ids=text_inputs['position_ids'],
274
- output_hidden_states=True)
275
- prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
276
- pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
277
- bs_embed, seq_len, _ = prompt_embeds.shape
278
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
279
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
280
-
281
- prompt_embeds_list.append(prompt_embeds)
282
-
283
- # prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
284
- prompt_embeds = prompt_embeds_list[0]
285
-
286
- # get unconditional embeddings for classifier free guidance
287
- zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
288
- if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
289
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
290
- negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
291
- elif do_classifier_free_guidance and negative_prompt_embeds is None:
292
- # negative_prompt = negative_prompt or ""
293
- uncond_tokens: List[str]
294
- if negative_prompt is None:
295
- uncond_tokens = [""] * batch_size
296
- elif prompt is not None and type(prompt) is not type(negative_prompt):
297
- raise TypeError(
298
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
299
- f" {type(prompt)}."
300
- )
301
- elif isinstance(negative_prompt, str):
302
- uncond_tokens = [negative_prompt]
303
- elif batch_size != len(negative_prompt):
304
- raise ValueError(
305
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
306
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
307
- " the batch size of `prompt`."
308
- )
309
- else:
310
- uncond_tokens = negative_prompt
311
-
312
- negative_prompt_embeds_list = []
313
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
314
- # textual inversion: procecss multi-vector tokens if necessary
315
- if isinstance(self, TextualInversionLoaderMixin):
316
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
317
-
318
- max_length = prompt_embeds.shape[1]
319
- uncond_input = tokenizer(
320
- uncond_tokens,
321
- padding="max_length",
322
- max_length=max_length,
323
- truncation=True,
324
- return_tensors="pt",
325
- ).to('cuda')
326
- output = text_encoder(
327
- input_ids=uncond_input['input_ids'] ,
328
- attention_mask=uncond_input['attention_mask'],
329
- position_ids=uncond_input['position_ids'],
330
- output_hidden_states=True)
331
- negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
332
- negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
333
-
334
- if do_classifier_free_guidance:
335
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
336
- seq_len = negative_prompt_embeds.shape[1]
337
-
338
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
339
-
340
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
341
- negative_prompt_embeds = negative_prompt_embeds.view(
342
- batch_size * num_images_per_prompt, seq_len, -1
343
- )
344
-
345
- # For classifier free guidance, we need to do two forward passes.
346
- # Here we concatenate the unconditional and text embeddings into a single batch
347
- # to avoid doing two forward passes
348
-
349
- negative_prompt_embeds_list.append(negative_prompt_embeds)
350
-
351
- # negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
352
- negative_prompt_embeds = negative_prompt_embeds_list[0]
353
-
354
- bs_embed = pooled_prompt_embeds.shape[0]
355
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
356
- bs_embed * num_images_per_prompt, -1
357
- )
358
- if do_classifier_free_guidance:
359
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
360
- bs_embed * num_images_per_prompt, -1
361
- )
362
-
363
- return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
364
-
365
-
366
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
367
- def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
368
- dtype = next(self.image_encoder.parameters()).dtype
369
-
370
- if not isinstance(image, torch.Tensor):
371
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
372
-
373
- image = image.to(device=device, dtype=dtype)
374
- if output_hidden_states:
375
- image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
376
- image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
377
- uncond_image_enc_hidden_states = self.image_encoder(
378
- torch.zeros_like(image), output_hidden_states=True
379
- ).hidden_states[-2]
380
- uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
381
- num_images_per_prompt, dim=0
382
- )
383
- return image_enc_hidden_states, uncond_image_enc_hidden_states
384
- else:
385
- image_embeds = self.image_encoder(image).image_embeds
386
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
387
- uncond_image_embeds = torch.zeros_like(image_embeds)
388
-
389
- return image_embeds, uncond_image_embeds
390
-
391
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
392
- def prepare_ip_adapter_image_embeds(
393
- self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
394
- ):
395
- image_embeds = []
396
- if do_classifier_free_guidance:
397
- negative_image_embeds = []
398
- if ip_adapter_image_embeds is None:
399
- if not isinstance(ip_adapter_image, list):
400
- ip_adapter_image = [ip_adapter_image]
401
-
402
- if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
403
- raise ValueError(
404
- f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
405
- )
406
-
407
- for single_ip_adapter_image, image_proj_layer in zip(
408
- ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
409
- ):
410
- output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
411
- single_image_embeds, single_negative_image_embeds = self.encode_image(
412
- single_ip_adapter_image, device, 1, output_hidden_state
413
- )
414
-
415
- image_embeds.append(single_image_embeds[None, :])
416
- if do_classifier_free_guidance:
417
- negative_image_embeds.append(single_negative_image_embeds[None, :])
418
- else:
419
- for single_image_embeds in ip_adapter_image_embeds:
420
- if do_classifier_free_guidance:
421
- single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
422
- negative_image_embeds.append(single_negative_image_embeds)
423
- image_embeds.append(single_image_embeds)
424
-
425
- ip_adapter_image_embeds = []
426
- for i, single_image_embeds in enumerate(image_embeds):
427
- single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
428
- if do_classifier_free_guidance:
429
- single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
430
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
431
-
432
- single_image_embeds = single_image_embeds.to(device=device)
433
- ip_adapter_image_embeds.append(single_image_embeds)
434
-
435
- return ip_adapter_image_embeds
436
-
437
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
438
- def prepare_extra_step_kwargs(self, generator, eta):
439
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
440
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
441
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
442
- # and should be between [0, 1]
443
-
444
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
445
- extra_step_kwargs = {}
446
- if accepts_eta:
447
- extra_step_kwargs["eta"] = eta
448
-
449
- # check if the scheduler accepts generator
450
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
451
- if accepts_generator:
452
- extra_step_kwargs["generator"] = generator
453
- return extra_step_kwargs
454
-
455
- def check_inputs(
456
- self,
457
- prompt,
458
- image,
459
- strength,
460
- num_inference_steps,
461
- callback_steps,
462
- negative_prompt=None,
463
- prompt_embeds=None,
464
- negative_prompt_embeds=None,
465
- pooled_prompt_embeds=None,
466
- negative_pooled_prompt_embeds=None,
467
- ip_adapter_image=None,
468
- ip_adapter_image_embeds=None,
469
- controlnet_conditioning_scale=1.0,
470
- control_guidance_start=0.0,
471
- control_guidance_end=1.0,
472
- callback_on_step_end_tensor_inputs=None,
473
- ):
474
- if strength < 0 or strength > 1:
475
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
476
- if num_inference_steps is None:
477
- raise ValueError("`num_inference_steps` cannot be None.")
478
- elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
479
- raise ValueError(
480
- f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
481
- f" {type(num_inference_steps)}."
482
- )
483
-
484
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
485
- raise ValueError(
486
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
487
- f" {type(callback_steps)}."
488
- )
489
-
490
- if callback_on_step_end_tensor_inputs is not None and not all(
491
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
492
- ):
493
- raise ValueError(
494
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
495
- )
496
-
497
- if prompt is not None and prompt_embeds is not None:
498
- raise ValueError(
499
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
500
- " only forward one of the two."
501
- )
502
- elif prompt is None and prompt_embeds is None:
503
- raise ValueError(
504
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
505
- )
506
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
507
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
508
-
509
- if negative_prompt is not None and negative_prompt_embeds is not None:
510
- raise ValueError(
511
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
512
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
513
- )
514
-
515
- if prompt_embeds is not None and negative_prompt_embeds is not None:
516
- if prompt_embeds.shape != negative_prompt_embeds.shape:
517
- raise ValueError(
518
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
519
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
520
- f" {negative_prompt_embeds.shape}."
521
- )
522
-
523
- if prompt_embeds is not None and pooled_prompt_embeds is None:
524
- raise ValueError(
525
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
526
- )
527
-
528
- if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
529
- raise ValueError(
530
- "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
531
- )
532
-
533
- # `prompt` needs more sophisticated handling when there are multiple
534
- # conditionings.
535
- if isinstance(self.controlnet, MultiControlNetModel):
536
- if isinstance(prompt, list):
537
- logger.warning(
538
- f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
539
- " prompts. The conditionings will be fixed across the prompts."
540
- )
541
-
542
- # Check `image`
543
- is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
544
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
545
- )
546
- if (
547
- isinstance(self.controlnet, ControlNetModel)
548
- or is_compiled
549
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
550
- ):
551
- self.check_image(image, prompt, prompt_embeds)
552
- elif (
553
- isinstance(self.controlnet, MultiControlNetModel)
554
- or is_compiled
555
- and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
556
- ):
557
- if not isinstance(image, list):
558
- raise TypeError("For multiple controlnets: `image` must be type `list`")
559
-
560
- # When `image` is a nested list:
561
- # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
562
- elif any(isinstance(i, list) for i in image):
563
- raise ValueError("A single batch of multiple conditionings are supported at the moment.")
564
- elif len(image) != len(self.controlnet.nets):
565
- raise ValueError(
566
- f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
567
- )
568
-
569
- for image_ in image:
570
- self.check_image(image_, prompt, prompt_embeds)
571
- else:
572
- assert False
573
-
574
- # Check `controlnet_conditioning_scale`
575
- if (
576
- isinstance(self.controlnet, ControlNetModel)
577
- or is_compiled
578
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
579
- ):
580
- if not isinstance(controlnet_conditioning_scale, float):
581
- raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
582
- elif (
583
- isinstance(self.controlnet, MultiControlNetModel)
584
- or is_compiled
585
- and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
586
- ):
587
- if isinstance(controlnet_conditioning_scale, list):
588
- if any(isinstance(i, list) for i in controlnet_conditioning_scale):
589
- raise ValueError("A single batch of multiple conditionings are supported at the moment.")
590
- elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
591
- self.controlnet.nets
592
- ):
593
- raise ValueError(
594
- "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
595
- " the same length as the number of controlnets"
596
- )
597
- else:
598
- assert False
599
-
600
- if not isinstance(control_guidance_start, (tuple, list)):
601
- control_guidance_start = [control_guidance_start]
602
-
603
- if not isinstance(control_guidance_end, (tuple, list)):
604
- control_guidance_end = [control_guidance_end]
605
-
606
- if len(control_guidance_start) != len(control_guidance_end):
607
- raise ValueError(
608
- f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
609
- )
610
-
611
- if isinstance(self.controlnet, MultiControlNetModel):
612
- if len(control_guidance_start) != len(self.controlnet.nets):
613
- raise ValueError(
614
- f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
615
- )
616
-
617
- for start, end in zip(control_guidance_start, control_guidance_end):
618
- if start >= end:
619
- raise ValueError(
620
- f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
621
- )
622
- if start < 0.0:
623
- raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
624
- if end > 1.0:
625
- raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
626
-
627
- if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
628
- raise ValueError(
629
- "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
630
- )
631
-
632
- if ip_adapter_image_embeds is not None:
633
- if not isinstance(ip_adapter_image_embeds, list):
634
- raise ValueError(
635
- f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
636
- )
637
- elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
638
- raise ValueError(
639
- f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
640
- )
641
-
642
- # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
643
- def check_image(self, image, prompt, prompt_embeds):
644
- image_is_pil = isinstance(image, PIL.Image.Image)
645
- image_is_tensor = isinstance(image, torch.Tensor)
646
- image_is_np = isinstance(image, np.ndarray)
647
- image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
648
- image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
649
- image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
650
-
651
- if (
652
- not image_is_pil
653
- and not image_is_tensor
654
- and not image_is_np
655
- and not image_is_pil_list
656
- and not image_is_tensor_list
657
- and not image_is_np_list
658
- ):
659
- raise TypeError(
660
- f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
661
- )
662
-
663
- if image_is_pil:
664
- image_batch_size = 1
665
- else:
666
- image_batch_size = len(image)
667
-
668
- if prompt is not None and isinstance(prompt, str):
669
- prompt_batch_size = 1
670
- elif prompt is not None and isinstance(prompt, list):
671
- prompt_batch_size = len(prompt)
672
- elif prompt_embeds is not None:
673
- prompt_batch_size = prompt_embeds.shape[0]
674
-
675
- if image_batch_size != 1 and image_batch_size != prompt_batch_size:
676
- raise ValueError(
677
- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
678
- )
679
-
680
- # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
681
- def prepare_control_image(
682
- self,
683
- image,
684
- width,
685
- height,
686
- batch_size,
687
- num_images_per_prompt,
688
- device,
689
- dtype,
690
- do_classifier_free_guidance=False,
691
- guess_mode=False,
692
- ):
693
- image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
694
- image_batch_size = image.shape[0]
695
-
696
- if image_batch_size == 1:
697
- repeat_by = batch_size
698
- else:
699
- # image batch size is the same as prompt batch size
700
- repeat_by = num_images_per_prompt
701
-
702
- image = image.repeat_interleave(repeat_by, dim=0)
703
-
704
- image = image.to(device=device, dtype=dtype)
705
-
706
- if do_classifier_free_guidance and not guess_mode:
707
- image = torch.cat([image] * 2)
708
-
709
- return image
710
-
711
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
712
- def get_timesteps(self, num_inference_steps, strength, device):
713
- # get the original timestep using init_timestep
714
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
715
-
716
- t_start = max(num_inference_steps - init_timestep, 0)
717
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
718
- if hasattr(self.scheduler, "set_begin_index"):
719
- self.scheduler.set_begin_index(t_start * self.scheduler.order)
720
-
721
- return timesteps, num_inference_steps - t_start
722
-
723
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
724
- def prepare_latents(
725
- self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
726
- ):
727
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
728
- raise ValueError(
729
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
730
- )
731
-
732
- # Offload text encoder if `enable_model_cpu_offload` was enabled
733
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
734
- torch.cuda.empty_cache()
735
-
736
- image = image.to(device=device, dtype=dtype)
737
-
738
- batch_size = batch_size * num_images_per_prompt
739
-
740
- if image.shape[1] == 4:
741
- init_latents = image
742
-
743
- else:
744
- # make sure the VAE is in float32 mode, as it overflows in float16
745
- if self.vae.config.force_upcast:
746
- image = image.float()
747
- self.vae.to(dtype=torch.float32)
748
-
749
- if isinstance(generator, list) and len(generator) != batch_size:
750
- raise ValueError(
751
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
752
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
753
- )
754
-
755
- elif isinstance(generator, list):
756
- init_latents = [
757
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
758
- for i in range(batch_size)
759
- ]
760
- init_latents = torch.cat(init_latents, dim=0)
761
- else:
762
- init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
763
-
764
- if self.vae.config.force_upcast:
765
- self.vae.to(dtype)
766
-
767
- init_latents = init_latents.to(dtype)
768
-
769
- init_latents = self.vae.config.scaling_factor * init_latents
770
-
771
- if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
772
- # expand init_latents for batch_size
773
- additional_image_per_prompt = batch_size // init_latents.shape[0]
774
- init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
775
- elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
776
- raise ValueError(
777
- f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
778
- )
779
- else:
780
- init_latents = torch.cat([init_latents], dim=0)
781
-
782
- if add_noise:
783
- shape = init_latents.shape
784
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
785
- # get latents
786
- init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
787
-
788
- latents = init_latents
789
-
790
- return latents
791
-
792
-
793
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
794
- def prepare_latents_t2i(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
795
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
796
- if isinstance(generator, list) and len(generator) != batch_size:
797
- raise ValueError(
798
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
799
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
800
- )
801
-
802
- if latents is None:
803
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
804
- else:
805
- latents = latents.to(device)
806
-
807
- # scale the initial noise by the standard deviation required by the scheduler
808
- latents = latents * self.scheduler.init_noise_sigma
809
- return latents
810
-
811
-
812
-
813
- def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
814
- add_time_ids = list(original_size + crops_coords_top_left + target_size)
815
-
816
- passed_add_embed_dim = (
817
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + 4096
818
- )
819
- expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
820
-
821
- if expected_add_embed_dim != passed_add_embed_dim:
822
- raise ValueError(
823
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
824
- )
825
-
826
- add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
827
- return add_time_ids
828
-
829
-
830
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
831
- def upcast_vae(self):
832
- dtype = self.vae.dtype
833
- self.vae.to(dtype=torch.float32)
834
- use_torch_2_0_or_xformers = isinstance(
835
- self.vae.decoder.mid_block.attentions[0].processor,
836
- (
837
- AttnProcessor2_0,
838
- XFormersAttnProcessor,
839
- ),
840
- )
841
- # if xformers or torch_2_0 is used attention block does not need
842
- # to be in float32 which can save lots of memory
843
- if use_torch_2_0_or_xformers:
844
- self.vae.post_quant_conv.to(dtype)
845
- self.vae.decoder.conv_in.to(dtype)
846
- self.vae.decoder.mid_block.to(dtype)
847
-
848
- @property
849
- def guidance_scale(self):
850
- return self._guidance_scale
851
-
852
- @property
853
- def clip_skip(self):
854
- return self._clip_skip
855
-
856
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
857
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
858
- # corresponds to doing no classifier free guidance.
859
- @property
860
- def do_classifier_free_guidance(self):
861
- return self._guidance_scale > 1
862
-
863
- @property
864
- def cross_attention_kwargs(self):
865
- return self._cross_attention_kwargs
866
-
867
- @property
868
- def num_timesteps(self):
869
- return self._num_timesteps
870
-
871
- @torch.no_grad()
872
- def __call__(
873
- self,
874
- prompt: Union[str, List[str]] = None,
875
- image: PipelineImageInput = None,
876
- control_image: PipelineImageInput = None,
877
- height: Optional[int] = None,
878
- width: Optional[int] = None,
879
- strength: float = 0.8,
880
- num_inference_steps: int = 50,
881
- guidance_scale: float = 5.0,
882
- negative_prompt: Optional[Union[str, List[str]]] = None,
883
- num_images_per_prompt: Optional[int] = 1,
884
- eta: float = 0.0,
885
- guess_mode: bool = False,
886
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
887
- latents: Optional[torch.Tensor] = None,
888
- prompt_embeds: Optional[torch.Tensor] = None,
889
- negative_prompt_embeds: Optional[torch.Tensor] = None,
890
- pooled_prompt_embeds: Optional[torch.Tensor] = None,
891
- negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
892
- ip_adapter_image: Optional[PipelineImageInput] = None,
893
- ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
894
- output_type: Optional[str] = "pil",
895
- return_dict: bool = True,
896
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
897
- controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
898
- control_guidance_start: Union[float, List[float]] = 0.0,
899
- control_guidance_end: Union[float, List[float]] = 1.0,
900
- original_size: Tuple[int, int] = None,
901
- crops_coords_top_left: Tuple[int, int] = (0, 0),
902
- target_size: Tuple[int, int] = None,
903
- clip_skip: Optional[int] = None,
904
- callback_on_step_end: Optional[
905
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
906
- ] = None,
907
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
908
- **kwargs,
909
- ):
910
- r"""
911
- Function invoked when calling the pipeline for generation.
912
-
913
- Args:
914
- prompt (`str` or `List[str]`, *optional*):
915
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
916
- instead.
917
- image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
918
- `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
919
- The initial image will be used as the starting point for the image generation process. Can also accept
920
- image latents as `image`, if passing latents directly, it will not be encoded again.
921
- control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
922
- `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
923
- The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
924
- the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
925
- be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
926
- and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in
927
- init, images must be passed as a list such that each element of the list can be correctly batched for
928
- input to a single controlnet.
929
- height (`int`, *optional*, defaults to the size of control_image):
930
- The height in pixels of the generated image. Anything below 512 pixels won't work well for
931
- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
932
- and checkpoints that are not specifically fine-tuned on low resolutions.
933
- width (`int`, *optional*, defaults to the size of control_image):
934
- The width in pixels of the generated image. Anything below 512 pixels won't work well for
935
- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
936
- and checkpoints that are not specifically fine-tuned on low resolutions.
937
- strength (`float`, *optional*, defaults to 0.8):
938
- Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
939
- starting point and more noise is added the higher the `strength`. The number of denoising steps depends
940
- on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
941
- process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
942
- essentially ignores `image`.
943
- num_inference_steps (`int`, *optional*, defaults to 50):
944
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
945
- expense of slower inference.
946
- guidance_scale (`float`, *optional*, defaults to 7.5):
947
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
948
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
949
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
950
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
951
- usually at the expense of lower image quality.
952
- negative_prompt (`str` or `List[str]`, *optional*):
953
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
954
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
955
- less than `1`).
956
- num_images_per_prompt (`int`, *optional*, defaults to 1):
957
- The number of images to generate per prompt.
958
- eta (`float`, *optional*, defaults to 0.0):
959
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
960
- [`schedulers.DDIMScheduler`], will be ignored for others.
961
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
962
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
963
- to make generation deterministic.
964
- latents (`torch.Tensor`, *optional*):
965
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
966
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
967
- tensor will ge generated by sampling using the supplied random `generator`.
968
- prompt_embeds (`torch.Tensor`, *optional*):
969
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
970
- provided, text embeddings will be generated from `prompt` input argument.
971
- negative_prompt_embeds (`torch.Tensor`, *optional*):
972
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
973
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
974
- argument.
975
- pooled_prompt_embeds (`torch.Tensor`, *optional*):
976
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
977
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
978
- negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
979
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
980
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
981
- input argument.
982
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
983
- ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
984
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
985
- IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
986
- contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
987
- provided, embeddings are computed from the `ip_adapter_image` input argument.
988
- output_type (`str`, *optional*, defaults to `"pil"`):
989
- The output format of the generate image. Choose between
990
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
991
- return_dict (`bool`, *optional*, defaults to `True`):
992
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
993
- plain tuple.
994
- cross_attention_kwargs (`dict`, *optional*):
995
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
996
- `self.processor` in
997
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
998
- controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
999
- The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
1000
- to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
1001
- corresponding scale as a list.
1002
- control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1003
- The percentage of total steps at which the controlnet starts applying.
1004
- control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1005
- The percentage of total steps at which the controlnet stops applying.
1006
- original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1007
- If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1008
- `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1009
- explained in section 2.2 of
1010
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1011
- crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1012
- `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1013
- `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1014
- `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1015
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1016
- target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1017
- For most cases, `target_size` should be set to the desired height and width of the generated image. If
1018
- not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1019
- section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1020
- clip_skip (`int`, *optional*):
1021
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1022
- the output of the pre-final layer will be used for computing the prompt embeddings.
1023
- callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1024
- A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1025
- each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1026
- DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1027
- list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1028
- callback_on_step_end_tensor_inputs (`List`, *optional*):
1029
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1030
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1031
- `._callback_tensor_inputs` attribute of your pipeline class.
1032
-
1033
- Examples:
1034
-
1035
- Returns:
1036
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1037
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
1038
- containing the output images.
1039
- """
1040
-
1041
- callback = kwargs.pop("callback", None)
1042
- callback_steps = kwargs.pop("callback_steps", None)
1043
-
1044
- if callback is not None:
1045
- deprecate(
1046
- "callback",
1047
- "1.0.0",
1048
- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1049
- )
1050
- if callback_steps is not None:
1051
- deprecate(
1052
- "callback_steps",
1053
- "1.0.0",
1054
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1055
- )
1056
-
1057
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1058
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1059
-
1060
- controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1061
-
1062
- # align format for control guidance
1063
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1064
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1065
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1066
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1067
- elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1068
- mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1069
- control_guidance_start, control_guidance_end = (
1070
- mult * [control_guidance_start],
1071
- mult * [control_guidance_end],
1072
- )
1073
-
1074
- # from IPython import embed; embed()
1075
- # 1. Check inputs. Raise error if not correct
1076
- self.check_inputs(
1077
- prompt,
1078
- control_image,
1079
- strength,
1080
- num_inference_steps,
1081
- callback_steps,
1082
- negative_prompt,
1083
- prompt_embeds,
1084
- negative_prompt_embeds,
1085
- pooled_prompt_embeds,
1086
- negative_pooled_prompt_embeds,
1087
- ip_adapter_image,
1088
- ip_adapter_image_embeds,
1089
- controlnet_conditioning_scale,
1090
- control_guidance_start,
1091
- control_guidance_end,
1092
- callback_on_step_end_tensor_inputs,
1093
- )
1094
-
1095
- self._guidance_scale = guidance_scale
1096
- self._clip_skip = clip_skip
1097
- self._cross_attention_kwargs = cross_attention_kwargs
1098
-
1099
- # 2. Define call parameters
1100
- if prompt is not None and isinstance(prompt, str):
1101
- batch_size = 1
1102
- elif prompt is not None and isinstance(prompt, list):
1103
- batch_size = len(prompt)
1104
- else:
1105
- batch_size = prompt_embeds.shape[0]
1106
-
1107
- device = self._execution_device
1108
-
1109
- if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1110
- controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1111
-
1112
- # 3.1. Encode input prompt
1113
- text_encoder_lora_scale = (
1114
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1115
- )
1116
- (
1117
- prompt_embeds,
1118
- negative_prompt_embeds,
1119
- pooled_prompt_embeds,
1120
- negative_pooled_prompt_embeds,
1121
- ) = self.encode_prompt(
1122
- prompt,
1123
- device,
1124
- num_images_per_prompt,
1125
- self.do_classifier_free_guidance,
1126
- negative_prompt,
1127
- prompt_embeds=prompt_embeds,
1128
- negative_prompt_embeds=negative_prompt_embeds,
1129
- pooled_prompt_embeds=pooled_prompt_embeds,
1130
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1131
- lora_scale=text_encoder_lora_scale,
1132
- )
1133
-
1134
- # 3.2 Encode ip_adapter_image
1135
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1136
- image_embeds = self.prepare_ip_adapter_image_embeds(
1137
- ip_adapter_image,
1138
- ip_adapter_image_embeds,
1139
- device,
1140
- batch_size * num_images_per_prompt,
1141
- self.do_classifier_free_guidance,
1142
- )
1143
-
1144
- # 4. Prepare image and controlnet_conditioning_image
1145
- image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
1146
-
1147
- if isinstance(controlnet, ControlNetModel):
1148
- control_image = self.prepare_control_image(
1149
- image=control_image,
1150
- width=width,
1151
- height=height,
1152
- batch_size=batch_size * num_images_per_prompt,
1153
- num_images_per_prompt=num_images_per_prompt,
1154
- device=device,
1155
- dtype=controlnet.dtype,
1156
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1157
- guess_mode=guess_mode,
1158
- )
1159
- height, width = control_image.shape[-2:]
1160
- elif isinstance(controlnet, MultiControlNetModel):
1161
- control_images = []
1162
-
1163
- for control_image_ in control_image:
1164
- control_image_ = self.prepare_control_image(
1165
- image=control_image_,
1166
- width=width,
1167
- height=height,
1168
- batch_size=batch_size * num_images_per_prompt,
1169
- num_images_per_prompt=num_images_per_prompt,
1170
- device=device,
1171
- dtype=controlnet.dtype,
1172
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1173
- guess_mode=guess_mode,
1174
- )
1175
-
1176
- control_images.append(control_image_)
1177
-
1178
- control_image = control_images
1179
- height, width = control_image[0].shape[-2:]
1180
- else:
1181
- assert False
1182
-
1183
- # 5. Prepare timesteps
1184
- self.scheduler.set_timesteps(num_inference_steps, device=device)
1185
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1186
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1187
- self._num_timesteps = len(timesteps)
1188
-
1189
- # 6. Prepare latent variables
1190
-
1191
- num_channels_latents = self.unet.config.in_channels
1192
- if latents is None:
1193
- if strength >= 1.0:
1194
- latents = self.prepare_latents_t2i(
1195
- batch_size * num_images_per_prompt,
1196
- num_channels_latents,
1197
- height,
1198
- width,
1199
- prompt_embeds.dtype,
1200
- device,
1201
- generator,
1202
- latents,
1203
- )
1204
- else:
1205
- latents = self.prepare_latents(
1206
- image,
1207
- latent_timestep,
1208
- batch_size,
1209
- num_images_per_prompt,
1210
- prompt_embeds.dtype,
1211
- device,
1212
- generator,
1213
- True,
1214
- )
1215
-
1216
-
1217
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1218
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1219
-
1220
- # 7.1 Create tensor stating which controlnets to keep
1221
- controlnet_keep = []
1222
- for i in range(len(timesteps)):
1223
- keeps = [
1224
- 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1225
- for s, e in zip(control_guidance_start, control_guidance_end)
1226
- ]
1227
- controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1228
-
1229
- # 7.2 Prepare added time ids & embeddings
1230
- if isinstance(control_image, list):
1231
- original_size = original_size or control_image[0].shape[-2:]
1232
- else:
1233
- original_size = original_size or control_image.shape[-2:]
1234
- target_size = target_size or (height, width)
1235
-
1236
- # 7. Prepare added time ids & embeddings
1237
- add_text_embeds = pooled_prompt_embeds
1238
- add_time_ids = self._get_add_time_ids(
1239
- original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
1240
- )
1241
-
1242
- if self.do_classifier_free_guidance:
1243
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1244
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1245
- add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
1246
-
1247
- prompt_embeds = prompt_embeds.to(device)
1248
- add_text_embeds = add_text_embeds.to(device)
1249
- add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1250
-
1251
- # 8. Denoising loop
1252
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1253
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1254
- for i, t in enumerate(timesteps):
1255
- # expand the latents if we are doing classifier free guidance
1256
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1257
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1258
-
1259
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1260
-
1261
- # controlnet(s) inference
1262
- if guess_mode and self.do_classifier_free_guidance:
1263
- # Infer ControlNet only for the conditional batch.
1264
- control_model_input = latents
1265
- control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1266
- controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1267
- controlnet_added_cond_kwargs = {
1268
- "text_embeds": add_text_embeds.chunk(2)[1],
1269
- "time_ids": add_time_ids.chunk(2)[1],
1270
- }
1271
- else:
1272
- control_model_input = latent_model_input
1273
- controlnet_prompt_embeds = prompt_embeds
1274
- controlnet_added_cond_kwargs = added_cond_kwargs
1275
-
1276
- if isinstance(controlnet_keep[i], list):
1277
- cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1278
- else:
1279
- controlnet_cond_scale = controlnet_conditioning_scale
1280
- if isinstance(controlnet_cond_scale, list):
1281
- controlnet_cond_scale = controlnet_cond_scale[0]
1282
- cond_scale = controlnet_cond_scale * controlnet_keep[i]
1283
-
1284
- down_block_res_samples, mid_block_res_sample = self.controlnet(
1285
- control_model_input,
1286
- t,
1287
- encoder_hidden_states=controlnet_prompt_embeds,
1288
- controlnet_cond=control_image,
1289
- conditioning_scale=cond_scale,
1290
- guess_mode=guess_mode,
1291
- added_cond_kwargs=controlnet_added_cond_kwargs,
1292
- return_dict=False,
1293
- )
1294
-
1295
- if guess_mode and self.do_classifier_free_guidance:
1296
- # Infered ControlNet only for the conditional batch.
1297
- # To apply the output of ControlNet to both the unconditional and conditional batches,
1298
- # add 0 to the unconditional batch to keep it unchanged.
1299
- down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1300
- mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1301
-
1302
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1303
- added_cond_kwargs["image_embeds"] = image_embeds
1304
-
1305
- # predict the noise residual
1306
- noise_pred = self.unet(
1307
- latent_model_input,
1308
- t,
1309
- encoder_hidden_states=prompt_embeds,
1310
- cross_attention_kwargs=self.cross_attention_kwargs,
1311
- down_block_additional_residuals=down_block_res_samples,
1312
- mid_block_additional_residual=mid_block_res_sample,
1313
- added_cond_kwargs=added_cond_kwargs,
1314
- return_dict=False,
1315
- )[0]
1316
-
1317
- # perform guidance
1318
- if self.do_classifier_free_guidance:
1319
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1320
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1321
-
1322
- # compute the previous noisy sample x_t -> x_t-1
1323
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1324
-
1325
- # call the callback, if provided
1326
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1327
- progress_bar.update()
1328
- if callback is not None and i % callback_steps == 0:
1329
- step_idx = i // getattr(self.scheduler, "order", 1)
1330
- callback(step_idx, t, latents)
1331
-
1332
- # If we do sequential model offloading, let's offload unet and controlnet
1333
- # manually for max memory savings
1334
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1335
- self.unet.to("cpu")
1336
- self.controlnet.to("cpu")
1337
- torch.cuda.empty_cache()
1338
-
1339
- if not output_type == "latent":
1340
- # make sure the VAE is in float32 mode, as it overflows in float16
1341
- needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1342
-
1343
- if needs_upcasting:
1344
- self.upcast_vae()
1345
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1346
-
1347
- latents = latents / self.vae.config.scaling_factor
1348
- image = self.vae.decode(latents, return_dict=False)[0]
1349
-
1350
- # cast back to fp16 if needed
1351
- if needs_upcasting:
1352
- self.vae.to(dtype=torch.float16)
1353
- else:
1354
- image = latents
1355
- return StableDiffusionXLPipelineOutput(images=image)
1356
-
1357
- image = self.image_processor.postprocess(image, output_type=output_type)
1358
-
1359
- # Offload all models
1360
- self.maybe_free_model_hooks()
1361
-
1362
- if not return_dict:
1363
- return (image,)
1364
-
1365
- return StableDiffusionXLPipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kolors/pipelines/pipeline_stable_diffusion_xl_chatglm_256.py DELETED
@@ -1,841 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import sys
15
- import os
16
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
17
- from kolors.models.modeling_chatglm import ChatGLMModel
18
- from kolors.models.tokenization_chatglm import ChatGLMTokenizer
19
- import inspect
20
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
- import torch
22
- from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
23
- from transformers import XLMRobertaModel, ChineseCLIPTextModel
24
-
25
- from diffusers.image_processor import VaeImageProcessor
26
- from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
28
- from diffusers.models.attention_processor import (
29
- AttnProcessor2_0,
30
- LoRAAttnProcessor2_0,
31
- LoRAXFormersAttnProcessor,
32
- XFormersAttnProcessor,
33
- )
34
- from diffusers.schedulers import KarrasDiffusionSchedulers
35
- from diffusers.utils import (
36
- is_accelerate_available,
37
- is_accelerate_version,
38
- logging,
39
- replace_example_docstring,
40
- )
41
- try:
42
- from diffusers.utils import randn_tensor
43
- except:
44
- from diffusers.utils.torch_utils import randn_tensor
45
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
46
- from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
47
-
48
-
49
-
50
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
-
52
- EXAMPLE_DOC_STRING = """
53
- Examples:
54
- ```py
55
- >>> import torch
56
- >>> from diffusers import StableDiffusionXLPipeline
57
-
58
- >>> pipe = StableDiffusionXLPipeline.from_pretrained(
59
- ... "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16
60
- ... )
61
- >>> pipe = pipe.to("cuda")
62
-
63
- >>> prompt = "a photo of an astronaut riding a horse on mars"
64
- >>> image = pipe(prompt).images[0]
65
- ```
66
- """
67
-
68
-
69
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
70
- def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
71
- """
72
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
73
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
74
- """
75
- std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
76
- std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
77
- # rescale the results from guidance (fixes overexposure)
78
- noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
79
- # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
80
- noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
81
- return noise_cfg
82
-
83
-
84
- class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
85
- r"""
86
- Pipeline for text-to-image generation using Stable Diffusion XL.
87
-
88
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
89
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
90
-
91
- In addition the pipeline inherits the following loading methods:
92
- - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
93
- - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
94
- - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
95
-
96
- as well as the following saving methods:
97
- - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
98
-
99
- Args:
100
- vae ([`AutoencoderKL`]):
101
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
102
- text_encoder ([`CLIPTextModel`]):
103
- Frozen text-encoder. Stable Diffusion XL uses the text portion of
104
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
105
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
106
-
107
- tokenizer (`CLIPTokenizer`):
108
- Tokenizer of class
109
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
110
-
111
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
112
- scheduler ([`SchedulerMixin`]):
113
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
114
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
115
- """
116
-
117
- def __init__(
118
- self,
119
- vae: AutoencoderKL,
120
- text_encoder: ChatGLMModel,
121
- tokenizer: ChatGLMTokenizer,
122
- unet: UNet2DConditionModel,
123
- scheduler: KarrasDiffusionSchedulers,
124
- force_zeros_for_empty_prompt: bool = True,
125
- ):
126
- super().__init__()
127
-
128
- self.register_modules(
129
- vae=vae,
130
- text_encoder=text_encoder,
131
- tokenizer=tokenizer,
132
- unet=unet,
133
- scheduler=scheduler,
134
- )
135
- self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
136
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
137
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
138
- self.default_sample_size = self.unet.config.sample_size
139
-
140
- # self.watermark = StableDiffusionXLWatermarker()
141
-
142
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
143
- def enable_vae_slicing(self):
144
- r"""
145
- Enable sliced VAE decoding.
146
-
147
- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
148
- steps. This is useful to save some memory and allow larger batch sizes.
149
- """
150
- self.vae.enable_slicing()
151
-
152
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
153
- def disable_vae_slicing(self):
154
- r"""
155
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
156
- computing decoding in one step.
157
- """
158
- self.vae.disable_slicing()
159
-
160
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
161
- def enable_vae_tiling(self):
162
- r"""
163
- Enable tiled VAE decoding.
164
-
165
- When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
166
- several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
167
- """
168
- self.vae.enable_tiling()
169
-
170
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
171
- def disable_vae_tiling(self):
172
- r"""
173
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
174
- computing decoding in one step.
175
- """
176
- self.vae.disable_tiling()
177
-
178
- def enable_sequential_cpu_offload(self, gpu_id=0):
179
- r"""
180
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
181
- text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
182
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
183
- Note that offloading happens on a submodule basis. Memory savings are higher than with
184
- `enable_model_cpu_offload`, but performance is lower.
185
- """
186
- if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
187
- from accelerate import cpu_offload
188
- else:
189
- raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
190
-
191
- device = torch.device(f"cuda:{gpu_id}")
192
-
193
- if self.device.type != "cpu":
194
- self.to("cpu", silence_dtype_warnings=True)
195
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
196
-
197
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
198
- cpu_offload(cpu_offloaded_model, device)
199
-
200
- def enable_model_cpu_offload(self, gpu_id=0):
201
- r"""
202
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
203
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
204
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
205
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
206
- """
207
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
208
- from accelerate import cpu_offload_with_hook
209
- else:
210
- raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
211
-
212
- device = torch.device(f"cuda:{gpu_id}")
213
-
214
- if self.device.type != "cpu":
215
- self.to("cpu", silence_dtype_warnings=True)
216
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
217
-
218
- model_sequence = (
219
- [self.text_encoder]
220
- )
221
- model_sequence.extend([self.unet, self.vae])
222
-
223
- hook = None
224
- for cpu_offloaded_model in model_sequence:
225
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
226
-
227
- # We'll offload the last model manually.
228
- self.final_offload_hook = hook
229
-
230
- @property
231
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
232
- def _execution_device(self):
233
- r"""
234
- Returns the device on which the pipeline's models will be executed. After calling
235
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
236
- hooks.
237
- """
238
- if not hasattr(self.unet, "_hf_hook"):
239
- return self.device
240
- for module in self.unet.modules():
241
- if (
242
- hasattr(module, "_hf_hook")
243
- and hasattr(module._hf_hook, "execution_device")
244
- and module._hf_hook.execution_device is not None
245
- ):
246
- return torch.device(module._hf_hook.execution_device)
247
- return self.device
248
-
249
- def encode_prompt(
250
- self,
251
- prompt,
252
- device: Optional[torch.device] = None,
253
- num_images_per_prompt: int = 1,
254
- do_classifier_free_guidance: bool = True,
255
- negative_prompt=None,
256
- prompt_embeds: Optional[torch.FloatTensor] = None,
257
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
258
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
259
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
260
- lora_scale: Optional[float] = None,
261
- ):
262
- r"""
263
- Encodes the prompt into text encoder hidden states.
264
-
265
- Args:
266
- prompt (`str` or `List[str]`, *optional*):
267
- prompt to be encoded
268
- device: (`torch.device`):
269
- torch device
270
- num_images_per_prompt (`int`):
271
- number of images that should be generated per prompt
272
- do_classifier_free_guidance (`bool`):
273
- whether to use classifier free guidance or not
274
- negative_prompt (`str` or `List[str]`, *optional*):
275
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
276
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
277
- less than `1`).
278
- prompt_embeds (`torch.FloatTensor`, *optional*):
279
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
280
- provided, text embeddings will be generated from `prompt` input argument.
281
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
282
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
283
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
284
- argument.
285
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
286
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
287
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
288
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
289
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
290
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
291
- input argument.
292
- lora_scale (`float`, *optional*):
293
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
294
- """
295
- # from IPython import embed; embed(); exit()
296
- device = device or self._execution_device
297
-
298
- # set lora scale so that monkey patched LoRA
299
- # function of text encoder can correctly access it
300
- if lora_scale is not None and isinstance(self, LoraLoaderMixin):
301
- self._lora_scale = lora_scale
302
-
303
- if prompt is not None and isinstance(prompt, str):
304
- batch_size = 1
305
- elif prompt is not None and isinstance(prompt, list):
306
- batch_size = len(prompt)
307
- else:
308
- batch_size = prompt_embeds.shape[0]
309
-
310
- # Define tokenizers and text encoders
311
- tokenizers = [self.tokenizer]
312
- text_encoders = [self.text_encoder]
313
-
314
- if prompt_embeds is None:
315
- # textual inversion: procecss multi-vector tokens if necessary
316
- prompt_embeds_list = []
317
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
318
- if isinstance(self, TextualInversionLoaderMixin):
319
- prompt = self.maybe_convert_prompt(prompt, tokenizer)
320
-
321
- text_inputs = tokenizer(
322
- prompt,
323
- padding="max_length",
324
- max_length=256,
325
- truncation=True,
326
- return_tensors="pt",
327
- ).to('cuda')
328
- output = text_encoder(
329
- input_ids=text_inputs['input_ids'] ,
330
- attention_mask=text_inputs['attention_mask'],
331
- position_ids=text_inputs['position_ids'],
332
- output_hidden_states=True)
333
- prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
334
- pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
335
- bs_embed, seq_len, _ = prompt_embeds.shape
336
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
337
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
338
-
339
- prompt_embeds_list.append(prompt_embeds)
340
-
341
- # prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
342
- prompt_embeds = prompt_embeds_list[0]
343
-
344
- # get unconditional embeddings for classifier free guidance
345
- zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
346
- if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
347
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
348
- negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
349
- elif do_classifier_free_guidance and negative_prompt_embeds is None:
350
- # negative_prompt = negative_prompt or ""
351
- uncond_tokens: List[str]
352
- if negative_prompt is None:
353
- uncond_tokens = [""] * batch_size
354
- elif prompt is not None and type(prompt) is not type(negative_prompt):
355
- raise TypeError(
356
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
357
- f" {type(prompt)}."
358
- )
359
- elif isinstance(negative_prompt, str):
360
- uncond_tokens = [negative_prompt]
361
- elif batch_size != len(negative_prompt):
362
- raise ValueError(
363
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
364
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
365
- " the batch size of `prompt`."
366
- )
367
- else:
368
- uncond_tokens = negative_prompt
369
-
370
- negative_prompt_embeds_list = []
371
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
372
- # textual inversion: procecss multi-vector tokens if necessary
373
- if isinstance(self, TextualInversionLoaderMixin):
374
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
375
-
376
- max_length = prompt_embeds.shape[1]
377
- uncond_input = tokenizer(
378
- uncond_tokens,
379
- padding="max_length",
380
- max_length=max_length,
381
- truncation=True,
382
- return_tensors="pt",
383
- ).to('cuda')
384
- output = text_encoder(
385
- input_ids=uncond_input['input_ids'] ,
386
- attention_mask=uncond_input['attention_mask'],
387
- position_ids=uncond_input['position_ids'],
388
- output_hidden_states=True)
389
- negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
390
- negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
391
-
392
- if do_classifier_free_guidance:
393
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
394
- seq_len = negative_prompt_embeds.shape[1]
395
-
396
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
397
-
398
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
399
- negative_prompt_embeds = negative_prompt_embeds.view(
400
- batch_size * num_images_per_prompt, seq_len, -1
401
- )
402
-
403
- # For classifier free guidance, we need to do two forward passes.
404
- # Here we concatenate the unconditional and text embeddings into a single batch
405
- # to avoid doing two forward passes
406
-
407
- negative_prompt_embeds_list.append(negative_prompt_embeds)
408
-
409
- # negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
410
- negative_prompt_embeds = negative_prompt_embeds_list[0]
411
-
412
- bs_embed = pooled_prompt_embeds.shape[0]
413
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
414
- bs_embed * num_images_per_prompt, -1
415
- )
416
- if do_classifier_free_guidance:
417
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
418
- bs_embed * num_images_per_prompt, -1
419
- )
420
-
421
- return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
422
-
423
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
424
- def prepare_extra_step_kwargs(self, generator, eta):
425
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
426
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
427
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
428
- # and should be between [0, 1]
429
-
430
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
431
- extra_step_kwargs = {}
432
- if accepts_eta:
433
- extra_step_kwargs["eta"] = eta
434
-
435
- # check if the scheduler accepts generator
436
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
437
- if accepts_generator:
438
- extra_step_kwargs["generator"] = generator
439
- return extra_step_kwargs
440
-
441
- def check_inputs(
442
- self,
443
- prompt,
444
- height,
445
- width,
446
- callback_steps,
447
- negative_prompt=None,
448
- prompt_embeds=None,
449
- negative_prompt_embeds=None,
450
- pooled_prompt_embeds=None,
451
- negative_pooled_prompt_embeds=None,
452
- ):
453
- if height % 8 != 0 or width % 8 != 0:
454
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
455
-
456
- if (callback_steps is None) or (
457
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
458
- ):
459
- raise ValueError(
460
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
461
- f" {type(callback_steps)}."
462
- )
463
-
464
- if prompt is not None and prompt_embeds is not None:
465
- raise ValueError(
466
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
467
- " only forward one of the two."
468
- )
469
- elif prompt is None and prompt_embeds is None:
470
- raise ValueError(
471
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
472
- )
473
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
474
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
475
-
476
- if negative_prompt is not None and negative_prompt_embeds is not None:
477
- raise ValueError(
478
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
479
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
480
- )
481
-
482
- if prompt_embeds is not None and negative_prompt_embeds is not None:
483
- if prompt_embeds.shape != negative_prompt_embeds.shape:
484
- raise ValueError(
485
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
486
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
487
- f" {negative_prompt_embeds.shape}."
488
- )
489
-
490
- if prompt_embeds is not None and pooled_prompt_embeds is None:
491
- raise ValueError(
492
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
493
- )
494
-
495
- if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
496
- raise ValueError(
497
- "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
498
- )
499
-
500
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
501
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
502
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
503
- if isinstance(generator, list) and len(generator) != batch_size:
504
- raise ValueError(
505
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
506
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
507
- )
508
-
509
- if latents is None:
510
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
511
- else:
512
- latents = latents.to(device)
513
-
514
- # scale the initial noise by the standard deviation required by the scheduler
515
- latents = latents * self.scheduler.init_noise_sigma
516
- return latents
517
-
518
- def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
519
- add_time_ids = list(original_size + crops_coords_top_left + target_size)
520
-
521
- passed_add_embed_dim = (
522
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + 4096
523
- )
524
- expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
525
-
526
- if expected_add_embed_dim != passed_add_embed_dim:
527
- raise ValueError(
528
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
529
- )
530
-
531
- add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
532
- return add_time_ids
533
-
534
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
535
- def upcast_vae(self):
536
- dtype = self.vae.dtype
537
- self.vae.to(dtype=torch.float32)
538
- use_torch_2_0_or_xformers = isinstance(
539
- self.vae.decoder.mid_block.attentions[0].processor,
540
- (
541
- AttnProcessor2_0,
542
- XFormersAttnProcessor,
543
- LoRAXFormersAttnProcessor,
544
- LoRAAttnProcessor2_0,
545
- ),
546
- )
547
- # if xformers or torch_2_0 is used attention block does not need
548
- # to be in float32 which can save lots of memory
549
- if use_torch_2_0_or_xformers:
550
- self.vae.post_quant_conv.to(dtype)
551
- self.vae.decoder.conv_in.to(dtype)
552
- self.vae.decoder.mid_block.to(dtype)
553
-
554
- @torch.no_grad()
555
- @replace_example_docstring(EXAMPLE_DOC_STRING)
556
- def __call__(
557
- self,
558
- prompt: Union[str, List[str]] = None,
559
- height: Optional[int] = None,
560
- width: Optional[int] = None,
561
- num_inference_steps: int = 50,
562
- denoising_end: Optional[float] = None,
563
- guidance_scale: float = 5.0,
564
- negative_prompt: Optional[Union[str, List[str]]] = None,
565
- num_images_per_prompt: Optional[int] = 1,
566
- eta: float = 0.0,
567
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
568
- latents: Optional[torch.FloatTensor] = None,
569
- prompt_embeds: Optional[torch.FloatTensor] = None,
570
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
571
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
572
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
573
- output_type: Optional[str] = "pil",
574
- return_dict: bool = True,
575
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
576
- callback_steps: int = 1,
577
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
578
- guidance_rescale: float = 0.0,
579
- original_size: Optional[Tuple[int, int]] = None,
580
- crops_coords_top_left: Tuple[int, int] = (0, 0),
581
- target_size: Optional[Tuple[int, int]] = None,
582
- use_dynamic_threshold: Optional[bool] = False,
583
- ):
584
- r"""
585
- Function invoked when calling the pipeline for generation.
586
-
587
- Args:
588
- prompt (`str` or `List[str]`, *optional*):
589
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
590
- instead.
591
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
592
- The height in pixels of the generated image.
593
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
594
- The width in pixels of the generated image.
595
- num_inference_steps (`int`, *optional*, defaults to 50):
596
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
597
- expense of slower inference.
598
- denoising_end (`float`, *optional*):
599
- When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
600
- completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to
601
- 0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50)
602
- Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
603
- Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
604
- guidance_scale (`float`, *optional*, defaults to 7.5):
605
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
606
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
607
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
608
- negative_prompt (`str` or `List[str]`, *optional*):
609
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
610
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
611
- less than `1`).
612
- num_images_per_prompt (`int`, *optional*, defaults to 1):
613
- The number of images to generate per prompt.
614
- eta (`float`, *optional*, defaults to 0.0):
615
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
616
- [`schedulers.DDIMScheduler`], will be ignored for others.
617
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
618
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
619
- to make generation deterministic.
620
- latents (`torch.FloatTensor`, *optional*):
621
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
622
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
623
- tensor will ge generated by sampling using the supplied random `generator`.
624
- prompt_embeds (`torch.FloatTensor`, *optional*):
625
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
626
- provided, text embeddings will be generated from `prompt` input argument.
627
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
628
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
629
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
630
- argument.
631
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
632
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
633
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
634
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
635
- output_type (`str`, *optional*, defaults to `"pil"`):
636
- The output format of the generate image. Choose between
637
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
638
- return_dict (`bool`, *optional*, defaults to `True`):
639
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
640
- callback (`Callable`, *optional*):
641
- A function that will be called every `callback_steps` steps during inference. The function will be
642
- callback_steps (`int`, *optional*, defaults to 1):
643
- The frequency at which the `callback` function will be called. If not specified, the callback will be
644
- called at every step.
645
- cross_attention_kwargs (`dict`, *optional*):
646
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
647
- `self.processor` in
648
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
649
- guidance_rescale (`float`, *optional*, defaults to 0.7):
650
- Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
651
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
652
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
653
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
654
- original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
655
- TODO
656
- crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
657
- TODO
658
- target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
659
- TODO
660
-
661
- Examples:
662
-
663
- Returns:
664
- [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
665
- [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
666
- `tuple. When returning a tuple, the first element is a list with the generated images, and the second
667
- element is a list of `bool`s denoting whether the corresponding generated image likely represents
668
- "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
669
- """
670
- # 0. Default height and width to unet
671
- height = height or self.default_sample_size * self.vae_scale_factor
672
- width = width or self.default_sample_size * self.vae_scale_factor
673
-
674
- original_size = original_size or (height, width)
675
- target_size = target_size or (height, width)
676
-
677
- # 1. Check inputs. Raise error if not correct
678
- self.check_inputs(
679
- prompt,
680
- height,
681
- width,
682
- callback_steps,
683
- negative_prompt,
684
- prompt_embeds,
685
- negative_prompt_embeds,
686
- pooled_prompt_embeds,
687
- negative_pooled_prompt_embeds,
688
- )
689
-
690
- # 2. Define call parameters
691
- if prompt is not None and isinstance(prompt, str):
692
- batch_size = 1
693
- elif prompt is not None and isinstance(prompt, list):
694
- batch_size = len(prompt)
695
- else:
696
- batch_size = prompt_embeds.shape[0]
697
-
698
- device = self._execution_device
699
-
700
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
701
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
702
- # corresponds to doing no classifier free guidance.
703
- do_classifier_free_guidance = guidance_scale > 1.0
704
-
705
- # 3. Encode input prompt
706
- text_encoder_lora_scale = (
707
- cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
708
- )
709
- (
710
- prompt_embeds,
711
- negative_prompt_embeds,
712
- pooled_prompt_embeds,
713
- negative_pooled_prompt_embeds,
714
- ) = self.encode_prompt(
715
- prompt,
716
- device,
717
- num_images_per_prompt,
718
- do_classifier_free_guidance,
719
- negative_prompt,
720
- prompt_embeds=prompt_embeds,
721
- negative_prompt_embeds=negative_prompt_embeds,
722
- pooled_prompt_embeds=pooled_prompt_embeds,
723
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
724
- lora_scale=text_encoder_lora_scale,
725
- )
726
-
727
- # 4. Prepare timesteps
728
- self.scheduler.set_timesteps(num_inference_steps, device=device)
729
-
730
- timesteps = self.scheduler.timesteps
731
-
732
- # 5. Prepare latent variables
733
- num_channels_latents = self.unet.config.in_channels
734
- latents = self.prepare_latents(
735
- batch_size * num_images_per_prompt,
736
- num_channels_latents,
737
- height,
738
- width,
739
- prompt_embeds.dtype,
740
- device,
741
- generator,
742
- latents,
743
- )
744
-
745
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
746
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
747
-
748
- # 7. Prepare added time ids & embeddings
749
- add_text_embeds = pooled_prompt_embeds
750
- add_time_ids = self._get_add_time_ids(
751
- original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
752
- )
753
-
754
- if do_classifier_free_guidance:
755
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
756
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
757
- add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
758
-
759
- prompt_embeds = prompt_embeds.to(device)
760
- add_text_embeds = add_text_embeds.to(device)
761
- add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
762
-
763
- # 8. Denoising loop
764
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
765
-
766
- # 7.1 Apply denoising_end
767
- if denoising_end is not None:
768
- num_inference_steps = int(round(denoising_end * num_inference_steps))
769
- timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]
770
-
771
- with self.progress_bar(total=num_inference_steps) as progress_bar:
772
- for i, t in enumerate(timesteps):
773
- # expand the latents if we are doing classifier free guidance
774
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
775
-
776
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
777
-
778
- # predict the noise residual
779
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
780
- noise_pred = self.unet(
781
- latent_model_input,
782
- t,
783
- encoder_hidden_states=prompt_embeds,
784
- cross_attention_kwargs=cross_attention_kwargs,
785
- added_cond_kwargs=added_cond_kwargs,
786
- return_dict=False,
787
- )[0]
788
-
789
- # perform guidance
790
- if do_classifier_free_guidance:
791
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
792
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
793
- if use_dynamic_threshold:
794
- DynamicThresh = DynThresh(maxSteps=num_inference_steps, experiment_mode=0)
795
- noise_pred = DynamicThresh.dynthresh(noise_pred_text,
796
- noise_pred_uncond,
797
- guidance_scale,
798
- None)
799
-
800
- if do_classifier_free_guidance and guidance_rescale > 0.0:
801
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
802
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
803
-
804
- # compute the previous noisy sample x_t -> x_t-1
805
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
806
-
807
- # call the callback, if provided
808
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
809
- progress_bar.update()
810
- if callback is not None and i % callback_steps == 0:
811
- callback(i, t, latents)
812
-
813
- # make sureo the VAE is in float32 mode, as it overflows in float16
814
- # torch.cuda.empty_cache()
815
- if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
816
- self.upcast_vae()
817
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
818
-
819
-
820
- if not output_type == "latent":
821
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
822
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
823
- else:
824
- image = latents
825
- return StableDiffusionXLPipelineOutput(images=image)
826
-
827
- # image = self.watermark.apply_watermark(image)
828
- image = self.image_processor.postprocess(image, output_type=output_type)
829
-
830
- # Offload last model to CPU
831
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
832
- self.final_offload_hook.offload()
833
-
834
- if not return_dict:
835
- return (image,)
836
-
837
- return StableDiffusionXLPipelineOutput(images=image)
838
-
839
-
840
- if __name__ == "__main__":
841
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kolors/pipelines/pipeline_stable_diffusion_xl_chatglm_256_inpainting.py DELETED
@@ -1,1790 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import inspect
16
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
-
18
- import numpy as np
19
- import PIL.Image
20
- import torch
21
- from transformers import (
22
- CLIPImageProcessor,
23
- CLIPTextModel,
24
- CLIPTextModelWithProjection,
25
- CLIPTokenizer,
26
- CLIPVisionModelWithProjection,
27
- )
28
-
29
- from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
30
- from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
31
- from diffusers.loaders import (
32
- FromSingleFileMixin,
33
- IPAdapterMixin,
34
- StableDiffusionXLLoraLoaderMixin,
35
- TextualInversionLoaderMixin,
36
- )
37
- from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
38
- from diffusers.models.attention_processor import (
39
- AttnProcessor2_0,
40
- LoRAAttnProcessor2_0,
41
- LoRAXFormersAttnProcessor,
42
- XFormersAttnProcessor,
43
- )
44
- from diffusers.models.lora import adjust_lora_scale_text_encoder
45
- from diffusers.schedulers import KarrasDiffusionSchedulers
46
- from diffusers.utils import (
47
- USE_PEFT_BACKEND,
48
- deprecate,
49
- is_invisible_watermark_available,
50
- is_torch_xla_available,
51
- logging,
52
- replace_example_docstring,
53
- scale_lora_layers,
54
- unscale_lora_layers,
55
- )
56
- from diffusers.utils.torch_utils import randn_tensor
57
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
58
- from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
59
-
60
-
61
- if is_invisible_watermark_available():
62
- from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
63
-
64
- if is_torch_xla_available():
65
- import torch_xla.core.xla_model as xm
66
-
67
- XLA_AVAILABLE = True
68
- else:
69
- XLA_AVAILABLE = False
70
-
71
-
72
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
73
-
74
-
75
- EXAMPLE_DOC_STRING = """
76
- Examples:
77
- ```py
78
- >>> import torch
79
- >>> from diffusers import StableDiffusionXLInpaintPipeline
80
- >>> from diffusers.utils import load_image
81
-
82
- >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
83
- ... "stabilityai/stable-diffusion-xl-base-1.0",
84
- ... torch_dtype=torch.float16,
85
- ... variant="fp16",
86
- ... use_safetensors=True,
87
- ... )
88
- >>> pipe.to("cuda")
89
-
90
- >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
91
- >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
92
-
93
- >>> init_image = load_image(img_url).convert("RGB")
94
- >>> mask_image = load_image(mask_url).convert("RGB")
95
-
96
- >>> prompt = "A majestic tiger sitting on a bench"
97
- >>> image = pipe(
98
- ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
99
- ... ).images[0]
100
- ```
101
- """
102
-
103
-
104
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
105
- def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
106
- """
107
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
108
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
109
- """
110
- std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
111
- std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
112
- # rescale the results from guidance (fixes overexposure)
113
- noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
114
- # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
115
- noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
116
- return noise_cfg
117
-
118
-
119
- def mask_pil_to_torch(mask, height, width):
120
- # preprocess mask
121
- if isinstance(mask, (PIL.Image.Image, np.ndarray)):
122
- mask = [mask]
123
-
124
- if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
125
- mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
126
- mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
127
- mask = mask.astype(np.float32) / 255.0
128
- elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
129
- mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
130
-
131
- mask = torch.from_numpy(mask)
132
- return mask
133
-
134
-
135
- def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
136
- """
137
- Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
138
- converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
139
- ``image`` and ``1`` for the ``mask``.
140
-
141
- The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
142
- binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
143
-
144
- Args:
145
- image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
146
- It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
147
- ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
148
- mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
149
- It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
150
- ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
151
-
152
-
153
- Raises:
154
- ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
155
- should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
156
- TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
157
- (ot the other way around).
158
-
159
- Returns:
160
- tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
161
- dimensions: ``batch x channels x height x width``.
162
- """
163
-
164
- # checkpoint. TOD(Yiyi) - need to clean this up later
165
- deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
166
- deprecate(
167
- "prepare_mask_and_masked_image",
168
- "0.30.0",
169
- deprecation_message,
170
- )
171
- if image is None:
172
- raise ValueError("`image` input cannot be undefined.")
173
-
174
- if mask is None:
175
- raise ValueError("`mask_image` input cannot be undefined.")
176
-
177
- if isinstance(image, torch.Tensor):
178
- if not isinstance(mask, torch.Tensor):
179
- mask = mask_pil_to_torch(mask, height, width)
180
-
181
- if image.ndim == 3:
182
- image = image.unsqueeze(0)
183
-
184
- # Batch and add channel dim for single mask
185
- if mask.ndim == 2:
186
- mask = mask.unsqueeze(0).unsqueeze(0)
187
-
188
- # Batch single mask or add channel dim
189
- if mask.ndim == 3:
190
- # Single batched mask, no channel dim or single mask not batched but channel dim
191
- if mask.shape[0] == 1:
192
- mask = mask.unsqueeze(0)
193
-
194
- # Batched masks no channel dim
195
- else:
196
- mask = mask.unsqueeze(1)
197
-
198
- assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
199
- # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
200
- assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
201
-
202
- # Check image is in [-1, 1]
203
- # if image.min() < -1 or image.max() > 1:
204
- # raise ValueError("Image should be in [-1, 1] range")
205
-
206
- # Check mask is in [0, 1]
207
- if mask.min() < 0 or mask.max() > 1:
208
- raise ValueError("Mask should be in [0, 1] range")
209
-
210
- # Binarize mask
211
- mask[mask < 0.5] = 0
212
- mask[mask >= 0.5] = 1
213
-
214
- # Image as float32
215
- image = image.to(dtype=torch.float32)
216
- elif isinstance(mask, torch.Tensor):
217
- raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
218
- else:
219
- # preprocess image
220
- if isinstance(image, (PIL.Image.Image, np.ndarray)):
221
- image = [image]
222
- if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
223
- # resize all images w.r.t passed height an width
224
- image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
225
- image = [np.array(i.convert("RGB"))[None, :] for i in image]
226
- image = np.concatenate(image, axis=0)
227
- elif isinstance(image, list) and isinstance(image[0], np.ndarray):
228
- image = np.concatenate([i[None, :] for i in image], axis=0)
229
-
230
- image = image.transpose(0, 3, 1, 2)
231
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
232
-
233
- mask = mask_pil_to_torch(mask, height, width)
234
- mask[mask < 0.5] = 0
235
- mask[mask >= 0.5] = 1
236
-
237
- if image.shape[1] == 4:
238
- # images are in latent space and thus can't
239
- # be masked set masked_image to None
240
- # we assume that the checkpoint is not an inpainting
241
- # checkpoint. TOD(Yiyi) - need to clean this up later
242
- masked_image = None
243
- else:
244
- masked_image = image * (mask < 0.5)
245
-
246
- # n.b. ensure backwards compatibility as old function does not return image
247
- if return_image:
248
- return mask, masked_image, image
249
-
250
- return mask, masked_image
251
-
252
-
253
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
254
- def retrieve_latents(
255
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
256
- ):
257
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
258
- return encoder_output.latent_dist.sample(generator)
259
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
260
- return encoder_output.latent_dist.mode()
261
- elif hasattr(encoder_output, "latents"):
262
- return encoder_output.latents
263
- else:
264
- raise AttributeError("Could not access latents of provided encoder_output")
265
-
266
-
267
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
268
- def retrieve_timesteps(
269
- scheduler,
270
- num_inference_steps: Optional[int] = None,
271
- device: Optional[Union[str, torch.device]] = None,
272
- timesteps: Optional[List[int]] = None,
273
- sigmas: Optional[List[float]] = None,
274
- **kwargs,
275
- ):
276
- """
277
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
278
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
279
-
280
- Args:
281
- scheduler (`SchedulerMixin`):
282
- The scheduler to get timesteps from.
283
- num_inference_steps (`int`):
284
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
285
- must be `None`.
286
- device (`str` or `torch.device`, *optional*):
287
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
288
- timesteps (`List[int]`, *optional*):
289
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
290
- `num_inference_steps` and `sigmas` must be `None`.
291
- sigmas (`List[float]`, *optional*):
292
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
293
- `num_inference_steps` and `timesteps` must be `None`.
294
-
295
- Returns:
296
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
297
- second element is the number of inference steps.
298
- """
299
- if timesteps is not None and sigmas is not None:
300
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
301
- if timesteps is not None:
302
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
303
- if not accepts_timesteps:
304
- raise ValueError(
305
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
306
- f" timestep schedules. Please check whether you are using the correct scheduler."
307
- )
308
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
309
- timesteps = scheduler.timesteps
310
- num_inference_steps = len(timesteps)
311
- elif sigmas is not None:
312
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
313
- if not accept_sigmas:
314
- raise ValueError(
315
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
316
- f" sigmas schedules. Please check whether you are using the correct scheduler."
317
- )
318
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
319
- timesteps = scheduler.timesteps
320
- num_inference_steps = len(timesteps)
321
- else:
322
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
323
- timesteps = scheduler.timesteps
324
- return timesteps, num_inference_steps
325
-
326
-
327
- class StableDiffusionXLInpaintPipeline(
328
- DiffusionPipeline,
329
- StableDiffusionMixin,
330
- TextualInversionLoaderMixin,
331
- StableDiffusionXLLoraLoaderMixin,
332
- FromSingleFileMixin,
333
- IPAdapterMixin,
334
- ):
335
- r"""
336
- Pipeline for text-to-image generation using Stable Diffusion XL.
337
-
338
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
339
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
340
-
341
- The pipeline also inherits the following loading methods:
342
- - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
343
- - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
344
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
345
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
346
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
347
-
348
- Args:
349
- vae ([`AutoencoderKL`]):
350
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
351
- text_encoder ([`CLIPTextModel`]):
352
- Frozen text-encoder. Stable Diffusion XL uses the text portion of
353
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
354
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
355
- text_encoder_2 ([` CLIPTextModelWithProjection`]):
356
- Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
357
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
358
- specifically the
359
- [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
360
- variant.
361
- tokenizer (`CLIPTokenizer`):
362
- Tokenizer of class
363
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
364
- tokenizer_2 (`CLIPTokenizer`):
365
- Second Tokenizer of class
366
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
367
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
368
- scheduler ([`SchedulerMixin`]):
369
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
370
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
371
- requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
372
- Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
373
- of `stabilityai/stable-diffusion-xl-refiner-1-0`.
374
- force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
375
- Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
376
- `stabilityai/stable-diffusion-xl-base-1-0`.
377
- add_watermarker (`bool`, *optional*):
378
- Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
379
- watermark output images. If not defined, it will default to True if the package is installed, otherwise no
380
- watermarker will be used.
381
- """
382
-
383
- model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
384
-
385
- _optional_components = [
386
- "tokenizer",
387
- "tokenizer_2",
388
- "text_encoder",
389
- "text_encoder_2",
390
- "image_encoder",
391
- "feature_extractor",
392
- ]
393
- _callback_tensor_inputs = [
394
- "latents",
395
- "prompt_embeds",
396
- "negative_prompt_embeds",
397
- "add_text_embeds",
398
- "add_time_ids",
399
- "negative_pooled_prompt_embeds",
400
- "add_neg_time_ids",
401
- "mask",
402
- "masked_image_latents",
403
- ]
404
-
405
- def __init__(
406
- self,
407
- vae: AutoencoderKL,
408
- text_encoder: CLIPTextModel,
409
- tokenizer: CLIPTokenizer,
410
- unet: UNet2DConditionModel,
411
- scheduler: KarrasDiffusionSchedulers,
412
- tokenizer_2: CLIPTokenizer = None,
413
- text_encoder_2: CLIPTextModelWithProjection = None,
414
- image_encoder: CLIPVisionModelWithProjection = None,
415
- feature_extractor: CLIPImageProcessor = None,
416
- requires_aesthetics_score: bool = False,
417
- force_zeros_for_empty_prompt: bool = True,
418
- add_watermarker: Optional[bool] = None,
419
- ):
420
- super().__init__()
421
-
422
- self.register_modules(
423
- vae=vae,
424
- text_encoder=text_encoder,
425
- text_encoder_2=text_encoder_2,
426
- tokenizer=tokenizer,
427
- tokenizer_2=tokenizer_2,
428
- unet=unet,
429
- image_encoder=image_encoder,
430
- feature_extractor=feature_extractor,
431
- scheduler=scheduler,
432
- )
433
- self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
434
- self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
435
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
436
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
437
- self.mask_processor = VaeImageProcessor(
438
- vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
439
- )
440
-
441
- add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
442
-
443
- if add_watermarker:
444
- self.watermark = StableDiffusionXLWatermarker()
445
- else:
446
- self.watermark = None
447
-
448
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
449
- def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
450
- dtype = next(self.image_encoder.parameters()).dtype
451
-
452
- if not isinstance(image, torch.Tensor):
453
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
454
-
455
- image = image.to(device=device, dtype=dtype)
456
- if output_hidden_states:
457
- image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
458
- image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
459
- uncond_image_enc_hidden_states = self.image_encoder(
460
- torch.zeros_like(image), output_hidden_states=True
461
- ).hidden_states[-2]
462
- uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
463
- num_images_per_prompt, dim=0
464
- )
465
- return image_enc_hidden_states, uncond_image_enc_hidden_states
466
- else:
467
- image_embeds = self.image_encoder(image).image_embeds
468
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
469
- uncond_image_embeds = torch.zeros_like(image_embeds)
470
-
471
- return image_embeds, uncond_image_embeds
472
-
473
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
474
- def prepare_ip_adapter_image_embeds(
475
- self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
476
- ):
477
- if ip_adapter_image_embeds is None:
478
- if not isinstance(ip_adapter_image, list):
479
- ip_adapter_image = [ip_adapter_image]
480
-
481
- if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
482
- raise ValueError(
483
- f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
484
- )
485
-
486
- image_embeds = []
487
- for single_ip_adapter_image, image_proj_layer in zip(
488
- ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
489
- ):
490
- output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
491
- single_image_embeds, single_negative_image_embeds = self.encode_image(
492
- single_ip_adapter_image, device, 1, output_hidden_state
493
- )
494
- single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
495
- single_negative_image_embeds = torch.stack(
496
- [single_negative_image_embeds] * num_images_per_prompt, dim=0
497
- )
498
-
499
- if do_classifier_free_guidance:
500
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
501
- single_image_embeds = single_image_embeds.to(device)
502
-
503
- image_embeds.append(single_image_embeds)
504
- else:
505
- repeat_dims = [1]
506
- image_embeds = []
507
- for single_image_embeds in ip_adapter_image_embeds:
508
- if do_classifier_free_guidance:
509
- single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
510
- single_image_embeds = single_image_embeds.repeat(
511
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
512
- )
513
- single_negative_image_embeds = single_negative_image_embeds.repeat(
514
- num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
515
- )
516
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
517
- else:
518
- single_image_embeds = single_image_embeds.repeat(
519
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
520
- )
521
- image_embeds.append(single_image_embeds)
522
-
523
- return image_embeds
524
-
525
- def encode_prompt(
526
- self,
527
- prompt,
528
- device: Optional[torch.device] = None,
529
- num_images_per_prompt: int = 1,
530
- do_classifier_free_guidance: bool = True,
531
- negative_prompt=None,
532
- prompt_embeds: Optional[torch.FloatTensor] = None,
533
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
534
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
535
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
536
- lora_scale: Optional[float] = None,
537
- ):
538
- r"""
539
- Encodes the prompt into text encoder hidden states.
540
-
541
- Args:
542
- prompt (`str` or `List[str]`, *optional*):
543
- prompt to be encoded
544
- device: (`torch.device`):
545
- torch device
546
- num_images_per_prompt (`int`):
547
- number of images that should be generated per prompt
548
- do_classifier_free_guidance (`bool`):
549
- whether to use classifier free guidance or not
550
- negative_prompt (`str` or `List[str]`, *optional*):
551
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
552
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
553
- less than `1`).
554
- prompt_embeds (`torch.FloatTensor`, *optional*):
555
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
556
- provided, text embeddings will be generated from `prompt` input argument.
557
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
558
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
559
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
560
- argument.
561
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
562
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
563
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
564
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
565
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
566
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
567
- input argument.
568
- lora_scale (`float`, *optional*):
569
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
570
- """
571
- # from IPython import embed; embed(); exit()
572
- device = device or self._execution_device
573
-
574
- # set lora scale so that monkey patched LoRA
575
- # function of text encoder can correctly access it
576
- if lora_scale is not None and isinstance(self, LoraLoaderMixin):
577
- self._lora_scale = lora_scale
578
-
579
- if prompt is not None and isinstance(prompt, str):
580
- batch_size = 1
581
- elif prompt is not None and isinstance(prompt, list):
582
- batch_size = len(prompt)
583
- else:
584
- batch_size = prompt_embeds.shape[0]
585
-
586
- # Define tokenizers and text encoders
587
- tokenizers = [self.tokenizer]
588
- text_encoders = [self.text_encoder]
589
-
590
- if prompt_embeds is None:
591
- # textual inversion: procecss multi-vector tokens if necessary
592
- prompt_embeds_list = []
593
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
594
- if isinstance(self, TextualInversionLoaderMixin):
595
- prompt = self.maybe_convert_prompt(prompt, tokenizer)
596
-
597
- text_inputs = tokenizer(
598
- prompt,
599
- padding="max_length",
600
- max_length=256,
601
- truncation=True,
602
- return_tensors="pt",
603
- ).to('cuda')
604
- output = text_encoder(
605
- input_ids=text_inputs['input_ids'] ,
606
- attention_mask=text_inputs['attention_mask'],
607
- position_ids=text_inputs['position_ids'],
608
- output_hidden_states=True)
609
- prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
610
- text_proj = output.hidden_states[-1][-1, :, :].clone()
611
- bs_embed, seq_len, _ = prompt_embeds.shape
612
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
613
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
614
- prompt_embeds_list.append(prompt_embeds)
615
-
616
- prompt_embeds = prompt_embeds_list[0]
617
-
618
- # get unconditional embeddings for classifier free guidance
619
- zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
620
- if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
621
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
622
- negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
623
- elif do_classifier_free_guidance and negative_prompt_embeds is None:
624
- # negative_prompt = negative_prompt or ""
625
- uncond_tokens: List[str]
626
- if negative_prompt is None:
627
- uncond_tokens = [""] * batch_size
628
- elif prompt is not None and type(prompt) is not type(negative_prompt):
629
- raise TypeError(
630
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
631
- f" {type(prompt)}."
632
- )
633
- elif isinstance(negative_prompt, str):
634
- uncond_tokens = [negative_prompt]
635
- elif batch_size != len(negative_prompt):
636
- raise ValueError(
637
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
638
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
639
- " the batch size of `prompt`."
640
- )
641
- else:
642
- uncond_tokens = negative_prompt
643
-
644
- negative_prompt_embeds_list = []
645
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
646
- # textual inversion: procecss multi-vector tokens if necessary
647
- if isinstance(self, TextualInversionLoaderMixin):
648
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
649
-
650
- max_length = prompt_embeds.shape[1]
651
- uncond_input = tokenizer(
652
- uncond_tokens,
653
- padding="max_length",
654
- max_length=max_length,
655
- truncation=True,
656
- return_tensors="pt",
657
- ).to('cuda')
658
- output = text_encoder(
659
- input_ids=uncond_input['input_ids'] ,
660
- attention_mask=uncond_input['attention_mask'],
661
- position_ids=uncond_input['position_ids'],
662
- output_hidden_states=True)
663
- negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
664
- negative_text_proj = output.hidden_states[-1][-1, :, :].clone()
665
-
666
- if do_classifier_free_guidance:
667
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
668
- seq_len = negative_prompt_embeds.shape[1]
669
-
670
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
671
-
672
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
673
- negative_prompt_embeds = negative_prompt_embeds.view(
674
- batch_size * num_images_per_prompt, seq_len, -1
675
- )
676
-
677
- # For classifier free guidance, we need to do two forward passes.
678
- # Here we concatenate the unconditional and text embeddings into a single batch
679
- # to avoid doing two forward passes
680
-
681
- negative_prompt_embeds_list.append(negative_prompt_embeds)
682
-
683
- negative_prompt_embeds = negative_prompt_embeds_list[0]
684
-
685
- bs_embed = text_proj.shape[0]
686
- text_proj = text_proj.repeat(1, num_images_per_prompt).view(
687
- bs_embed * num_images_per_prompt, -1
688
- )
689
- negative_text_proj = negative_text_proj.repeat(1, num_images_per_prompt).view(
690
- bs_embed * num_images_per_prompt, -1
691
- )
692
-
693
- return prompt_embeds, negative_prompt_embeds, text_proj, negative_text_proj
694
-
695
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
696
- def prepare_extra_step_kwargs(self, generator, eta):
697
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
698
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
699
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
700
- # and should be between [0, 1]
701
-
702
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
703
- extra_step_kwargs = {}
704
- if accepts_eta:
705
- extra_step_kwargs["eta"] = eta
706
-
707
- # check if the scheduler accepts generator
708
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
709
- if accepts_generator:
710
- extra_step_kwargs["generator"] = generator
711
- return extra_step_kwargs
712
-
713
- def check_inputs(
714
- self,
715
- prompt,
716
- prompt_2,
717
- image,
718
- mask_image,
719
- height,
720
- width,
721
- strength,
722
- callback_steps,
723
- output_type,
724
- negative_prompt=None,
725
- negative_prompt_2=None,
726
- prompt_embeds=None,
727
- negative_prompt_embeds=None,
728
- ip_adapter_image=None,
729
- ip_adapter_image_embeds=None,
730
- callback_on_step_end_tensor_inputs=None,
731
- padding_mask_crop=None,
732
- ):
733
- if strength < 0 or strength > 1:
734
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
735
-
736
- if height % 8 != 0 or width % 8 != 0:
737
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
738
-
739
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
740
- raise ValueError(
741
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
742
- f" {type(callback_steps)}."
743
- )
744
-
745
- if callback_on_step_end_tensor_inputs is not None and not all(
746
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
747
- ):
748
- raise ValueError(
749
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
750
- )
751
-
752
- if prompt is not None and prompt_embeds is not None:
753
- raise ValueError(
754
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
755
- " only forward one of the two."
756
- )
757
- elif prompt_2 is not None and prompt_embeds is not None:
758
- raise ValueError(
759
- f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
760
- " only forward one of the two."
761
- )
762
- elif prompt is None and prompt_embeds is None:
763
- raise ValueError(
764
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
765
- )
766
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
767
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
768
- elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
769
- raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
770
-
771
- if negative_prompt is not None and negative_prompt_embeds is not None:
772
- raise ValueError(
773
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
774
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
775
- )
776
- elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
777
- raise ValueError(
778
- f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
779
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
780
- )
781
-
782
- if prompt_embeds is not None and negative_prompt_embeds is not None:
783
- if prompt_embeds.shape != negative_prompt_embeds.shape:
784
- raise ValueError(
785
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
786
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
787
- f" {negative_prompt_embeds.shape}."
788
- )
789
- if padding_mask_crop is not None:
790
- if not isinstance(image, PIL.Image.Image):
791
- raise ValueError(
792
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
793
- )
794
- if not isinstance(mask_image, PIL.Image.Image):
795
- raise ValueError(
796
- f"The mask image should be a PIL image when inpainting mask crop, but is of type"
797
- f" {type(mask_image)}."
798
- )
799
- if output_type != "pil":
800
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
801
-
802
- if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
803
- raise ValueError(
804
- "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
805
- )
806
-
807
- if ip_adapter_image_embeds is not None:
808
- if not isinstance(ip_adapter_image_embeds, list):
809
- raise ValueError(
810
- f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
811
- )
812
- elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
813
- raise ValueError(
814
- f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
815
- )
816
-
817
- def prepare_latents(
818
- self,
819
- batch_size,
820
- num_channels_latents,
821
- height,
822
- width,
823
- dtype,
824
- device,
825
- generator,
826
- latents=None,
827
- image=None,
828
- timestep=None,
829
- is_strength_max=True,
830
- add_noise=True,
831
- return_noise=False,
832
- return_image_latents=False,
833
- ):
834
- shape = (
835
- batch_size,
836
- num_channels_latents,
837
- int(height) // self.vae_scale_factor,
838
- int(width) // self.vae_scale_factor,
839
- )
840
- if isinstance(generator, list) and len(generator) != batch_size:
841
- raise ValueError(
842
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
843
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
844
- )
845
-
846
- if (image is None or timestep is None) and not is_strength_max:
847
- raise ValueError(
848
- "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
849
- "However, either the image or the noise timestep has not been provided."
850
- )
851
-
852
- if image.shape[1] == 4:
853
- image_latents = image.to(device=device, dtype=dtype)
854
- image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
855
- elif return_image_latents or (latents is None and not is_strength_max):
856
- image = image.to(device=device, dtype=dtype)
857
- image_latents = self._encode_vae_image(image=image, generator=generator)
858
- image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
859
-
860
- if latents is None and add_noise:
861
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
862
- # if strength is 1. then initialise the latents to noise, else initial to image + noise
863
- latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
864
- # if pure noise then scale the initial latents by the Scheduler's init sigma
865
- latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
866
- elif add_noise:
867
- noise = latents.to(device)
868
- latents = noise * self.scheduler.init_noise_sigma
869
- else:
870
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
871
- latents = image_latents.to(device)
872
-
873
- outputs = (latents,)
874
-
875
- if return_noise:
876
- outputs += (noise,)
877
-
878
- if return_image_latents:
879
- outputs += (image_latents,)
880
-
881
- return outputs
882
-
883
- def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
884
- dtype = image.dtype
885
- if self.vae.config.force_upcast:
886
- image = image.float()
887
- self.vae.to(dtype=torch.float32)
888
-
889
- if isinstance(generator, list):
890
- image_latents = [
891
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
892
- for i in range(image.shape[0])
893
- ]
894
- image_latents = torch.cat(image_latents, dim=0)
895
- else:
896
- image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
897
-
898
- if self.vae.config.force_upcast:
899
- self.vae.to(dtype)
900
-
901
- image_latents = image_latents.to(dtype)
902
- image_latents = self.vae.config.scaling_factor * image_latents
903
-
904
- return image_latents
905
-
906
- def prepare_mask_latents(
907
- self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
908
- ):
909
- # resize the mask to latents shape as we concatenate the mask to the latents
910
- # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
911
- # and half precision
912
- mask = torch.nn.functional.interpolate(
913
- mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
914
- )
915
- mask = mask.to(device=device, dtype=dtype)
916
-
917
- # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
918
- if mask.shape[0] < batch_size:
919
- if not batch_size % mask.shape[0] == 0:
920
- raise ValueError(
921
- "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
922
- f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
923
- " of masks that you pass is divisible by the total requested batch size."
924
- )
925
- mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
926
-
927
- mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
928
-
929
- if masked_image is not None and masked_image.shape[1] == 4:
930
- masked_image_latents = masked_image
931
- else:
932
- masked_image_latents = None
933
-
934
- if masked_image is not None:
935
- if masked_image_latents is None:
936
- masked_image = masked_image.to(device=device, dtype=dtype)
937
- masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
938
-
939
- if masked_image_latents.shape[0] < batch_size:
940
- if not batch_size % masked_image_latents.shape[0] == 0:
941
- raise ValueError(
942
- "The passed images and the required batch size don't match. Images are supposed to be duplicated"
943
- f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
944
- " Make sure the number of images that you pass is divisible by the total requested batch size."
945
- )
946
- masked_image_latents = masked_image_latents.repeat(
947
- batch_size // masked_image_latents.shape[0], 1, 1, 1
948
- )
949
-
950
- masked_image_latents = (
951
- torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
952
- )
953
-
954
- # aligning device to prevent device errors when concating it with the latent model input
955
- masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
956
-
957
- return mask, masked_image_latents
958
-
959
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
960
- def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
961
- # get the original timestep using init_timestep
962
- if denoising_start is None:
963
- init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
964
- t_start = max(num_inference_steps - init_timestep, 0)
965
- else:
966
- t_start = 0
967
-
968
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
969
-
970
- # Strength is irrelevant if we directly request a timestep to start at;
971
- # that is, strength is determined by the denoising_start instead.
972
- if denoising_start is not None:
973
- discrete_timestep_cutoff = int(
974
- round(
975
- self.scheduler.config.num_train_timesteps
976
- - (denoising_start * self.scheduler.config.num_train_timesteps)
977
- )
978
- )
979
-
980
- num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
981
- if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
982
- # if the scheduler is a 2nd order scheduler we might have to do +1
983
- # because `num_inference_steps` might be even given that every timestep
984
- # (except the highest one) is duplicated. If `num_inference_steps` is even it would
985
- # mean that we cut the timesteps in the middle of the denoising step
986
- # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
987
- # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
988
- num_inference_steps = num_inference_steps + 1
989
-
990
- # because t_n+1 >= t_n, we slice the timesteps starting from the end
991
- timesteps = timesteps[-num_inference_steps:]
992
- return timesteps, num_inference_steps
993
-
994
- return timesteps, num_inference_steps - t_start
995
-
996
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
997
- def _get_add_time_ids(
998
- self,
999
- original_size,
1000
- crops_coords_top_left,
1001
- target_size,
1002
- aesthetic_score,
1003
- negative_aesthetic_score,
1004
- negative_original_size,
1005
- negative_crops_coords_top_left,
1006
- negative_target_size,
1007
- dtype,
1008
- text_encoder_projection_dim=None,
1009
- ):
1010
- if self.config.requires_aesthetics_score:
1011
- add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
1012
- add_neg_time_ids = list(
1013
- negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
1014
- )
1015
- else:
1016
- add_time_ids = list(original_size + crops_coords_top_left + target_size)
1017
- add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
1018
-
1019
- passed_add_embed_dim = (
1020
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + 4096
1021
- )
1022
- expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
1023
-
1024
- if (
1025
- expected_add_embed_dim > passed_add_embed_dim
1026
- and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
1027
- ):
1028
- raise ValueError(
1029
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
1030
- )
1031
- elif (
1032
- expected_add_embed_dim < passed_add_embed_dim
1033
- and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
1034
- ):
1035
- raise ValueError(
1036
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
1037
- )
1038
- elif expected_add_embed_dim != passed_add_embed_dim:
1039
- raise ValueError(
1040
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
1041
- )
1042
-
1043
- add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
1044
- add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
1045
-
1046
- return add_time_ids, add_neg_time_ids
1047
-
1048
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
1049
- def upcast_vae(self):
1050
- dtype = self.vae.dtype
1051
- self.vae.to(dtype=torch.float32)
1052
- use_torch_2_0_or_xformers = isinstance(
1053
- self.vae.decoder.mid_block.attentions[0].processor,
1054
- (
1055
- AttnProcessor2_0,
1056
- XFormersAttnProcessor,
1057
- LoRAXFormersAttnProcessor,
1058
- LoRAAttnProcessor2_0,
1059
- ),
1060
- )
1061
- # if xformers or torch_2_0 is used attention block does not need
1062
- # to be in float32 which can save lots of memory
1063
- if use_torch_2_0_or_xformers:
1064
- self.vae.post_quant_conv.to(dtype)
1065
- self.vae.decoder.conv_in.to(dtype)
1066
- self.vae.decoder.mid_block.to(dtype)
1067
-
1068
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
1069
- def get_guidance_scale_embedding(
1070
- self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
1071
- ) -> torch.Tensor:
1072
- """
1073
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
1074
-
1075
- Args:
1076
- w (`torch.Tensor`):
1077
- Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
1078
- embedding_dim (`int`, *optional*, defaults to 512):
1079
- Dimension of the embeddings to generate.
1080
- dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
1081
- Data type of the generated embeddings.
1082
-
1083
- Returns:
1084
- `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
1085
- """
1086
- assert len(w.shape) == 1
1087
- w = w * 1000.0
1088
-
1089
- half_dim = embedding_dim // 2
1090
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
1091
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
1092
- emb = w.to(dtype)[:, None] * emb[None, :]
1093
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
1094
- if embedding_dim % 2 == 1: # zero pad
1095
- emb = torch.nn.functional.pad(emb, (0, 1))
1096
- assert emb.shape == (w.shape[0], embedding_dim)
1097
- return emb
1098
-
1099
- @property
1100
- def guidance_scale(self):
1101
- return self._guidance_scale
1102
-
1103
- @property
1104
- def guidance_rescale(self):
1105
- return self._guidance_rescale
1106
-
1107
- @property
1108
- def clip_skip(self):
1109
- return self._clip_skip
1110
-
1111
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1112
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1113
- # corresponds to doing no classifier free guidance.
1114
- @property
1115
- def do_classifier_free_guidance(self):
1116
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1117
-
1118
- @property
1119
- def cross_attention_kwargs(self):
1120
- return self._cross_attention_kwargs
1121
-
1122
- @property
1123
- def denoising_end(self):
1124
- return self._denoising_end
1125
-
1126
- @property
1127
- def denoising_start(self):
1128
- return self._denoising_start
1129
-
1130
- @property
1131
- def num_timesteps(self):
1132
- return self._num_timesteps
1133
-
1134
- @property
1135
- def interrupt(self):
1136
- return self._interrupt
1137
-
1138
- @torch.no_grad()
1139
- @replace_example_docstring(EXAMPLE_DOC_STRING)
1140
- def __call__(
1141
- self,
1142
- prompt: Union[str, List[str]] = None,
1143
- prompt_2: Optional[Union[str, List[str]]] = None,
1144
- image: PipelineImageInput = None,
1145
- mask_image: PipelineImageInput = None,
1146
- masked_image_latents: torch.Tensor = None,
1147
- height: Optional[int] = None,
1148
- width: Optional[int] = None,
1149
- padding_mask_crop: Optional[int] = None,
1150
- strength: float = 0.9999,
1151
- num_inference_steps: int = 50,
1152
- timesteps: List[int] = None,
1153
- sigmas: List[float] = None,
1154
- denoising_start: Optional[float] = None,
1155
- denoising_end: Optional[float] = None,
1156
- guidance_scale: float = 7.5,
1157
- negative_prompt: Optional[Union[str, List[str]]] = None,
1158
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
1159
- num_images_per_prompt: Optional[int] = 1,
1160
- eta: float = 0.0,
1161
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1162
- latents: Optional[torch.Tensor] = None,
1163
- prompt_embeds: Optional[torch.Tensor] = None,
1164
- negative_prompt_embeds: Optional[torch.Tensor] = None,
1165
- pooled_prompt_embeds: Optional[torch.Tensor] = None,
1166
- negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
1167
- ip_adapter_image: Optional[PipelineImageInput] = None,
1168
- ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1169
- output_type: Optional[str] = "pil",
1170
- return_dict: bool = True,
1171
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1172
- guidance_rescale: float = 0.0,
1173
- original_size: Tuple[int, int] = None,
1174
- crops_coords_top_left: Tuple[int, int] = (0, 0),
1175
- target_size: Tuple[int, int] = None,
1176
- negative_original_size: Optional[Tuple[int, int]] = None,
1177
- negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1178
- negative_target_size: Optional[Tuple[int, int]] = None,
1179
- aesthetic_score: float = 6.0,
1180
- negative_aesthetic_score: float = 2.5,
1181
- clip_skip: Optional[int] = None,
1182
- callback_on_step_end: Optional[
1183
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
1184
- ] = None,
1185
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1186
- **kwargs,
1187
- ):
1188
- r"""
1189
- Function invoked when calling the pipeline for generation.
1190
-
1191
- Args:
1192
- prompt (`str` or `List[str]`, *optional*):
1193
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1194
- instead.
1195
- prompt_2 (`str` or `List[str]`, *optional*):
1196
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1197
- used in both text-encoders
1198
- image (`PIL.Image.Image`):
1199
- `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
1200
- be masked out with `mask_image` and repainted according to `prompt`.
1201
- mask_image (`PIL.Image.Image`):
1202
- `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1203
- repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
1204
- to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
1205
- instead of 3, so the expected shape would be `(B, H, W, 1)`.
1206
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1207
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
1208
- Anything below 512 pixels won't work well for
1209
- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1210
- and checkpoints that are not specifically fine-tuned on low resolutions.
1211
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1212
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
1213
- Anything below 512 pixels won't work well for
1214
- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1215
- and checkpoints that are not specifically fine-tuned on low resolutions.
1216
- padding_mask_crop (`int`, *optional*, defaults to `None`):
1217
- The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
1218
- image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
1219
- with the same aspect ration of the image and contains all masked area, and then expand that area based
1220
- on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
1221
- resizing to the original image size for inpainting. This is useful when the masked area is small while
1222
- the image is large and contain information irrelevant for inpainting, such as background.
1223
- strength (`float`, *optional*, defaults to 0.9999):
1224
- Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
1225
- between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
1226
- `strength`. The number of denoising steps depends on the amount of noise initially added. When
1227
- `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
1228
- iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
1229
- portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
1230
- integer, the value of `strength` will be ignored.
1231
- num_inference_steps (`int`, *optional*, defaults to 50):
1232
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1233
- expense of slower inference.
1234
- timesteps (`List[int]`, *optional*):
1235
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1236
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1237
- passed will be used. Must be in descending order.
1238
- sigmas (`List[float]`, *optional*):
1239
- Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1240
- their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1241
- will be used.
1242
- denoising_start (`float`, *optional*):
1243
- When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
1244
- bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
1245
- it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
1246
- strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
1247
- is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
1248
- Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1249
- denoising_end (`float`, *optional*):
1250
- When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1251
- completed before it is intentionally prematurely terminated. As a result, the returned sample will
1252
- still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
1253
- denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
1254
- final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
1255
- forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1256
- Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1257
- guidance_scale (`float`, *optional*, defaults to 7.5):
1258
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1259
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
1260
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1261
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1262
- usually at the expense of lower image quality.
1263
- negative_prompt (`str` or `List[str]`, *optional*):
1264
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
1265
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1266
- less than `1`).
1267
- negative_prompt_2 (`str` or `List[str]`, *optional*):
1268
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1269
- `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1270
- prompt_embeds (`torch.Tensor`, *optional*):
1271
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1272
- provided, text embeddings will be generated from `prompt` input argument.
1273
- negative_prompt_embeds (`torch.Tensor`, *optional*):
1274
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1275
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1276
- argument.
1277
- pooled_prompt_embeds (`torch.Tensor`, *optional*):
1278
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1279
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
1280
- negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
1281
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1282
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1283
- input argument.
1284
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1285
- ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1286
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1287
- IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1288
- contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1289
- provided, embeddings are computed from the `ip_adapter_image` input argument.
1290
- num_images_per_prompt (`int`, *optional*, defaults to 1):
1291
- The number of images to generate per prompt.
1292
- eta (`float`, *optional*, defaults to 0.0):
1293
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1294
- [`schedulers.DDIMScheduler`], will be ignored for others.
1295
- generator (`torch.Generator`, *optional*):
1296
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1297
- to make generation deterministic.
1298
- latents (`torch.Tensor`, *optional*):
1299
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1300
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1301
- tensor will ge generated by sampling using the supplied random `generator`.
1302
- output_type (`str`, *optional*, defaults to `"pil"`):
1303
- The output format of the generate image. Choose between
1304
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1305
- return_dict (`bool`, *optional*, defaults to `True`):
1306
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1307
- plain tuple.
1308
- cross_attention_kwargs (`dict`, *optional*):
1309
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1310
- `self.processor` in
1311
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1312
- original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1313
- If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1314
- `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1315
- explained in section 2.2 of
1316
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1317
- crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1318
- `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1319
- `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1320
- `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1321
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1322
- target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1323
- For most cases, `target_size` should be set to the desired height and width of the generated image. If
1324
- not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1325
- section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1326
- negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1327
- To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1328
- micro-conditioning as explained in section 2.2 of
1329
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1330
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1331
- negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1332
- To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1333
- micro-conditioning as explained in section 2.2 of
1334
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1335
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1336
- negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1337
- To negatively condition the generation process based on a target image resolution. It should be as same
1338
- as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1339
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1340
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1341
- aesthetic_score (`float`, *optional*, defaults to 6.0):
1342
- Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
1343
- Part of SDXL's micro-conditioning as explained in section 2.2 of
1344
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1345
- negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
1346
- Part of SDXL's micro-conditioning as explained in section 2.2 of
1347
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
1348
- simulate an aesthetic score of the generated image by influencing the negative text condition.
1349
- clip_skip (`int`, *optional*):
1350
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1351
- the output of the pre-final layer will be used for computing the prompt embeddings.
1352
- callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1353
- A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1354
- each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1355
- DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1356
- list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1357
- callback_on_step_end_tensor_inputs (`List`, *optional*):
1358
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1359
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1360
- `._callback_tensor_inputs` attribute of your pipeline class.
1361
-
1362
- Examples:
1363
-
1364
- Returns:
1365
- [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
1366
- [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1367
- `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
1368
- """
1369
-
1370
- callback = kwargs.pop("callback", None)
1371
- callback_steps = kwargs.pop("callback_steps", None)
1372
-
1373
- if callback is not None:
1374
- deprecate(
1375
- "callback",
1376
- "1.0.0",
1377
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1378
- )
1379
- if callback_steps is not None:
1380
- deprecate(
1381
- "callback_steps",
1382
- "1.0.0",
1383
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1384
- )
1385
-
1386
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1387
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1388
-
1389
- # 0. Default height and width to unet
1390
- height = height or self.unet.config.sample_size * self.vae_scale_factor
1391
- width = width or self.unet.config.sample_size * self.vae_scale_factor
1392
-
1393
- # 1. Check inputs
1394
- self.check_inputs(
1395
- prompt,
1396
- prompt_2,
1397
- image,
1398
- mask_image,
1399
- height,
1400
- width,
1401
- strength,
1402
- callback_steps,
1403
- output_type,
1404
- negative_prompt,
1405
- negative_prompt_2,
1406
- prompt_embeds,
1407
- negative_prompt_embeds,
1408
- ip_adapter_image,
1409
- ip_adapter_image_embeds,
1410
- callback_on_step_end_tensor_inputs,
1411
- padding_mask_crop,
1412
- )
1413
-
1414
- self._guidance_scale = guidance_scale
1415
- self._guidance_rescale = guidance_rescale
1416
- self._clip_skip = clip_skip
1417
- self._cross_attention_kwargs = cross_attention_kwargs
1418
- self._denoising_end = denoising_end
1419
- self._denoising_start = denoising_start
1420
- self._interrupt = False
1421
-
1422
- # 2. Define call parameters
1423
- if prompt is not None and isinstance(prompt, str):
1424
- batch_size = 1
1425
- elif prompt is not None and isinstance(prompt, list):
1426
- batch_size = len(prompt)
1427
- else:
1428
- batch_size = prompt_embeds.shape[0]
1429
-
1430
- device = self._execution_device
1431
-
1432
- # 3. Encode input prompt
1433
- text_encoder_lora_scale = (
1434
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1435
- )
1436
-
1437
- (
1438
- prompt_embeds,
1439
- negative_prompt_embeds,
1440
- pooled_prompt_embeds,
1441
- negative_pooled_prompt_embeds,
1442
- ) = self.encode_prompt(
1443
- prompt=prompt,
1444
- device=device,
1445
- num_images_per_prompt=num_images_per_prompt,
1446
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1447
- negative_prompt=negative_prompt,
1448
- prompt_embeds=prompt_embeds,
1449
- negative_prompt_embeds=negative_prompt_embeds,
1450
- pooled_prompt_embeds=pooled_prompt_embeds,
1451
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1452
- lora_scale=text_encoder_lora_scale,
1453
- )
1454
-
1455
- # 4. set timesteps
1456
- def denoising_value_valid(dnv):
1457
- return isinstance(dnv, float) and 0 < dnv < 1
1458
-
1459
- timesteps, num_inference_steps = retrieve_timesteps(
1460
- self.scheduler, num_inference_steps, device, timesteps, sigmas
1461
- )
1462
- timesteps, num_inference_steps = self.get_timesteps(
1463
- num_inference_steps,
1464
- strength,
1465
- device,
1466
- denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
1467
- )
1468
- # check that number of inference steps is not < 1 - as this doesn't make sense
1469
- if num_inference_steps < 1:
1470
- raise ValueError(
1471
- f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
1472
- f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
1473
- )
1474
- # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1475
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1476
- # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1477
- is_strength_max = strength == 1.0
1478
-
1479
- # 5. Preprocess mask and image
1480
- if padding_mask_crop is not None:
1481
- crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
1482
- resize_mode = "fill"
1483
- else:
1484
- crops_coords = None
1485
- resize_mode = "default"
1486
-
1487
- original_image = image
1488
- init_image = self.image_processor.preprocess(
1489
- image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
1490
- )
1491
- init_image = init_image.to(dtype=torch.float32)
1492
-
1493
- mask = self.mask_processor.preprocess(
1494
- mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1495
- )
1496
-
1497
- if masked_image_latents is not None:
1498
- masked_image = masked_image_latents
1499
- elif init_image.shape[1] == 4:
1500
- # if images are in latent space, we can't mask it
1501
- masked_image = None
1502
- else:
1503
- masked_image = init_image * (mask < 0.5)
1504
-
1505
- # 6. Prepare latent variables
1506
- num_channels_latents = self.vae.config.latent_channels
1507
- num_channels_unet = self.unet.config.in_channels
1508
- return_image_latents = num_channels_unet == 4
1509
-
1510
- add_noise = True if self.denoising_start is None else False
1511
- latents_outputs = self.prepare_latents(
1512
- batch_size * num_images_per_prompt,
1513
- num_channels_latents,
1514
- height,
1515
- width,
1516
- prompt_embeds.dtype,
1517
- device,
1518
- generator,
1519
- latents,
1520
- image=init_image,
1521
- timestep=latent_timestep,
1522
- is_strength_max=is_strength_max,
1523
- add_noise=add_noise,
1524
- return_noise=True,
1525
- return_image_latents=return_image_latents,
1526
- )
1527
-
1528
- if return_image_latents:
1529
- latents, noise, image_latents = latents_outputs
1530
- else:
1531
- latents, noise = latents_outputs
1532
-
1533
- # 7. Prepare mask latent variables
1534
- mask, masked_image_latents = self.prepare_mask_latents(
1535
- mask,
1536
- masked_image,
1537
- batch_size * num_images_per_prompt,
1538
- height,
1539
- width,
1540
- prompt_embeds.dtype,
1541
- device,
1542
- generator,
1543
- self.do_classifier_free_guidance,
1544
- )
1545
-
1546
- # 8. Check that sizes of mask, masked image and latents match
1547
- if num_channels_unet == 9:
1548
- # default case for runwayml/stable-diffusion-inpainting
1549
- num_channels_mask = mask.shape[1]
1550
- num_channels_masked_image = masked_image_latents.shape[1]
1551
- if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
1552
- raise ValueError(
1553
- f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
1554
- f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1555
- f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1556
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1557
- " `pipeline.unet` or your `mask_image` or `image` input."
1558
- )
1559
- elif num_channels_unet != 4:
1560
- raise ValueError(
1561
- f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
1562
- )
1563
- # 8.1 Prepare extra step kwargs.
1564
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1565
-
1566
- # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1567
- height, width = latents.shape[-2:]
1568
- height = height * self.vae_scale_factor
1569
- width = width * self.vae_scale_factor
1570
-
1571
- original_size = original_size or (height, width)
1572
- target_size = target_size or (height, width)
1573
-
1574
- # 10. Prepare added time ids & embeddings
1575
- if negative_original_size is None:
1576
- negative_original_size = original_size
1577
- if negative_target_size is None:
1578
- negative_target_size = target_size
1579
-
1580
- add_text_embeds = pooled_prompt_embeds
1581
- if self.text_encoder_2 is None:
1582
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1583
- else:
1584
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1585
-
1586
- add_time_ids, add_neg_time_ids = self._get_add_time_ids(
1587
- original_size,
1588
- crops_coords_top_left,
1589
- target_size,
1590
- aesthetic_score,
1591
- negative_aesthetic_score,
1592
- negative_original_size,
1593
- negative_crops_coords_top_left,
1594
- negative_target_size,
1595
- dtype=prompt_embeds.dtype,
1596
- text_encoder_projection_dim=text_encoder_projection_dim,
1597
- )
1598
- add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1599
-
1600
- if self.do_classifier_free_guidance:
1601
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1602
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1603
- add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1604
- add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
1605
-
1606
- prompt_embeds = prompt_embeds.to(device)
1607
- add_text_embeds = add_text_embeds.to(device)
1608
- add_time_ids = add_time_ids.to(device)
1609
-
1610
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1611
- image_embeds = self.prepare_ip_adapter_image_embeds(
1612
- ip_adapter_image,
1613
- ip_adapter_image_embeds,
1614
- device,
1615
- batch_size * num_images_per_prompt,
1616
- self.do_classifier_free_guidance,
1617
- )
1618
-
1619
-
1620
- # 11. Denoising loop
1621
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1622
-
1623
- if (
1624
- self.denoising_end is not None
1625
- and self.denoising_start is not None
1626
- and denoising_value_valid(self.denoising_end)
1627
- and denoising_value_valid(self.denoising_start)
1628
- and self.denoising_start >= self.denoising_end
1629
- ):
1630
- raise ValueError(
1631
- f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
1632
- + f" {self.denoising_end} when using type float."
1633
- )
1634
- elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
1635
- discrete_timestep_cutoff = int(
1636
- round(
1637
- self.scheduler.config.num_train_timesteps
1638
- - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1639
- )
1640
- )
1641
- num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1642
- timesteps = timesteps[:num_inference_steps]
1643
-
1644
- # 11.1 Optionally get Guidance Scale Embedding
1645
- timestep_cond = None
1646
- if self.unet.config.time_cond_proj_dim is not None:
1647
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1648
- timestep_cond = self.get_guidance_scale_embedding(
1649
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1650
- ).to(device=device, dtype=latents.dtype)
1651
-
1652
- self._num_timesteps = len(timesteps)
1653
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1654
- for i, t in enumerate(timesteps):
1655
- if self.interrupt:
1656
- continue
1657
- # expand the latents if we are doing classifier free guidance
1658
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1659
-
1660
- # concat latents, mask, masked_image_latents in the channel dimension
1661
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1662
-
1663
- if num_channels_unet == 9:
1664
- latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1665
-
1666
- # predict the noise residual
1667
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1668
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1669
- added_cond_kwargs["image_embeds"] = image_embeds
1670
- noise_pred = self.unet(
1671
- latent_model_input,
1672
- t,
1673
- encoder_hidden_states=prompt_embeds,
1674
- timestep_cond=timestep_cond,
1675
- cross_attention_kwargs=self.cross_attention_kwargs,
1676
- added_cond_kwargs=added_cond_kwargs,
1677
- return_dict=False,
1678
- )[0]
1679
-
1680
- # perform guidance
1681
- if self.do_classifier_free_guidance:
1682
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1683
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1684
-
1685
- if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1686
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1687
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1688
-
1689
- # compute the previous noisy sample x_t -> x_t-1
1690
- latents_dtype = latents.dtype
1691
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1692
- if latents.dtype != latents_dtype:
1693
- if torch.backends.mps.is_available():
1694
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1695
- latents = latents.to(latents_dtype)
1696
-
1697
- if num_channels_unet == 4:
1698
- init_latents_proper = image_latents
1699
- if self.do_classifier_free_guidance:
1700
- init_mask, _ = mask.chunk(2)
1701
- else:
1702
- init_mask = mask
1703
-
1704
- if i < len(timesteps) - 1:
1705
- noise_timestep = timesteps[i + 1]
1706
- init_latents_proper = self.scheduler.add_noise(
1707
- init_latents_proper, noise, torch.tensor([noise_timestep])
1708
- )
1709
-
1710
- latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1711
-
1712
- if callback_on_step_end is not None:
1713
- callback_kwargs = {}
1714
- for k in callback_on_step_end_tensor_inputs:
1715
- callback_kwargs[k] = locals()[k]
1716
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1717
-
1718
- latents = callback_outputs.pop("latents", latents)
1719
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1720
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1721
- add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1722
- negative_pooled_prompt_embeds = callback_outputs.pop(
1723
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1724
- )
1725
- add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1726
- add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
1727
- mask = callback_outputs.pop("mask", mask)
1728
- masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
1729
-
1730
- # call the callback, if provided
1731
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1732
- progress_bar.update()
1733
- if callback is not None and i % callback_steps == 0:
1734
- step_idx = i // getattr(self.scheduler, "order", 1)
1735
- callback(step_idx, t, latents)
1736
-
1737
- if XLA_AVAILABLE:
1738
- xm.mark_step()
1739
-
1740
- if not output_type == "latent":
1741
- # make sure the VAE is in float32 mode, as it overflows in float16
1742
- needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1743
-
1744
- if needs_upcasting:
1745
- self.upcast_vae()
1746
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1747
- elif latents.dtype != self.vae.dtype:
1748
- if torch.backends.mps.is_available():
1749
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1750
- self.vae = self.vae.to(latents.dtype)
1751
-
1752
- # unscale/denormalize the latents
1753
- # denormalize with the mean and std if available and not None
1754
- has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1755
- has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1756
- if has_latents_mean and has_latents_std:
1757
- latents_mean = (
1758
- torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1759
- )
1760
- latents_std = (
1761
- torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1762
- )
1763
- latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1764
- else:
1765
- latents = latents / self.vae.config.scaling_factor
1766
-
1767
- image = self.vae.decode(latents, return_dict=False)[0]
1768
-
1769
- # cast back to fp16 if needed
1770
- if needs_upcasting:
1771
- self.vae.to(dtype=torch.float16)
1772
- else:
1773
- return StableDiffusionXLPipelineOutput(images=latents)
1774
-
1775
- # apply watermark if available
1776
- if self.watermark is not None:
1777
- image = self.watermark.apply_watermark(image)
1778
-
1779
- image = self.image_processor.postprocess(image, output_type=output_type)
1780
-
1781
- if padding_mask_crop is not None:
1782
- image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
1783
-
1784
- # Offload all models
1785
- self.maybe_free_model_hooks()
1786
-
1787
- if not return_dict:
1788
- return (image,)
1789
-
1790
- return StableDiffusionXLPipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kolors/pipelines/pipeline_stable_diffusion_xl_chatglm_256_ipadapter.py DELETED
@@ -1,948 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import sys
15
- import os
16
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
17
- from kolors.models.modeling_chatglm import ChatGLMModel
18
- from kolors.models.tokenization_chatglm import ChatGLMTokenizer
19
- import inspect
20
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
- import torch
22
- from transformers import (
23
- CLIPImageProcessor,
24
- CLIPTextModel,
25
- CLIPTextModelWithProjection,
26
- CLIPTokenizer,
27
- CLIPVisionModelWithProjection,
28
- )
29
- from transformers import XLMRobertaModel, ChineseCLIPTextModel
30
-
31
- from diffusers.image_processor import VaeImageProcessor,PipelineImageInput
32
- from diffusers.loaders import (
33
- FromSingleFileMixin,
34
- IPAdapterMixin,
35
- LoraLoaderMixin,
36
- TextualInversionLoaderMixin
37
- )
38
- from diffusers.models import AutoencoderKL, UNet2DConditionModel,ImageProjection
39
- from diffusers.models.attention_processor import (
40
- AttnProcessor2_0,
41
- LoRAAttnProcessor2_0,
42
- LoRAXFormersAttnProcessor,
43
- XFormersAttnProcessor,
44
- )
45
- from diffusers.schedulers import KarrasDiffusionSchedulers
46
- from diffusers.utils import (
47
- is_accelerate_available,
48
- is_accelerate_version,
49
- logging,
50
- replace_example_docstring,
51
- )
52
- try:
53
- from diffusers.utils import randn_tensor
54
- except:
55
- from diffusers.utils.torch_utils import randn_tensor
56
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
57
- from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
58
-
59
-
60
-
61
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
-
63
- EXAMPLE_DOC_STRING = """
64
- Examples:
65
- ```py
66
- >>> import torch
67
- >>> from diffusers import StableDiffusionXLPipeline
68
-
69
- >>> pipe = StableDiffusionXLPipeline.from_pretrained(
70
- ... "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16
71
- ... )
72
- >>> pipe = pipe.to("cuda")
73
-
74
- >>> prompt = "a photo of an astronaut riding a horse on mars"
75
- >>> image = pipe(prompt).images[0]
76
- ```
77
- """
78
-
79
-
80
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
81
- def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
82
- """
83
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
84
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
85
- """
86
- std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
87
- std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
88
- # rescale the results from guidance (fixes overexposure)
89
- noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
90
- # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
91
- noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
92
- return noise_cfg
93
-
94
-
95
- class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, IPAdapterMixin,):
96
- r"""
97
- Pipeline for text-to-image generation using Stable Diffusion XL.
98
-
99
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
100
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
101
-
102
- In addition the pipeline inherits the following loading methods:
103
- - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
104
- - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
105
- - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
106
-
107
- as well as the following saving methods:
108
- - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
109
-
110
- Args:
111
- vae ([`AutoencoderKL`]):
112
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
113
- text_encoder ([`CLIPTextModel`]):
114
- Frozen text-encoder. Stable Diffusion XL uses the text portion of
115
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
116
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
117
-
118
- tokenizer (`CLIPTokenizer`):
119
- Tokenizer of class
120
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
121
-
122
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
123
- scheduler ([`SchedulerMixin`]):
124
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
125
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
126
- """
127
-
128
- def __init__(
129
- self,
130
- vae: AutoencoderKL,
131
- text_encoder: ChatGLMModel,
132
- tokenizer: ChatGLMTokenizer,
133
- unet: UNet2DConditionModel,
134
- scheduler: KarrasDiffusionSchedulers,
135
- image_encoder: CLIPVisionModelWithProjection = None,
136
- feature_extractor: CLIPImageProcessor = None,
137
- force_zeros_for_empty_prompt: bool = True,
138
- ):
139
- super().__init__()
140
-
141
- self.register_modules(
142
- vae=vae,
143
- text_encoder=text_encoder,
144
- tokenizer=tokenizer,
145
- unet=unet,
146
- scheduler=scheduler,
147
- image_encoder=image_encoder,
148
- feature_extractor=feature_extractor,
149
- )
150
- self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
151
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
152
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
153
- self.default_sample_size = self.unet.config.sample_size
154
-
155
- # self.watermark = StableDiffusionXLWatermarker()
156
-
157
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
158
- def enable_vae_slicing(self):
159
- r"""
160
- Enable sliced VAE decoding.
161
-
162
- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
163
- steps. This is useful to save some memory and allow larger batch sizes.
164
- """
165
- self.vae.enable_slicing()
166
-
167
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
168
- def disable_vae_slicing(self):
169
- r"""
170
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
171
- computing decoding in one step.
172
- """
173
- self.vae.disable_slicing()
174
-
175
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
176
- def enable_vae_tiling(self):
177
- r"""
178
- Enable tiled VAE decoding.
179
-
180
- When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
181
- several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
182
- """
183
- self.vae.enable_tiling()
184
-
185
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
186
- def disable_vae_tiling(self):
187
- r"""
188
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
189
- computing decoding in one step.
190
- """
191
- self.vae.disable_tiling()
192
-
193
- def enable_sequential_cpu_offload(self, gpu_id=0):
194
- r"""
195
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
196
- text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
197
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
198
- Note that offloading happens on a submodule basis. Memory savings are higher than with
199
- `enable_model_cpu_offload`, but performance is lower.
200
- """
201
- if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
202
- from accelerate import cpu_offload
203
- else:
204
- raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
205
-
206
- device = torch.device(f"cuda:{gpu_id}")
207
-
208
- if self.device.type != "cpu":
209
- self.to("cpu", silence_dtype_warnings=True)
210
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
211
-
212
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
213
- cpu_offload(cpu_offloaded_model, device)
214
-
215
- def enable_model_cpu_offload(self, gpu_id=0):
216
- r"""
217
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
218
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
219
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
220
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
221
- """
222
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
223
- from accelerate import cpu_offload_with_hook
224
- else:
225
- raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
226
-
227
- device = torch.device(f"cuda:{gpu_id}")
228
-
229
- if self.device.type != "cpu":
230
- self.to("cpu", silence_dtype_warnings=True)
231
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
232
-
233
- model_sequence = (
234
- [self.text_encoder, self.image_encoder]
235
- )
236
- model_sequence.extend([self.unet, self.vae])
237
-
238
- hook = None
239
- for cpu_offloaded_model in model_sequence:
240
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
241
-
242
- # We'll offload the last model manually.
243
- self.final_offload_hook = hook
244
-
245
- @property
246
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
247
- def _execution_device(self):
248
- r"""
249
- Returns the device on which the pipeline's models will be executed. After calling
250
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
251
- hooks.
252
- """
253
- if not hasattr(self.unet, "_hf_hook"):
254
- return self.device
255
- for module in self.unet.modules():
256
- if (
257
- hasattr(module, "_hf_hook")
258
- and hasattr(module._hf_hook, "execution_device")
259
- and module._hf_hook.execution_device is not None
260
- ):
261
- return torch.device(module._hf_hook.execution_device)
262
- return self.device
263
-
264
- def encode_prompt(
265
- self,
266
- prompt,
267
- device: Optional[torch.device] = None,
268
- num_images_per_prompt: int = 1,
269
- do_classifier_free_guidance: bool = True,
270
- negative_prompt=None,
271
- prompt_embeds: Optional[torch.FloatTensor] = None,
272
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
273
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
274
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
275
- lora_scale: Optional[float] = None,
276
- ):
277
- r"""
278
- Encodes the prompt into text encoder hidden states.
279
-
280
- Args:
281
- prompt (`str` or `List[str]`, *optional*):
282
- prompt to be encoded
283
- device: (`torch.device`):
284
- torch device
285
- num_images_per_prompt (`int`):
286
- number of images that should be generated per prompt
287
- do_classifier_free_guidance (`bool`):
288
- whether to use classifier free guidance or not
289
- negative_prompt (`str` or `List[str]`, *optional*):
290
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
291
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
292
- less than `1`).
293
- prompt_embeds (`torch.FloatTensor`, *optional*):
294
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
295
- provided, text embeddings will be generated from `prompt` input argument.
296
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
297
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
298
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
299
- argument.
300
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
301
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
302
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
303
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
304
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
305
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
306
- input argument.
307
- lora_scale (`float`, *optional*):
308
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
309
- """
310
- # from IPython import embed; embed(); exit()
311
- device = device or self._execution_device
312
-
313
- # set lora scale so that monkey patched LoRA
314
- # function of text encoder can correctly access it
315
- if lora_scale is not None and isinstance(self, LoraLoaderMixin):
316
- self._lora_scale = lora_scale
317
-
318
- if prompt is not None and isinstance(prompt, str):
319
- batch_size = 1
320
- elif prompt is not None and isinstance(prompt, list):
321
- batch_size = len(prompt)
322
- else:
323
- batch_size = prompt_embeds.shape[0]
324
-
325
- # Define tokenizers and text encoders
326
- tokenizers = [self.tokenizer]
327
- text_encoders = [self.text_encoder]
328
-
329
- if prompt_embeds is None:
330
- # textual inversion: procecss multi-vector tokens if necessary
331
- prompt_embeds_list = []
332
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
333
- if isinstance(self, TextualInversionLoaderMixin):
334
- prompt = self.maybe_convert_prompt(prompt, tokenizer)
335
-
336
- text_inputs = tokenizer(
337
- prompt,
338
- padding="max_length",
339
- max_length=256,
340
- truncation=True,
341
- return_tensors="pt",
342
- ).to('cuda')
343
- output = text_encoder(
344
- input_ids=text_inputs['input_ids'] ,
345
- attention_mask=text_inputs['attention_mask'],
346
- position_ids=text_inputs['position_ids'],
347
- output_hidden_states=True)
348
- prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
349
- pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
350
- bs_embed, seq_len, _ = prompt_embeds.shape
351
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
352
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
353
-
354
- prompt_embeds_list.append(prompt_embeds)
355
-
356
- # prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
357
- prompt_embeds = prompt_embeds_list[0]
358
-
359
- # get unconditional embeddings for classifier free guidance
360
- zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
361
- if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
362
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
363
- negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
364
- elif do_classifier_free_guidance and negative_prompt_embeds is None:
365
- # negative_prompt = negative_prompt or ""
366
- uncond_tokens: List[str]
367
- if negative_prompt is None:
368
- uncond_tokens = [""] * batch_size
369
- elif prompt is not None and type(prompt) is not type(negative_prompt):
370
- raise TypeError(
371
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
372
- f" {type(prompt)}."
373
- )
374
- elif isinstance(negative_prompt, str):
375
- uncond_tokens = [negative_prompt]
376
- elif batch_size != len(negative_prompt):
377
- raise ValueError(
378
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
379
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
380
- " the batch size of `prompt`."
381
- )
382
- else:
383
- uncond_tokens = negative_prompt
384
-
385
- negative_prompt_embeds_list = []
386
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
387
- # textual inversion: procecss multi-vector tokens if necessary
388
- if isinstance(self, TextualInversionLoaderMixin):
389
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
390
-
391
- max_length = prompt_embeds.shape[1]
392
- uncond_input = tokenizer(
393
- uncond_tokens,
394
- padding="max_length",
395
- max_length=max_length,
396
- truncation=True,
397
- return_tensors="pt",
398
- ).to('cuda')
399
- output = text_encoder(
400
- input_ids=uncond_input['input_ids'] ,
401
- attention_mask=uncond_input['attention_mask'],
402
- position_ids=uncond_input['position_ids'],
403
- output_hidden_states=True)
404
- negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
405
- negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
406
-
407
- if do_classifier_free_guidance:
408
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
409
- seq_len = negative_prompt_embeds.shape[1]
410
-
411
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
412
-
413
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
414
- negative_prompt_embeds = negative_prompt_embeds.view(
415
- batch_size * num_images_per_prompt, seq_len, -1
416
- )
417
-
418
- # For classifier free guidance, we need to do two forward passes.
419
- # Here we concatenate the unconditional and text embeddings into a single batch
420
- # to avoid doing two forward passes
421
-
422
- negative_prompt_embeds_list.append(negative_prompt_embeds)
423
-
424
- # negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
425
- negative_prompt_embeds = negative_prompt_embeds_list[0]
426
-
427
- bs_embed = pooled_prompt_embeds.shape[0]
428
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
429
- bs_embed * num_images_per_prompt, -1
430
- )
431
- if do_classifier_free_guidance:
432
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
433
- bs_embed * num_images_per_prompt, -1
434
- )
435
-
436
- return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
437
-
438
-
439
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
440
- def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
441
- dtype = next(self.image_encoder.parameters()).dtype
442
-
443
- if not isinstance(image, torch.Tensor):
444
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
445
-
446
- image = image.to(device=device, dtype=dtype)
447
- if output_hidden_states:
448
- image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
449
- image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
450
- uncond_image_enc_hidden_states = self.image_encoder(
451
- torch.zeros_like(image), output_hidden_states=True
452
- ).hidden_states[-2]
453
- uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
454
- num_images_per_prompt, dim=0
455
- )
456
- return image_enc_hidden_states, uncond_image_enc_hidden_states
457
- else:
458
- image_embeds = self.image_encoder(image).image_embeds
459
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
460
- uncond_image_embeds = torch.zeros_like(image_embeds)
461
-
462
- return image_embeds, uncond_image_embeds
463
-
464
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
465
- def prepare_ip_adapter_image_embeds(
466
- self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
467
- ):
468
- image_embeds = []
469
- if do_classifier_free_guidance:
470
- negative_image_embeds = []
471
- if ip_adapter_image_embeds is None:
472
- if not isinstance(ip_adapter_image, list):
473
- ip_adapter_image = [ip_adapter_image]
474
-
475
- if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
476
- raise ValueError(
477
- f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
478
- )
479
-
480
- for single_ip_adapter_image, image_proj_layer in zip(
481
- ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
482
- ):
483
- output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
484
- single_image_embeds, single_negative_image_embeds = self.encode_image(
485
- single_ip_adapter_image, device, 1, output_hidden_state
486
- )
487
-
488
- image_embeds.append(single_image_embeds[None, :])
489
- if do_classifier_free_guidance:
490
- negative_image_embeds.append(single_negative_image_embeds[None, :])
491
- else:
492
- for single_image_embeds in ip_adapter_image_embeds:
493
- if do_classifier_free_guidance:
494
- single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
495
- negative_image_embeds.append(single_negative_image_embeds)
496
- image_embeds.append(single_image_embeds)
497
-
498
- ip_adapter_image_embeds = []
499
- for i, single_image_embeds in enumerate(image_embeds):
500
- single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
501
- if do_classifier_free_guidance:
502
- single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
503
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
504
-
505
- single_image_embeds = single_image_embeds.to(device=device)
506
- ip_adapter_image_embeds.append(single_image_embeds)
507
-
508
- return ip_adapter_image_embeds
509
-
510
-
511
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
512
- def prepare_extra_step_kwargs(self, generator, eta):
513
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
514
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
515
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
516
- # and should be between [0, 1]
517
-
518
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
519
- extra_step_kwargs = {}
520
- if accepts_eta:
521
- extra_step_kwargs["eta"] = eta
522
-
523
- # check if the scheduler accepts generator
524
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
525
- if accepts_generator:
526
- extra_step_kwargs["generator"] = generator
527
- return extra_step_kwargs
528
-
529
- def check_inputs(
530
- self,
531
- prompt,
532
- height,
533
- width,
534
- callback_steps,
535
- negative_prompt=None,
536
- prompt_embeds=None,
537
- negative_prompt_embeds=None,
538
- pooled_prompt_embeds=None,
539
- negative_pooled_prompt_embeds=None,
540
- ):
541
- if height % 8 != 0 or width % 8 != 0:
542
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
543
-
544
- if (callback_steps is None) or (
545
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
546
- ):
547
- raise ValueError(
548
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
549
- f" {type(callback_steps)}."
550
- )
551
-
552
- if prompt is not None and prompt_embeds is not None:
553
- raise ValueError(
554
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
555
- " only forward one of the two."
556
- )
557
- elif prompt is None and prompt_embeds is None:
558
- raise ValueError(
559
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
560
- )
561
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
562
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
563
-
564
- if negative_prompt is not None and negative_prompt_embeds is not None:
565
- raise ValueError(
566
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
567
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
568
- )
569
-
570
- if prompt_embeds is not None and negative_prompt_embeds is not None:
571
- if prompt_embeds.shape != negative_prompt_embeds.shape:
572
- raise ValueError(
573
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
574
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
575
- f" {negative_prompt_embeds.shape}."
576
- )
577
-
578
- if prompt_embeds is not None and pooled_prompt_embeds is None:
579
- raise ValueError(
580
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
581
- )
582
-
583
- if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
584
- raise ValueError(
585
- "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
586
- )
587
-
588
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
589
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
590
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
591
- if isinstance(generator, list) and len(generator) != batch_size:
592
- raise ValueError(
593
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
594
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
595
- )
596
-
597
- if latents is None:
598
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
599
- else:
600
- latents = latents.to(device)
601
-
602
- # scale the initial noise by the standard deviation required by the scheduler
603
- latents = latents * self.scheduler.init_noise_sigma
604
- return latents
605
-
606
- def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
607
- add_time_ids = list(original_size + crops_coords_top_left + target_size)
608
-
609
- passed_add_embed_dim = (
610
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + 4096
611
- )
612
- expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
613
-
614
- if expected_add_embed_dim != passed_add_embed_dim:
615
- raise ValueError(
616
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
617
- )
618
-
619
- add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
620
- return add_time_ids
621
-
622
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
623
- def upcast_vae(self):
624
- dtype = self.vae.dtype
625
- self.vae.to(dtype=torch.float32)
626
- use_torch_2_0_or_xformers = isinstance(
627
- self.vae.decoder.mid_block.attentions[0].processor,
628
- (
629
- AttnProcessor2_0,
630
- XFormersAttnProcessor,
631
- LoRAXFormersAttnProcessor,
632
- LoRAAttnProcessor2_0,
633
- ),
634
- )
635
- # if xformers or torch_2_0 is used attention block does not need
636
- # to be in float32 which can save lots of memory
637
- if use_torch_2_0_or_xformers:
638
- self.vae.post_quant_conv.to(dtype)
639
- self.vae.decoder.conv_in.to(dtype)
640
- self.vae.decoder.mid_block.to(dtype)
641
-
642
- @torch.no_grad()
643
- @replace_example_docstring(EXAMPLE_DOC_STRING)
644
- def __call__(
645
- self,
646
- prompt: Union[str, List[str]] = None,
647
- height: Optional[int] = None,
648
- width: Optional[int] = None,
649
- num_inference_steps: int = 50,
650
- denoising_end: Optional[float] = None,
651
- guidance_scale: float = 5.0,
652
- negative_prompt: Optional[Union[str, List[str]]] = None,
653
- num_images_per_prompt: Optional[int] = 1,
654
- eta: float = 0.0,
655
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
656
- latents: Optional[torch.FloatTensor] = None,
657
- prompt_embeds: Optional[torch.FloatTensor] = None,
658
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
659
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
660
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
661
-
662
- ip_adapter_image: Optional[PipelineImageInput] = None,
663
- ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
664
-
665
- output_type: Optional[str] = "pil",
666
- return_dict: bool = True,
667
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
668
- callback_steps: int = 1,
669
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
670
- guidance_rescale: float = 0.0,
671
- original_size: Optional[Tuple[int, int]] = None,
672
- crops_coords_top_left: Tuple[int, int] = (0, 0),
673
- target_size: Optional[Tuple[int, int]] = None,
674
- use_dynamic_threshold: Optional[bool] = False,
675
- ):
676
- r"""
677
- Function invoked when calling the pipeline for generation.
678
-
679
- Args:
680
- prompt (`str` or `List[str]`, *optional*):
681
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
682
- instead.
683
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
684
- The height in pixels of the generated image.
685
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
686
- The width in pixels of the generated image.
687
- num_inference_steps (`int`, *optional*, defaults to 50):
688
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
689
- expense of slower inference.
690
- denoising_end (`float`, *optional*):
691
- When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
692
- completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to
693
- 0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50)
694
- Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
695
- Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
696
- guidance_scale (`float`, *optional*, defaults to 7.5):
697
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
698
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
699
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
700
- negative_prompt (`str` or `List[str]`, *optional*):
701
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
702
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
703
- less than `1`).
704
- num_images_per_prompt (`int`, *optional*, defaults to 1):
705
- The number of images to generate per prompt.
706
- eta (`float`, *optional*, defaults to 0.0):
707
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
708
- [`schedulers.DDIMScheduler`], will be ignored for others.
709
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
710
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
711
- to make generation deterministic.
712
- latents (`torch.FloatTensor`, *optional*):
713
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
714
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
715
- tensor will ge generated by sampling using the supplied random `generator`.
716
- prompt_embeds (`torch.FloatTensor`, *optional*):
717
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
718
- provided, text embeddings will be generated from `prompt` input argument.
719
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
720
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
721
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
722
- argument.
723
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
724
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
725
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
726
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
727
- output_type (`str`, *optional*, defaults to `"pil"`):
728
- The output format of the generate image. Choose between
729
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
730
- return_dict (`bool`, *optional*, defaults to `True`):
731
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
732
- callback (`Callable`, *optional*):
733
- A function that will be called every `callback_steps` steps during inference. The function will be
734
- callback_steps (`int`, *optional*, defaults to 1):
735
- The frequency at which the `callback` function will be called. If not specified, the callback will be
736
- called at every step.
737
- cross_attention_kwargs (`dict`, *optional*):
738
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
739
- `self.processor` in
740
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
741
- guidance_rescale (`float`, *optional*, defaults to 0.7):
742
- Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
743
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
744
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
745
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
746
- original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
747
- TODO
748
- crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
749
- TODO
750
- target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
751
- TODO
752
-
753
- Examples:
754
-
755
- Returns:
756
- [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
757
- [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
758
- `tuple. When returning a tuple, the first element is a list with the generated images, and the second
759
- element is a list of `bool`s denoting whether the corresponding generated image likely represents
760
- "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
761
- """
762
- # 0. Default height and width to unet
763
- height = height or self.default_sample_size * self.vae_scale_factor
764
- width = width or self.default_sample_size * self.vae_scale_factor
765
-
766
- original_size = original_size or (height, width)
767
- target_size = target_size or (height, width)
768
-
769
- # 1. Check inputs. Raise error if not correct
770
- self.check_inputs(
771
- prompt,
772
- height,
773
- width,
774
- callback_steps,
775
- negative_prompt,
776
- prompt_embeds,
777
- negative_prompt_embeds,
778
- pooled_prompt_embeds,
779
- negative_pooled_prompt_embeds,
780
- )
781
-
782
- # 2. Define call parameters
783
- if prompt is not None and isinstance(prompt, str):
784
- batch_size = 1
785
- elif prompt is not None and isinstance(prompt, list):
786
- batch_size = len(prompt)
787
- else:
788
- batch_size = prompt_embeds.shape[0]
789
-
790
- device = self._execution_device
791
-
792
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
793
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
794
- # corresponds to doing no classifier free guidance.
795
- do_classifier_free_guidance = guidance_scale > 1.0
796
-
797
- # 3. Encode input prompt
798
- text_encoder_lora_scale = (
799
- cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
800
- )
801
- (
802
- prompt_embeds,
803
- negative_prompt_embeds,
804
- pooled_prompt_embeds,
805
- negative_pooled_prompt_embeds,
806
- ) = self.encode_prompt(
807
- prompt,
808
- device,
809
- num_images_per_prompt,
810
- do_classifier_free_guidance,
811
- negative_prompt,
812
- prompt_embeds=prompt_embeds,
813
- negative_prompt_embeds=negative_prompt_embeds,
814
- pooled_prompt_embeds=pooled_prompt_embeds,
815
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
816
- lora_scale=text_encoder_lora_scale,
817
- )
818
-
819
- # 4. Prepare timesteps
820
- self.scheduler.set_timesteps(num_inference_steps, device=device)
821
-
822
- timesteps = self.scheduler.timesteps
823
-
824
- # 5. Prepare latent variables
825
- num_channels_latents = self.unet.config.in_channels
826
- latents = self.prepare_latents(
827
- batch_size * num_images_per_prompt,
828
- num_channels_latents,
829
- height,
830
- width,
831
- prompt_embeds.dtype,
832
- device,
833
- generator,
834
- latents,
835
- )
836
-
837
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
838
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
839
-
840
- # 7. Prepare added time ids & embeddings
841
- add_text_embeds = pooled_prompt_embeds
842
- add_time_ids = self._get_add_time_ids(
843
- original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
844
- )
845
-
846
- if do_classifier_free_guidance:
847
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
848
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
849
- add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
850
-
851
- prompt_embeds = prompt_embeds.to(device)
852
- add_text_embeds = add_text_embeds.to(device)
853
- add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
854
-
855
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
856
- image_embeds = self.prepare_ip_adapter_image_embeds(
857
- ip_adapter_image,
858
- ip_adapter_image_embeds,
859
- device,
860
- batch_size * num_images_per_prompt,
861
- do_classifier_free_guidance,
862
- )
863
-
864
- # 8. Denoising loop
865
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
866
-
867
- # 7.1 Apply denoising_end
868
- if denoising_end is not None:
869
- num_inference_steps = int(round(denoising_end * num_inference_steps))
870
- timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps]
871
-
872
- with self.progress_bar(total=num_inference_steps) as progress_bar:
873
- for i, t in enumerate(timesteps):
874
- # expand the latents if we are doing classifier free guidance
875
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
876
-
877
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
878
-
879
- # predict the noise residual
880
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
881
-
882
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
883
- added_cond_kwargs["image_embeds"] = image_embeds
884
-
885
- # import pdb; pdb.set_trace()
886
-
887
- noise_pred = self.unet(
888
- latent_model_input,
889
- t,
890
- encoder_hidden_states=prompt_embeds,
891
- cross_attention_kwargs=cross_attention_kwargs,
892
- added_cond_kwargs=added_cond_kwargs,
893
- return_dict=False,
894
- )[0]
895
-
896
- # perform guidance
897
- if do_classifier_free_guidance:
898
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
899
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
900
- if use_dynamic_threshold:
901
- DynamicThresh = DynThresh(maxSteps=num_inference_steps, experiment_mode=0)
902
- noise_pred = DynamicThresh.dynthresh(noise_pred_text,
903
- noise_pred_uncond,
904
- guidance_scale,
905
- None)
906
-
907
- if do_classifier_free_guidance and guidance_rescale > 0.0:
908
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
909
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
910
-
911
- # compute the previous noisy sample x_t -> x_t-1
912
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
913
-
914
- # call the callback, if provided
915
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
916
- progress_bar.update()
917
- if callback is not None and i % callback_steps == 0:
918
- callback(i, t, latents)
919
-
920
- # make sureo the VAE is in float32 mode, as it overflows in float16
921
- # torch.cuda.empty_cache()
922
- if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
923
- self.upcast_vae()
924
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
925
-
926
-
927
- if not output_type == "latent":
928
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
929
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
930
- else:
931
- image = latents
932
- return StableDiffusionXLPipelineOutput(images=image)
933
-
934
- # image = self.watermark.apply_watermark(image)
935
- image = self.image_processor.postprocess(image, output_type=output_type)
936
-
937
- # Offload last model to CPU
938
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
939
- self.final_offload_hook.offload()
940
-
941
- if not return_dict:
942
- return (image,)
943
-
944
- return StableDiffusionXLPipelineOutput(images=image)
945
-
946
-
947
- if __name__ == "__main__":
948
- pass