hyoungwoncho commited on
Commit
017c6d6
1 Parent(s): 1090980

initial commit of StableDiffusionPAGPipeline

Files changed (2) hide show
  1. README.md +35 -0
  2. pipeline_stable_diffusion_pag.py +1510 -0
README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Perturbed-Attention Guidance
2
+
3
+ This repository is based on [Diffusers](https://huggingface.co/docs/diffusers/index). StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).
4
+
5
+ ## Quickstart
6
+
7
+ Load StableDiffusionPAGPipeline as below:
8
+
9
+ ```
10
+ pipe = StableDiffusionPipeline.from_pretrained(
11
+ "runwayml/stable-diffusion-v1-5",
12
+ custom_pipeline="hyoungwoncho/sd_perturbed_attention_guidance",
13
+ torch_dtype=torch.float16,
14
+ safety_checker=None
15
+ )
16
+ ```
17
+
18
+ Sampling:
19
+
20
+ ```
21
+ output_baseline = pipe(
22
+ prompts,
23
+ width=512,
24
+ height=512,
25
+ num_inference_steps=50,
26
+ guidance_scale=0.0,
27
+ pag_scale=4.5,
28
+ pag_applied_layers_index=pag_applied_layers_index
29
+ ).images
30
+ ```
31
+
32
+ Parameters:
33
+
34
+ pag_scale : gudiance scale of PAG (ex: 4.5)
35
+ pag_applied_layers_index = index of the layer to apply perturbation (ex: ['m0'])
pipeline_stable_diffusion_pag.py ADDED
@@ -0,0 +1,1510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation of StableDiffusionPAGPipeline
2
+
3
+ import inspect
4
+ from typing import Any, Callable, Dict, List, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from packaging import version
9
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
10
+
11
+ from diffusers.src.diffusers.configuration_utils import FrozenDict
12
+ from diffusers.src.diffusers.image_processor import PipelineImageInput, VaeImageProcessor
13
+ from diffusers.src.diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
14
+ from diffusers.src.diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
15
+ from diffusers.src.diffusers.models.attention_processor import FusedAttnProcessor2_0
16
+ from diffusers.src.diffusers.models.lora import adjust_lora_scale_text_encoder
17
+ from diffusers.src.diffusers.schedulers import KarrasDiffusionSchedulers
18
+ from diffusers.src.diffusers.utils import (
19
+ USE_PEFT_BACKEND,
20
+ deprecate,
21
+ logging,
22
+ replace_example_docstring,
23
+ scale_lora_layers,
24
+ unscale_lora_layers,
25
+ )
26
+ from diffusers.src.diffusers.utils.torch_utils import randn_tensor
27
+ from diffusers.src.diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.src.diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
29
+ from diffusers.src.diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
30
+
31
+ from diffusers.src.diffusers.models.attention_processor import Attention, AttnProcessor2_0
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+ EXAMPLE_DOC_STRING = """
37
+ Examples:
38
+ ```py
39
+ >>> import torch
40
+ >>> from diffusers import StableDiffusionPipeline
41
+
42
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
43
+ >>> pipe = pipe.to("cuda")
44
+
45
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
46
+ >>> image = pipe(prompt).images[0]
47
+ ```
48
+ """
49
+
50
+
51
+ class PAGIdentitySelfAttnProcessor:
52
+ r"""
53
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
54
+ """
55
+
56
+ def __init__(self):
57
+ if not hasattr(F, "scaled_dot_product_attention"):
58
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
59
+
60
+ def __call__(
61
+ self,
62
+ attn: Attention,
63
+ hidden_states: torch.FloatTensor,
64
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
65
+ attention_mask: Optional[torch.FloatTensor] = None,
66
+ temb: Optional[torch.FloatTensor] = None,
67
+ scale: float = 1.0,
68
+ ) -> torch.FloatTensor:
69
+
70
+ residual = hidden_states
71
+ if attn.spatial_norm is not None:
72
+ hidden_states = attn.spatial_norm(hidden_states, temb)
73
+
74
+ input_ndim = hidden_states.ndim
75
+ if input_ndim == 4:
76
+ batch_size, channel, height, width = hidden_states.shape
77
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
78
+
79
+ # chunk
80
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
81
+
82
+ # original path
83
+ batch_size, sequence_length, _ = hidden_states_org.shape
84
+
85
+ if attention_mask is not None:
86
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
87
+ # scaled_dot_product_attention expects attention_mask shape to be
88
+ # (batch, heads, source_length, target_length)
89
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
90
+
91
+ if attn.group_norm is not None:
92
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
93
+
94
+ args = () if USE_PEFT_BACKEND else (scale,)
95
+
96
+ query = attn.to_q(hidden_states_org, *args)
97
+ key = attn.to_k(hidden_states_org, *args)
98
+ value = attn.to_v(hidden_states_org, *args)
99
+
100
+ inner_dim = key.shape[-1]
101
+ head_dim = inner_dim // attn.heads
102
+
103
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
104
+
105
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
106
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
107
+
108
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
109
+ # TODO: add support for attn.scale when we move to Torch 2.1
110
+ hidden_states_org = F.scaled_dot_product_attention(
111
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
112
+ )
113
+
114
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
115
+ hidden_states_org = hidden_states_org.to(query.dtype)
116
+
117
+ # linear proj
118
+ hidden_states_org = attn.to_out[0](hidden_states_org, *args)
119
+ # dropout
120
+ hidden_states_org = attn.to_out[1](hidden_states_org)
121
+
122
+ if input_ndim == 4:
123
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
124
+
125
+ # perturbed path (identity attention)
126
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
127
+
128
+ if attention_mask is not None:
129
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
130
+ # scaled_dot_product_attention expects attention_mask shape to be
131
+ # (batch, heads, source_length, target_length)
132
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
133
+
134
+ if attn.group_norm is not None:
135
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
136
+
137
+ args = () if USE_PEFT_BACKEND else (scale,)
138
+
139
+ value = attn.to_v(hidden_states_ptb, *args)
140
+
141
+ hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
142
+ #hidden_states_ptb = value
143
+
144
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
145
+
146
+ # linear proj
147
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb, *args)
148
+ # dropout
149
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
150
+
151
+ if input_ndim == 4:
152
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
153
+
154
+ # cat
155
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
156
+
157
+ if attn.residual_connection:
158
+ hidden_states = hidden_states + residual
159
+
160
+ hidden_states = hidden_states / attn.rescale_output_factor
161
+
162
+ return hidden_states
163
+
164
+
165
+ class PAGCFGIdentitySelfAttnProcessor:
166
+ r"""
167
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
168
+ """
169
+
170
+ def __init__(self):
171
+ if not hasattr(F, "scaled_dot_product_attention"):
172
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
173
+
174
+ def __call__(
175
+ self,
176
+ attn: Attention,
177
+ hidden_states: torch.FloatTensor,
178
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
179
+ attention_mask: Optional[torch.FloatTensor] = None,
180
+ temb: Optional[torch.FloatTensor] = None,
181
+ scale: float = 1.0,
182
+ ) -> torch.FloatTensor:
183
+
184
+ residual = hidden_states
185
+ if attn.spatial_norm is not None:
186
+ hidden_states = attn.spatial_norm(hidden_states, temb)
187
+
188
+ input_ndim = hidden_states.ndim
189
+ if input_ndim == 4:
190
+ batch_size, channel, height, width = hidden_states.shape
191
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
192
+
193
+ # chunk
194
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
195
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
196
+
197
+ # original path
198
+ batch_size, sequence_length, _ = hidden_states_org.shape
199
+
200
+ if attention_mask is not None:
201
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
202
+ # scaled_dot_product_attention expects attention_mask shape to be
203
+ # (batch, heads, source_length, target_length)
204
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
205
+
206
+ if attn.group_norm is not None:
207
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
208
+
209
+ args = () if USE_PEFT_BACKEND else (scale,)
210
+
211
+ query = attn.to_q(hidden_states_org, *args)
212
+ key = attn.to_k(hidden_states_org, *args)
213
+ value = attn.to_v(hidden_states_org, *args)
214
+
215
+ inner_dim = key.shape[-1]
216
+ head_dim = inner_dim // attn.heads
217
+
218
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
219
+
220
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
221
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
222
+
223
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
224
+ # TODO: add support for attn.scale when we move to Torch 2.1
225
+ hidden_states_org = F.scaled_dot_product_attention(
226
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
227
+ )
228
+
229
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
230
+ hidden_states_org = hidden_states_org.to(query.dtype)
231
+
232
+ # linear proj
233
+ hidden_states_org = attn.to_out[0](hidden_states_org, *args)
234
+ # dropout
235
+ hidden_states_org = attn.to_out[1](hidden_states_org)
236
+
237
+ if input_ndim == 4:
238
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
239
+
240
+ # perturbed path (identity attention)
241
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
242
+
243
+ if attention_mask is not None:
244
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
245
+ # scaled_dot_product_attention expects attention_mask shape to be
246
+ # (batch, heads, source_length, target_length)
247
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
248
+
249
+ if attn.group_norm is not None:
250
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
251
+
252
+ args = () if USE_PEFT_BACKEND else (scale,)
253
+
254
+ value = attn.to_v(hidden_states_ptb, *args)
255
+ hidden_states_ptb = value
256
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
257
+
258
+ # linear proj
259
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb, *args)
260
+ # dropout
261
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
262
+
263
+ if input_ndim == 4:
264
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
265
+
266
+ # cat
267
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
268
+
269
+ if attn.residual_connection:
270
+ hidden_states = hidden_states + residual
271
+
272
+ hidden_states = hidden_states / attn.rescale_output_factor
273
+
274
+ return hidden_states
275
+
276
+
277
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
278
+ """
279
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
280
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
281
+ """
282
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
283
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
284
+ # rescale the results from guidance (fixes overexposure)
285
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
286
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
287
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
288
+ return noise_cfg
289
+
290
+
291
+ def retrieve_timesteps(
292
+ scheduler,
293
+ num_inference_steps: Optional[int] = None,
294
+ device: Optional[Union[str, torch.device]] = None,
295
+ timesteps: Optional[List[int]] = None,
296
+ **kwargs,
297
+ ):
298
+ """
299
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
300
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
301
+
302
+ Args:
303
+ scheduler (`SchedulerMixin`):
304
+ The scheduler to get timesteps from.
305
+ num_inference_steps (`int`):
306
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
307
+ `timesteps` must be `None`.
308
+ device (`str` or `torch.device`, *optional*):
309
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
310
+ timesteps (`List[int]`, *optional*):
311
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
312
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
313
+ must be `None`.
314
+
315
+ Returns:
316
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
317
+ second element is the number of inference steps.
318
+ """
319
+ if timesteps is not None:
320
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
321
+ if not accepts_timesteps:
322
+ raise ValueError(
323
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
324
+ f" timestep schedules. Please check whether you are using the correct scheduler."
325
+ )
326
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
327
+ timesteps = scheduler.timesteps
328
+ num_inference_steps = len(timesteps)
329
+ else:
330
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
331
+ timesteps = scheduler.timesteps
332
+ return timesteps, num_inference_steps
333
+
334
+
335
+ class StableDiffusionPAGPipeline(
336
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
337
+ ):
338
+ r"""
339
+ Pipeline for text-to-image generation using Stable Diffusion.
340
+
341
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
342
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
343
+
344
+ The pipeline also inherits the following loading methods:
345
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
346
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
347
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
348
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
349
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
350
+
351
+ Args:
352
+ vae ([`AutoencoderKL`]):
353
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
354
+ text_encoder ([`~transformers.CLIPTextModel`]):
355
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
356
+ tokenizer ([`~transformers.CLIPTokenizer`]):
357
+ A `CLIPTokenizer` to tokenize text.
358
+ unet ([`UNet2DConditionModel`]):
359
+ A `UNet2DConditionModel` to denoise the encoded image latents.
360
+ scheduler ([`SchedulerMixin`]):
361
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
362
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
363
+ safety_checker ([`StableDiffusionSafetyChecker`]):
364
+ Classification module that estimates whether generated images could be considered offensive or harmful.
365
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
366
+ about a model's potential harms.
367
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
368
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
369
+ """
370
+
371
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
372
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
373
+ _exclude_from_cpu_offload = ["safety_checker"]
374
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
375
+
376
+ def __init__(
377
+ self,
378
+ vae: AutoencoderKL,
379
+ text_encoder: CLIPTextModel,
380
+ tokenizer: CLIPTokenizer,
381
+ unet: UNet2DConditionModel,
382
+ scheduler: KarrasDiffusionSchedulers,
383
+ safety_checker: StableDiffusionSafetyChecker,
384
+ feature_extractor: CLIPImageProcessor,
385
+ image_encoder: CLIPVisionModelWithProjection = None,
386
+ requires_safety_checker: bool = True,
387
+ ):
388
+ super().__init__()
389
+
390
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
391
+ deprecation_message = (
392
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
393
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
394
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
395
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
396
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
397
+ " file"
398
+ )
399
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
400
+ new_config = dict(scheduler.config)
401
+ new_config["steps_offset"] = 1
402
+ scheduler._internal_dict = FrozenDict(new_config)
403
+
404
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
405
+ deprecation_message = (
406
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
407
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
408
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
409
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
410
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
411
+ )
412
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
413
+ new_config = dict(scheduler.config)
414
+ new_config["clip_sample"] = False
415
+ scheduler._internal_dict = FrozenDict(new_config)
416
+
417
+ if safety_checker is None and requires_safety_checker:
418
+ logger.warning(
419
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
420
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
421
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
422
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
423
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
424
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
425
+ )
426
+
427
+ if safety_checker is not None and feature_extractor is None:
428
+ raise ValueError(
429
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
430
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
431
+ )
432
+
433
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
434
+ version.parse(unet.config._diffusers_version).base_version
435
+ ) < version.parse("0.9.0.dev0")
436
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
437
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
438
+ deprecation_message = (
439
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
440
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
441
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
442
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
443
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
444
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
445
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
446
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
447
+ " the `unet/config.json` file"
448
+ )
449
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
450
+ new_config = dict(unet.config)
451
+ new_config["sample_size"] = 64
452
+ unet._internal_dict = FrozenDict(new_config)
453
+
454
+ self.register_modules(
455
+ vae=vae,
456
+ text_encoder=text_encoder,
457
+ tokenizer=tokenizer,
458
+ unet=unet,
459
+ scheduler=scheduler,
460
+ safety_checker=safety_checker,
461
+ feature_extractor=feature_extractor,
462
+ image_encoder=image_encoder,
463
+ )
464
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
465
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
466
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
467
+
468
+ def enable_vae_slicing(self):
469
+ r"""
470
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
471
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
472
+ """
473
+ self.vae.enable_slicing()
474
+
475
+ def disable_vae_slicing(self):
476
+ r"""
477
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
478
+ computing decoding in one step.
479
+ """
480
+ self.vae.disable_slicing()
481
+
482
+ def enable_vae_tiling(self):
483
+ r"""
484
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
485
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
486
+ processing larger images.
487
+ """
488
+ self.vae.enable_tiling()
489
+
490
+ def disable_vae_tiling(self):
491
+ r"""
492
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
493
+ computing decoding in one step.
494
+ """
495
+ self.vae.disable_tiling()
496
+
497
+ def _encode_prompt(
498
+ self,
499
+ prompt,
500
+ device,
501
+ num_images_per_prompt,
502
+ do_classifier_free_guidance,
503
+ negative_prompt=None,
504
+ prompt_embeds: Optional[torch.FloatTensor] = None,
505
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
506
+ lora_scale: Optional[float] = None,
507
+ **kwargs,
508
+ ):
509
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
510
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
511
+
512
+ prompt_embeds_tuple = self.encode_prompt(
513
+ prompt=prompt,
514
+ device=device,
515
+ num_images_per_prompt=num_images_per_prompt,
516
+ do_classifier_free_guidance=do_classifier_free_guidance,
517
+ negative_prompt=negative_prompt,
518
+ prompt_embeds=prompt_embeds,
519
+ negative_prompt_embeds=negative_prompt_embeds,
520
+ lora_scale=lora_scale,
521
+ **kwargs,
522
+ )
523
+
524
+ # concatenate for backwards comp
525
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
526
+
527
+ return prompt_embeds
528
+
529
+ def encode_prompt(
530
+ self,
531
+ prompt,
532
+ device,
533
+ num_images_per_prompt,
534
+ do_classifier_free_guidance,
535
+ negative_prompt=None,
536
+ prompt_embeds: Optional[torch.FloatTensor] = None,
537
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
538
+ lora_scale: Optional[float] = None,
539
+ clip_skip: Optional[int] = None,
540
+ ):
541
+ r"""
542
+ Encodes the prompt into text encoder hidden states.
543
+
544
+ Args:
545
+ prompt (`str` or `List[str]`, *optional*):
546
+ prompt to be encoded
547
+ device: (`torch.device`):
548
+ torch device
549
+ num_images_per_prompt (`int`):
550
+ number of images that should be generated per prompt
551
+ do_classifier_free_guidance (`bool`):
552
+ whether to use classifier free guidance or not
553
+ negative_prompt (`str` or `List[str]`, *optional*):
554
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
555
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
556
+ less than `1`).
557
+ prompt_embeds (`torch.FloatTensor`, *optional*):
558
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
559
+ provided, text embeddings will be generated from `prompt` input argument.
560
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
561
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
562
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
563
+ argument.
564
+ lora_scale (`float`, *optional*):
565
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
566
+ clip_skip (`int`, *optional*):
567
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
568
+ the output of the pre-final layer will be used for computing the prompt embeddings.
569
+ """
570
+ # set lora scale so that monkey patched LoRA
571
+ # function of text encoder can correctly access it
572
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
573
+ self._lora_scale = lora_scale
574
+
575
+ # dynamically adjust the LoRA scale
576
+ if not USE_PEFT_BACKEND:
577
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
578
+ else:
579
+ scale_lora_layers(self.text_encoder, lora_scale)
580
+
581
+ if prompt is not None and isinstance(prompt, str):
582
+ batch_size = 1
583
+ elif prompt is not None and isinstance(prompt, list):
584
+ batch_size = len(prompt)
585
+ else:
586
+ batch_size = prompt_embeds.shape[0]
587
+
588
+ if prompt_embeds is None:
589
+ # textual inversion: process multi-vector tokens if necessary
590
+ if isinstance(self, TextualInversionLoaderMixin):
591
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
592
+
593
+ text_inputs = self.tokenizer(
594
+ prompt,
595
+ padding="max_length",
596
+ max_length=self.tokenizer.model_max_length,
597
+ truncation=True,
598
+ return_tensors="pt",
599
+ )
600
+ text_input_ids = text_inputs.input_ids
601
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
602
+
603
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
604
+ text_input_ids, untruncated_ids
605
+ ):
606
+ removed_text = self.tokenizer.batch_decode(
607
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
608
+ )
609
+ logger.warning(
610
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
611
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
612
+ )
613
+
614
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
615
+ attention_mask = text_inputs.attention_mask.to(device)
616
+ else:
617
+ attention_mask = None
618
+
619
+ if clip_skip is None:
620
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
621
+ prompt_embeds = prompt_embeds[0]
622
+ else:
623
+ prompt_embeds = self.text_encoder(
624
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
625
+ )
626
+ # Access the `hidden_states` first, that contains a tuple of
627
+ # all the hidden states from the encoder layers. Then index into
628
+ # the tuple to access the hidden states from the desired layer.
629
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
630
+ # We also need to apply the final LayerNorm here to not mess with the
631
+ # representations. The `last_hidden_states` that we typically use for
632
+ # obtaining the final prompt representations passes through the LayerNorm
633
+ # layer.
634
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
635
+
636
+ if self.text_encoder is not None:
637
+ prompt_embeds_dtype = self.text_encoder.dtype
638
+ elif self.unet is not None:
639
+ prompt_embeds_dtype = self.unet.dtype
640
+ else:
641
+ prompt_embeds_dtype = prompt_embeds.dtype
642
+
643
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
644
+
645
+ bs_embed, seq_len, _ = prompt_embeds.shape
646
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
647
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
648
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
649
+
650
+ # get unconditional embeddings for classifier free guidance
651
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
652
+ uncond_tokens: List[str]
653
+ if negative_prompt is None:
654
+ uncond_tokens = [""] * batch_size
655
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
656
+ raise TypeError(
657
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
658
+ f" {type(prompt)}."
659
+ )
660
+ elif isinstance(negative_prompt, str):
661
+ uncond_tokens = [negative_prompt]
662
+ elif batch_size != len(negative_prompt):
663
+ raise ValueError(
664
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
665
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
666
+ " the batch size of `prompt`."
667
+ )
668
+ else:
669
+ uncond_tokens = negative_prompt
670
+
671
+ # textual inversion: process multi-vector tokens if necessary
672
+ if isinstance(self, TextualInversionLoaderMixin):
673
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
674
+
675
+ max_length = prompt_embeds.shape[1]
676
+ uncond_input = self.tokenizer(
677
+ uncond_tokens,
678
+ padding="max_length",
679
+ max_length=max_length,
680
+ truncation=True,
681
+ return_tensors="pt",
682
+ )
683
+
684
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
685
+ attention_mask = uncond_input.attention_mask.to(device)
686
+ else:
687
+ attention_mask = None
688
+
689
+ negative_prompt_embeds = self.text_encoder(
690
+ uncond_input.input_ids.to(device),
691
+ attention_mask=attention_mask,
692
+ )
693
+ negative_prompt_embeds = negative_prompt_embeds[0]
694
+
695
+ if do_classifier_free_guidance:
696
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
697
+ seq_len = negative_prompt_embeds.shape[1]
698
+
699
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
700
+
701
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
702
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
703
+
704
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
705
+ # Retrieve the original scale by scaling back the LoRA layers
706
+ unscale_lora_layers(self.text_encoder, lora_scale)
707
+
708
+ return prompt_embeds, negative_prompt_embeds
709
+
710
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
711
+ dtype = next(self.image_encoder.parameters()).dtype
712
+
713
+ if not isinstance(image, torch.Tensor):
714
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
715
+
716
+ image = image.to(device=device, dtype=dtype)
717
+ if output_hidden_states:
718
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
719
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
720
+ uncond_image_enc_hidden_states = self.image_encoder(
721
+ torch.zeros_like(image), output_hidden_states=True
722
+ ).hidden_states[-2]
723
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
724
+ num_images_per_prompt, dim=0
725
+ )
726
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
727
+ else:
728
+ image_embeds = self.image_encoder(image).image_embeds
729
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
730
+ uncond_image_embeds = torch.zeros_like(image_embeds)
731
+
732
+ return image_embeds, uncond_image_embeds
733
+
734
+ def prepare_ip_adapter_image_embeds(
735
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
736
+ ):
737
+ if ip_adapter_image_embeds is None:
738
+ if not isinstance(ip_adapter_image, list):
739
+ ip_adapter_image = [ip_adapter_image]
740
+
741
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
742
+ raise ValueError(
743
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
744
+ )
745
+
746
+ image_embeds = []
747
+ for single_ip_adapter_image, image_proj_layer in zip(
748
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
749
+ ):
750
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
751
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
752
+ single_ip_adapter_image, device, 1, output_hidden_state
753
+ )
754
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
755
+ single_negative_image_embeds = torch.stack(
756
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
757
+ )
758
+
759
+ if self.do_classifier_free_guidance:
760
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
761
+ single_image_embeds = single_image_embeds.to(device)
762
+
763
+ image_embeds.append(single_image_embeds)
764
+ else:
765
+ image_embeds = ip_adapter_image_embeds
766
+ return image_embeds
767
+
768
+ def run_safety_checker(self, image, device, dtype):
769
+ if self.safety_checker is None:
770
+ has_nsfw_concept = None
771
+ else:
772
+ if torch.is_tensor(image):
773
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
774
+ else:
775
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
776
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
777
+ image, has_nsfw_concept = self.safety_checker(
778
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
779
+ )
780
+ return image, has_nsfw_concept
781
+
782
+ def decode_latents(self, latents):
783
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
784
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
785
+
786
+ latents = 1 / self.vae.config.scaling_factor * latents
787
+ image = self.vae.decode(latents, return_dict=False)[0]
788
+ image = (image / 2 + 0.5).clamp(0, 1)
789
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
790
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
791
+ return image
792
+
793
+ def prepare_extra_step_kwargs(self, generator, eta):
794
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
795
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
796
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
797
+ # and should be between [0, 1]
798
+
799
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
800
+ extra_step_kwargs = {}
801
+ if accepts_eta:
802
+ extra_step_kwargs["eta"] = eta
803
+
804
+ # check if the scheduler accepts generator
805
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
806
+ if accepts_generator:
807
+ extra_step_kwargs["generator"] = generator
808
+ return extra_step_kwargs
809
+
810
+ def check_inputs(
811
+ self,
812
+ prompt,
813
+ height,
814
+ width,
815
+ callback_steps,
816
+ negative_prompt=None,
817
+ prompt_embeds=None,
818
+ negative_prompt_embeds=None,
819
+ ip_adapter_image=None,
820
+ ip_adapter_image_embeds=None,
821
+ callback_on_step_end_tensor_inputs=None,
822
+ ):
823
+ if height % 8 != 0 or width % 8 != 0:
824
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
825
+
826
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
827
+ raise ValueError(
828
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
829
+ f" {type(callback_steps)}."
830
+ )
831
+ if callback_on_step_end_tensor_inputs is not None and not all(
832
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
833
+ ):
834
+ raise ValueError(
835
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
836
+ )
837
+
838
+ if prompt is not None and prompt_embeds is not None:
839
+ raise ValueError(
840
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
841
+ " only forward one of the two."
842
+ )
843
+ elif prompt is None and prompt_embeds is None:
844
+ raise ValueError(
845
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
846
+ )
847
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
848
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
849
+
850
+ if negative_prompt is not None and negative_prompt_embeds is not None:
851
+ raise ValueError(
852
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
853
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
854
+ )
855
+
856
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
857
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
858
+ raise ValueError(
859
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
860
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
861
+ f" {negative_prompt_embeds.shape}."
862
+ )
863
+
864
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
865
+ raise ValueError(
866
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
867
+ )
868
+
869
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
870
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
871
+ if isinstance(generator, list) and len(generator) != batch_size:
872
+ raise ValueError(
873
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
874
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
875
+ )
876
+
877
+ if latents is None:
878
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
879
+ else:
880
+ latents = latents.to(device)
881
+
882
+ # scale the initial noise by the standard deviation required by the scheduler
883
+ latents = latents * self.scheduler.init_noise_sigma
884
+ return latents
885
+
886
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
887
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
888
+
889
+ The suffixes after the scaling factors represent the stages where they are being applied.
890
+
891
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
892
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
893
+
894
+ Args:
895
+ s1 (`float`):
896
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
897
+ mitigate "oversmoothing effect" in the enhanced denoising process.
898
+ s2 (`float`):
899
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
900
+ mitigate "oversmoothing effect" in the enhanced denoising process.
901
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
902
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
903
+ """
904
+ if not hasattr(self, "unet"):
905
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
906
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
907
+
908
+ def disable_freeu(self):
909
+ """Disables the FreeU mechanism if enabled."""
910
+ self.unet.disable_freeu()
911
+
912
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
913
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
914
+ """
915
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
916
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
917
+
918
+ <Tip warning={true}>
919
+
920
+ This API is 🧪 experimental.
921
+
922
+ </Tip>
923
+
924
+ Args:
925
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
926
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
927
+ """
928
+ self.fusing_unet = False
929
+ self.fusing_vae = False
930
+
931
+ if unet:
932
+ self.fusing_unet = True
933
+ self.unet.fuse_qkv_projections()
934
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
935
+
936
+ if vae:
937
+ if not isinstance(self.vae, AutoencoderKL):
938
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
939
+
940
+ self.fusing_vae = True
941
+ self.vae.fuse_qkv_projections()
942
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
943
+
944
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
945
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
946
+ """Disable QKV projection fusion if enabled.
947
+
948
+ <Tip warning={true}>
949
+
950
+ This API is 🧪 experimental.
951
+
952
+ </Tip>
953
+
954
+ Args:
955
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
956
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
957
+
958
+ """
959
+ if unet:
960
+ if not self.fusing_unet:
961
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
962
+ else:
963
+ self.unet.unfuse_qkv_projections()
964
+ self.fusing_unet = False
965
+
966
+ if vae:
967
+ if not self.fusing_vae:
968
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
969
+ else:
970
+ self.vae.unfuse_qkv_projections()
971
+ self.fusing_vae = False
972
+
973
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
974
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
975
+ """
976
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
977
+
978
+ Args:
979
+ timesteps (`torch.Tensor`):
980
+ generate embedding vectors at these timesteps
981
+ embedding_dim (`int`, *optional*, defaults to 512):
982
+ dimension of the embeddings to generate
983
+ dtype:
984
+ data type of the generated embeddings
985
+
986
+ Returns:
987
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
988
+ """
989
+ assert len(w.shape) == 1
990
+ w = w * 1000.0
991
+
992
+ half_dim = embedding_dim // 2
993
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
994
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
995
+ emb = w.to(dtype)[:, None] * emb[None, :]
996
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
997
+ if embedding_dim % 2 == 1: # zero pad
998
+ emb = torch.nn.functional.pad(emb, (0, 1))
999
+ assert emb.shape == (w.shape[0], embedding_dim)
1000
+ return emb
1001
+
1002
+ def pred_z0(self, sample, model_output, timestep):
1003
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device)
1004
+
1005
+ beta_prod_t = 1 - alpha_prod_t
1006
+ if self.scheduler.config.prediction_type == "epsilon":
1007
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
1008
+ elif self.scheduler.config.prediction_type == "sample":
1009
+ pred_original_sample = model_output
1010
+ elif self.scheduler.config.prediction_type == "v_prediction":
1011
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
1012
+ # predict V
1013
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
1014
+ else:
1015
+ raise ValueError(
1016
+ f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
1017
+ " or `v_prediction`"
1018
+ )
1019
+
1020
+ return pred_original_sample
1021
+
1022
+ def pred_x0(self, latents, noise_pred, t, generator, device, prompt_embeds, output_type):
1023
+
1024
+ pred_z0 = self.pred_z0(latents, noise_pred, t)
1025
+ pred_x0 = self.vae.decode(
1026
+ pred_z0 / self.vae.config.scaling_factor,
1027
+ return_dict=False,
1028
+ generator=generator
1029
+ )[0]
1030
+ pred_x0, ____ = self.run_safety_checker(pred_x0, device, prompt_embeds.dtype)
1031
+ do_denormalize = [True] * pred_x0.shape[0]
1032
+ pred_x0 = self.image_processor.postprocess(pred_x0, output_type=output_type, do_denormalize=do_denormalize)
1033
+
1034
+ return pred_x0
1035
+
1036
+ @property
1037
+ def guidance_scale(self):
1038
+ return self._guidance_scale
1039
+
1040
+ @property
1041
+ def guidance_rescale(self):
1042
+ return self._guidance_rescale
1043
+
1044
+ @property
1045
+ def clip_skip(self):
1046
+ return self._clip_skip
1047
+
1048
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1049
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1050
+ # corresponds to doing no classifier free guidance.
1051
+ @property
1052
+ def do_classifier_free_guidance(self):
1053
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1054
+
1055
+ @property
1056
+ def cross_attention_kwargs(self):
1057
+ return self._cross_attention_kwargs
1058
+
1059
+ @property
1060
+ def num_timesteps(self):
1061
+ return self._num_timesteps
1062
+
1063
+ @property
1064
+ def interrupt(self):
1065
+ return self._interrupt
1066
+
1067
+ @property
1068
+ def pag_scale(self):
1069
+ return self._pag_scale
1070
+
1071
+ @property
1072
+ def do_adversarial_guidance(self):
1073
+ return self._pag_scale > 0
1074
+
1075
+ @property
1076
+ def pag_adaptive_scaling(self):
1077
+ return self._pag_adaptive_scaling
1078
+
1079
+ @property
1080
+ def do_pag_adaptive_scaling(self):
1081
+ return self._pag_adaptive_scaling > 0
1082
+
1083
+ @property
1084
+ def pag_drop_rate(self):
1085
+ return self._pag_drop_rate
1086
+
1087
+ @property
1088
+ def pag_applied_layers(self):
1089
+ return self._pag_applied_layers
1090
+
1091
+ @property
1092
+ def pag_applied_layers_index(self):
1093
+ return self._pag_applied_layers_index
1094
+
1095
+
1096
+ @torch.no_grad()
1097
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1098
+ def __call__(
1099
+ self,
1100
+ prompt: Union[str, List[str]] = None,
1101
+ height: Optional[int] = None,
1102
+ width: Optional[int] = None,
1103
+ num_inference_steps: int = 50,
1104
+ timesteps: List[int] = None,
1105
+ guidance_scale: float = 7.5,
1106
+ pag_scale: float = 0.0,
1107
+ pag_adaptive_scaling: float = 0.0,
1108
+ pag_drop_rate: float = 0.5,
1109
+ pag_applied_layers: List[str] = ['down'], #['down', 'mid', 'up']
1110
+ pag_applied_layers_index: List[str] = ['d4'], #['d4', 'd5', 'm0']
1111
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1112
+ num_images_per_prompt: Optional[int] = 1,
1113
+ eta: float = 0.0,
1114
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1115
+ latents: Optional[torch.FloatTensor] = None,
1116
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1117
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1118
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1119
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
1120
+ output_type: Optional[str] = "pil",
1121
+ return_dict: bool = True,
1122
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1123
+ guidance_rescale: float = 0.0,
1124
+ clip_skip: Optional[int] = None,
1125
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1126
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1127
+ **kwargs,
1128
+ ):
1129
+ r"""
1130
+ The call function to the pipeline for generation.
1131
+
1132
+ Args:
1133
+ prompt (`str` or `List[str]`, *optional*):
1134
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
1135
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1136
+ The height in pixels of the generated image.
1137
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1138
+ The width in pixels of the generated image.
1139
+ num_inference_steps (`int`, *optional*, defaults to 50):
1140
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1141
+ expense of slower inference.
1142
+ timesteps (`List[int]`, *optional*):
1143
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1144
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1145
+ passed will be used. Must be in descending order.
1146
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1147
+ A higher guidance scale value encourages the model to generate images closely linked to the text
1148
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
1149
+ negative_prompt (`str` or `List[str]`, *optional*):
1150
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
1151
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
1152
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1153
+ The number of images to generate per prompt.
1154
+ eta (`float`, *optional*, defaults to 0.0):
1155
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
1156
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1157
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1158
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1159
+ generation deterministic.
1160
+ latents (`torch.FloatTensor`, *optional*):
1161
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1162
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1163
+ tensor is generated by sampling using the supplied random `generator`.
1164
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1165
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1166
+ provided, text embeddings are generated from the `prompt` input argument.
1167
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1168
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1169
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1170
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1171
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
1172
+ Pre-generated image embeddings for IP-Adapter. If not
1173
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1174
+ output_type (`str`, *optional*, defaults to `"pil"`):
1175
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1176
+ return_dict (`bool`, *optional*, defaults to `True`):
1177
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1178
+ plain tuple.
1179
+ cross_attention_kwargs (`dict`, *optional*):
1180
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1181
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1182
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
1183
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
1184
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
1185
+ using zero terminal SNR.
1186
+ clip_skip (`int`, *optional*):
1187
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1188
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1189
+ callback_on_step_end (`Callable`, *optional*):
1190
+ A function that calls at the end of each denoising steps during the inference. The function is called
1191
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1192
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1193
+ `callback_on_step_end_tensor_inputs`.
1194
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1195
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1196
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1197
+ `._callback_tensor_inputs` attribute of your pipeline class.
1198
+
1199
+ Examples:
1200
+
1201
+ Returns:
1202
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1203
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1204
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1205
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1206
+ "not-safe-for-work" (nsfw) content.
1207
+ """
1208
+
1209
+ callback = kwargs.pop("callback", None)
1210
+ callback_steps = kwargs.pop("callback_steps", None)
1211
+
1212
+ if callback is not None:
1213
+ deprecate(
1214
+ "callback",
1215
+ "1.0.0",
1216
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1217
+ )
1218
+ if callback_steps is not None:
1219
+ deprecate(
1220
+ "callback_steps",
1221
+ "1.0.0",
1222
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1223
+ )
1224
+
1225
+ # 0. Default height and width to unet
1226
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
1227
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
1228
+ # to deal with lora scaling and other possible forward hooks
1229
+
1230
+ # 1. Check inputs. Raise error if not correct
1231
+ self.check_inputs(
1232
+ prompt,
1233
+ height,
1234
+ width,
1235
+ callback_steps,
1236
+ negative_prompt,
1237
+ prompt_embeds,
1238
+ negative_prompt_embeds,
1239
+ ip_adapter_image,
1240
+ ip_adapter_image_embeds,
1241
+ callback_on_step_end_tensor_inputs,
1242
+ )
1243
+
1244
+ self._guidance_scale = guidance_scale
1245
+ self._guidance_rescale = guidance_rescale
1246
+ self._clip_skip = clip_skip
1247
+ self._cross_attention_kwargs = cross_attention_kwargs
1248
+ self._interrupt = False
1249
+
1250
+ self._pag_scale = pag_scale
1251
+ self._pag_adaptive_scaling = pag_adaptive_scaling
1252
+ self._pag_drop_rate = pag_drop_rate
1253
+ self._pag_applied_layers = pag_applied_layers
1254
+ self._pag_applied_layers_index = pag_applied_layers_index
1255
+
1256
+ # 2. Define call parameters
1257
+ if prompt is not None and isinstance(prompt, str):
1258
+ batch_size = 1
1259
+ elif prompt is not None and isinstance(prompt, list):
1260
+ batch_size = len(prompt)
1261
+ else:
1262
+ batch_size = prompt_embeds.shape[0]
1263
+
1264
+ device = self._execution_device
1265
+
1266
+ # 3. Encode input prompt
1267
+ lora_scale = (
1268
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1269
+ )
1270
+
1271
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1272
+ prompt,
1273
+ device,
1274
+ num_images_per_prompt,
1275
+ self.do_classifier_free_guidance,
1276
+ negative_prompt,
1277
+ prompt_embeds=prompt_embeds,
1278
+ negative_prompt_embeds=negative_prompt_embeds,
1279
+ lora_scale=lora_scale,
1280
+ clip_skip=self.clip_skip,
1281
+ )
1282
+
1283
+ # For classifier free guidance, we need to do two forward passes.
1284
+ # Here we concatenate the unconditional and text embeddings into a single batch
1285
+ # to avoid doing two forward passes
1286
+
1287
+ #cfg
1288
+ if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
1289
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1290
+ #pag
1291
+ elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
1292
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
1293
+ #both
1294
+ elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
1295
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
1296
+
1297
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1298
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1299
+ ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
1300
+ )
1301
+
1302
+ # 4. Prepare timesteps
1303
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1304
+
1305
+ # 5. Prepare latent variables
1306
+ num_channels_latents = self.unet.config.in_channels
1307
+ latents = self.prepare_latents(
1308
+ batch_size * num_images_per_prompt,
1309
+ num_channels_latents,
1310
+ height,
1311
+ width,
1312
+ prompt_embeds.dtype,
1313
+ device,
1314
+ generator,
1315
+ latents,
1316
+ )
1317
+
1318
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1319
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1320
+
1321
+ # 6.1 Add image embeds for IP-Adapter
1322
+ added_cond_kwargs = (
1323
+ {"image_embeds": image_embeds}
1324
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
1325
+ else None
1326
+ )
1327
+
1328
+ # 6.2 Optionally get Guidance Scale Embedding
1329
+ timestep_cond = None
1330
+ if self.unet.config.time_cond_proj_dim is not None:
1331
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1332
+ timestep_cond = self.get_guidance_scale_embedding(
1333
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1334
+ ).to(device=device, dtype=latents.dtype)
1335
+
1336
+ # 7. Denoising loop
1337
+ if self.do_adversarial_guidance:
1338
+ down_layers = []
1339
+ mid_layers = []
1340
+ up_layers = []
1341
+ for name, module in self.unet.named_modules():
1342
+ if 'attn1' in name and 'to' not in name:
1343
+ layer_type = name.split('.')[0].split('_')[0]
1344
+ if layer_type == 'down':
1345
+ down_layers.append(module)
1346
+ elif layer_type == 'mid':
1347
+ mid_layers.append(module)
1348
+ elif layer_type == 'up':
1349
+ up_layers.append(module)
1350
+ else:
1351
+ raise ValueError(f"Invalid layer type: {layer_type}")
1352
+
1353
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1354
+ self._num_timesteps = len(timesteps)
1355
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1356
+ for i, t in enumerate(timesteps):
1357
+ if self.interrupt:
1358
+ continue
1359
+
1360
+ #cfg
1361
+ if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
1362
+ latent_model_input = torch.cat([latents] * 2)
1363
+ #pag
1364
+ elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
1365
+ latent_model_input = torch.cat([latents] * 2)
1366
+ #both
1367
+ elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
1368
+ latent_model_input = torch.cat([latents] * 3)
1369
+ #no
1370
+ else:
1371
+ latent_model_input = latents
1372
+
1373
+ # change attention layer in UNet if use PAG
1374
+ if self.do_adversarial_guidance:
1375
+
1376
+ if self.do_classifier_free_guidance:
1377
+ replace_processor = PAGCFGIdentitySelfAttnProcessor()
1378
+ else:
1379
+ replace_processor = PAGIdentitySelfAttnProcessor()
1380
+
1381
+ drop_layers = self.pag_applied_layers_index
1382
+ for drop_layer in drop_layers:
1383
+ try:
1384
+ if drop_layer[0] == 'd':
1385
+ down_layers[int(drop_layer[1])].processor = replace_processor
1386
+ elif drop_layer[0] == 'm':
1387
+ mid_layers[int(drop_layer[1])].processor = replace_processor
1388
+ elif drop_layer[0] == 'u':
1389
+ up_layers[int(drop_layer[1])].processor = replace_processor
1390
+ else:
1391
+ raise ValueError(f"Invalid layer type: {drop_layer[0]}")
1392
+ except IndexError:
1393
+ raise ValueError(
1394
+ f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1395
+ )
1396
+
1397
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1398
+
1399
+ # predict the noise residual
1400
+ noise_pred = self.unet(
1401
+ latent_model_input,
1402
+ t,
1403
+ encoder_hidden_states=prompt_embeds,
1404
+ timestep_cond=timestep_cond,
1405
+ cross_attention_kwargs=self.cross_attention_kwargs,
1406
+ added_cond_kwargs=added_cond_kwargs,
1407
+ return_dict=False,
1408
+ )[0]
1409
+
1410
+ # perform guidance
1411
+
1412
+ # cfg
1413
+ if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
1414
+
1415
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1416
+
1417
+ delta = noise_pred_text - noise_pred_uncond
1418
+ noise_pred = noise_pred_uncond + self.guidance_scale * delta
1419
+
1420
+ # pag
1421
+ elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
1422
+
1423
+ noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)
1424
+
1425
+ signal_scale = self.pag_scale
1426
+ if self.do_pag_adaptive_scaling:
1427
+ signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t)
1428
+ if signal_scale<0:
1429
+ signal_scale = 0
1430
+
1431
+ noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)
1432
+
1433
+ # both
1434
+ elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
1435
+
1436
+ noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)
1437
+
1438
+ signal_scale = self.pag_scale
1439
+ if self.do_pag_adaptive_scaling:
1440
+ signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t)
1441
+ if signal_scale<0:
1442
+ signal_scale = 0
1443
+
1444
+ noise_pred = noise_pred_text + (self.guidance_scale-1.0) * (noise_pred_text - noise_pred_uncond) + signal_scale * (noise_pred_text - noise_pred_text_perturb)
1445
+
1446
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1447
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1448
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1449
+
1450
+ # compute the previous noisy sample x_t -> x_t-1
1451
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1452
+
1453
+ if callback_on_step_end is not None:
1454
+ callback_kwargs = {}
1455
+ for k in callback_on_step_end_tensor_inputs:
1456
+ callback_kwargs[k] = locals()[k]
1457
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1458
+
1459
+ latents = callback_outputs.pop("latents", latents)
1460
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1461
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1462
+
1463
+ # call the callback, if provided
1464
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1465
+ progress_bar.update()
1466
+ if callback is not None and i % callback_steps == 0:
1467
+ step_idx = i // getattr(self.scheduler, "order", 1)
1468
+ callback(step_idx, t, latents)
1469
+
1470
+ if not output_type == "latent":
1471
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1472
+ 0
1473
+ ]
1474
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1475
+ else:
1476
+ image = latents
1477
+ has_nsfw_concept = None
1478
+
1479
+ if has_nsfw_concept is None:
1480
+ do_denormalize = [True] * image.shape[0]
1481
+ else:
1482
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1483
+
1484
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1485
+
1486
+ # Offload all models
1487
+ self.maybe_free_model_hooks()
1488
+
1489
+ if not return_dict:
1490
+ return (image, has_nsfw_concept)
1491
+
1492
+ # change attention layer in UNet if use PAG
1493
+ if self.do_adversarial_guidance:
1494
+ drop_layers = self.pag_applied_layers_index
1495
+ for drop_layer in drop_layers:
1496
+ try:
1497
+ if drop_layer[0] == 'd':
1498
+ down_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1499
+ elif drop_layer[0] == 'm':
1500
+ mid_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1501
+ elif drop_layer[0] == 'u':
1502
+ up_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1503
+ else:
1504
+ raise ValueError(f"Invalid layer type: {drop_layer[0]}")
1505
+ except IndexError:
1506
+ raise ValueError(
1507
+ f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1508
+ )
1509
+
1510
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)