sachit-menon commited on
Commit
ef02a1a
1 Parent(s): 4115a6e

Create snt_pipeline.py

Browse files
Files changed (1) hide show
  1. snt_pipeline.py +756 -0
snt_pipeline.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import warnings
17
+ from typing import Callable, List, Optional, Union
18
+
19
+ import PIL
20
+ import torch
21
+ from transformers import CLIPImageProcessor
22
+
23
+ from diffusers.image_processor import VaeImageProcessor
24
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
25
+ from diffusers.utils import (
26
+ deprecate,
27
+ is_accelerate_available,
28
+ is_accelerate_version,
29
+ logging,
30
+ )
31
+
32
+ try:
33
+ from diffusers.utils import randn_tensor
34
+ except ImportError:
35
+ from diffusers.utils.torch_utils import randn_tensor
36
+
37
+
38
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
39
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
40
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
41
+
42
+ from trainer.models.sd_model import SDModel
43
+
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+ from typing import Callable, List, Optional, Union
48
+ import PIL
49
+
50
+ from transformers import CLIPImageProcessor
51
+
52
+ from diffusers.image_processor import VaeImageProcessor
53
+
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+ # from hydra.utils import instantiate
63
+
64
+ from einops import rearrange, repeat
65
+
66
+
67
+ class ShowNotTellPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
68
+ r"""
69
+ Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
70
+
71
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
72
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
73
+
74
+ In addition the pipeline inherits the following loading methods:
75
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
76
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
77
+
78
+ as well as the following saving methods:
79
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
80
+
81
+ Args:
82
+ vae ([`AutoencoderKL`]):
83
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
84
+ text_encoder ([`CLIPTextModel`]):
85
+ Frozen text-encoder. Stable Diffusion uses the text portion of
86
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
87
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
88
+ tokenizer (`CLIPTokenizer`):
89
+ Tokenizer of class
90
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
91
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
92
+ scheduler ([`SchedulerMixin`]):
93
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
94
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
95
+ safety_checker ([`StableDiffusionSafetyChecker`]):
96
+ Classification module that estimates whether generated images could be considered offensive or harmful.
97
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
98
+ feature_extractor ([`CLIPImageProcessor`]):
99
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
100
+ """
101
+ _optional_components = ["safety_checker", "feature_extractor"]
102
+
103
+ def __init__(
104
+ self,
105
+ # cfg: SDModelConfig,
106
+ model: SDModel,
107
+ safety_checker: StableDiffusionSafetyChecker = None,
108
+ feature_extractor: CLIPImageProcessor = None,
109
+ requires_safety_checker: bool = False,
110
+ ):
111
+ super().__init__()
112
+ # self.model.cfg = cfg
113
+ self.register_modules(model=model, safety_checker=safety_checker, feature_extractor=feature_extractor)
114
+ # self.register_to_config(cfg=dataclasses.asdict(cfg))
115
+
116
+ self.model.vae_scale_factor = 2 ** (len(self.model.vae.config.block_out_channels) - 1)
117
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.model.vae_scale_factor)
118
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
119
+
120
+ @torch.no_grad()
121
+ def __call__(
122
+ self,
123
+ prompts,
124
+ image,
125
+ num_inference_steps: int = 100,
126
+ guidance_scale: float = 7.5,
127
+ image_guidance_scale: float = 1.5,
128
+ negative_prompt: Optional[Union[str, List[str]]] = None,
129
+ num_images_per_prompt: Optional[int] = 1,
130
+ eta: float = 0.0,
131
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
132
+ latents: Optional[torch.FloatTensor] = None,
133
+ prompt_embeds: Optional[torch.FloatTensor] = None,
134
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
135
+ output_type: Optional[str] = "pil",
136
+ return_dict: bool = True,
137
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
138
+ callback_steps: int = 1,):
139
+
140
+ if isinstance(prompts, str):
141
+ prompts = [prompts]
142
+ if isinstance(prompts, list):
143
+ input_ids = self.fancy_get_input_ids(prompts, self.model.text_encoder.device) # TODO see if reshaping needed to match train dataloader
144
+ else:
145
+ input_ids = prompts
146
+
147
+ if isinstance(image, PIL.Image.Image):
148
+ image = [image]
149
+ if isinstance(image, list):
150
+ preprocessed_images = self.image_processor.preprocess(image)
151
+ else:
152
+ preprocessed_images = image
153
+
154
+ batch_size = input_ids.shape[0]
155
+
156
+ # device = self._execution_device
157
+ device = self.model.text_encoder.device # TODO figure out execution device stuff
158
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
159
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
160
+ # corresponds to doing no classifier free guidance.
161
+ do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0
162
+ # check if scheduler is in sigmas space
163
+ scheduler_is_in_sigma_space = hasattr(self.model.noise_scheduler, "sigmas")
164
+
165
+
166
+ prompt_embeds = self.encode_prompt_batch(input_ids, batch_size, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds, negative_prompt_embeds)
167
+
168
+ # 4. set timesteps
169
+ self.model.noise_scheduler.set_timesteps(num_inference_steps, device=device)
170
+ timesteps = self.model.noise_scheduler.timesteps
171
+
172
+ # 5. Prepare Image latents
173
+ image_latents = self.prepare_image_latents(
174
+ preprocessed_images,
175
+ batch_size,
176
+ num_images_per_prompt,
177
+ prompt_embeds.dtype,
178
+ device,
179
+ do_classifier_free_guidance,
180
+ generator,
181
+ )
182
+
183
+ height, width = image_latents.shape[-2:]
184
+ height = height * self.model.vae_scale_factor
185
+ width = width * self.model.vae_scale_factor
186
+
187
+ # 6. Prepare latent variables
188
+ num_channels_latents = self.model.vae.config.latent_channels
189
+
190
+ latents = self.prepare_latents(
191
+ batch_size * num_images_per_prompt,
192
+ num_channels_latents,
193
+ height,
194
+ width,
195
+ prompt_embeds.dtype,
196
+ device,
197
+ generator,
198
+ latents,
199
+ )
200
+
201
+ # 7. Check that shapes of latents and image match the UNet channels
202
+ num_channels_image = image_latents.shape[1]
203
+ if num_channels_latents + num_channels_image != self.model.unet.config.in_channels:
204
+ raise ValueError(
205
+ f"Incorrect configuration settings! The config of `pipeline.model.unet`: {self.model.unet.config} expects"
206
+ f" {self.model.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
207
+ f" `num_channels_image`: {num_channels_image} "
208
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
209
+ " `pipeline.model.unet` or your `image` input."
210
+ )
211
+
212
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
213
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
214
+
215
+ # 9. Denoising loop
216
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.model.noise_scheduler.order
217
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
218
+ for i, t in enumerate(timesteps):
219
+
220
+
221
+
222
+ # Expand the latents if we are doing classifier free guidance.
223
+ # The latents are expanded 3 times because for pix2pix the guidance\
224
+ # is applied for both the text and the input image.
225
+ latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents
226
+ # if i == 0:
227
+ # if self.model.cfg.image_positional_encoding_type is not None:
228
+ # third = latents.shape[0]//3
229
+ # cond_latents = latents[third:2*third]
230
+ # cond_latents = rearrange(cond_latents, 'b c (s h) w -> (b s) c h w', s=self.model.cfg.sequence_length)
231
+ # cond_latents = self.model.apply_image_positional_encoding(cond_latents, self.model.cfg.sequence_length)
232
+ # cond_latents = rearrange(cond_latents, '(b s) c h w -> b c (s h) w', s=self.model.cfg.sequence_length)
233
+ # latents[third:2*third] = cond_latents
234
+
235
+ # concat latents, image_latents in the channel dimension
236
+ scaled_latent_model_input = self.model.noise_scheduler.scale_model_input(latent_model_input, t)
237
+
238
+ scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
239
+
240
+ # predict the noise residual
241
+ noise_pred = self.model.unet(
242
+ scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False
243
+ )[0]
244
+
245
+ # Hack:
246
+ # For karras style schedulers the model does classifer free guidance using the
247
+ # predicted_original_sample instead of the noise_pred. So we need to compute the
248
+ # predicted_original_sample here if we are using a karras style scheduler.
249
+ if scheduler_is_in_sigma_space:
250
+ step_index = (self.model.noise_scheduler.timesteps == t).nonzero().item()
251
+ sigma = self.model.noise_scheduler.sigmas[step_index]
252
+ noise_pred = latent_model_input - sigma * noise_pred
253
+
254
+ # perform guidance
255
+ if do_classifier_free_guidance:
256
+ noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
257
+ noise_pred = (
258
+ noise_pred_uncond
259
+ + guidance_scale * (noise_pred_text - noise_pred_image)
260
+ + image_guidance_scale * (noise_pred_image - noise_pred_uncond)
261
+ )
262
+
263
+ # Hack:
264
+ # For karras style schedulers the model does classifer free guidance using the
265
+ # predicted_original_sample instead of the noise_pred. But the scheduler.step function
266
+ # expects the noise_pred and computes the predicted_original_sample internally. So we
267
+ # need to overwrite the noise_pred here such that the value of the computed
268
+ # predicted_original_sample is correct.
269
+ if scheduler_is_in_sigma_space:
270
+ noise_pred = (noise_pred - latents) / (-sigma)
271
+
272
+ # compute the previous noisy sample x_t -> x_t-1
273
+ latents = self.model.noise_scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
274
+
275
+
276
+
277
+ # call the callback, if provided
278
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.model.noise_scheduler.order == 0):
279
+ progress_bar.update()
280
+ if callback is not None and i % callback_steps == 0:
281
+ callback(i, t, latents)
282
+
283
+ if not output_type == "latent":
284
+ latents = rearrange(latents, 'b c (s h) w -> (b s) c h w', s=self.model.cfg.sequence_length) # these are image latents, so sequence_length instead of text_sequence_length
285
+ image = self.model.vae.decode(latents / self.model.vae.config.scaling_factor, return_dict=False)[0]
286
+ # image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
287
+ else:
288
+ image = latents
289
+
290
+ has_nsfw_concept = None
291
+ do_denormalize = [True] * image.shape[0]
292
+ # if has_nsfw_concept is None:
293
+ # do_denormalize = [True] * image.shape[0]
294
+ # else:
295
+ # do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
296
+
297
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
298
+
299
+ # Offload last model to CPU
300
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
301
+ self.final_offload_hook.offload()
302
+
303
+ if not return_dict:
304
+ return (image, has_nsfw_concept)
305
+
306
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
307
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
308
+ def enable_sequential_cpu_offload(self, gpu_id=0):
309
+ r"""
310
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
311
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
312
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
313
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
314
+ `enable_model_cpu_offload`, but performance is lower.
315
+ """
316
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
317
+ from accelerate import cpu_offload
318
+ else:
319
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
320
+
321
+ device = torch.device(f"cuda:{gpu_id}")
322
+
323
+ if self.device.type != "cpu":
324
+ self.to("cpu", silence_dtype_warnings=True)
325
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
326
+
327
+ for cpu_offloaded_model in [self.model.unet, self.model.text_encoder, self.model.vae]:
328
+ cpu_offload(cpu_offloaded_model, device)
329
+
330
+ if self.safety_checker is not None:
331
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
332
+
333
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
334
+ def enable_model_cpu_offload(self, gpu_id=0):
335
+ r"""
336
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
337
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
338
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
339
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
340
+ """
341
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
342
+ from accelerate import cpu_offload_with_hook
343
+ else:
344
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
345
+
346
+ device = torch.device(f"cuda:{gpu_id}")
347
+
348
+ if self.device.type != "cpu":
349
+ self.to("cpu", silence_dtype_warnings=True)
350
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
351
+
352
+ hook = None
353
+ for cpu_offloaded_model in [self.model.text_encoder, self.model.unet, self.model.vae]:
354
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
355
+
356
+ if self.safety_checker is not None:
357
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
358
+
359
+ # We'll offload the last model manually.
360
+ self.final_offload_hook = hook
361
+
362
+ @property
363
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
364
+ def _execution_device(self):
365
+ r"""
366
+ Returns the device on which the pipeline's models will be executed. After calling
367
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
368
+ hooks.
369
+ """
370
+ if not hasattr(self.model.unet, "_hf_hook"):
371
+ return self.device
372
+ for module in self.model.unet.modules():
373
+ if (
374
+ hasattr(module, "_hf_hook")
375
+ and hasattr(module._hf_hook, "execution_device")
376
+ and module._hf_hook.execution_device is not None
377
+ ):
378
+ return torch.device(module._hf_hook.execution_device)
379
+ return self.device
380
+
381
+ def _encode_prompt(
382
+ self,
383
+ prompt,
384
+ device,
385
+ num_images_per_prompt,
386
+ do_classifier_free_guidance,
387
+ negative_prompt=None,
388
+ prompt_embeds: Optional[torch.FloatTensor] = None,
389
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
390
+ ):
391
+ r"""
392
+ Encodes the prompt into text encoder hidden states.
393
+
394
+ Args:
395
+ prompt (`str` or `List[str]`, *optional*):
396
+ prompt to be encoded
397
+ device: (`torch.device`):
398
+ torch device
399
+ num_images_per_prompt (`int`):
400
+ number of images that should be generated per prompt
401
+ do_classifier_free_guidance (`bool`):
402
+ whether to use classifier free guidance or not
403
+ negative_ prompt (`str` or `List[str]`, *optional*):
404
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
405
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
406
+ less than `1`).
407
+ prompt_embeds (`torch.FloatTensor`, *optional*):
408
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
409
+ provided, text embeddings will be generated from `prompt` input argument.
410
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
411
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
412
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
413
+ argument.
414
+ """
415
+ if prompt is not None and isinstance(prompt, str):
416
+ batch_size = 1
417
+ elif prompt is not None and isinstance(prompt, list):
418
+ batch_size = len(prompt)
419
+ else:
420
+ batch_size = prompt_embeds.shape[0]
421
+
422
+ if prompt_embeds is None:
423
+ # textual inversion: procecss multi-vector tokens if necessary
424
+ if isinstance(self, TextualInversionLoaderMixin):
425
+ prompt = self.maybe_convert_prompt(prompt, self.model.tokenizer)
426
+
427
+ text_inputs = self.model.tokenizer(
428
+ prompt,
429
+ padding="max_length",
430
+ max_length=self.model.tokenizer.model_max_length,
431
+ truncation=True,
432
+ return_tensors="pt",
433
+ )
434
+ text_input_ids = text_inputs.input_ids
435
+ untruncated_ids = self.model.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
436
+
437
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
438
+ text_input_ids, untruncated_ids
439
+ ):
440
+ removed_text = self.model.tokenizer.batch_decode(
441
+ untruncated_ids[:, self.model.tokenizer.model_max_length - 1 : -1]
442
+ )
443
+ logger.warning(
444
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
445
+ f" {self.model.tokenizer.model_max_length} tokens: {removed_text}"
446
+ )
447
+
448
+ if hasattr(self.model.text_encoder.config, "use_attention_mask") and self.model.text_encoder.config.use_attention_mask:
449
+ attention_mask = text_inputs.attention_mask.to(device)
450
+ else:
451
+ attention_mask = None
452
+
453
+ prompt_embeds = self.model.text_encoder(
454
+ text_input_ids.to(device),
455
+ attention_mask=attention_mask,
456
+ )
457
+ prompt_embeds = prompt_embeds[0]
458
+
459
+ prompt_embeds = prompt_embeds.to(dtype=self.model.text_encoder.dtype, device=device)
460
+
461
+ bs_embed, seq_len, _ = prompt_embeds.shape
462
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
463
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
464
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
465
+
466
+ # get unconditional embeddings for classifier free guidance
467
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
468
+ uncond_tokens: List[str]
469
+ if negative_prompt is None:
470
+ uncond_tokens = [""] * batch_size
471
+ elif type(prompt) is not type(negative_prompt):
472
+ raise TypeError(
473
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
474
+ f" {type(prompt)}."
475
+ )
476
+ elif isinstance(negative_prompt, str):
477
+ uncond_tokens = [negative_prompt]
478
+ elif batch_size != len(negative_prompt):
479
+ raise ValueError(
480
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
481
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
482
+ " the batch size of `prompt`."
483
+ )
484
+ else:
485
+ uncond_tokens = negative_prompt
486
+
487
+ # textual inversion: procecss multi-vector tokens if necessary
488
+ if isinstance(self, TextualInversionLoaderMixin):
489
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.model.tokenizer)
490
+
491
+ max_length = prompt_embeds.shape[1]
492
+ uncond_input = self.model.tokenizer(
493
+ uncond_tokens,
494
+ padding="max_length",
495
+ max_length=max_length,
496
+ truncation=True,
497
+ return_tensors="pt",
498
+ )
499
+
500
+ if hasattr(self.model.text_encoder.config, "use_attention_mask") and self.model.text_encoder.config.use_attention_mask:
501
+ attention_mask = uncond_input.attention_mask.to(device)
502
+ else:
503
+ attention_mask = None
504
+
505
+ negative_prompt_embeds = self.model.text_encoder(
506
+ uncond_input.input_ids.to(device),
507
+ attention_mask=attention_mask,
508
+ )
509
+ negative_prompt_embeds = negative_prompt_embeds[0]
510
+
511
+ if do_classifier_free_guidance:
512
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
513
+ seq_len = negative_prompt_embeds.shape[1]
514
+
515
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.model.text_encoder.dtype, device=device)
516
+
517
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
518
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
519
+
520
+ # For classifier free guidance, we need to do two forward passes.
521
+ # Here we concatenate the unconditional and text embeddings into a single batch
522
+ # to avoid doing two forward passes
523
+ # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
524
+ prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds])
525
+
526
+ return prompt_embeds
527
+
528
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
529
+ def run_safety_checker(self, image, device, dtype):
530
+ if self.safety_checker is None:
531
+ has_nsfw_concept = None
532
+ else:
533
+ if torch.is_tensor(image):
534
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
535
+ else:
536
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
537
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
538
+ image, has_nsfw_concept = self.safety_checker(
539
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
540
+ )
541
+ return image, has_nsfw_concept
542
+
543
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
544
+ def prepare_extra_step_kwargs(self, generator, eta):
545
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
546
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
547
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
548
+ # and should be between [0, 1]
549
+
550
+ accepts_eta = "eta" in set(inspect.signature(self.model.noise_scheduler.step).parameters.keys())
551
+ extra_step_kwargs = {}
552
+ if accepts_eta:
553
+ extra_step_kwargs["eta"] = eta
554
+
555
+ # check if the scheduler accepts generator
556
+ accepts_generator = "generator" in set(inspect.signature(self.model.noise_scheduler.step).parameters.keys())
557
+ if accepts_generator:
558
+ extra_step_kwargs["generator"] = generator
559
+ return extra_step_kwargs
560
+
561
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
562
+ def decode_latents(self, latents):
563
+ warnings.warn(
564
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
565
+ " use VaeImageProcessor instead",
566
+ FutureWarning,
567
+ )
568
+ latents = 1 / self.model.vae.config.scaling_factor * latents
569
+ image = self.model.vae.decode(latents, return_dict=False)[0]
570
+ image = (image / 2 + 0.5).clamp(0, 1)
571
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
572
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
573
+ return image
574
+
575
+ def check_inputs(
576
+ self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
577
+ ):
578
+ if (callback_steps is None) or (
579
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
580
+ ):
581
+ raise ValueError(
582
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
583
+ f" {type(callback_steps)}."
584
+ )
585
+
586
+ if prompt is not None and prompt_embeds is not None:
587
+ raise ValueError(
588
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
589
+ " only forward one of the two."
590
+ )
591
+ elif prompt is None and prompt_embeds is None:
592
+ raise ValueError(
593
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
594
+ )
595
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
596
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
597
+
598
+ if negative_prompt is not None and negative_prompt_embeds is not None:
599
+ raise ValueError(
600
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
601
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
602
+ )
603
+
604
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
605
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
606
+ raise ValueError(
607
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
608
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
609
+ f" {negative_prompt_embeds.shape}."
610
+ )
611
+
612
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
613
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
614
+ shape = (batch_size, num_channels_latents, height // self.model.vae_scale_factor, width // self.model.vae_scale_factor)
615
+ if isinstance(generator, list) and len(generator) != batch_size:
616
+ raise ValueError(
617
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
618
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
619
+ )
620
+
621
+ if latents is None:
622
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
623
+ else:
624
+ latents = latents.to(device)
625
+
626
+ # scale the initial noise by the standard deviation required by the scheduler
627
+ latents = latents * self.model.noise_scheduler.init_noise_sigma
628
+ return latents
629
+
630
+ def original_prepare_image_latents(
631
+ self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
632
+ ):
633
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
634
+ raise ValueError(
635
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
636
+ )
637
+
638
+ image = image.to(device=device, dtype=dtype)
639
+
640
+ batch_size = batch_size * num_images_per_prompt
641
+
642
+ if image.shape[1] == 4:
643
+ image_latents = image
644
+ else:
645
+ if isinstance(generator, list) and len(generator) != batch_size:
646
+ raise ValueError(
647
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
648
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
649
+ )
650
+
651
+ if isinstance(generator, list):
652
+ image_latents = [self.model.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
653
+ image_latents = torch.cat(image_latents, dim=0)
654
+ else:
655
+ image_latents = self.model.vae.encode(image).latent_dist.mode()
656
+
657
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
658
+ # expand image_latents for batch_size
659
+ deprecation_message = (
660
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
661
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
662
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
663
+ " your script to pass as many initial images as text prompts to suppress this warning."
664
+ )
665
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
666
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
667
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
668
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
669
+ raise ValueError(
670
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
671
+ )
672
+ else:
673
+ image_latents = torch.cat([image_latents], dim=0)
674
+
675
+ if do_classifier_free_guidance:
676
+ uncond_image_latents = torch.zeros_like(image_latents)
677
+ image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
678
+
679
+ return image_latents
680
+
681
+ def prepare_image_latents(self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None):
682
+ image_latents = self.original_prepare_image_latents(image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator)
683
+ return repeat(image_latents, 'b c h w -> b c (s h) w', s=self.model.cfg.sequence_length)
684
+
685
+ def fancy_get_input_ids(self, prompt, device):
686
+ # textual inversion: procecss multi-vector tokens if necessary
687
+ if isinstance(self, TextualInversionLoaderMixin):
688
+ prompt = self.maybe_convert_prompt(prompt, self.model.tokenizer)
689
+
690
+ text_inputs = self.model.tokenizer(
691
+ prompt,
692
+ padding="max_length",
693
+ max_length=self.model.tokenizer.model_max_length,
694
+ truncation=True,
695
+ return_tensors="pt",
696
+ )
697
+ text_input_ids = text_inputs.input_ids
698
+ untruncated_ids = self.model.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
699
+
700
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
701
+ text_input_ids, untruncated_ids
702
+ ):
703
+ removed_text = self.model.tokenizer.batch_decode(
704
+ untruncated_ids[:, self.model.tokenizer.model_max_length - 1 : -1]
705
+ )
706
+ logger.warning(
707
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
708
+ f" {self.model.tokenizer.model_max_length} tokens: {removed_text}"
709
+ )
710
+
711
+ if hasattr(self.model.text_encoder.config, "use_attention_mask") and self.model.text_encoder.config.use_attention_mask:
712
+ attention_mask = text_inputs.attention_mask.to(device)
713
+ else:
714
+ attention_mask = None
715
+ text_input_ids = text_input_ids
716
+ return text_input_ids,attention_mask
717
+
718
+ def encode_prompt_batch(self,
719
+ input_ids,
720
+ batch_size,
721
+ device,
722
+ num_images_per_prompt: int=1,
723
+ do_classifier_free_guidance: bool=False,
724
+ negative_prompt=None,
725
+ prompt_embeds=None,
726
+ negative_prompt_embeds=None,):
727
+ encoder_hidden_states = self.model.input_ids_to_text_condition(input_ids)
728
+ if self.model.cfg.positional_encoding_type is not None:
729
+ encoder_hidden_states = self.model.apply_step_positional_encoding(encoder_hidden_states)
730
+ prompt_embeds = encoder_hidden_states
731
+ prompt_embeds = prompt_embeds.to(dtype=self.model.text_encoder.dtype, device=device)
732
+
733
+ bs_embed, seq_len, _ = prompt_embeds.shape
734
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
735
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
736
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
737
+
738
+ if do_classifier_free_guidance:
739
+ if negative_prompt_embeds is None:
740
+ negative_prompt_embeds = self.model.get_null_conditioning()
741
+ negative_prompt_embeds = repeat(negative_prompt_embeds, 'o t l -> (b o) t l', b=batch_size) #, o=1
742
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
743
+ seq_len = negative_prompt_embeds.shape[1]
744
+
745
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.model.text_encoder.dtype, device=device)
746
+
747
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
748
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
749
+
750
+ # For classifier free guidance, we need to do two forward passes.
751
+ # Here we concatenate the unconditional and text embeddings into a single batch
752
+ # to avoid doing two forward passes
753
+ # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
754
+ prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds])
755
+ return prompt_embeds
756
+