hyoungwoncho commited on
Commit
0bd0ab1
1 Parent(s): 9044708

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +1168 -0
pipeline.py ADDED
@@ -0,0 +1,1168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation of Stable Diffusion Upscale Pipeline with Perturbed-Attention Guidance
2
+
3
+ import inspect
4
+ import warnings
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import PIL.Image
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
12
+
13
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
14
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
15
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
16
+ from diffusers.models.attention_processor import (
17
+ Attention,
18
+ AttnProcessor2_0,
19
+ LoRAAttnProcessor2_0,
20
+ LoRAXFormersAttnProcessor,
21
+ XFormersAttnProcessor,
22
+ )
23
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
24
+ from diffusers.schedulers import DDPMScheduler, KarrasDiffusionSchedulers
25
+ from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
28
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+ class PAGIdentitySelfAttnProcessor:
34
+ r"""
35
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
36
+ """
37
+
38
+ def __init__(self):
39
+ if not hasattr(F, "scaled_dot_product_attention"):
40
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
41
+
42
+ def __call__(
43
+ self,
44
+ attn: Attention,
45
+ hidden_states: torch.FloatTensor,
46
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
47
+ attention_mask: Optional[torch.FloatTensor] = None,
48
+ temb: Optional[torch.FloatTensor] = None,
49
+ *args,
50
+ **kwargs,
51
+ ) -> torch.FloatTensor:
52
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
53
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
54
+ deprecate("scale", "1.0.0", deprecation_message)
55
+
56
+ residual = hidden_states
57
+ if attn.spatial_norm is not None:
58
+ hidden_states = attn.spatial_norm(hidden_states, temb)
59
+
60
+ input_ndim = hidden_states.ndim
61
+ if input_ndim == 4:
62
+ batch_size, channel, height, width = hidden_states.shape
63
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
64
+
65
+ # chunk
66
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
67
+
68
+ # original path
69
+ batch_size, sequence_length, _ = hidden_states_org.shape
70
+
71
+ if attention_mask is not None:
72
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
73
+ # scaled_dot_product_attention expects attention_mask shape to be
74
+ # (batch, heads, source_length, target_length)
75
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
76
+
77
+ if attn.group_norm is not None:
78
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
79
+
80
+ query = attn.to_q(hidden_states_org)
81
+ key = attn.to_k(hidden_states_org)
82
+ value = attn.to_v(hidden_states_org)
83
+
84
+ inner_dim = key.shape[-1]
85
+ head_dim = inner_dim // attn.heads
86
+
87
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
88
+
89
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
90
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
91
+
92
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
93
+ # TODO: add support for attn.scale when we move to Torch 2.1
94
+ hidden_states_org = F.scaled_dot_product_attention(
95
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
96
+ )
97
+
98
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
99
+ hidden_states_org = hidden_states_org.to(query.dtype)
100
+
101
+ # linear proj
102
+ hidden_states_org = attn.to_out[0](hidden_states_org)
103
+ # dropout
104
+ hidden_states_org = attn.to_out[1](hidden_states_org)
105
+
106
+ if input_ndim == 4:
107
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
108
+
109
+ # perturbed path (identity attention)
110
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
111
+
112
+ if attention_mask is not None:
113
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
114
+ # scaled_dot_product_attention expects attention_mask shape to be
115
+ # (batch, heads, source_length, target_length)
116
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
117
+
118
+ if attn.group_norm is not None:
119
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
120
+
121
+ value = attn.to_v(hidden_states_ptb)
122
+
123
+ # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
124
+ hidden_states_ptb = value
125
+
126
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
127
+
128
+ # linear proj
129
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
130
+ # dropout
131
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
132
+
133
+ if input_ndim == 4:
134
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
135
+
136
+ # cat
137
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
138
+
139
+ if attn.residual_connection:
140
+ hidden_states = hidden_states + residual
141
+
142
+ hidden_states = hidden_states / attn.rescale_output_factor
143
+
144
+ return hidden_states
145
+
146
+
147
+ class PAGCFGIdentitySelfAttnProcessor:
148
+ r"""
149
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
150
+ """
151
+
152
+ def __init__(self):
153
+ if not hasattr(F, "scaled_dot_product_attention"):
154
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
155
+
156
+ def __call__(
157
+ self,
158
+ attn: Attention,
159
+ hidden_states: torch.FloatTensor,
160
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
161
+ attention_mask: Optional[torch.FloatTensor] = None,
162
+ temb: Optional[torch.FloatTensor] = None,
163
+ *args,
164
+ **kwargs,
165
+ ) -> torch.FloatTensor:
166
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
167
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
168
+ deprecate("scale", "1.0.0", deprecation_message)
169
+
170
+ residual = hidden_states
171
+ if attn.spatial_norm is not None:
172
+ hidden_states = attn.spatial_norm(hidden_states, temb)
173
+
174
+ input_ndim = hidden_states.ndim
175
+ if input_ndim == 4:
176
+ batch_size, channel, height, width = hidden_states.shape
177
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
178
+
179
+ # chunk
180
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
181
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
182
+
183
+ # original path
184
+ batch_size, sequence_length, _ = hidden_states_org.shape
185
+
186
+ if attention_mask is not None:
187
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
188
+ # scaled_dot_product_attention expects attention_mask shape to be
189
+ # (batch, heads, source_length, target_length)
190
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
191
+
192
+ if attn.group_norm is not None:
193
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
194
+
195
+ query = attn.to_q(hidden_states_org)
196
+ key = attn.to_k(hidden_states_org)
197
+ value = attn.to_v(hidden_states_org)
198
+
199
+ inner_dim = key.shape[-1]
200
+ head_dim = inner_dim // attn.heads
201
+
202
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
203
+
204
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
205
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
206
+
207
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
208
+ # TODO: add support for attn.scale when we move to Torch 2.1
209
+ hidden_states_org = F.scaled_dot_product_attention(
210
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
211
+ )
212
+
213
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
214
+ hidden_states_org = hidden_states_org.to(query.dtype)
215
+
216
+ # linear proj
217
+ hidden_states_org = attn.to_out[0](hidden_states_org)
218
+ # dropout
219
+ hidden_states_org = attn.to_out[1](hidden_states_org)
220
+
221
+ if input_ndim == 4:
222
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
223
+
224
+ # perturbed path (identity attention)
225
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
226
+
227
+ if attention_mask is not None:
228
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
229
+ # scaled_dot_product_attention expects attention_mask shape to be
230
+ # (batch, heads, source_length, target_length)
231
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
232
+
233
+ if attn.group_norm is not None:
234
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
235
+
236
+ value = attn.to_v(hidden_states_ptb)
237
+ hidden_states_ptb = value
238
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
239
+
240
+ # linear proj
241
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
242
+ # dropout
243
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
244
+
245
+ if input_ndim == 4:
246
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
247
+
248
+ # cat
249
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
250
+
251
+ if attn.residual_connection:
252
+ hidden_states = hidden_states + residual
253
+
254
+ hidden_states = hidden_states / attn.rescale_output_factor
255
+
256
+ return hidden_states
257
+
258
+ def preprocess(image):
259
+ warnings.warn(
260
+ "The preprocess method is deprecated and will be removed in a future version. Please"
261
+ " use VaeImageProcessor.preprocess instead",
262
+ FutureWarning,
263
+ )
264
+ if isinstance(image, torch.Tensor):
265
+ return image
266
+ elif isinstance(image, PIL.Image.Image):
267
+ image = [image]
268
+
269
+ if isinstance(image[0], PIL.Image.Image):
270
+ w, h = image[0].size
271
+ w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
272
+
273
+ image = [np.array(i.resize((w, h)))[None, :] for i in image]
274
+ image = np.concatenate(image, axis=0)
275
+ image = np.array(image).astype(np.float32) / 255.0
276
+ image = image.transpose(0, 3, 1, 2)
277
+ image = 2.0 * image - 1.0
278
+ image = torch.from_numpy(image)
279
+ elif isinstance(image[0], torch.Tensor):
280
+ image = torch.cat(image, dim=0)
281
+ return image
282
+
283
+
284
+ class StableDiffusionUpscalePipeline(
285
+ DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
286
+ ):
287
+ r"""
288
+ Pipeline for text-guided image super-resolution using Stable Diffusion 2.
289
+
290
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
291
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
292
+
293
+ The pipeline also inherits the following loading methods:
294
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
295
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
296
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
297
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
298
+
299
+ Args:
300
+ vae ([`AutoencoderKL`]):
301
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
302
+ text_encoder ([`~transformers.CLIPTextModel`]):
303
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
304
+ tokenizer ([`~transformers.CLIPTokenizer`]):
305
+ A `CLIPTokenizer` to tokenize text.
306
+ unet ([`UNet2DConditionModel`]):
307
+ A `UNet2DConditionModel` to denoise the encoded image latents.
308
+ low_res_scheduler ([`SchedulerMixin`]):
309
+ A scheduler used to add initial noise to the low resolution conditioning image. It must be an instance of
310
+ [`DDPMScheduler`].
311
+ scheduler ([`SchedulerMixin`]):
312
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
313
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
314
+ """
315
+
316
+ model_cpu_offload_seq = "text_encoder->unet->vae"
317
+ _optional_components = ["watermarker", "safety_checker", "feature_extractor"]
318
+ _exclude_from_cpu_offload = ["safety_checker"]
319
+
320
+ def __init__(
321
+ self,
322
+ vae: AutoencoderKL,
323
+ text_encoder: CLIPTextModel,
324
+ tokenizer: CLIPTokenizer,
325
+ unet: UNet2DConditionModel,
326
+ low_res_scheduler: DDPMScheduler,
327
+ scheduler: KarrasDiffusionSchedulers,
328
+ safety_checker: Optional[Any] = None,
329
+ feature_extractor: Optional[CLIPImageProcessor] = None,
330
+ watermarker: Optional[Any] = None,
331
+ max_noise_level: int = 350,
332
+ ):
333
+ super().__init__()
334
+
335
+ if hasattr(
336
+ vae, "config"
337
+ ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
338
+ is_vae_scaling_factor_set_to_0_08333 = (
339
+ hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
340
+ )
341
+ if not is_vae_scaling_factor_set_to_0_08333:
342
+ deprecation_message = (
343
+ "The configuration file of the vae does not contain `scaling_factor` or it is set to"
344
+ f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned"
345
+ " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to"
346
+ " 0.08333 Please make sure to update the config accordingly, as not doing so might lead to"
347
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging"
348
+ " Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file"
349
+ )
350
+ deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False)
351
+ vae.register_to_config(scaling_factor=0.08333)
352
+
353
+ self.register_modules(
354
+ vae=vae,
355
+ text_encoder=text_encoder,
356
+ tokenizer=tokenizer,
357
+ unet=unet,
358
+ low_res_scheduler=low_res_scheduler,
359
+ scheduler=scheduler,
360
+ safety_checker=safety_checker,
361
+ watermarker=watermarker,
362
+ feature_extractor=feature_extractor,
363
+ )
364
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
365
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
366
+ self.register_to_config(max_noise_level=max_noise_level)
367
+
368
+ def run_safety_checker(self, image, device, dtype):
369
+ if self.safety_checker is not None:
370
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
371
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
372
+ image, nsfw_detected, watermark_detected = self.safety_checker(
373
+ images=image,
374
+ clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
375
+ )
376
+ else:
377
+ nsfw_detected = None
378
+ watermark_detected = None
379
+
380
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
381
+ self.unet_offload_hook.offload()
382
+
383
+ return image, nsfw_detected, watermark_detected
384
+
385
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
386
+ def _encode_prompt(
387
+ self,
388
+ prompt,
389
+ device,
390
+ num_images_per_prompt,
391
+ do_classifier_free_guidance,
392
+ negative_prompt=None,
393
+ prompt_embeds: Optional[torch.FloatTensor] = None,
394
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
395
+ lora_scale: Optional[float] = None,
396
+ **kwargs,
397
+ ):
398
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
399
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
400
+
401
+ prompt_embeds_tuple = self.encode_prompt(
402
+ prompt=prompt,
403
+ device=device,
404
+ num_images_per_prompt=num_images_per_prompt,
405
+ do_classifier_free_guidance=do_classifier_free_guidance,
406
+ negative_prompt=negative_prompt,
407
+ prompt_embeds=prompt_embeds,
408
+ negative_prompt_embeds=negative_prompt_embeds,
409
+ lora_scale=lora_scale,
410
+ **kwargs,
411
+ )
412
+
413
+ # concatenate for backwards comp
414
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
415
+
416
+ return prompt_embeds
417
+
418
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
419
+ def encode_prompt(
420
+ self,
421
+ prompt,
422
+ device,
423
+ num_images_per_prompt,
424
+ do_classifier_free_guidance,
425
+ negative_prompt=None,
426
+ prompt_embeds: Optional[torch.FloatTensor] = None,
427
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
428
+ lora_scale: Optional[float] = None,
429
+ clip_skip: Optional[int] = None,
430
+ ):
431
+ r"""
432
+ Encodes the prompt into text encoder hidden states.
433
+
434
+ Args:
435
+ prompt (`str` or `List[str]`, *optional*):
436
+ prompt to be encoded
437
+ device: (`torch.device`):
438
+ torch device
439
+ num_images_per_prompt (`int`):
440
+ number of images that should be generated per prompt
441
+ do_classifier_free_guidance (`bool`):
442
+ whether to use classifier free guidance or not
443
+ negative_prompt (`str` or `List[str]`, *optional*):
444
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
445
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
446
+ less than `1`).
447
+ prompt_embeds (`torch.FloatTensor`, *optional*):
448
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
449
+ provided, text embeddings will be generated from `prompt` input argument.
450
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
451
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
452
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
453
+ argument.
454
+ lora_scale (`float`, *optional*):
455
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
456
+ clip_skip (`int`, *optional*):
457
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
458
+ the output of the pre-final layer will be used for computing the prompt embeddings.
459
+ """
460
+ # set lora scale so that monkey patched LoRA
461
+ # function of text encoder can correctly access it
462
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
463
+ self._lora_scale = lora_scale
464
+
465
+ # dynamically adjust the LoRA scale
466
+ if not USE_PEFT_BACKEND:
467
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
468
+ else:
469
+ scale_lora_layers(self.text_encoder, lora_scale)
470
+
471
+ if prompt is not None and isinstance(prompt, str):
472
+ batch_size = 1
473
+ elif prompt is not None and isinstance(prompt, list):
474
+ batch_size = len(prompt)
475
+ else:
476
+ batch_size = prompt_embeds.shape[0]
477
+
478
+ if prompt_embeds is None:
479
+ # textual inversion: process multi-vector tokens if necessary
480
+ if isinstance(self, TextualInversionLoaderMixin):
481
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
482
+
483
+ text_inputs = self.tokenizer(
484
+ prompt,
485
+ padding="max_length",
486
+ max_length=self.tokenizer.model_max_length,
487
+ truncation=True,
488
+ return_tensors="pt",
489
+ )
490
+ text_input_ids = text_inputs.input_ids
491
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
492
+
493
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
494
+ text_input_ids, untruncated_ids
495
+ ):
496
+ removed_text = self.tokenizer.batch_decode(
497
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
498
+ )
499
+ logger.warning(
500
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
501
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
502
+ )
503
+
504
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
505
+ attention_mask = text_inputs.attention_mask.to(device)
506
+ else:
507
+ attention_mask = None
508
+
509
+ if clip_skip is None:
510
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
511
+ prompt_embeds = prompt_embeds[0]
512
+ else:
513
+ prompt_embeds = self.text_encoder(
514
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
515
+ )
516
+ # Access the `hidden_states` first, that contains a tuple of
517
+ # all the hidden states from the encoder layers. Then index into
518
+ # the tuple to access the hidden states from the desired layer.
519
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
520
+ # We also need to apply the final LayerNorm here to not mess with the
521
+ # representations. The `last_hidden_states` that we typically use for
522
+ # obtaining the final prompt representations passes through the LayerNorm
523
+ # layer.
524
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
525
+
526
+ if self.text_encoder is not None:
527
+ prompt_embeds_dtype = self.text_encoder.dtype
528
+ elif self.unet is not None:
529
+ prompt_embeds_dtype = self.unet.dtype
530
+ else:
531
+ prompt_embeds_dtype = prompt_embeds.dtype
532
+
533
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
534
+
535
+ bs_embed, seq_len, _ = prompt_embeds.shape
536
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
537
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
538
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
539
+
540
+ # get unconditional embeddings for classifier free guidance
541
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
542
+ uncond_tokens: List[str]
543
+ if negative_prompt is None:
544
+ uncond_tokens = [""] * batch_size
545
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
546
+ raise TypeError(
547
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
548
+ f" {type(prompt)}."
549
+ )
550
+ elif isinstance(negative_prompt, str):
551
+ uncond_tokens = [negative_prompt]
552
+ elif batch_size != len(negative_prompt):
553
+ raise ValueError(
554
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
555
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
556
+ " the batch size of `prompt`."
557
+ )
558
+ else:
559
+ uncond_tokens = negative_prompt
560
+
561
+ # textual inversion: process multi-vector tokens if necessary
562
+ if isinstance(self, TextualInversionLoaderMixin):
563
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
564
+
565
+ max_length = prompt_embeds.shape[1]
566
+ uncond_input = self.tokenizer(
567
+ uncond_tokens,
568
+ padding="max_length",
569
+ max_length=max_length,
570
+ truncation=True,
571
+ return_tensors="pt",
572
+ )
573
+
574
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
575
+ attention_mask = uncond_input.attention_mask.to(device)
576
+ else:
577
+ attention_mask = None
578
+
579
+ negative_prompt_embeds = self.text_encoder(
580
+ uncond_input.input_ids.to(device),
581
+ attention_mask=attention_mask,
582
+ )
583
+ negative_prompt_embeds = negative_prompt_embeds[0]
584
+
585
+ if do_classifier_free_guidance:
586
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
587
+ seq_len = negative_prompt_embeds.shape[1]
588
+
589
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
590
+
591
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
592
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
593
+
594
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
595
+ # Retrieve the original scale by scaling back the LoRA layers
596
+ unscale_lora_layers(self.text_encoder, lora_scale)
597
+
598
+ return prompt_embeds, negative_prompt_embeds
599
+
600
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
601
+ def prepare_extra_step_kwargs(self, generator, eta):
602
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
603
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
604
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
605
+ # and should be between [0, 1]
606
+
607
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
608
+ extra_step_kwargs = {}
609
+ if accepts_eta:
610
+ extra_step_kwargs["eta"] = eta
611
+
612
+ # check if the scheduler accepts generator
613
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
614
+ if accepts_generator:
615
+ extra_step_kwargs["generator"] = generator
616
+ return extra_step_kwargs
617
+
618
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
619
+ def decode_latents(self, latents):
620
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
621
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
622
+
623
+ latents = 1 / self.vae.config.scaling_factor * latents
624
+ image = self.vae.decode(latents, return_dict=False)[0]
625
+ image = (image / 2 + 0.5).clamp(0, 1)
626
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
627
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
628
+ return image
629
+
630
+ def check_inputs(
631
+ self,
632
+ prompt,
633
+ image,
634
+ noise_level,
635
+ callback_steps,
636
+ negative_prompt=None,
637
+ prompt_embeds=None,
638
+ negative_prompt_embeds=None,
639
+ ):
640
+ if (callback_steps is None) or (
641
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
642
+ ):
643
+ raise ValueError(
644
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
645
+ f" {type(callback_steps)}."
646
+ )
647
+
648
+ if prompt is not None and prompt_embeds is not None:
649
+ raise ValueError(
650
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
651
+ " only forward one of the two."
652
+ )
653
+ elif prompt is None and prompt_embeds is None:
654
+ raise ValueError(
655
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
656
+ )
657
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
658
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
659
+
660
+ if negative_prompt is not None and negative_prompt_embeds is not None:
661
+ raise ValueError(
662
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
663
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
664
+ )
665
+
666
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
667
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
668
+ raise ValueError(
669
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
670
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
671
+ f" {negative_prompt_embeds.shape}."
672
+ )
673
+
674
+ if (
675
+ not isinstance(image, torch.Tensor)
676
+ and not isinstance(image, PIL.Image.Image)
677
+ and not isinstance(image, np.ndarray)
678
+ and not isinstance(image, list)
679
+ ):
680
+ raise ValueError(
681
+ f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}"
682
+ )
683
+
684
+ # verify batch size of prompt and image are same if image is a list or tensor or numpy array
685
+ if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
686
+ if prompt is not None and isinstance(prompt, str):
687
+ batch_size = 1
688
+ elif prompt is not None and isinstance(prompt, list):
689
+ batch_size = len(prompt)
690
+ else:
691
+ batch_size = prompt_embeds.shape[0]
692
+
693
+ if isinstance(image, list):
694
+ image_batch_size = len(image)
695
+ else:
696
+ image_batch_size = image.shape[0]
697
+ if batch_size != image_batch_size:
698
+ raise ValueError(
699
+ f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
700
+ " Please make sure that passed `prompt` matches the batch size of `image`."
701
+ )
702
+
703
+ # check noise level
704
+ if noise_level > self.config.max_noise_level:
705
+ raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}")
706
+
707
+ if (callback_steps is None) or (
708
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
709
+ ):
710
+ raise ValueError(
711
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
712
+ f" {type(callback_steps)}."
713
+ )
714
+
715
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
716
+ shape = (batch_size, num_channels_latents, height, width)
717
+ if latents is None:
718
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
719
+ else:
720
+ if latents.shape != shape:
721
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
722
+ latents = latents.to(device)
723
+
724
+ # scale the initial noise by the standard deviation required by the scheduler
725
+ latents = latents * self.scheduler.init_noise_sigma
726
+ return latents
727
+
728
+ def upcast_vae(self):
729
+ dtype = self.vae.dtype
730
+ self.vae.to(dtype=torch.float32)
731
+ use_torch_2_0_or_xformers = isinstance(
732
+ self.vae.decoder.mid_block.attentions[0].processor,
733
+ (
734
+ AttnProcessor2_0,
735
+ XFormersAttnProcessor,
736
+ LoRAXFormersAttnProcessor,
737
+ LoRAAttnProcessor2_0,
738
+ ),
739
+ )
740
+ # if xformers or torch_2_0 is used attention block does not need
741
+ # to be in float32 which can save lots of memory
742
+ if use_torch_2_0_or_xformers:
743
+ self.vae.post_quant_conv.to(dtype)
744
+ self.vae.decoder.conv_in.to(dtype)
745
+ self.vae.decoder.mid_block.to(dtype)
746
+
747
+ @property
748
+ def guidance_scale(self):
749
+ return self._guidance_scale
750
+
751
+ @property
752
+ def do_classifier_free_guidance(self):
753
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
754
+
755
+ @property
756
+ def pag_scale(self):
757
+ return self._pag_scale
758
+
759
+ @property
760
+ def do_perturbed_attention_guidance(self):
761
+ return self._pag_scale > 0
762
+
763
+ @property
764
+ def pag_adaptive_scaling(self):
765
+ return self._pag_adaptive_scaling
766
+
767
+ @property
768
+ def do_pag_adaptive_scaling(self):
769
+ return self._pag_adaptive_scaling > 0
770
+
771
+ @property
772
+ def pag_applied_layers_index(self):
773
+ return self._pag_applied_layers_index
774
+
775
+ @torch.no_grad()
776
+ def __call__(
777
+ self,
778
+ prompt: Union[str, List[str]] = None,
779
+ image: PipelineImageInput = None,
780
+ num_inference_steps: int = 75,
781
+ guidance_scale: float = 9.0,
782
+ pag_scale: float = 0.0,
783
+ pag_adaptive_scaling: float = 0.0,
784
+ pag_applied_layers_index: List[str] = ["d4"], # ['d4', 'd5', 'm0']
785
+ noise_level: int = 20,
786
+ negative_prompt: Optional[Union[str, List[str]]] = None,
787
+ num_images_per_prompt: Optional[int] = 1,
788
+ eta: float = 0.0,
789
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
790
+ latents: Optional[torch.FloatTensor] = None,
791
+ prompt_embeds: Optional[torch.FloatTensor] = None,
792
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
793
+ output_type: Optional[str] = "pil",
794
+ return_dict: bool = True,
795
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
796
+ callback_steps: int = 1,
797
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
798
+ clip_skip: int = None,
799
+ ):
800
+ r"""
801
+ The call function to the pipeline for generation.
802
+
803
+ Args:
804
+ prompt (`str` or `List[str]`, *optional*):
805
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
806
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
807
+ `Image` or tensor representing an image batch to be upscaled.
808
+ num_inference_steps (`int`, *optional*, defaults to 50):
809
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
810
+ expense of slower inference.
811
+ guidance_scale (`float`, *optional*, defaults to 7.5):
812
+ A higher guidance scale value encourages the model to generate images closely linked to the text
813
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
814
+ negative_prompt (`str` or `List[str]`, *optional*):
815
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
816
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
817
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
818
+ The number of images to generate per prompt.
819
+ eta (`float`, *optional*, defaults to 0.0):
820
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
821
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
822
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
823
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
824
+ generation deterministic.
825
+ latents (`torch.FloatTensor`, *optional*):
826
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
827
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
828
+ tensor is generated by sampling using the supplied random `generator`.
829
+ prompt_embeds (`torch.FloatTensor`, *optional*):
830
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
831
+ provided, text embeddings are generated from the `prompt` input argument.
832
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
833
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
834
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
835
+ output_type (`str`, *optional*, defaults to `"pil"`):
836
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
837
+ return_dict (`bool`, *optional*, defaults to `True`):
838
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
839
+ plain tuple.
840
+ callback (`Callable`, *optional*):
841
+ A function that calls every `callback_steps` steps during inference. The function is called with the
842
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
843
+ callback_steps (`int`, *optional*, defaults to 1):
844
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
845
+ every step.
846
+ cross_attention_kwargs (`dict`, *optional*):
847
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
848
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
849
+ clip_skip (`int`, *optional*):
850
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
851
+ the output of the pre-final layer will be used for computing the prompt embeddings.
852
+ Examples:
853
+ ```py
854
+ >>> import requests
855
+ >>> from PIL import Image
856
+ >>> from io import BytesIO
857
+ >>> from diffusers import StableDiffusionUpscalePipeline
858
+ >>> import torch
859
+
860
+ >>> # load model and scheduler
861
+ >>> model_id = "stabilityai/stable-diffusion-x4-upscaler"
862
+ >>> pipeline = StableDiffusionUpscalePipeline.from_pretrained(
863
+ ... model_id, revision="fp16", torch_dtype=torch.float16
864
+ ... )
865
+ >>> pipeline = pipeline.to("cuda")
866
+
867
+ >>> # let's download an image
868
+ >>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
869
+ >>> response = requests.get(url)
870
+ >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
871
+ >>> low_res_img = low_res_img.resize((128, 128))
872
+ >>> prompt = "a white cat"
873
+
874
+ >>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
875
+ >>> upscaled_image.save("upsampled_cat.png")
876
+ ```
877
+
878
+ Returns:
879
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
880
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
881
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
882
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
883
+ "not-safe-for-work" (nsfw) content.
884
+ """
885
+
886
+ # 1. Check inputs
887
+ self.check_inputs(
888
+ prompt,
889
+ image,
890
+ noise_level,
891
+ callback_steps,
892
+ negative_prompt,
893
+ prompt_embeds,
894
+ negative_prompt_embeds,
895
+ )
896
+
897
+ self._guidance_scale = guidance_scale
898
+
899
+ self._pag_scale = pag_scale
900
+ self._pag_adaptive_scaling = pag_adaptive_scaling
901
+ self._pag_applied_layers_index = pag_applied_layers_index
902
+
903
+ if image is None:
904
+ raise ValueError("`image` input cannot be undefined.")
905
+
906
+ # 2. Define call parameters
907
+ if prompt is not None and isinstance(prompt, str):
908
+ batch_size = 1
909
+ elif prompt is not None and isinstance(prompt, list):
910
+ batch_size = len(prompt)
911
+ else:
912
+ batch_size = prompt_embeds.shape[0]
913
+
914
+ device = self._execution_device
915
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
916
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
917
+ # corresponds to doing no classifier free guidance.
918
+
919
+ # 3. Encode input prompt
920
+ text_encoder_lora_scale = (
921
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
922
+ )
923
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
924
+ prompt,
925
+ device,
926
+ num_images_per_prompt,
927
+ self.do_classifier_free_guidance,
928
+ negative_prompt,
929
+ prompt_embeds=prompt_embeds,
930
+ negative_prompt_embeds=negative_prompt_embeds,
931
+ lora_scale=text_encoder_lora_scale,
932
+ clip_skip=clip_skip,
933
+ )
934
+ # For classifier free guidance, we need to do two forward passes.
935
+ # Here we concatenate the unconditional and text embeddings into a single batch
936
+ # to avoid doing two forward passes
937
+
938
+ # cfg
939
+ if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
940
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
941
+ # pag
942
+ elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
943
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
944
+ # both
945
+ elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
946
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
947
+
948
+ # 4. Preprocess image
949
+ image = self.image_processor.preprocess(image)
950
+ image = image.to(dtype=prompt_embeds.dtype, device=device)
951
+
952
+ # 5. set timesteps
953
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
954
+ timesteps = self.scheduler.timesteps
955
+
956
+ # 5. Add noise to image
957
+ noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
958
+ noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
959
+ image = self.low_res_scheduler.add_noise(image, noise, noise_level)
960
+
961
+ image = torch.cat([image] * num_images_per_prompt)
962
+ noise_level = torch.cat([noise_level] * image.shape[0])
963
+
964
+ # 6. Prepare latent variables
965
+ height, width = image.shape[2:]
966
+ num_channels_latents = self.vae.config.latent_channels
967
+ latents = self.prepare_latents(
968
+ batch_size * num_images_per_prompt,
969
+ num_channels_latents,
970
+ height,
971
+ width,
972
+ prompt_embeds.dtype,
973
+ device,
974
+ generator,
975
+ latents,
976
+ )
977
+
978
+ # 7. Check that sizes of image and latents match
979
+ num_channels_image = image.shape[1]
980
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
981
+ raise ValueError(
982
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
983
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
984
+ f" `num_channels_image`: {num_channels_image} "
985
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
986
+ " `pipeline.unet` or your `image` input."
987
+ )
988
+
989
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
990
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
991
+
992
+ # 9. Denoising loop
993
+ if self.do_perturbed_attention_guidance:
994
+ down_layers = []
995
+ mid_layers = []
996
+ up_layers = []
997
+ for name, module in self.unet.named_modules():
998
+ if "attn1" in name and "to" not in name:
999
+ layer_type = name.split(".")[0].split("_")[0]
1000
+ if layer_type == "down":
1001
+ down_layers.append(module)
1002
+ elif layer_type == "mid":
1003
+ mid_layers.append(module)
1004
+ elif layer_type == "up":
1005
+ up_layers.append(module)
1006
+ else:
1007
+ raise ValueError(f"Invalid layer type: {layer_type}")
1008
+
1009
+ # change attention layer in UNet if use PAG
1010
+ if self.do_perturbed_attention_guidance:
1011
+ if self.do_classifier_free_guidance:
1012
+ replace_processor = PAGCFGIdentitySelfAttnProcessor()
1013
+ else:
1014
+ replace_processor = PAGIdentitySelfAttnProcessor()
1015
+
1016
+ drop_layers = self.pag_applied_layers_index
1017
+ for drop_layer in drop_layers:
1018
+ try:
1019
+ if drop_layer[0] == "d":
1020
+ down_layers[int(drop_layer[1])].processor = replace_processor
1021
+ elif drop_layer[0] == "m":
1022
+ mid_layers[int(drop_layer[1])].processor = replace_processor
1023
+ elif drop_layer[0] == "u":
1024
+ up_layers[int(drop_layer[1])].processor = replace_processor
1025
+ else:
1026
+ raise ValueError(f"Invalid layer type: {drop_layer[0]}")
1027
+ except IndexError:
1028
+ raise ValueError(
1029
+ f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1030
+ )
1031
+
1032
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1033
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1034
+ for i, t in enumerate(timesteps):
1035
+
1036
+ # cfg
1037
+ if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
1038
+ latent_model_input = torch.cat([latents] * 2)
1039
+ image_input = torch.cat([image] * 2)
1040
+ # pag
1041
+ elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
1042
+ latent_model_input = torch.cat([latents] * 2)
1043
+ image_input = torch.cat([image] * 2)
1044
+ # both
1045
+ elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
1046
+ latent_model_input = torch.cat([latents] * 3)
1047
+ image_input = torch.cat([image] * 3)
1048
+ # no
1049
+ else:
1050
+ latent_model_input = latents
1051
+ image_input = image
1052
+
1053
+ # concat latents, mask, masked_image_latents in the channel dimension
1054
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1055
+ latent_model_input = torch.cat([latent_model_input, image_input], dim=1)
1056
+
1057
+ # predict the noise residual
1058
+ noise_pred = self.unet(
1059
+ latent_model_input,
1060
+ t,
1061
+ encoder_hidden_states=prompt_embeds,
1062
+ cross_attention_kwargs=cross_attention_kwargs,
1063
+ class_labels=noise_level,
1064
+ return_dict=False,
1065
+ )[0]
1066
+
1067
+ # perform guidance
1068
+ # cfg
1069
+ if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
1070
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1071
+
1072
+ delta = noise_pred_text - noise_pred_uncond
1073
+ noise_pred = noise_pred_uncond + self.guidance_scale * delta
1074
+
1075
+ # pag
1076
+ elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
1077
+ noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)
1078
+
1079
+ signal_scale = self.pag_scale
1080
+ if self.do_pag_adaptive_scaling:
1081
+ signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t)
1082
+ if signal_scale < 0:
1083
+ signal_scale = 0
1084
+
1085
+ noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)
1086
+
1087
+ # both
1088
+ elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
1089
+ noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)
1090
+
1091
+ signal_scale = self.pag_scale
1092
+ if self.do_pag_adaptive_scaling:
1093
+ signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t)
1094
+ if signal_scale < 0:
1095
+ signal_scale = 0
1096
+
1097
+ noise_pred = (
1098
+ noise_pred_text
1099
+ + (self.guidance_scale - 1.0) * (noise_pred_text - noise_pred_uncond)
1100
+ + signal_scale * (noise_pred_text - noise_pred_text_perturb)
1101
+ )
1102
+
1103
+ # compute the previous noisy sample x_t -> x_t-1
1104
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1105
+
1106
+ # call the callback, if provided
1107
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1108
+ progress_bar.update()
1109
+ if callback is not None and i % callback_steps == 0:
1110
+ step_idx = i // getattr(self.scheduler, "order", 1)
1111
+ callback(step_idx, t, latents)
1112
+
1113
+ if not output_type == "latent":
1114
+ # make sure the VAE is in float32 mode, as it overflows in float16
1115
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1116
+
1117
+ if needs_upcasting:
1118
+ self.upcast_vae()
1119
+
1120
+ # Ensure latents are always the same type as the VAE
1121
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1122
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1123
+
1124
+ # cast back to fp16 if needed
1125
+ if needs_upcasting:
1126
+ self.vae.to(dtype=torch.float16)
1127
+
1128
+ image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
1129
+ else:
1130
+ image = latents
1131
+ has_nsfw_concept = None
1132
+
1133
+ if has_nsfw_concept is None:
1134
+ do_denormalize = [True] * image.shape[0]
1135
+ else:
1136
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1137
+
1138
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1139
+
1140
+ # 11. Apply watermark
1141
+ if output_type == "pil" and self.watermarker is not None:
1142
+ image = self.watermarker.apply_watermark(image)
1143
+
1144
+ # Offload all models
1145
+ self.maybe_free_model_hooks()
1146
+
1147
+ # change attention layer in UNet if use PAG
1148
+ if self.do_perturbed_attention_guidance:
1149
+ drop_layers = self.pag_applied_layers_index
1150
+ for drop_layer in drop_layers:
1151
+ try:
1152
+ if drop_layer[0] == "d":
1153
+ down_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1154
+ elif drop_layer[0] == "m":
1155
+ mid_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1156
+ elif drop_layer[0] == "u":
1157
+ up_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1158
+ else:
1159
+ raise ValueError(f"Invalid layer type: {drop_layer[0]}")
1160
+ except IndexError:
1161
+ raise ValueError(
1162
+ f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1163
+ )
1164
+
1165
+ if not return_dict:
1166
+ return (image, has_nsfw_concept)
1167
+
1168
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)