Kevin Turner commited on
Commit
843dc97
1 Parent(s): 48fbe62

add pipeline-with-callback

Browse files

based on the proposal in diffusers#521
https://github.com/huggingface/diffusers/pull/521/files#diff-ab952f41078da66b9fcbbd913b419f8c334badceefac03a5f7edcd6dd986a8ef

reset diffusers requirement to main repo; specify versions for other
various dependencies

Files changed (3) hide show
  1. app.py +3 -4
  2. pipeline_with_callback.py +335 -0
  3. requirements.txt +9 -5
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import gradio as gr
2
 
3
  import torch
4
- from torch import autocast
5
- from diffusers import StableDiffusionPipeline
6
  from datasets import load_dataset
7
  from PIL import Image
8
  import re
@@ -11,7 +10,7 @@ model_id = "CompVis/stable-diffusion-v1-4"
11
  device = "cuda"
12
 
13
  #If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below.
14
- pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
15
  pipe = pipe.to(device)
16
  torch.backends.cudnn.benchmark = True
17
 
@@ -299,4 +298,4 @@ Despite how impressive being able to turn text into image is, beware to the fact
299
  """
300
  )
301
 
302
- block.queue(max_size=25).launch()
 
1
  import gradio as gr
2
 
3
  import torch
4
+ from pipeline_with_callback import StableDiffusionPipelineWithCallback
 
5
  from datasets import load_dataset
6
  from PIL import Image
7
  import re
 
10
  device = "cuda"
11
 
12
  #If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below.
13
+ pipe = StableDiffusionPipelineWithCallback.from_pretrained(model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
14
  pipe = pipe.to(device)
15
  torch.backends.cudnn.benchmark = True
16
 
 
298
  """
299
  )
300
 
