AlanB commited on
Commit
dab0ba6
1 Parent(s): 1618ac2

Many updates from the Github

Browse files
Files changed (1) hide show
  1. pipeline.py +433 -208
pipeline.py CHANGED
@@ -1,25 +1,52 @@
1
- """
2
- modified based on diffusion library from Huggingface: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
3
- """
 
 
 
 
 
 
 
 
 
 
 
4
  import inspect
5
- import warnings
6
  from typing import Callable, List, Optional, Union
7
 
8
  import torch
9
 
10
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
- from diffusers.pipeline_utils import DiffusionPipeline
12
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
- from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
15
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class ComposableStableDiffusionPipeline(DiffusionPipeline):
19
  r"""
20
  Pipeline for text-to-image generation using Stable Diffusion.
 
21
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
22
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
 
23
  Args:
24
  vae ([`AutoencoderKL`]):
25
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -35,11 +62,12 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
35
  A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
36
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
37
  safety_checker ([`StableDiffusionSafetyChecker`]):
38
- Classification module that estimates whether generated images could be considered offsensive or harmful.
39
- Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
40
  feature_extractor ([`CLIPFeatureExtractor`]):
41
  Model that extracts features from generated images to be used as inputs for the `safety_checker`.
42
  """
 
43
 
44
  def __init__(
45
  self,
@@ -47,11 +75,84 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
47
  text_encoder: CLIPTextModel,
48
  tokenizer: CLIPTokenizer,
49
  unet: UNet2DConditionModel,
50
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler],
 
 
 
 
 
 
 
51
  safety_checker: StableDiffusionSafetyChecker,
52
  feature_extractor: CLIPFeatureExtractor,
 
53
  ):
54
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  self.register_modules(
56
  vae=vae,
57
  text_encoder=text_encoder,
@@ -61,39 +162,13 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
61
  safety_checker=safety_checker,
62
  feature_extractor=feature_extractor,
63
  )
64
-
65
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
66
- r"""
67
- Enable sliced attention computation.
68
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
69
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
70
- Args:
71
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
72
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
73
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
74
- `attention_head_dim` must be a multiple of `slice_size`.
75
- """
76
- if slice_size == "auto":
77
- if isinstance(self.unet.config.attention_head_dim, int):
78
- # half the attention head size is usually a good trade-off between
79
- # speed and memory
80
- slice_size = self.unet.config.attention_head_dim // 2
81
- else:
82
- # if `attention_head_dim` is a list, take the smallest head size
83
- slice_size = min(self.unet.config.attention_head_dim)
84
- self.unet.set_attention_slice(slice_size)
85
-
86
- def disable_attention_slicing(self):
87
- r"""
88
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
89
- back to computing attention in one step.
90
- """
91
- # set slice_size = `None` to disable `attention slicing`
92
- self.enable_attention_slicing(None)
93
 
94
  def enable_vae_slicing(self):
95
  r"""
96
  Enable sliced VAE decoding.
 
97
  When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
98
  steps. This is useful to save some memory and allow larger batch sizes.
99
  """
@@ -106,15 +181,229 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
106
  """
107
  self.vae.disable_slicing()
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  @torch.no_grad()
110
  def __call__(
111
  self,
112
  prompt: Union[str, List[str]],
113
- height: Optional[int] = 512,
114
- width: Optional[int] = 512,
115
- num_inference_steps: Optional[int] = 50,
116
- guidance_scale: Optional[float] = 7.5,
117
- eta: Optional[float] = 0.0,
 
 
118
  generator: Optional[torch.Generator] = None,
119
  latents: Optional[torch.FloatTensor] = None,
120
  output_type: Optional[str] = "pil",
@@ -122,16 +411,16 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
122
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
123
  callback_steps: Optional[int] = 1,
124
  weights: Optional[str] = "",
125
- **kwargs,
126
  ):
127
  r"""
128
  Function invoked when calling the pipeline for generation.
 
129
  Args:
130
  prompt (`str` or `List[str]`):
131
  The prompt or prompts to guide the image generation.
