sihanxu commited on
Commit
0da639c
1 Parent(s): a8a7bbe

Upload pipeline_ddcm.py

Browse files
Files changed (1) hide show
  1. pipeline_ddcm.py +676 -0
pipeline_ddcm.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+ from packaging import version
8
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
9
+
10
+ from diffusers.configuration_utils import FrozenDict
11
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
12
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
13
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
14
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
15
+ from diffusers.schedulers import LCMScheduler
16
+ from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
19
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
20
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
21
+
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+
26
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
27
+ def preprocess(image):
28
+ deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
29
+ deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
30
+ if isinstance(image, torch.Tensor):
31
+ return image
32
+ elif isinstance(image, PIL.Image.Image):
33
+ image = [image]
34
+
35
+ if isinstance(image[0], PIL.Image.Image):
36
+ w, h = image[0].size
37
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
38
+
39
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
40
+ image = np.concatenate(image, axis=0)
41
+ image = np.array(image).astype(np.float32) / 255.0
42
+ image = image.transpose(0, 3, 1, 2)
43
+ image = 2.0 * image - 1.0
44
+ image = torch.from_numpy(image)
45
+ elif isinstance(image[0], torch.Tensor):
46
+ image = torch.cat(image, dim=0)
47
+ return image
48
+
49
+
50
+ def ddcm_sampler(scheduler, x_s, x_t, timestep, e_s, e_t, x_0, noise, eta):
51
+ if scheduler.num_inference_steps is None:
52
+ raise ValueError(
53
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
54
+ )
55
+
56
+ if scheduler.step_index is None:
57
+ scheduler._init_step_index(timestep)
58
+
59
+ prev_step_index = scheduler.step_index + 1
60
+ if prev_step_index < len(scheduler.timesteps):
61
+ prev_timestep = scheduler.timesteps[prev_step_index]
62
+ else:
63
+ prev_timestep = timestep
64
+
65
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
66
+ alpha_prod_t_prev = (
67
+ scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
68
+ )
69
+ beta_prod_t = 1 - alpha_prod_t
70
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
71
+ variance = beta_prod_t_prev
72
+ std_dev_t = eta * variance
73
+ noise = std_dev_t ** (0.5) * noise
74
+
75
+ e_c = (x_s - alpha_prod_t ** (0.5) * x_0) / (1 - alpha_prod_t) ** (0.5)
76
+
77
+ pred_x0 = x_0 + ((x_t - x_s) - beta_prod_t ** (0.5) * (e_t - e_s)) / alpha_prod_t ** (0.5)
78
+ eps = (e_t - e_s) + e_c
79
+ dir_xt = (beta_prod_t_prev - std_dev_t) ** (0.5) * eps
80
+
81
+ # Noise is not used for one-step sampling.
82
+ if len(scheduler.timesteps) > 1:
83
+ prev_xt = alpha_prod_t_prev ** (0.5) * pred_x0 + dir_xt + noise
84
+ prev_xs = alpha_prod_t_prev ** (0.5) * x_0 + dir_xt + noise
85
+ else:
86
+ prev_xt = pred_x0
87
+ prev_xs = x_0
88
+
89
+ scheduler._step_index += 1
90
+ return prev_xs, prev_xt, pred_x0
91
+
92
+
93
+ class DDCMPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
94
+ model_cpu_offload_seq = "text_encoder->unet->vae"
95
+ _optional_components = ["safety_checker", "feature_extractor"]
96
+
97
+ def __init__(
98
+ self,
99
+ vae: AutoencoderKL,
100
+ text_encoder: CLIPTextModel,
101
+ tokenizer: CLIPTokenizer,
102
+ unet: UNet2DConditionModel,
103
+ scheduler: LCMScheduler,
104
+ safety_checker: StableDiffusionSafetyChecker,
105
+ feature_extractor: CLIPImageProcessor,
106
+ requires_safety_checker: bool = True,
107
+ ):
108
+ super().__init__()
109
+
110
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
111
+ deprecation_message = (
112
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
113
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
114
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
115
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
116
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
117
+ " file"
118
+ )
119
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
120
+ new_config = dict(scheduler.config)
121
+ new_config["steps_offset"] = 1
122
+ scheduler._internal_dict = FrozenDict(new_config)
123
+
124
+ if safety_checker is None and requires_safety_checker:
125
+ logger.warning(
126
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
127
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
128
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
129
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
130
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
131
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
132
+ )
133
+
134
+ if safety_checker is not None and feature_extractor is None:
135
+ raise ValueError(
136
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
137
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
138
+ )
139
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
140
+ version.parse(unet.config._diffusers_version).base_version
141
+ ) < version.parse("0.9.0.dev0")
142
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
143
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
144
+ deprecation_message = (
145
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
146
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
147
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
148
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
149
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
150
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
151
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
152
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
153
+ " the `unet/config.json` file"
154
+ )
155
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
156
+ new_config = dict(unet.config)
157
+ new_config["sample_size"] = 64
158
+ unet._internal_dict = FrozenDict(new_config)
159
+
160
+ self.register_modules(
161
+ vae=vae,
162
+ text_encoder=text_encoder,
163
+ tokenizer=tokenizer,
164
+ unet=unet,
165
+ scheduler=scheduler,
166
+ safety_checker=safety_checker,
167
+ feature_extractor=feature_extractor,
168
+ )
169
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
170
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
171
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
172
+
173
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
174
+ def _encode_prompt(
175
+ self,
176
+ prompt,
177
+ device,
178
+ num_images_per_prompt,
179
+ do_classifier_free_guidance,
180
+ negative_prompt=None,
181
+ prompt_embeds: Optional[torch.FloatTensor] = None,
182
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
183
+ lora_scale: Optional[float] = None,
184
+ ):
185
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
186
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
187
+
188
+ prompt_embeds_tuple = self.encode_prompt(
189
+ prompt=prompt,
190
+ device=device,
191
+ num_images_per_prompt=num_images_per_prompt,
192
+ do_classifier_free_guidance=do_classifier_free_guidance,
193
+ negative_prompt=negative_prompt,
194
+ prompt_embeds=prompt_embeds,
195
+ negative_prompt_embeds=negative_prompt_embeds,
196
+ lora_scale=lora_scale,
197
+ )
198
+
199
+ # concatenate for backwards comp
200
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
201
+
202
+ return prompt_embeds
203
+
204
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
205
+ def encode_prompt(
206
+ self,
207
+ prompt,
208
+ device,
209
+ num_images_per_prompt,
210
+ do_classifier_free_guidance,
211
+ negative_prompt=None,
212
+ prompt_embeds: Optional[torch.FloatTensor] = None,
213
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
214
+ lora_scale: Optional[float] = None,
215
+ ):
216
+ # set lora scale so that monkey patched LoRA
217
+ # function of text encoder can correctly access it
218
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
219
+ self._lora_scale = lora_scale
220
+
221
+ # dynamically adjust the LoRA scale
222
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
223
+
224
+ if prompt is not None and isinstance(prompt, str):
225
+ batch_size = 1
226
+ elif prompt is not None and isinstance(prompt, list):
227
+ batch_size = len(prompt)
228
+ else:
229
+ batch_size = prompt_embeds.shape[0]
230
+
231
+ if prompt_embeds is None:
232
+ # textual inversion: procecss multi-vector tokens if necessary
233
+ if isinstance(self, TextualInversionLoaderMixin):
234
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
235
+
236
+ text_inputs = self.tokenizer(
237
+ prompt,
238
+ padding="max_length",
239
+ max_length=self.tokenizer.model_max_length,
240
+ truncation=True,
241
+ return_tensors="pt",
242
+ )
243
+ text_input_ids = text_inputs.input_ids
244
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
245
+
246
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
247
+ text_input_ids, untruncated_ids
248
+ ):
249
+ removed_text = self.tokenizer.batch_decode(
250
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
251
+ )
252
+ logger.warning(
253
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
254
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
255
+ )
256
+
257
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
258
+ attention_mask = text_inputs.attention_mask.to(device)
259
+ else:
260
+ attention_mask = None
261
+
262
+ prompt_embeds = self.text_encoder(
263
+ text_input_ids.to(device),
264
+ attention_mask=attention_mask,
265
+ )
266
+ prompt_embeds = prompt_embeds[0]
267
+
268
+ if self.text_encoder is not None:
269
+ prompt_embeds_dtype = self.text_encoder.dtype
270
+ elif self.unet is not None:
271
+ prompt_embeds_dtype = self.unet.dtype
272
+ else:
273
+ prompt_embeds_dtype = prompt_embeds.dtype
274
+
275
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
276
+
277
+ bs_embed, seq_len, _ = prompt_embeds.shape
278
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
279
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
280
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
281
+
282
+ # get unconditional embeddings for classifier free guidance
283
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
284
+ uncond_tokens: List[str]
285
+ if negative_prompt is None:
286
+ uncond_tokens = [""] * batch_size
287
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
288
+ raise TypeError(
289
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
290
+ f" {type(prompt)}."
291
+ )
292
+ elif isinstance(negative_prompt, str):
293
+ uncond_tokens = [negative_prompt]
294
+ elif batch_size != len(negative_prompt):
295
+ raise ValueError(
296
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
297
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
298
+ " the batch size of `prompt`."
299
+ )
300
+ else:
301
+ uncond_tokens = negative_prompt
302
+
303
+ # textual inversion: procecss multi-vector tokens if necessary
304
+ if isinstance(self, TextualInversionLoaderMixin):
305
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
306
+
307
+ max_length = prompt_embeds.shape[1]
308
+ uncond_input = self.tokenizer(
309
+ uncond_tokens,
310
+ padding="max_length",
311
+ max_length=max_length,
312
+ truncation=True,
313
+ return_tensors="pt",
314
+ )
315
+
316
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
317
+ attention_mask = uncond_input.attention_mask.to(device)
318
+ else:
319
+ attention_mask = None
320
+
321
+ negative_prompt_embeds = self.text_encoder(
322
+ uncond_input.input_ids.to(device),
323
+ attention_mask=attention_mask,
324
+ )
325
+ negative_prompt_embeds = negative_prompt_embeds[0]
326
+
327
+ if do_classifier_free_guidance:
328
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
329
+ seq_len = negative_prompt_embeds.shape[1]
330
+
331
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
332
+
333
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
334
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
335
+
336
+ return prompt_embeds, negative_prompt_embeds
337
+
338
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
339
+ def check_inputs(
340
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
341
+ ):
342
+ if strength < 0 or strength > 1:
343
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
344
+
345
+ if (callback_steps is None) or (
346
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
347
+ ):
348
+ raise ValueError(
349
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
350
+ f" {type(callback_steps)}."
351
+ )
352
+
353
+ if prompt is not None and prompt_embeds is not None:
354
+ raise ValueError(
355
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
356
+ " only forward one of the two."
357
+ )
358
+ elif prompt is None and prompt_embeds is None:
359
+ raise ValueError(
360
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
361
+ )
362
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
363
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
364
+
365
+ if negative_prompt is not None and negative_prompt_embeds is not None:
366
+ raise ValueError(
367
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
368
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
369
+ )
370
+
371
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
372
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
373
+ raise ValueError(
374
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
375
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
376
+ f" {negative_prompt_embeds.shape}."
377
+ )
378
+
379
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
380
+ def prepare_extra_step_kwargs(self, generator, eta):
381
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
382
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
383
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
384
+ # and should be between [0, 1]
385
+
386
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
387
+ extra_step_kwargs = {}
388
+ if accepts_eta:
389
+ extra_step_kwargs["eta"] = eta
390
+
391
+ # check if the scheduler accepts generator
392
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
393
+ if accepts_generator:
394
+ extra_step_kwargs["generator"] = generator
395
+ return extra_step_kwargs
396
+
397
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
398
+ def run_safety_checker(self, image, device, dtype):
399
+ if self.safety_checker is None:
400
+ has_nsfw_concept = None
401
+ else:
402
+ if torch.is_tensor(image):
403
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
404
+ else:
405
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
406
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
407
+ image, has_nsfw_concept = self.safety_checker(
408
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
409
+ )
410
+ return image, has_nsfw_concept
411
+
412
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
413
+ def decode_latents(self, latents):
414
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
415
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
416
+
417
+ latents = 1 / self.vae.config.scaling_factor * latents
418
+ image = self.vae.decode(latents, return_dict=False)[0]
419
+ image = (image / 2 + 0.5).clamp(0, 1)
420
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
421
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
422
+ return image
423
+
424
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
425
+ def get_timesteps(self, num_inference_steps, strength, device):
426
+ # get the original timestep using init_timestep
427
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
428
+
429
+ t_start = max(num_inference_steps - init_timestep, 0)
430
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
431
+
432
+ return timesteps, num_inference_steps - t_start
433
+
434
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, denoise_model, generator=None):
435
+ image = image.to(device=device, dtype=dtype)
436
+
437
+ batch_size = image.shape[0]
438
+
439
+ if image.shape[1] == 4:
440
+ init_latents = image
441
+
442
+ else:
443
+ if isinstance(generator, list) and len(generator) != batch_size:
444
+ raise ValueError(
445
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
446
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
447
+ )
448
+
449
+ if isinstance(generator, list):
450
+ init_latents = [
451
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
452
+ ]
453
+ init_latents = torch.cat(init_latents, dim=0)
454
+ else:
455
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
456
+
457
+ init_latents = self.vae.config.scaling_factor * init_latents
458
+
459
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
460
+ # expand init_latents for batch_size
461
+ deprecation_message = (
462
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
463
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
464
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
465
+ " your script to pass as many initial images as text prompts to suppress this warning."
466
+ )
467
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
468
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
469
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
470
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
471
+ raise ValueError(
472
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
473
+ )
474
+ else:
475
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
476
+
477
+ # add noise to latents using the timestep
478
+ shape = init_latents.shape
479
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
480
+
481
+ # get latents
482
+ clean_latents = init_latents
483
+ if denoise_model:
484
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
485
+ latents = init_latents
486
+ else:
487
+ latents = noise
488
+
489
+ return latents, clean_latents
490
+
491
+ @torch.no_grad()
492
+ def __call__(
493
+ self,
494
+ prompt: Union[str, List[str]],
495
+ source_prompt: Union[str, List[str]],
496
+ negative_prompt: Union[str, List[str]]=None,
497
+ positive_prompt: Union[str, List[str]]=None,
498
+ image: PipelineImageInput = None,
499
+ strength: float = 0.8,
500
+ num_inference_steps: Optional[int] = 50,
501
+ original_inference_steps: Optional[int] = 50,
502
+ guidance_scale: Optional[float] = 7.5,
503
+ source_guidance_scale: Optional[float] = 1,
504
+ num_images_per_prompt: Optional[int] = 1,
505
+ eta: Optional[float] = 1.0,
506
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
507
+ prompt_embeds: Optional[torch.FloatTensor] = None,
508
+ output_type: Optional[str] = "pil",
509
+ return_dict: bool = True,
510
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
511
+ callback_steps: int = 1,
512
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
513
+ denoise_model: Optional[bool] = True,
514
+ ):
515
+ # 1. Check inputs
516
+ self.check_inputs(prompt, strength, callback_steps)
517
+
518
+ # 2. Define call parameters
519
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
520
+ device = self._execution_device
521
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
522
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
523
+ # corresponds to doing no classifier free guidance.
524
+ do_classifier_free_guidance = guidance_scale > 1.0
525
+
526
+ # 3. Encode input prompt
527
+ text_encoder_lora_scale = (
528
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
529
+ )
530
+ prompt_embeds_tuple = self.encode_prompt(
531
+ prompt,
532
+ device,
533
+ num_images_per_prompt,
534
+ do_classifier_free_guidance,
535
+ negative_prompt=negative_prompt,
536
+ prompt_embeds=prompt_embeds,
537
+ lora_scale=text_encoder_lora_scale,
538
+ )
539
+ source_prompt_embeds_tuple = self.encode_prompt(
540
+ source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, positive_prompt, None
541
+ )
542
+ if prompt_embeds_tuple[1] is not None:
543
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
544
+ else:
545
+ prompt_embeds = prompt_embeds_tuple[0]
546
+ if source_prompt_embeds_tuple[1] is not None:
547
+ source_prompt_embeds = torch.cat([source_prompt_embeds_tuple[1], source_prompt_embeds_tuple[0]])
548
+ else:
549
+ source_prompt_embeds = source_prompt_embeds_tuple[0]
550
+
551
+ # 4. Preprocess image
552
+ image = self.image_processor.preprocess(image)
553
+
554
+ # 5. Prepare timesteps
555
+ self.scheduler.set_timesteps(
556
+ num_inference_steps=num_inference_steps,
557
+ device=device,
558
+ original_inference_steps=original_inference_steps)
559
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
560
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
561
+
562
+ # 6. Prepare latent variables
563
+ latents, clean_latents = self.prepare_latents(
564
+ image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, denoise_model, generator
565
+ )
566
+ source_latents = latents
567
+
568
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
569
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
570
+ generator = extra_step_kwargs.pop("generator", None)
571
+
572
+ # 8. Denoising loop
573
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
574
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
575
+ for i, t in enumerate(timesteps):
576
+ # expand the latents if we are doing classifier free guidance
577
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
578
+ source_latent_model_input = (
579
+ torch.cat([source_latents] * 2) if do_classifier_free_guidance else source_latents
580
+ )
581
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
582
+ source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
583
+
584
+ # predict the noise residual
585
+ if do_classifier_free_guidance:
586
+ concat_latent_model_input = torch.stack(
587
+ [
588
+ source_latent_model_input[0],
589
+ latent_model_input[0],
590
+ source_latent_model_input[1],
591
+ latent_model_input[1],
592
+ ],
593
+ dim=0,
594
+ )
595
+ concat_prompt_embeds = torch.stack(
596
+ [
597
+ source_prompt_embeds[0],
598
+ prompt_embeds[0],
599
+ source_prompt_embeds[1],
600
+ prompt_embeds[1],
601
+ ],
602
+ dim=0,
603
+ )
604
+ else:
605
+ concat_latent_model_input = torch.cat(
606
+ [
607
+ source_latent_model_input,
608
+ latent_model_input,
609
+ ],
610
+ dim=0,
611
+ )
612
+ concat_prompt_embeds = torch.cat(
613
+ [
614
+ source_prompt_embeds,
615
+ prompt_embeds,
616
+ ],
617
+ dim=0,
618
+ )
619
+
620
+ concat_noise_pred = self.unet(
621
+ concat_latent_model_input,
622
+ t,
623
+ cross_attention_kwargs=cross_attention_kwargs,
624
+ encoder_hidden_states=concat_prompt_embeds,
625
+ ).sample
626
+
627
+ # perform guidance
628
+ if do_classifier_free_guidance:
629
+ (
630
+ source_noise_pred_uncond,
631
+ noise_pred_uncond,
632
+ source_noise_pred_text,
633
+ noise_pred_text,
634
+ ) = concat_noise_pred.chunk(4, dim=0)
635
+
636
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
637
+ source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
638
+ source_noise_pred_text - source_noise_pred_uncond
639
+ )
640
+
641
+ else:
642
+ (source_noise_pred, noise_pred) = concat_noise_pred.chunk(2, dim=0)
643
+
644
+ noise = torch.randn(
645
+ latents.shape, dtype=latents.dtype, device=latents.device, generator=generator
646
+ )
647
+
648
+ source_latents, latents, pred_x0 = ddcm_sampler(
649
+ self.scheduler, source_latents, latents, t, source_noise_pred, noise_pred, clean_latents, noise=noise, eta=eta, **extra_step_kwargs
650
+ )
651
+
652
+ # call the callback, if provided
653
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
654
+ progress_bar.update()
655
+ if callback is not None and i % callback_steps == 0:
656
+ callback(i, t, latents)
657
+
658
+ # 9. Post-processing
659
+ if not output_type == "latent":
660
+ image = self.vae.decode(pred_x0 / self.vae.config.scaling_factor, return_dict=False)[0]
661
+ has_nsfw_concept = [False] * len(image)
662
+ else:
663
+ image = pred_x0
664
+ has_nsfw_concept = None
665
+
666
+ if has_nsfw_concept is None:
667
+ do_denormalize = [True] * image.shape[0]
668
+ else:
669
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
670
+
671
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
672
+
673
+ if not return_dict:
674
+ return (image, has_nsfw_concept)
675
+
676
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)