Shuang59 commited on
Commit
471b0b2
β€’
1 Parent(s): d469c4f

Update composable_stable_diffusion_pipeline.py

Browse files
composable_stable_diffusion_pipeline.py CHANGED
@@ -15,8 +15,60 @@ from diffusers.pipeline_utils import DiffusionPipeline
15
  from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
16
  from safety_checker import StableDiffusionSafetyChecker
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  class ComposableStableDiffusionPipeline(DiffusionPipeline):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def __init__(
21
  self,
22
  vae: AutoencoderKL,
@@ -39,6 +91,33 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
39
  feature_extractor=feature_extractor,
40
  )
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  @torch.no_grad()
43
  def __call__(
44
  self,
@@ -49,9 +128,56 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
49
  guidance_scale: Optional[float] = 7.5,
50
  eta: Optional[float] = 0.0,
51
  generator: Optional[torch.Generator] = None,
 
52
  output_type: Optional[str] = "pil",
 
 
53
  **kwargs,
54
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  if "torch_device" in kwargs:
56
  device = kwargs.pop("torch_device")
57
  warnings.warn(
@@ -76,7 +202,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
76
 
77
  if '|' in prompt:
78
  prompt = [x.strip() for x in prompt.split('|')]
79
- print(prompt)
80
 
81
  # get prompt text embeddings
82
  text_input = self.tokenizer(
@@ -88,6 +214,38 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
88
  )
89
  text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
92
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
93
  # corresponds to doing no classifier free guidance.
@@ -95,22 +253,40 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
95
  # get unconditional embeddings for classifier free guidance
96
  if do_classifier_free_guidance:
97
  max_length = text_input.input_ids.shape[-1]
98
- uncond_input = self.tokenizer(
99
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  )
101
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
102
-
103
- # For classifier free guidance, we need to do two forward passes.
104
- # Here we concatenate the unconditional and text embeddings into a single batch
105
- # to avoid doing two forward passes
106
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
107
-
108
- # get the intial random noise
109
- latents = torch.randn(
110
- (batch_size, self.unet.in_channels, height // 8, width // 8),
111
- generator=generator,
112
- device=self.device,
113
- )
114
 
115
  # set timesteps
116
  accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
@@ -133,31 +309,38 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
133
  if accepts_eta:
134
  extra_step_kwargs["eta"] = eta
135
 
136
- for i, t in tqdm(enumerate(self.scheduler.timesteps)):
137
  # expand the latents if we are doing classifier free guidance
138
  latent_model_input = torch.cat([latents] * text_embeddings.shape[0]) if do_classifier_free_guidance else latents
139
  if isinstance(self.scheduler, LMSDiscreteScheduler):
140
  sigma = self.scheduler.sigmas[i]
 
141
  latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
142
 
 
 
143
  # predict the noise residual
144
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
 
 
 
 
145
 
146
  # perform guidance
147
  if do_classifier_free_guidance:
148
- pred_decomp = noise_pred.chunk(text_embeddings.shape[0])
149
- noise_pred_uncond, noise_pred_text = pred_decomp[0], torch.cat(pred_decomp[1:], dim=0).mean(dim=0, keepdim=True)
150
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
151
 
152
  # compute the previous noisy sample x_t -> x_t-1
153
  if isinstance(self.scheduler, LMSDiscreteScheduler):
154
- latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
155
  else:
156
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
157
 
158
  # scale and decode the image latents with vae
159
  latents = 1 / 0.18215 * latents
160
- image = self.vae.decode(latents)
161
 
162
  image = (image / 2 + 0.5).clamp(0, 1)
163
  image = image.cpu().permute(0, 2, 3, 1).numpy()
@@ -169,4 +352,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
169
  if output_type == "pil":
170
  image = self.numpy_to_pil(image)
171
 
172
- return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
 
 
 
 
15
  from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
16
  from safety_checker import StableDiffusionSafetyChecker
17
 
18
+ from dataclasses import dataclass
19
+ from typing import List, Union
20
+
21
+ import numpy as np
22
+
23
+ import PIL
24
+
25
+ from diffusers.utils import BaseOutput
26
+
27
+
28
+ @dataclass
29
+ class StableDiffusionPipelineOutput(BaseOutput):
30
+ """
31
+ Output class for Stable Diffusion pipelines.
32
+ Args:
33
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
34
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
35
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
36
+ nsfw_content_detected (`List[bool]`)
37
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
38
+ (nsfw) content.
39
+ """
40
+
41
+ images: Union[List[PIL.Image.Image], np.ndarray]
42
+ nsfw_content_detected: List[bool]
43
 
44
  class ComposableStableDiffusionPipeline(DiffusionPipeline):
45
+ r"""
46
+ Pipeline for text-to-image generation using Stable Diffusion.
47
+
48
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
49
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
50
+
51
+ Args:
52
+ vae ([`AutoencoderKL`]):
53
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
54
+ text_encoder ([`CLIPTextModel`]):
55
+ Frozen text-encoder. Stable Diffusion uses the text portion of
56
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
57
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
58
+ tokenizer (`CLIPTokenizer`):
59
+ Tokenizer of class
60
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
61
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
62
+ scheduler ([`SchedulerMixin`]):
63
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
64
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
65
+ safety_checker ([`StableDiffusionSafetyChecker`]):
66
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
67
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
68
+ feature_extractor ([`CLIPFeatureExtractor`]):
69
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
70
+ """
71
+
72
  def __init__(
73
  self,
74
  vae: AutoencoderKL,
 
91
  feature_extractor=feature_extractor,
92
  )
93
 
94
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
95
+ r"""
96
+ Enable sliced attention computation.
97
+
98
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
99
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
100
+
101
+ Args:
102
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
103
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
104
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
105
+ `attention_head_dim` must be a multiple of `slice_size`.
106
+ """
107
+ if slice_size == "auto":
108
+ # half the attention head size is usually a good trade-off between
109
+ # speed and memory
110
+ slice_size = self.unet.config.attention_head_dim // 2
111
+ self.unet.set_attention_slice(slice_size)
112
+
113
+ def disable_attention_slicing(self):
114
+ r"""
115
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
116
+ back to computing attention in one step.
117
+ """
118
+ # set slice_size = `None` to disable `attention slicing`
119
+ self.enable_attention_slicing(None)
120
+
121
  @torch.no_grad()
122
  def __call__(
123
  self,
 
128
  guidance_scale: Optional[float] = 7.5,
129
  eta: Optional[float] = 0.0,
130
  generator: Optional[torch.Generator] = None,
131
+ latents: Optional[torch.FloatTensor] = None,
132
  output_type: Optional[str] = "pil",
133
+ return_dict: bool = True,
134
+ weights: Optional[str] = "",
135
  **kwargs,
136
  ):
137
+ r"""
138
+ Function invoked when calling the pipeline for generation.
139
+
140
+ Args:
141
+ prompt (`str` or `List[str]`):
142
+ The prompt or prompts to guide the image generation.
143
+ height (`int`, *optional*, defaults to 512):
144
+ The height in pixels of the generated image.
145
+ width (`int`, *optional*, defaults to 512):
146
+ The width in pixels of the generated image.
147
+ num_inference_steps (`int`, *optional*, defaults to 50):
148
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
149
+ expense of slower inference.
150
+ guidance_scale (`float`, *optional*, defaults to 7.5):
151
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
152
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
153
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
154
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
155
+ usually at the expense of lower image quality.
156
+ eta (`float`, *optional*, defaults to 0.0):
157
+ Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
158
+ [`schedulers.DDIMScheduler`], will be ignored for others.
159
+ generator (`torch.Generator`, *optional*):
160
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
161
+ deterministic.
162
+ latents (`torch.FloatTensor`, *optional*):
163
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
164
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
165
+ tensor will ge generated by sampling using the supplied random `generator`.
166
+ output_type (`str`, *optional*, defaults to `"pil"`):
167
+ The output format of the generate image. Choose between
168
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
169
+ return_dict (`bool`, *optional*, defaults to `True`):
170
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
171
+ plain tuple.
172
+
173
+ Returns:
174
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
175
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
176
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
177
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
178
+ (nsfw) content, according to the `safety_checker`.
179
+ """
180
+
181
  if "torch_device" in kwargs:
182
  device = kwargs.pop("torch_device")
183
  warnings.warn(
 
202
 
203
  if '|' in prompt:
204
  prompt = [x.strip() for x in prompt.split('|')]
205
+ print(f"composing {prompt}...")
206
 
207
  # get prompt text embeddings
208
  text_input = self.tokenizer(
 
214
  )
215
  text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
216
 
217
+ if not weights:
218
+ # specify weights for prompts (excluding the unconditional score)
219
+ print('using equal weights for all prompts...')
220
+ pos_weights = torch.tensor([1 / (text_embeddings.shape[0] - 1)] * (text_embeddings.shape[0] - 1),
221
+ device=self.device).reshape(-1, 1, 1, 1)
222
+ neg_weights = torch.tensor([1.], device=self.device).reshape(-1, 1, 1, 1)
223
+ mask = torch.tensor([False] + [True] * pos_weights.shape[0], dtype=torch.bool)
224
+ else:
225
+ # set prompt weight for each
226
+ num_prompts = len(prompt) if isinstance(prompt, list) else 1
227
+ weights = [float(w.strip()) for w in weights.split("|")]
228
+ if len(weights) < num_prompts:
229
+ weights.append(1.)
230
+ weights = torch.tensor(weights, device=self.device)
231
+ assert len(weights) == text_embeddings.shape[0], "weights specified are not equal to the number of prompts"
232
+ pos_weights = []
233
+ neg_weights = []
234
+ mask = [] # first one is unconditional score
235
+ for w in weights:
236
+ if w > 0:
237
+ pos_weights.append(w)
238
+ mask.append(True)
239
+ else:
240
+ neg_weights.append(abs(w))
241
+ mask.append(False)
242
+ # normalize the weights
243
+ pos_weights = torch.tensor(pos_weights, device=self.device).reshape(-1, 1, 1, 1)
244
+ pos_weights = pos_weights / pos_weights.sum()
245
+ neg_weights = torch.tensor(neg_weights, device=self.device).reshape(-1, 1, 1, 1)
246
+ neg_weights = neg_weights / neg_weights.sum()
247
+ mask = torch.tensor(mask, device=self.device, dtype=torch.bool)
248
+
249
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
250
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
251
  # corresponds to doing no classifier free guidance.
 
253
  # get unconditional embeddings for classifier free guidance
254
  if do_classifier_free_guidance:
255
  max_length = text_input.input_ids.shape[-1]
256
+
257
+ if torch.all(mask):
258
+ # no negative prompts, so we use empty string as the negative prompt
259
+ uncond_input = self.tokenizer(
260
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
261
+ )
262
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
263
+
264
+ # For classifier free guidance, we need to do two forward passes.
265
+ # Here we concatenate the unconditional and text embeddings into a single batch
266
+ # to avoid doing two forward passes
267
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
268
+
269
+ # update negative weights
270
+ neg_weights = torch.tensor([1.], device=self.device)
271
+ mask = torch.tensor([False] + mask.detach().tolist(), device=self.device, dtype=torch.bool)
272
+
273
+ # get the initial random noise unless the user supplied it
274
+
275
+ # Unlike in other pipelines, latents need to be generated in the target device
276
+ # for 1-to-1 results reproducibility with the CompVis implementation.
277
+ # However this currently doesn't work in `mps`.
278
+ latents_device = "cpu" if self.device.type == "mps" else self.device
279
+ latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
280
+ if latents is None:
281
+ latents = torch.randn(
282
+ latents_shape,
283
+ generator=generator,
284
+ device=latents_device,
285
  )
286
+ else:
287
+ if latents.shape != latents_shape:
288
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
289
+ latents = latents.to(self.device)
 
 
 
 
 
 
 
 
 
290
 
291
  # set timesteps
292
  accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
 
309
  if accepts_eta:
310
  extra_step_kwargs["eta"] = eta
311
 
312
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
313
  # expand the latents if we are doing classifier free guidance
314
  latent_model_input = torch.cat([latents] * text_embeddings.shape[0]) if do_classifier_free_guidance else latents
315
  if isinstance(self.scheduler, LMSDiscreteScheduler):
316
  sigma = self.scheduler.sigmas[i]
317
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
318
  latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
319
 
320
+ # reduce memory by predicting each score sequentially
321
+ noise_preds = []
322
  # predict the noise residual
323
+ for latent_in, text_embedding_in in zip(
324
+ torch.chunk(latent_model_input, chunks=latent_model_input.shape[0], dim=0),
325
+ torch.chunk(text_embeddings, chunks=text_embeddings.shape[0], dim=0)):
326
+ noise_preds.append(self.unet(latent_in, t, encoder_hidden_states=text_embedding_in).sample)
327
+ noise_preds = torch.cat(noise_preds, dim=0)
328
 
329
  # perform guidance
330
  if do_classifier_free_guidance:
331
+ noise_pred_uncond = noise_preds[~mask] * neg_weights
332
+ noise_pred_text = (noise_preds[mask] * pos_weights).sum(dim=0, keepdims=True)
333
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
334
 
335
  # compute the previous noisy sample x_t -> x_t-1
336
  if isinstance(self.scheduler, LMSDiscreteScheduler):
337
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
338
  else:
339
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
340
 
341
  # scale and decode the image latents with vae
342
  latents = 1 / 0.18215 * latents
343
+ image = self.vae.decode(latents).sample
344
 
345
  image = (image / 2 + 0.5).clamp(0, 1)
346
  image = image.cpu().permute(0, 2, 3, 1).numpy()
 
352
  if output_type == "pil":
353
  image = self.numpy_to_pil(image)
354
 
355
+ if not return_dict:
356
+ return (image, has_nsfw_concept)
357
+
358
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)