132
- height (`int`, *optional*, defaults to 512):
133
  The height in pixels of the generated image.
134
- width (`int`, *optional*, defaults to 512):
135
  The width in pixels of the generated image.
136
  num_inference_steps (`int`, *optional*, defaults to 50):
137
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -142,6 +431,11 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
142
  Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
143
  1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
144
  usually at the expense of lower image quality.
 
 
 
 
 
145
  eta (`float`, *optional*, defaults to 0.0):
146
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
147
  [`schedulers.DDIMScheduler`], will be ignored for others.
@@ -171,186 +465,117 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
171
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
172
  (nsfw) content, according to the `safety_checker`.
173
  """
 
 
 
174
 
175
- if "torch_device" in kwargs:
176
- device = kwargs.pop("torch_device")
177
- warnings.warn(
178
- "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
179
- " Consider using `pipe.to(torch_device)` instead."
180
- )
181
-
182
- # Set device as before (to be removed in 0.3.0)
183
- if device is None:
184
- device = "cuda" if torch.cuda.is_available() else "cpu"
185
- self.to(device)
186
-
187
- if isinstance(prompt, str):
188
- batch_size = 1
189
- elif isinstance(prompt, list):
190
- batch_size = len(prompt)
191
- else:
192
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
193
-
194
- if height % 8 != 0 or width % 8 != 0:
195
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
196
-
197
- if "|" in prompt:
198
- prompt = [x.strip() for x in prompt.split("|")]
199
- print(f"composing {prompt}...")
200
-
201
- # get prompt text embeddings
202
- text_input = self.tokenizer(
203
- prompt,
204
- padding="max_length",
205
- max_length=self.tokenizer.model_max_length,
206
- truncation=True,
207
- return_tensors="pt",
208
- )
209
- text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
210
-
211
- if not weights:
212
- # specify weights for prompts (excluding the unconditional score)
213
- print("using equal weights for all prompts...")
214
- pos_weights = torch.tensor(
215
- [1 / (text_embeddings.shape[0] - 1)] * (text_embeddings.shape[0] - 1), device=self.device
216
- ).reshape(-1, 1, 1, 1)
217
- neg_weights = torch.tensor([1.0], device=self.device).reshape(-1, 1, 1, 1)
218
- mask = torch.tensor([False] + [True] * pos_weights.shape[0], dtype=torch.bool)
219
- else:
220
- # set prompt weight for each
221
- num_prompts = len(prompt) if isinstance(prompt, list) else 1
222
- weights = [float(w.strip()) for w in weights.split("|")]
223
- if len(weights) < num_prompts:
224
- weights.append(1.0)
225
- weights = torch.tensor(weights, device=self.device)
226
- assert len(weights) == text_embeddings.shape[0], "weights specified are not equal to the number of prompts"
227
- pos_weights = []
228
- neg_weights = []
229
- mask = [] # first one is unconditional score
230
- for w in weights:
231
- if w > 0:
232
- pos_weights.append(w)
233
- mask.append(True)
234
- else:
235
- neg_weights.append(abs(w))
236
- mask.append(False)
237
- # normalize the weights
238
- pos_weights = torch.tensor(pos_weights, device=self.device).reshape(-1, 1, 1, 1)
239
- pos_weights = pos_weights / pos_weights.sum()
240
- neg_weights = torch.tensor(neg_weights, device=self.device).reshape(-1, 1, 1, 1)
241
- neg_weights = neg_weights / neg_weights.sum()
242
- mask = torch.tensor(mask, device=self.device, dtype=torch.bool)
243
 
 
 
 
244
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
245
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
246
  # corresponds to doing no classifier free guidance.
247
  do_classifier_free_guidance = guidance_scale > 1.0
248
- # get unconditional embeddings for classifier free guidance
249
- if do_classifier_free_guidance:
250
- max_length = text_input.input_ids.shape[-1]
251
-
252
- if torch.all(mask):
253
- # no negative prompts, so we use empty string as the negative prompt
254
- uncond_input = self.tokenizer(
255
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
256
- )
257
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
258
 
259
- # For classifier free guidance, we need to do two forward passes.
260
- # Here we concatenate the unconditional and text embeddings into a single batch
261
- # to avoid doing two forward passes
262
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
263
-
264
- # update negative weights
265
- neg_weights = torch.tensor([1.0], device=self.device)
266
- mask = torch.tensor([False] + mask.detach().tolist(), device=self.device, dtype=torch.bool)
267
-
268
- # get the initial random noise unless the user supplied it
269
 
270
- # Unlike in other pipelines, latents need to be generated in the target device
271
- # for 1-to-1 results reproducibility with the CompVis implementation.
272
- # However this currently doesn't work in `mps`.
273
- latents_device = "cpu" if self.device.type == "mps" else self.device
274
- latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
275
- if latents is None:
276
- latents = torch.randn(
277
- latents_shape,
278
- generator=generator,
279
- device=latents_device,
280
- )
 
 
 
 
281
  else:
282
- if latents.shape != latents_shape:
283
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
284
- latents = latents.to(self.device)
285
 
286
- # set timesteps
287
- accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
288
- extra_set_kwargs = {}
289
- if accepts_offset:
290
- extra_set_kwargs["offset"] = 1
291
-
292
- self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
293
-
294
- # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
295
- if isinstance(self.scheduler, LMSDiscreteScheduler):
296
- latents = latents * self.scheduler.sigmas[0]
297
 
298
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
299
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
300
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
301
- # and should be between [0, 1]
302
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
303
- extra_step_kwargs = {}
304
- if accepts_eta:
305
- extra_step_kwargs["eta"] = eta
 
 
 
 
 
 
 
 
306
 
307
- for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
308
- # expand the latents if we are doing classifier free guidance
309
- latent_model_input = (
310
- torch.cat([latents] * text_embeddings.shape[0]) if do_classifier_free_guidance else latents
311
- )
312
- if isinstance(self.scheduler, LMSDiscreteScheduler):
313
- sigma = self.scheduler.sigmas[i]
314
- # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
315
- latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
316
-
317
- # reduce memory by predicting each score sequentially
318
- noise_preds = []
319
- # predict the noise residual
320
- for latent_in, text_embedding_in in zip(
321
- torch.chunk(latent_model_input, chunks=latent_model_input.shape[0], dim=0),
322
- torch.chunk(text_embeddings, chunks=text_embeddings.shape[0], dim=0),
323
- ):
324
- noise_preds.append(self.unet(latent_in, t, encoder_hidden_states=text_embedding_in).sample)
325
- noise_preds = torch.cat(noise_preds, dim=0)
326
-
327
- # perform guidance
328
- if do_classifier_free_guidance:
329
- noise_pred_uncond = (noise_preds[~mask] * neg_weights).sum(dim=0, keepdims=True)
330
- noise_pred_text = (noise_preds[mask] * pos_weights).sum(dim=0, keepdims=True)
331
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
332
-
333
- # compute the previous noisy sample x_t -> x_t-1
334
- if isinstance(self.scheduler, LMSDiscreteScheduler):
335
- latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
336
- else:
 
 
 
337
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
338
 
339
  # call the callback, if provided
340
  if callback is not None and i % callback_steps == 0:
341
  callback(i, t, latents)
342
 
343
- # scale and decode the image latents with vae
344
- latents = 1 / 0.18215 * latents
345
- image = self.vae.decode(latents).sample
 
 
346
 
347
- image = (image / 2 + 0.5).clamp(0, 1)
348
- image = image.cpu().permute(0, 2, 3, 1).numpy()
349
 
350
- # run safety checker
351
- safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
352
- image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
353
 
 
354
  if output_type == "pil":
355
  image = self.numpy_to_pil(image)
356
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
  import inspect
 
16
  from typing import Callable, List, Optional, Union
17
 
18
  import torch
19
 
20
+ from diffusers.utils import is_accelerate_available
21
+ from packaging import version
 
 
 
22
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
23
 
24
+ from ...configuration_utils import FrozenDict
25
+ from ...models import AutoencoderKL, UNet2DConditionModel
26
+ from ...pipeline_utils import DiffusionPipeline
27
+ from ...schedulers import (
28
+ DDIMScheduler,
29
+ DPMSolverMultistepScheduler,
30
+ EulerAncestralDiscreteScheduler,
31
+ EulerDiscreteScheduler,
32
+ LMSDiscreteScheduler,
33
+ PNDMScheduler,
34
+ )
35
+ from ...utils import deprecate, logging
36
+ from . import StableDiffusionPipelineOutput
37
+ from .safety_checker import StableDiffusionSafetyChecker
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
 
43
  class ComposableStableDiffusionPipeline(DiffusionPipeline):
44
  r"""
