AlanB commited on
Commit
6093aac
·
1 Parent(s): bcf485e

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +345 -0
pipeline.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modified based on diffusion library from Huggingface: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
3
+ """
4
+ import inspect
5
+ import warnings
6
+ from typing import Callable, List, Optional, Union
7
+
8
+ import torch
9
+
10
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
+ from diffusers.pipeline_utils import DiffusionPipeline
12
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
+ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
15
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
16
+
17
+
18
+ class ComposableStableDiffusionPipeline(DiffusionPipeline):
19
+ r"""
20
+ Pipeline for text-to-image generation using Stable Diffusion.
21
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
22
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
23
+ Args:
24
+ vae ([`AutoencoderKL`]):
25
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
26
+ text_encoder ([`CLIPTextModel`]):
27
+ Frozen text-encoder. Stable Diffusion uses the text portion of
28
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
29
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
30
+ tokenizer (`CLIPTokenizer`):
31
+ Tokenizer of class
32
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
33
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
34
+ scheduler ([`SchedulerMixin`]):
35
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
36
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
37
+ safety_checker ([`StableDiffusionSafetyChecker`]):
38
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
39
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
40
+ feature_extractor ([`CLIPFeatureExtractor`]):
41
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ vae: AutoencoderKL,
47
+ text_encoder: CLIPTextModel,
48
+ tokenizer: CLIPTokenizer,
49
+ unet: UNet2DConditionModel,
50
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler],
51
+ safety_checker: StableDiffusionSafetyChecker,
52
+ feature_extractor: CLIPFeatureExtractor,
53
+ ):
54
+ super().__init__()
55
+ self.register_modules(
56
+ vae=vae,
57
+ text_encoder=text_encoder,
58
+ tokenizer=tokenizer,
59
+ unet=unet,
60
+ scheduler=scheduler,
61
+ safety_checker=safety_checker,
62
+ feature_extractor=feature_extractor,
63
+ )
64
+
65
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
66
+ r"""
67
+ Enable sliced attention computation.
68
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
69
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
70
+ Args:
71
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
72
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
73
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
74
+ `attention_head_dim` must be a multiple of `slice_size`.
75
+ """
76
+ if slice_size == "auto":
77
+ if isinstance(self.unet.config.attention_head_dim, int):
78
+ # half the attention head size is usually a good trade-off between
79
+ # speed and memory
80
+ slice_size = self.unet.config.attention_head_dim // 2
81
+ else:
82
+ # if `attention_head_dim` is a list, take the smallest head size
83
+ slice_size = min(self.unet.config.attention_head_dim)
84
+ self.unet.set_attention_slice(slice_size)
85
+
86
+ def disable_attention_slicing(self):
87
+ r"""
88
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
89
+ back to computing attention in one step.
90
+ """
91
+ # set slice_size = `None` to disable `attention slicing`
92
+ self.enable_attention_slicing(None)
93
+
94
+ @torch.no_grad()
95
+ def __call__(
96
+ self,
97
+ prompt: Union[str, List[str]],
98
+ height: Optional[int] = 512,
99
+ width: Optional[int] = 512,
100
+ num_inference_steps: Optional[int] = 50,
101
+ guidance_scale: Optional[float] = 7.5,
102
+ eta: Optional[float] = 0.0,
103
+ generator: Optional[torch.Generator] = None,
104
+ latents: Optional[torch.FloatTensor] = None,
105
+ output_type: Optional[str] = "pil",
106
+ return_dict: bool = True,
107
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
108
+ callback_steps: Optional[int] = 1,
109
+ weights: Optional[str] = "",
110
+ **kwargs,
111
+ ):
112
+ r"""
113
+ Function invoked when calling the pipeline for generation.
114
+ Args:
115
+ prompt (`str` or `List[str]`):
116
+ The prompt or prompts to guide the image generation.
117
+ height (`int`, *optional*, defaults to 512):
118
+ The height in pixels of the generated image.
119
+ width (`int`, *optional*, defaults to 512):
120
+ The width in pixels of the generated image.
121
+ num_inference_steps (`int`, *optional*, defaults to 50):
122
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
123
+ expense of slower inference.
124
+ guidance_scale (`float`, *optional*, defaults to 7.5):
125
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
126
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
127
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
128
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
129
+ usually at the expense of lower image quality.
130
+ eta (`float`, *optional*, defaults to 0.0):
131
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
132
+ [`schedulers.DDIMScheduler`], will be ignored for others.
133
+ generator (`torch.Generator`, *optional*):
134
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
135
+ deterministic.
136
+ latents (`torch.FloatTensor`, *optional*):
137
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
138
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
139
+ tensor will ge generated by sampling using the supplied random `generator`.
140
+ output_type (`str`, *optional*, defaults to `"pil"`):
141
+ The output format of the generate image. Choose between
142
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
143
+ return_dict (`bool`, *optional*, defaults to `True`):
144
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
145
+ plain tuple.
146
+ callback (`Callable`, *optional*):
147
+ A function that will be called every `callback_steps` steps during inference. The function will be
148
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
149
+ callback_steps (`int`, *optional*, defaults to 1):
150
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
151
+ called at every step.
152
+ Returns:
153
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
154
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
155
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
156
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
157
+ (nsfw) content, according to the `safety_checker`.
158
+ """
159
+
160
+ if "torch_device" in kwargs:
161
+ device = kwargs.pop("torch_device")
162
+ warnings.warn(
163
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
164
+ " Consider using `pipe.to(torch_device)` instead."
165
+ )
166
+
167
+ # Set device as before (to be removed in 0.3.0)
168
+ if device is None:
169
+ device = "cuda" if torch.cuda.is_available() else "cpu"
170
+ self.to(device)
171
+
172
+ if isinstance(prompt, str):
173
+ batch_size = 1
174
+ elif isinstance(prompt, list):
175
+ batch_size = len(prompt)
176
+ else:
177
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
178
+
179
+ if height % 8 != 0 or width % 8 != 0:
180
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
181
+
182
+ if "|" in prompt:
183
+ prompt = [x.strip() for x in prompt.split("|")]
184
+ print(f"composing {prompt}...")
185
+
186
+ # get prompt text embeddings
187
+ text_input = self.tokenizer(
188
+ prompt,
189
+ padding="max_length",
190
+ max_length=self.tokenizer.model_max_length,
191
+ truncation=True,
192
+ return_tensors="pt",
193
+ )
194
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
195
+
196
+ if not weights:
197
+ # specify weights for prompts (excluding the unconditional score)
198
+ print("using equal weights for all prompts...")
199
+ pos_weights = torch.tensor(
200
+ [1 / (text_embeddings.shape[0] - 1)] * (text_embeddings.shape[0] - 1), device=self.device
201
+ ).reshape(-1, 1, 1, 1)
202
+ neg_weights = torch.tensor([1.0], device=self.device).reshape(-1, 1, 1, 1)
203
+ mask = torch.tensor([False] + [True] * pos_weights.shape[0], dtype=torch.bool)
204
+ else:
205
+ # set prompt weight for each
206
+ num_prompts = len(prompt) if isinstance(prompt, list) else 1
207
+ weights = [float(w.strip()) for w in weights.split("|")]
208
+ if len(weights) < num_prompts:
209
+ weights.append(1.0)
210
+ weights = torch.tensor(weights, device=self.device)
211
+ assert len(weights) == text_embeddings.shape[0], "weights specified are not equal to the number of prompts"
212
+ pos_weights = []
213
+ neg_weights = []
214
+ mask = [] # first one is unconditional score
215
+ for w in weights:
216
+ if w > 0:
217
+ pos_weights.append(w)
218
+ mask.append(True)
219
+ else:
220
+ neg_weights.append(abs(w))
221
+ mask.append(False)
222
+ # normalize the weights
223
+ pos_weights = torch.tensor(pos_weights, device=self.device).reshape(-1, 1, 1, 1)
224
+ pos_weights = pos_weights / pos_weights.sum()
225
+ neg_weights = torch.tensor(neg_weights, device=self.device).reshape(-1, 1, 1, 1)
226
+ neg_weights = neg_weights / neg_weights.sum()
227
+ mask = torch.tensor(mask, device=self.device, dtype=torch.bool)
228
+
229
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
230
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
231
+ # corresponds to doing no classifier free guidance.
232
+ do_classifier_free_guidance = guidance_scale > 1.0
233
+ # get unconditional embeddings for classifier free guidance
234
+ if do_classifier_free_guidance:
235
+ max_length = text_input.input_ids.shape[-1]
236
+
237
+ if torch.all(mask):
238
+ # no negative prompts, so we use empty string as the negative prompt
239
+ uncond_input = self.tokenizer(
240
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
241
+ )
242
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
243
+
244
+ # For classifier free guidance, we need to do two forward passes.
245
+ # Here we concatenate the unconditional and text embeddings into a single batch
246
+ # to avoid doing two forward passes
247
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
248
+
249
+ # update negative weights
250
+ neg_weights = torch.tensor([1.0], device=self.device)
251
+ mask = torch.tensor([False] + mask.detach().tolist(), device=self.device, dtype=torch.bool)
252
+
253
+ # get the initial random noise unless the user supplied it
254
+
255
+ # Unlike in other pipelines, latents need to be generated in the target device
256
+ # for 1-to-1 results reproducibility with the CompVis implementation.
257
+ # However this currently doesn't work in `mps`.
258
+ latents_device = "cpu" if self.device.type == "mps" else self.device
259
+ latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
260
+ if latents is None:
261
+ latents = torch.randn(
262
+ latents_shape,
263
+ generator=generator,
264
+ device=latents_device,
265
+ )
266
+ else:
267
+ if latents.shape != latents_shape:
268
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
269
+ latents = latents.to(self.device)
270
+
271
+ # set timesteps
272
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
273
+ extra_set_kwargs = {}
274
+ if accepts_offset:
275
+ extra_set_kwargs["offset"] = 1
276
+
277
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
278
+
279
+ # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
280
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
281
+ latents = latents * self.scheduler.sigmas[0]
282
+
283
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
284
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
285
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
286
+ # and should be between [0, 1]
287
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
288
+ extra_step_kwargs = {}
289
+ if accepts_eta:
290
+ extra_step_kwargs["eta"] = eta
291
+
292
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
293
+ # expand the latents if we are doing classifier free guidance
294
+ latent_model_input = (
295
+ torch.cat([latents] * text_embeddings.shape[0]) if do_classifier_free_guidance else latents
296
+ )
297
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
298
+ sigma = self.scheduler.sigmas[i]
299
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
300
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
301
+
302
+ # reduce memory by predicting each score sequentially
303
+ noise_preds = []
304
+ # predict the noise residual
305
+ for latent_in, text_embedding_in in zip(
306
+ torch.chunk(latent_model_input, chunks=latent_model_input.shape[0], dim=0),
307
+ torch.chunk(text_embeddings, chunks=text_embeddings.shape[0], dim=0),
308
+ ):
309
+ noise_preds.append(self.unet(latent_in, t, encoder_hidden_states=text_embedding_in).sample)
310
+ noise_preds = torch.cat(noise_preds, dim=0)
311
+
312
+ # perform guidance
313
+ if do_classifier_free_guidance:
314
+ noise_pred_uncond = (noise_preds[~mask] * neg_weights).sum(dim=0, keepdims=True)
315
+ noise_pred_text = (noise_preds[mask] * pos_weights).sum(dim=0, keepdims=True)
316
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
317
+
318
+ # compute the previous noisy sample x_t -> x_t-1
319
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
320
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
321
+ else:
322
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
323
+
324
+ # call the callback, if provided
325
+ if callback is not None and i % callback_steps == 0:
326
+ callback(i, t, latents)
327
+
328
+ # scale and decode the image latents with vae
329
+ latents = 1 / 0.18215 * latents
330
+ image = self.vae.decode(latents).sample
331
+
332
+ image = (image / 2 + 0.5).clamp(0, 1)
333
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
334
+
335
+ # run safety checker
336
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
337
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
338
+
339
+ if output_type == "pil":
340
+ image = self.numpy_to_pil(image)
341
+
342
+ if not return_dict:
343
+ return (image, has_nsfw_concept)
344
+
345
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)