301
+ block.queue(max_size=25).launch()
pipeline_with_callback.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import warnings
3
+ from typing import Callable, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
9
+
10
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
+ from diffusers.pipeline_utils import DiffusionPipeline
12
+ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
13
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
14
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
15
+
16
+
17
+ class StableDiffusionPipelineWithCallback(DiffusionPipeline):
18
+ r"""
19
+ Pipeline for text-to-image generation using Stable Diffusion.
20
+
21
+ ** based on https://github.com/huggingface/diffusers/pull/521/files#diff-ab952f41078da66b9fcbbd913b419f8c334badceefac03a5f7edcd6dd986a8ef **
22
+
23
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
24
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
25
+
26
+ Args:
27
+ vae ([`AutoencoderKL`]):
28
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
29
+ text_encoder ([`CLIPTextModel`]):
30
+ Frozen text-encoder. Stable Diffusion uses the text portion of
31
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
32
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
33
+ tokenizer (`CLIPTokenizer`):
34
+ Tokenizer of class
35
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
36
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
37
+ scheduler ([`SchedulerMixin`]):
38
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
39
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
40
+ safety_checker ([`StableDiffusionSafetyChecker`]):
41
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
42
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
43
+ feature_extractor ([`CLIPFeatureExtractor`]):
44
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ vae: AutoencoderKL,
50
+ text_encoder: CLIPTextModel,
51
+ tokenizer: CLIPTokenizer,
52
+ unet: UNet2DConditionModel,
53
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
54
+ safety_checker: StableDiffusionSafetyChecker,
55
+ feature_extractor: CLIPFeatureExtractor,
56
+ ):
57
+ super().__init__()
58
+ scheduler = scheduler.set_format("pt")
59
+ self.register_modules(
60
+ vae=vae,
61
+ text_encoder=text_encoder,
62
+ tokenizer=tokenizer,
63
+ unet=unet,
64
+ scheduler=scheduler,
65
+ safety_checker=safety_checker,
66
+ feature_extractor=feature_extractor,
67
+ )
68
+
69
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
70
+ r"""
71
+ Enable sliced attention computation.
72
+
73
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
74
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
75
+
76
+ Args:
77
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
78
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
79
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
80
+ `attention_head_dim` must be a multiple of `slice_size`.
81
+ """
82
+ if slice_size == "auto":
83
+ # half the attention head size is usually a good trade-off between
84
+ # speed and memory
85
+ slice_size = self.unet.config.attention_head_dim // 2
86
+ self.unet.set_attention_slice(slice_size)
87
+
88
+ def disable_attention_slicing(self):
89
+ r"""
90
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
91
+ back to computing attention in one step.
92
+ """
93
+ # set slice_size = `None` to disable `attention slicing`
94
+ self.enable_attention_slicing(None)
95
+
96
+ @torch.no_grad()
97
+ def decode_latents(self, latents: torch.FloatTensor) -> np.ndarray:
98
+ r"""
99
+ Scale and decode the latent representations into images using the VAE.
100
+
101
+ Args:
102
+ latents (`torch.FloatTensor`):
103
+ Latent representations to decode into images.
104
+
105
+ Returns:
106
+ `np.ndarray`: Decoded images.
107
+ """
108
+ latents = 1 / 0.18215 * latents
109
+ image = self.vae.decode(latents).sample
110
+
111
+ image = (image / 2 + 0.5).clamp(0, 1)
112
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
113
+ return image
114
+
115
+ @torch.no_grad()
116
+ def run_safety_checker(self, image: np.ndarray) -> Tuple[np.ndarray, List[bool]]:
117
+ r"""
118
+ Run the safety checker on the generated images. If potential NSFW content was detected, a warning will be
119
+ raised and a black image will be returned instead.
120
+
121
+ Args:
122
+ image (`np.ndarray`):
123
+ Images to run the safety checker on.
124
+
125
+ Returns:
126
+ `Tuple[np.ndarray, List[bool]]`: The first element contains the images that has been processed by the
127
+ safety checker. The second element is a boolean array indicating whether the images contain NSFW content.
128
+ """
129
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
130
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
131
+ return image, has_nsfw_concept
132
+
133
+ @torch.no_grad()
134
+ def __call__(
135
+ self,
136
+ prompt: Union[str, List[str]],
137
+ height: Optional[int] = 512,
138
+ width: Optional[int] = 512,
139
+ num_inference_steps: Optional[int] = 50,
140
+ guidance_scale: Optional[float] = 7.5,
141
+ eta: Optional[float] = 0.0,
142
+ generator: Optional[torch.Generator] = None,
143
+ latents: Optional[torch.FloatTensor] = None,
144
+ output_type: Optional[str] = "pil",
145
+ return_dict: bool = True,
146
+ callback: Optional[
147
+ Callable[[int, np.ndarray, torch.FloatTensor], None]
148
+ ] = None,
149
+ callback_frequency: Optional[int] = 1,
150
+ **kwargs,
151
+ ):
152
+ r"""
153
+ Function invoked when calling the pipeline for generation.
154
+
155
+ Args:
156
+ prompt (`str` or `List[str]`):
157
+ The prompt or prompts to guide the image generation.
158
+ height (`int`, *optional*, defaults to 512):
159
+ The height in pixels of the generated image.
160
+ width (`int`, *optional*, defaults to 512):
161
+ The width in pixels of the generated image.
162
+ num_inference_steps (`int`, *optional*, defaults to 50):
163
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
164
+ expense of slower inference.
165
+ guidance_scale (`float`, *optional*, defaults to 7.5):
166
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
167
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
168
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
169
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
170
+ usually at the expense of lower image quality.
171
+ eta (`float`, *optional*, defaults to 0.0):
172
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
173
+ [`schedulers.DDIMScheduler`], will be ignored for others.
174
+ generator (`torch.Generator`, *optional*):
175
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
176
+ deterministic.
177
+ latents (`torch.FloatTensor`, *optional*):
178
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
179
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
180
+ tensor will ge generated by sampling using the supplied random `generator`.
181
+ output_type (`str`, *optional*, defaults to `"pil"`):
182
+ The output format of the generate image. Choose between
183
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
184
+ return_dict (`bool`, *optional*, defaults to `True`):
185
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
186
+ plain tuple.
187
+ callback (`Callable`, *optional*):
188
+ A function that will be called every `callback_frequency` steps during inference. The function will be
189
+ called with the following arguments: `callback(step: int, timestep: np.ndarray, latents:
190
+ torch.FloatTensor, image: Union[List[PIL.Image.Image], np.ndarray])`.
191
+ callback_frequency (`int`, *optional*, defaults to 1):
192
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
193
+ called at every step.
194
+
195
+ Returns:
196
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
197
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
198
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
199
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
200
+ (nsfw) content, according to the `safety_checker`.
201
+ """
202
+
203
+ if "torch_device" in kwargs:
204
+ device = kwargs.pop("torch_device")
205
+ warnings.warn(
206
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
207
+ " Consider using `pipe.to(torch_device)` instead."
208
+ )
209
+
210
+ # Set device as before (to be removed in 0.3.0)
211
+ if device is None:
212
+ device = "cuda" if torch.cuda.is_available() else "cpu"
213
+ self.to(device)
214
+
215
+ if isinstance(prompt, str):
216
+ batch_size = 1
217
+ elif isinstance(prompt, list):
218
+ batch_size = len(prompt)
219
+ else:
220
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
221
+
222
+ if height % 8 != 0 or width % 8 != 0:
223
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
224
+
225
+ if (callback_frequency is None) or (
226
+ callback_frequency is not None and (not isinstance(callback_frequency, int) or callback_frequency <= 0)
227
+ ):
228
+ raise ValueError(
229
+ f"`callback_frequency` has to be a positive integer but is {callback_frequency} of type"
230
+ f" {type(callback_frequency)}."
231
+ )
232
+
233
+ # get prompt text embeddings
234
+ text_input = self.tokenizer(
235
+ prompt,
236
+ padding="max_length",
237
+ max_length=self.tokenizer.model_max_length,
238
+ truncation=True,
239
+ return_tensors="pt",
240
+ )
241
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
242
+
243
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
244
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
245
+ # corresponds to doing no classifier free guidance.
246
+ do_classifier_free_guidance = guidance_scale > 1.0
247
+ # get unconditional embeddings for classifier free guidance
248
+ if do_classifier_free_guidance:
249
+ max_length = text_input.input_ids.shape[-1]
250
+ uncond_input = self.tokenizer(
251
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
252
+ )
253
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
254
+
255
+ # For classifier free guidance, we need to do two forward passes.
256
+ # Here we concatenate the unconditional and text embeddings into a single batch
257
+ # to avoid doing two forward passes
258
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
259
+
260
+ # get the initial random noise unless the user supplied it
261
+
262
+ # Unlike in other pipelines, latents need to be generated in the target device
263
+ # for 1-to-1 results reproducibility with the CompVis implementation.
264
+ # However this currently doesn't work in `mps`.
265
+ latents_device = "cpu" if self.device.type == "mps" else self.device
266
+ latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
267
+ if latents is None:
268
+ latents = torch.randn(
269
+ latents_shape,
270
+ generator=generator,
271
+ device=latents_device,
272
+ )
273
+ else:
274
+ if latents.shape != latents_shape:
275
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
276
+ latents = latents.to(self.device)
277
+
278
+ # set timesteps
279
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
280
+ extra_set_kwargs = {}
281
+ if accepts_offset:
282
+ extra_set_kwargs["offset"] = 1
283
+
284
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
285
+
286
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
287
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
288
+ latents = latents * self.scheduler.sigmas[0]
289
+
290
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
291
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
292
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
293
+ # and should be between [0, 1]
294
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
295
+ extra_step_kwargs = {}
296
+ if accepts_eta:
297
+ extra_step_kwargs["eta"] = eta
298
+
299
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
300
+ # expand the latents if we are doing classifier free guidance
301
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
302
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
303
+ sigma = self.scheduler.sigmas[i]
304
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
305
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
306
+
307
+ # predict the noise residual
308
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
309
+
310
+ # perform guidance
311
+ if do_classifier_free_guidance:
312
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
313
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
314
+
315
+ # compute the previous noisy sample x_t -> x_t-1
316
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
317
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
318
+ else:
319
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
320
+
321
+ # call the callback, if provided
322
+ if callback is not None and i % callback_frequency == 0:
323
+ callback(i, t, latents)
324
+
325
+ image = self.decode_latents(latents)
326
+
327
+ image, has_nsfw_concept = self.run_safety_checker(image)
328
+
329
+ if output_type == "pil":
330
+ image = self.numpy_to_pil(image)
331
+
332
+ if not return_dict:
333
+ return (image, has_nsfw_concept)
334
+
335
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
requirements.txt CHANGED
@@ -1,6 +1,10 @@
1
- -e git+https://github.com/Narsil/diffusers.git@5ad7ed8d3eb2216b8f8c45d3d1fca929882441be#egg=diffusers
2
- transformers
3
- nvidia-ml-py3
4
  ftfy
5
- --extra-index-url https://download.pytorch.org/whl/cu113
6
- torch
 
 
 
 
 
 
1
+ -e git+https://github.com/huggingface/diffusers.git@429dace10a356a776f935fc11e16d5b321b496f3#egg=diffusers
2
+ datasets~=2.4.0
 
3
  ftfy
4
+ gradio~=3.3.1
5
+ numpy~=1.23.2
6
+ nvidia-ml-py3
7
+ Pillow~=9.2.0
8
+ transformers~=4.21.3
9
+ --extra-index-url https://download.pytorch.org/whl/cu113
10
+ torch~=1.12.1