45
  Pipeline for text-to-image generation using Stable Diffusion.
46
+
47
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
48
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
49
+
50
  Args:
51
  vae ([`AutoencoderKL`]):
52
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
 
62
  A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
63
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
64
  safety_checker ([`StableDiffusionSafetyChecker`]):
65
+ Classification module that estimates whether generated images could be considered offensive or harmful.
66
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
67
  feature_extractor ([`CLIPFeatureExtractor`]):
68
  Model that extracts features from generated images to be used as inputs for the `safety_checker`.
69
  """
70
+ _optional_components = ["safety_checker", "feature_extractor"]
71
 
72
  def __init__(
73
  self,
 
75
  text_encoder: CLIPTextModel,
76
  tokenizer: CLIPTokenizer,
77
  unet: UNet2DConditionModel,
78
+ scheduler: Union[
79
+ DDIMScheduler,
80
+ PNDMScheduler,
81
+ LMSDiscreteScheduler,
82
+ EulerDiscreteScheduler,
83
+ EulerAncestralDiscreteScheduler,
84
+ DPMSolverMultistepScheduler,
85
+ ],
86
  safety_checker: StableDiffusionSafetyChecker,
87
  feature_extractor: CLIPFeatureExtractor,
88
+ requires_safety_checker: bool = True,
89
  ):
90
  super().__init__()
91
+
92
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
93
+ deprecation_message = (
94
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
95
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
96
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
97
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
98
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
99
+ " file"
100
+ )
101
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
102
+ new_config = dict(scheduler.config)
103
+ new_config["steps_offset"] = 1
104
+ scheduler._internal_dict = FrozenDict(new_config)
105
+
106
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
107
+ deprecation_message = (
108
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
109
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
110
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
111
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
112
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
113
+ )
114
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
115
+ new_config = dict(scheduler.config)
116
+ new_config["clip_sample"] = False
117
+ scheduler._internal_dict = FrozenDict(new_config)
118
+
119
+ if safety_checker is None and requires_safety_checker:
120
+ logger.warning(
121
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
122
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
123
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
124
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
125
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
126
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
127
+ )
128
+
129
+ if safety_checker is not None and feature_extractor is None:
130
+ raise ValueError(
131
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
132
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
133
+ )
134
+
135
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
136
+ version.parse(unet.config._diffusers_version).base_version
137
+ ) < version.parse("0.9.0.dev0")
138
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
139
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
140
+ deprecation_message = (
141
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
142
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
143
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
144
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
145
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
146
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
147
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
148
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
149
+ " the `unet/config.json` file"
150
+ )
151
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
152
+ new_config = dict(unet.config)
153
+ new_config["sample_size"] = 64
154
+ unet._internal_dict = FrozenDict(new_config)
155
+
156
  self.register_modules(
157
  vae=vae,
158
  text_encoder=text_encoder,
 
162
  safety_checker=safety_checker,
163
  feature_extractor=feature_extractor,
164
  )
165
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
166
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  def enable_vae_slicing(self):
169
  r"""
170
  Enable sliced VAE decoding.
171
+
172
  When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
173
  steps. This is useful to save some memory and allow larger batch sizes.
174
  """
 
181
  """
182
  self.vae.disable_slicing()
183
 
184
+ def enable_sequential_cpu_offload(self, gpu_id=0):
185
+ r"""
186
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
187
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
188
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
189
+ """
190
+ if is_accelerate_available():
191
+ from accelerate import cpu_offload
192
+ else:
193
+ raise ImportError("Please install accelerate via `pip install accelerate`")
194
+
195
+ device = torch.device(f"cuda:{gpu_id}")
196
+
197
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
198
+ if cpu_offloaded_model is not None:
199
+ cpu_offload(cpu_offloaded_model, device)
200
+
201
+ if self.safety_checker is not None:
202
+ # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
203
+ # fix by only offloading self.safety_checker for now
204
+ cpu_offload(self.safety_checker.vision_model, device)
205
+
206
+ @property
207
+ def _execution_device(self):
208
+ r"""
209
+ Returns the device on which the pipeline's models will be executed. After calling
210
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
211
+ hooks.
212
+ """
213
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
214
+ return self.device
215
+ for module in self.unet.modules():
216
+ if (
217
+ hasattr(module, "_hf_hook")
218
+ and hasattr(module._hf_hook, "execution_device")
219
+ and module._hf_hook.execution_device is not None
220
+ ):
221
+ return torch.device(module._hf_hook.execution_device)
222
+ return self.device
223
+
224
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
225
+ r"""
226
+ Encodes the prompt into text encoder hidden states.
227
+
228
+ Args:
229
+ prompt (`str` or `list(int)`):
230
+ prompt to be encoded
231
+ device: (`torch.device`):
232
+ torch device
233
+ num_images_per_prompt (`int`):
234
+ number of images that should be generated per prompt
235
+ do_classifier_free_guidance (`bool`):
236
+ whether to use classifier free guidance or not
237
+ negative_prompt (`str` or `List[str]`):
238
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
239
+ if `guidance_scale` is less than `1`).
240
+ """
241
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
242
+
243
+ text_inputs = self.tokenizer(
244
+ prompt,
245
+ padding="max_length",
246
+ max_length=self.tokenizer.model_max_length,
247
+ truncation=True,
248
+ return_tensors="pt",
249
+ )
250
+ text_input_ids = text_inputs.input_ids
251
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
252
+
253
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
254
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
255
+ logger.warning(
256
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
257
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
258
+ )
259
+
260
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
261
+ attention_mask = text_inputs.attention_mask.to(device)
262
+ else:
263
+ attention_mask = None
264
+
265
+ text_embeddings = self.text_encoder(
266
+ text_input_ids.to(device),
267
+ attention_mask=attention_mask,
268
+ )
269
+ text_embeddings = text_embeddings[0]
270
+
271
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
272
+ bs_embed, seq_len, _ = text_embeddings.shape
273
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
274
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
275
+
276
+ # get unconditional embeddings for classifier free guidance
277
+ if do_classifier_free_guidance:
278
+ uncond_tokens: List[str]
279
+ if negative_prompt is None:
280
+ uncond_tokens = [""] * batch_size
281
+ elif type(prompt) is not type(negative_prompt):
282
+ raise TypeError(
283
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
284
+ f" {type(prompt)}."
285
+ )
286
+ elif isinstance(negative_prompt, str):
287
+ uncond_tokens = [negative_prompt]
288
+ elif batch_size != len(negative_prompt):
289
+ raise ValueError(
290
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
291
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
292
+ " the batch size of `prompt`."
293
+ )
294
+ else:
295
+ uncond_tokens = negative_prompt
296
+
297
+ max_length = text_input_ids.shape[-1]
298
+ uncond_input = self.tokenizer(
299
+ uncond_tokens,
300
+ padding="max_length",
301
+ max_length=max_length,
302
+ truncation=True,
303
+ return_tensors="pt",
304
+ )
305
+
306
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
307
+ attention_mask = uncond_input.attention_mask.to(device)
308
+ else:
309
+ attention_mask = None
310
+
311
+ uncond_embeddings = self.text_encoder(
312
+ uncond_input.input_ids.to(device),
313
+ attention_mask=attention_mask,
314
+ )
315
+ uncond_embeddings = uncond_embeddings[0]
316
+
317
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
318
+ seq_len = uncond_embeddings.shape[1]
319
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
320
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
321
+
322
+ # For classifier free guidance, we need to do two forward passes.
323
+ # Here we concatenate the unconditional and text embeddings into a single batch
324
+ # to avoid doing two forward passes
325
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
326
+
327
+ return text_embeddings
328
+
329
+ def run_safety_checker(self, image, device, dtype):
330
+ if self.safety_checker is not None:
331
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
332
+ image, has_nsfw_concept = self.safety_checker(
333
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
334
+ )
335
+ else:
336
+ has_nsfw_concept = None
337
+ return image, has_nsfw_concept
338
+
339
+ def decode_latents(self, latents):
340
+ latents = 1 / 0.18215 * latents
341
+ image = self.vae.decode(latents).sample
342
+ image = (image / 2 + 0.5).clamp(0, 1)
343
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
344
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
345
+ return image
346
+
347
+ def prepare_extra_step_kwargs(self, generator, eta):
348
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
350
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
351
+ # and should be between [0, 1]
352
+
353
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
354
+ extra_step_kwargs = {}
355
+ if accepts_eta:
356
+ extra_step_kwargs["eta"] = eta
357
+
358
+ # check if the scheduler accepts generator
359
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
360
+ if accepts_generator:
361
+ extra_step_kwargs["generator"] = generator
362
+ return extra_step_kwargs
363
+
364
+ def check_inputs(self, prompt, height, width, callback_steps):
365
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
366
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
367
+
368
+ if height % 8 != 0 or width % 8 != 0:
369
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
370
+
371
+ if (callback_steps is None) or (
372
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
373
+ ):
374
+ raise ValueError(
375
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
376
+ f" {type(callback_steps)}."
377
+ )
378
+
379
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
380
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
381
+ if latents is None:
382
+ if device.type == "mps":
383
+ # randn does not work reproducibly on mps
384
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
385
+ else:
386
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
387
+ else:
388
+ if latents.shape != shape:
389
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
390
+ latents = latents.to(device)
391
+
392
+ # scale the initial noise by the standard deviation required by the scheduler
393
+ latents = latents * self.scheduler.init_noise_sigma
394
+ return latents
395
+
396
  @torch.no_grad()
397
  def __call__(
398
  self,
399
  prompt: Union[str, List[str]],
400
+ height: Optional[int] = None,
401
+ width: Optional[int] = None,
402
+ num_inference_steps: int = 50,
403
+ guidance_scale: float = 7.5,
404
+ negative_prompt: Optional[Union[str, List[str]]] = None,
405
+ num_images_per_prompt: Optional[int] = 1,
406
+ eta: float = 0.0,
407
  generator: Optional[torch.Generator] = None,
408
  latents: Optional[torch.FloatTensor] = None,
409
  output_type: Optional[str] = "pil",
 
411
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
412
  callback_steps: Optional[int] = 1,
413
  weights: Optional[str] = "",
 
414
  ):
415
  r"""
416
  Function invoked when calling the pipeline for generation.
417
+
418
  Args:
419
  prompt (`str` or `List[str]`):
420
  The prompt or prompts to guide the image generation.
421
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
422
  The height in pixels of the generated image.
423
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
424
  The width in pixels of the generated image.
425
  num_inference_steps (`int`, *optional*, defaults to 50):
426
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
 
431
  Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
432
  1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
433
  usually at the expense of lower image quality.
434
+ negative_prompt (`str` or `List[str]`, *optional*):
435
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
436
+ if `guidance_scale` is less than `1`).
437
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
438
+ The number of images to generate per prompt.
439
  eta (`float`, *optional*, defaults to 0.0):
440
  Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
441
  [`schedulers.DDIMScheduler`], will be ignored for others.
 
465
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
466
  (nsfw) content, according to the `safety_checker`.
467
  """
468
+ # 0. Default height and width to unet
469
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
470
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
471
 
472
+ # 1. Check inputs. Raise error if not correct
473
+ self.check_inputs(prompt, height, width, callback_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
 
475
+ # 2. Define call parameters
476
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
477
+ device = self._execution_device
478
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
479
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
480
  # corresponds to doing no classifier free guidance.
481
  do_classifier_free_guidance = guidance_scale > 1.0
 
 
 
 
 
 
 
 
 
 
482
 
483
+ if "|" in prompt:
484
+ prompt = [x.strip() for x in prompt.split("|")]
485
+ print(f"composing {prompt}...")
 
 
 
 
 
 
 
486
 
487
+ if not weights:
488
+ # specify weights for prompts (excluding the unconditional score)
489
+ print("using equal positive weights (conjunction) for all prompts...")
490
+ weights = torch.tensor([guidance_scale] * len(prompt), device=self.device).reshape(-1, 1, 1, 1)
491
+ else:
492
+ # set prompt weight for each
493
+ num_prompts = len(prompt) if isinstance(prompt, list) else 1
494
+ weights = [float(w.strip()) for w in weights.split("|")]
495
+ # guidance scale as the default
496
+ if len(weights) < num_prompts:
497
+ weights.append(guidance_scale)
498
+ else:
499
+ weights = weights[:num_prompts]
500
+ assert len(weights) == len(prompt), "weights specified are not equal to the number of prompts"
501
+ weights = torch.tensor(weights, device=self.device).reshape(-1, 1, 1, 1)
502
  else:
503
+ weights = guidance_scale
 
 
504
 
505
+ # 3. Encode input prompt
506
+ text_embeddings = self._encode_prompt(
507
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
508
+ )
 
 
 
 
 
 
 
509
 
510
+ # 4. Prepare timesteps
511
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
512
+ timesteps = self.scheduler.timesteps
513
+
514
+ # 5. Prepare latent variables
515
+ num_channels_latents = self.unet.in_channels
516
+ latents = self.prepare_latents(
517
+ batch_size * num_images_per_prompt,
518
+ num_channels_latents,
519
+ height,
520
+ width,
521
+ text_embeddings.dtype,
522
+ device,
523
+ generator,
524
+ latents,
525
+ )
526
 
527
+ # composable diffusion
528
+ if isinstance(prompt, list) and batch_size == 1:
529
+ # remove extra unconditional embedding
530
+ # N = one unconditional embed + conditional embeds
531
+ text_embeddings = text_embeddings[len(prompt) - 1 :]
532
+
533
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
534
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
535
+
536
+ # 7. Denoising loop
537
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
538
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
539
+ for i, t in enumerate(timesteps):
540
+ # expand the latents if we are doing classifier free guidance
541
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
542
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
543
+
544
+ # predict the noise residual
545
+ noise_pred = []
546
+ for j in range(text_embeddings.shape[0]):
547
+ noise_pred.append(
548
+ self.unet(latent_model_input[:1], t, encoder_hidden_states=text_embeddings[j : j + 1]).sample
549
+ )
550
+ noise_pred = torch.cat(noise_pred, dim=0)
551
+
552
+ # perform guidance
553
+ if do_classifier_free_guidance:
554
+ noise_pred_uncond, noise_pred_text = noise_pred[:1], noise_pred[1:]
555
+ noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(
556
+ dim=0, keepdims=True
557
+ )
558
+
559
+ # compute the previous noisy sample x_t -> x_t-1
560
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
561
 
562
  # call the callback, if provided
563
  if callback is not None and i % callback_steps == 0:
564
  callback(i, t, latents)
565
 
566
+ # call the callback, if provided
567
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
568
+ progress_bar.update()
569
+ if callback is not None and i % callback_steps == 0:
570
+ callback(i, t, latents)
571
 
572
+ # 8. Post-processing
573
+ image = self.decode_latents(latents)
574
 
575
+ # 9. Run safety checker
576
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
 
577
 
578
+ # 10. Convert to PIL
579
  if output_type == "pil":
580
  image = self.numpy_to_pil(image)
581