r3gm commited on
Commit
99f75c3
1 Parent(s): 732b9b3

Upload 4 files

Browse files
Files changed (4) hide show
  1. lcm/lcm_i2i_pipeline.py +805 -0
  2. lcm/lcm_pipeline.py +269 -0
  3. lcm/lcm_scheduler.py +498 -0
  4. scripts/main.py +613 -0
lcm/lcm_i2i_pipeline.py ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
25
+
26
+ from diffusers import AutoencoderKL, ConfigMixin, DiffusionPipeline, SchedulerMixin, UNet2DConditionModel, logging
27
+ from diffusers.configuration_utils import register_to_config
28
+ from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
29
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
30
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
31
+ from diffusers.utils import BaseOutput
32
+
33
+ from diffusers.utils.torch_utils import randn_tensor
34
+
35
+
36
+ import PIL.Image
37
+
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
43
+ _optional_components = ["scheduler"]
44
+
45
+ def __init__(
46
+ self,
47
+ vae: AutoencoderKL,
48
+ text_encoder: CLIPTextModel,
49
+ tokenizer: CLIPTokenizer,
50
+ unet: UNet2DConditionModel,
51
+ scheduler: "LCMSchedulerWithTimestamp",
52
+ safety_checker: StableDiffusionSafetyChecker,
53
+ feature_extractor: CLIPImageProcessor,
54
+ requires_safety_checker: bool = False,
55
+ ):
56
+ super().__init__()
57
+
58
+ scheduler = (
59
+ scheduler
60
+ if scheduler is not None
61
+ else LCMSchedulerWithTimestamp(
62
+ beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear", prediction_type="epsilon"
63
+ )
64
+ )
65
+
66
+ self.register_modules(
67
+ vae=vae,
68
+ text_encoder=text_encoder,
69
+ tokenizer=tokenizer,
70
+ unet=unet,
71
+ scheduler=scheduler,
72
+ safety_checker=safety_checker,
73
+ feature_extractor=feature_extractor,
74
+ )
75
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
76
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
77
+
78
+ def _encode_prompt(
79
+ self,
80
+ prompt,
81
+ device,
82
+ num_images_per_prompt,
83
+ prompt_embeds: None,
84
+ ):
85
+ r"""
86
+ Encodes the prompt into text encoder hidden states.
87
+ Args:
88
+ prompt (`str` or `List[str]`, *optional*):
89
+ prompt to be encoded
90
+ device: (`torch.device`):
91
+ torch device
92
+ num_images_per_prompt (`int`):
93
+ number of images that should be generated per prompt
94
+ prompt_embeds (`torch.FloatTensor`, *optional*):
95
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
96
+ provided, text embeddings will be generated from `prompt` input argument.
97
+ """
98
+
99
+ if prompt is not None and isinstance(prompt, str):
100
+ pass
101
+ elif prompt is not None and isinstance(prompt, list):
102
+ len(prompt)
103
+ else:
104
+ prompt_embeds.shape[0]
105
+
106
+ if prompt_embeds is None:
107
+ text_inputs = self.tokenizer(
108
+ prompt,
109
+ padding="max_length",
110
+ max_length=self.tokenizer.model_max_length,
111
+ truncation=True,
112
+ return_tensors="pt",
113
+ )
114
+ text_input_ids = text_inputs.input_ids
115
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
116
+
117
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
118
+ text_input_ids, untruncated_ids
119
+ ):
120
+ removed_text = self.tokenizer.batch_decode(
121
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
122
+ )
123
+ logger.warning(
124
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
125
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
126
+ )
127
+
128
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
129
+ attention_mask = text_inputs.attention_mask.to(device)
130
+ else:
131
+ attention_mask = None
132
+
133
+ prompt_embeds = self.text_encoder(
134
+ text_input_ids.to(device),
135
+ attention_mask=attention_mask,
136
+ )
137
+ prompt_embeds = prompt_embeds[0]
138
+
139
+ if self.text_encoder is not None:
140
+ prompt_embeds_dtype = self.text_encoder.dtype
141
+ elif self.unet is not None:
142
+ prompt_embeds_dtype = self.unet.dtype
143
+ else:
144
+ prompt_embeds_dtype = prompt_embeds.dtype
145
+
146
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
147
+
148
+ bs_embed, seq_len, _ = prompt_embeds.shape
149
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
150
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
151
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
152
+
153
+ # Don't need to get uncond prompt embedding because of LCM Guided Distillation
154
+ return prompt_embeds
155
+
156
+ # ¯\_(ツ)_/¯
157
+ def run_safety_checker(self, image, device, dtype):
158
+ return image, None
159
+
160
+ def prepare_latents(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, latents=None, generator=None):
161
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
162
+
163
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
164
+ raise ValueError(
165
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
166
+ )
167
+
168
+ image = image.to(device=device, dtype=dtype)
169
+
170
+ # batch_size = batch_size * num_images_per_prompt
171
+
172
+ if image.shape[1] == 4:
173
+ init_latents = image
174
+
175
+ else:
176
+ if isinstance(generator, list) and len(generator) != batch_size:
177
+ raise ValueError(
178
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
179
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
180
+ )
181
+
182
+ elif isinstance(generator, list):
183
+ init_latents = [
184
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
185
+ ]
186
+ init_latents = torch.cat(init_latents, dim=0)
187
+ else:
188
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
189
+
190
+ init_latents = self.vae.config.scaling_factor * init_latents
191
+
192
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
193
+ # expand init_latents for batch_size
194
+ deprecation_message = (
195
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
196
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
197
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
198
+ " your script to pass as many initial images as text prompts to suppress this warning."
199
+ )
200
+ # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
201
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
202
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
203
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
204
+ raise ValueError(
205
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
206
+ )
207
+ else:
208
+ init_latents = torch.cat([init_latents], dim=0)
209
+
210
+ shape = init_latents.shape
211
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
212
+
213
+ # get latents
214
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
215
+ latents = init_latents
216
+
217
+ return latents
218
+
219
+ if latents is None:
220
+ latents = torch.randn(shape, dtype=dtype).to(device)
221
+ else:
222
+ latents = latents.to(device)
223
+ # scale the initial noise by the standard deviation required by the scheduler
224
+ latents = latents * self.scheduler.init_noise_sigma
225
+ return latents
226
+
227
+ def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
228
+ """
229
+ see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
230
+ Args:
231
+ timesteps: torch.Tensor: generate embedding vectors at these timesteps
232
+ embedding_dim: int: dimension of the embeddings to generate
233
+ dtype: data type of the generated embeddings
234
+ Returns:
235
+ embedding vectors with shape `(len(timesteps), embedding_dim)`
236
+ """
237
+ assert len(w.shape) == 1
238
+ w = w * 1000.0
239
+
240
+ half_dim = embedding_dim // 2
241
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
242
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
243
+ emb = w.to(dtype)[:, None] * emb[None, :]
244
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
245
+ if embedding_dim % 2 == 1: # zero pad
246
+ emb = torch.nn.functional.pad(emb, (0, 1))
247
+ assert emb.shape == (w.shape[0], embedding_dim)
248
+ return emb
249
+
250
+ def get_timesteps(self, num_inference_steps, strength, device):
251
+ # get the original timestep using init_timestep
252
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
253
+
254
+ t_start = max(num_inference_steps - init_timestep, 0)
255
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
256
+
257
+ return timesteps, num_inference_steps - t_start
258
+
259
+ @torch.no_grad()
260
+ def __call__(
261
+ self,
262
+ prompt: Union[str, List[str]] = None,
263
+ image: PipelineImageInput = None,
264
+ strength: float = 0.8,
265
+ height: Optional[int] = 768,
266
+ width: Optional[int] = 768,
267
+ guidance_scale: float = 7.5,
268
+ num_images_per_prompt: Optional[int] = 1,
269
+ latents: Optional[torch.FloatTensor] = None,
270
+ num_inference_steps: int = 4,
271
+ original_inference_steps: int = 50,
272
+ prompt_embeds: Optional[torch.FloatTensor] = None,
273
+ output_type: Optional[str] = "pil",
274
+ return_dict: bool = True,
275
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
276
+ device: Optional[Union[str, torch.device]] = None,
277
+ ):
278
+ # 0. Default height and width to unet
279
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
280
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
281
+
282
+ # 2. Define call parameters
283
+ if prompt is not None and isinstance(prompt, str):
284
+ batch_size = 1
285
+ elif prompt is not None and isinstance(prompt, list):
286
+ batch_size = len(prompt)
287
+ else:
288
+ batch_size = prompt_embeds.shape[0]
289
+
290
+ device = device
291
+ # do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
292
+
293
+ # 3. Encode input prompt
294
+ prompt_embeds = self._encode_prompt(
295
+ prompt,
296
+ device,
297
+ num_images_per_prompt,
298
+ prompt_embeds=prompt_embeds,
299
+ )
300
+
301
+ # 3.5 encode image
302
+ image = self.image_processor.preprocess(image=image)
303
+
304
+ # 4. Prepare timesteps
305
+ self.scheduler.set_timesteps(strength, num_inference_steps, original_inference_steps)
306
+ # timesteps = self.scheduler.timesteps
307
+ # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
308
+ timesteps = self.scheduler.timesteps
309
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
310
+
311
+ # 5. Prepare latent variable
312
+ num_channels_latents = self.unet.config.in_channels
313
+ latents = self.prepare_latents(
314
+ image,
315
+ latent_timestep,
316
+ batch_size * num_images_per_prompt,
317
+ num_channels_latents,
318
+ height,
319
+ width,
320
+ prompt_embeds.dtype,
321
+ device,
322
+ latents,
323
+ )
324
+ bs = batch_size * num_images_per_prompt
325
+
326
+ # 6. Get Guidance Scale Embedding
327
+ w = torch.tensor(guidance_scale).repeat(bs)
328
+ w_embedding = self.get_w_embedding(w, embedding_dim=256).to(device=device, dtype=latents.dtype)
329
+
330
+ # 7. LCM MultiStep Sampling Loop:
331
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
332
+ for i, t in enumerate(timesteps):
333
+ ts = torch.full((bs,), t, device=device, dtype=torch.long)
334
+ latents = latents.to(prompt_embeds.dtype)
335
+
336
+ # model prediction (v-prediction, eps, x)
337
+ model_pred = self.unet(
338
+ latents,
339
+ ts,
340
+ timestep_cond=w_embedding,
341
+ encoder_hidden_states=prompt_embeds,
342
+ cross_attention_kwargs=cross_attention_kwargs,
343
+ return_dict=False,
344
+ )[0]
345
+
346
+ # compute the previous noisy sample x_t -> x_t-1
347
+ latents, denoised = self.scheduler.step(model_pred, i, t, latents, return_dict=False)
348
+
349
+ # # call the callback, if provided
350
+ # if i == len(timesteps) - 1:
351
+ progress_bar.update()
352
+
353
+ denoised = denoised.to(prompt_embeds.dtype)
354
+ if not output_type == "latent":
355
+ image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
356
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
357
+ else:
358
+ image = denoised
359
+ has_nsfw_concept = None
360
+
361
+ if has_nsfw_concept is None:
362
+ do_denormalize = [True] * image.shape[0]
363
+ else:
364
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
365
+
366
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
367
+
368
+ if not return_dict:
369
+ return (image, has_nsfw_concept)
370
+
371
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
372
+
373
+
374
+ @dataclass
375
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
376
+ class LCMSchedulerOutput(BaseOutput):
377
+ """
378
+ Output class for the scheduler's `step` function output.
379
+ Args:
380
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
381
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
382
+ denoising loop.
383
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
384
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
385
+ `pred_original_sample` can be used to preview progress or for guidance.
386
+ """
387
+
388
+ prev_sample: torch.FloatTensor
389
+ denoised: Optional[torch.FloatTensor] = None
390
+
391
+
392
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
393
+ def betas_for_alpha_bar(
394
+ num_diffusion_timesteps,
395
+ max_beta=0.999,
396
+ alpha_transform_type="cosine",
397
+ ):
398
+ """
399
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
400
+ (1-beta) over time from t = [0,1].
401
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
402
+ to that part of the diffusion process.
403
+ Args:
404
+ num_diffusion_timesteps (`int`): the number of betas to produce.
405
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
406
+ prevent singularities.
407
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
408
+ Choose from `cosine` or `exp`
409
+ Returns:
410
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
411
+ """
412
+ if alpha_transform_type == "cosine":
413
+
414
+ def alpha_bar_fn(t):
415
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
416
+
417
+ elif alpha_transform_type == "exp":
418
+
419
+ def alpha_bar_fn(t):
420
+ return math.exp(t * -12.0)
421
+
422
+ else:
423
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
424
+
425
+ betas = []
426
+ for i in range(num_diffusion_timesteps):
427
+ t1 = i / num_diffusion_timesteps
428
+ t2 = (i + 1) / num_diffusion_timesteps
429
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
430
+ return torch.tensor(betas, dtype=torch.float32)
431
+
432
+
433
+ def rescale_zero_terminal_snr(betas):
434
+ """
435
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
436
+ Args:
437
+ betas (`torch.FloatTensor`):
438
+ the betas that the scheduler is being initialized with.
439
+ Returns:
440
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
441
+ """
442
+ # Convert betas to alphas_bar_sqrt
443
+ alphas = 1.0 - betas
444
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
445
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
446
+
447
+ # Store old values.
448
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
449
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
450
+
451
+ # Shift so the last timestep is zero.
452
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
453
+
454
+ # Scale so the first timestep is back to the old value.
455
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
456
+
457
+ # Convert alphas_bar_sqrt to betas
458
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
459
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
460
+ alphas = torch.cat([alphas_bar[0:1], alphas])
461
+ betas = 1 - alphas
462
+
463
+ return betas
464
+
465
+
466
+ class LCMSchedulerWithTimestamp(SchedulerMixin, ConfigMixin):
467
+ """
468
+ This class modifies LCMScheduler to add a timestamp argument to set_timesteps
469
+
470
+
471
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
472
+ non-Markovian guidance.
473
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
474
+ methods the library implements for all schedulers such as loading and saving.
475
+ Args:
476
+ num_train_timesteps (`int`, defaults to 1000):
477
+ The number of diffusion steps to train the model.
478
+ beta_start (`float`, defaults to 0.0001):
479
+ The starting `beta` value of inference.
480
+ beta_end (`float`, defaults to 0.02):
481
+ The final `beta` value.
482
+ beta_schedule (`str`, defaults to `"linear"`):
483
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
484
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
485
+ trained_betas (`np.ndarray`, *optional*):
486
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
487
+ clip_sample (`bool`, defaults to `True`):
488
+ Clip the predicted sample for numerical stability.
489
+ clip_sample_range (`float`, defaults to 1.0):
490
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
491
+ set_alpha_to_one (`bool`, defaults to `True`):
492
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
493
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
494
+ otherwise it uses the alpha value at step 0.
495
+ steps_offset (`int`, defaults to 0):
496
+ An offset added to the inference steps. You can use a combination of `offset=1` and
497
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
498
+ Diffusion.
499
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
500
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
501
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
502
+ Video](https://imagen.research.google/video/paper.pdf) paper).
503
+ thresholding (`bool`, defaults to `False`):
504
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
505
+ as Stable Diffusion.
506
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
507
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
508
+ sample_max_value (`float`, defaults to 1.0):
509
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
510
+ timestep_spacing (`str`, defaults to `"leading"`):
511
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
512
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
513
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
514
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
515
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
516
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
517
+ """
518
+
519
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
520
+ order = 1
521
+
522
+ @register_to_config
523
+ def __init__(
524
+ self,
525
+ num_train_timesteps: int = 1000,
526
+ beta_start: float = 0.0001,
527
+ beta_end: float = 0.02,
528
+ beta_schedule: str = "linear",
529
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
530
+ clip_sample: bool = True,
531
+ set_alpha_to_one: bool = True,
532
+ steps_offset: int = 0,
533
+ prediction_type: str = "epsilon",
534
+ thresholding: bool = False,
535
+ dynamic_thresholding_ratio: float = 0.995,
536
+ clip_sample_range: float = 1.0,
537
+ sample_max_value: float = 1.0,
538
+ timestep_spacing: str = "leading",
539
+ rescale_betas_zero_snr: bool = False,
540
+ ):
541
+ if trained_betas is not None:
542
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
543
+ elif beta_schedule == "linear":
544
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
545
+ elif beta_schedule == "scaled_linear":
546
+ # this schedule is very specific to the latent diffusion model.
547
+ self.betas = (
548
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
549
+ )
550
+ elif beta_schedule == "squaredcos_cap_v2":
551
+ # Glide cosine schedule
552
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
553
+ else:
554
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
555
+
556
+ # Rescale for zero SNR
557
+ if rescale_betas_zero_snr:
558
+ self.betas = rescale_zero_terminal_snr(self.betas)
559
+
560
+ self.alphas = 1.0 - self.betas
561
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
562
+
563
+ # At every step in ddim, we are looking into the previous alphas_cumprod
564
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
565
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
566
+ # whether we use the final alpha of the "non-previous" one.
567
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
568
+
569
+ # standard deviation of the initial noise distribution
570
+ self.init_noise_sigma = 1.0
571
+
572
+ # setable values
573
+ self.num_inference_steps = None
574
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
575
+
576
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
577
+ """
578
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
579
+ current timestep.
580
+ Args:
581
+ sample (`torch.FloatTensor`):
582
+ The input sample.
583
+ timestep (`int`, *optional*):
584
+ The current timestep in the diffusion chain.
585
+ Returns:
586
+ `torch.FloatTensor`:
587
+ A scaled input sample.
588
+ """
589
+ return sample
590
+
591
+ def _get_variance(self, timestep, prev_timestep):
592
+ alpha_prod_t = self.alphas_cumprod[timestep]
593
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
594
+ beta_prod_t = 1 - alpha_prod_t
595
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
596
+
597
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
598
+
599
+ return variance
600
+
601
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
602
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
603
+ """
604
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
605
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
606
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
607
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
608
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
609
+ https://arxiv.org/abs/2205.11487
610
+ """
611
+ dtype = sample.dtype
612
+ batch_size, channels, height, width = sample.shape
613
+
614
+ if dtype not in (torch.float32, torch.float64):
615
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
616
+
617
+ # Flatten sample for doing quantile calculation along each image
618
+ sample = sample.reshape(batch_size, channels * height * width)
619
+
620
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
621
+
622
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
623
+ s = torch.clamp(
624
+ s, min=1, max=self.config.sample_max_value
625
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
626
+
627
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
628
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
629
+
630
+ sample = sample.reshape(batch_size, channels, height, width)
631
+ sample = sample.to(dtype)
632
+
633
+ return sample
634
+
635
+ def set_timesteps(self, stength, num_inference_steps: int, original_inference_steps: int, device: Union[str, torch.device] = None):
636
+ """
637
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
638
+ Args:
639
+ num_inference_steps (`int`):
640
+ The number of diffusion steps used when generating samples with a pre-trained model.
641
+ """
642
+
643
+ if num_inference_steps > self.config.num_train_timesteps:
644
+ raise ValueError(
645
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
646
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
647
+ f" maximal {self.config.num_train_timesteps} timesteps."
648
+ )
649
+
650
+ self.num_inference_steps = num_inference_steps
651
+
652
+ # LCM Timesteps Setting: # Linear Spacing
653
+ c = self.config.num_train_timesteps // original_inference_steps
654
+ lcm_origin_timesteps = np.asarray(list(range(1, int(original_inference_steps * stength) + 1))) * c - 1 # LCM Training Steps Schedule
655
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
656
+ timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
657
+
658
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
659
+
660
+ def get_scalings_for_boundary_condition_discrete(self, t):
661
+ self.sigma_data = 0.5 # Default: 0.5
662
+
663
+ # By dividing 0.1: This is almost a delta function at t=0.
664
+ c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
665
+ c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
666
+ return c_skip, c_out
667
+
668
+ def step(
669
+ self,
670
+ model_output: torch.FloatTensor,
671
+ timeindex: int,
672
+ timestep: int,
673
+ sample: torch.FloatTensor,
674
+ eta: float = 0.0,
675
+ use_clipped_model_output: bool = False,
676
+ generator=None,
677
+ variance_noise: Optional[torch.FloatTensor] = None,
678
+ return_dict: bool = True,
679
+ ) -> Union[LCMSchedulerOutput, Tuple]:
680
+ """
681
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
682
+ process from the learned model outputs (most often the predicted noise).
683
+ Args:
684
+ model_output (`torch.FloatTensor`):
685
+ The direct output from learned diffusion model.
686
+ timestep (`float`):
687
+ The current discrete timestep in the diffusion chain.
688
+ sample (`torch.FloatTensor`):
689
+ A current instance of a sample created by the diffusion process.
690
+ eta (`float`):
691
+ The weight of noise for added noise in diffusion step.
692
+ use_clipped_model_output (`bool`, defaults to `False`):
693
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
694
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
695
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
696
+ `use_clipped_model_output` has no effect.
697
+ generator (`torch.Generator`, *optional*):
698
+ A random number generator.
699
+ variance_noise (`torch.FloatTensor`):
700
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
701
+ itself. Useful for methods such as [`CycleDiffusion`].
702
+ return_dict (`bool`, *optional*, defaults to `True`):
703
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
704
+ Returns:
705
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
706
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
707
+ tuple is returned where the first element is the sample tensor.
708
+ """
709
+ if self.num_inference_steps is None:
710
+ raise ValueError(
711
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
712
+ )
713
+
714
+ # 1. get previous step value
715
+ prev_timeindex = timeindex + 1
716
+ if prev_timeindex < len(self.timesteps):
717
+ prev_timestep = self.timesteps[prev_timeindex]
718
+ else:
719
+ prev_timestep = timestep
720
+
721
+ # 2. compute alphas, betas
722
+ alpha_prod_t = self.alphas_cumprod[timestep]
723
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
724
+
725
+ beta_prod_t = 1 - alpha_prod_t
726
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
727
+
728
+ # 3. Get scalings for boundary conditions
729
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
730
+
731
+ # 4. Different Parameterization:
732
+ parameterization = self.config.prediction_type
733
+
734
+ if parameterization == "epsilon": # noise-prediction
735
+ pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
736
+
737
+ elif parameterization == "sample": # x-prediction
738
+ pred_x0 = model_output
739
+
740
+ elif parameterization == "v_prediction": # v-prediction
741
+ pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
742
+
743
+ # 4. Denoise model output using boundary conditions
744
+ denoised = c_out * pred_x0 + c_skip * sample
745
+
746
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
747
+ # Noise is not used for one-step sampling.
748
+ if len(self.timesteps) > 1:
749
+ noise = torch.randn(model_output.shape).to(model_output.device)
750
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
751
+ else:
752
+ prev_sample = denoised
753
+
754
+ if not return_dict:
755
+ return (prev_sample, denoised)
756
+
757
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
758
+
759
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
760
+ def add_noise(
761
+ self,
762
+ original_samples: torch.FloatTensor,
763
+ noise: torch.FloatTensor,
764
+ timesteps: torch.IntTensor,
765
+ ) -> torch.FloatTensor:
766
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
767
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
768
+ timesteps = timesteps.to(original_samples.device)
769
+
770
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
771
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
772
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
773
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
774
+
775
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
776
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
777
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
778
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
779
+
780
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
781
+ return noisy_samples
782
+
783
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
784
+ def get_velocity(
785
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
786
+ ) -> torch.FloatTensor:
787
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
788
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
789
+ timesteps = timesteps.to(sample.device)
790
+
791
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
792
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
793
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
794
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
795
+
796
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
797
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
798
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
799
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
800
+
801
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
802
+ return velocity
803
+
804
+ def __len__(self):
805
+ return self.config.num_train_timesteps
lcm/lcm_pipeline.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel
3
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
4
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
5
+ from diffusers.image_processor import VaeImageProcessor
6
+ # import modules.shared
7
+ from typing import List, Optional, Union, Dict, Any
8
+
9
+ from diffusers import logging
10
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
11
+
12
+
13
+ class LatentConsistencyModelPipeline(DiffusionPipeline):
14
+ def __init__(
15
+ self,
16
+ vae: AutoencoderKL,
17
+ text_encoder: CLIPTextModel,
18
+ tokenizer: CLIPTokenizer,
19
+ unet: UNet2DConditionModel,
20
+ scheduler: None,
21
+ safety_checker: None,
22
+ feature_extractor: CLIPImageProcessor
23
+ ):
24
+ super().__init__()
25
+
26
+ self.register_modules(
27
+ vae=vae,
28
+ text_encoder=text_encoder,
29
+ tokenizer=tokenizer,
30
+ unet=unet,
31
+ scheduler=scheduler,
32
+ safety_checker=safety_checker,
33
+ feature_extractor=feature_extractor,
34
+ )
35
+ self.vae_scale_factor = 2 ** (
36
+ len(self.vae.config.block_out_channels) - 1)
37
+ self.image_processor = VaeImageProcessor(
38
+ vae_scale_factor=self.vae_scale_factor)
39
+
40
+ def _encode_prompt(
41
+ self,
42
+ prompt,
43
+ device,
44
+ num_images_per_prompt,
45
+ prompt_embeds: None,
46
+ ):
47
+ r"""
48
+ Encodes the prompt into text encoder hidden states.
49
+
50
+ Args:
51
+ prompt (`str` or `List[str]`, *optional*):
52
+ prompt to be encoded
53
+ device: (`torch.device`):
54
+ torch device
55
+ num_images_per_prompt (`int`):
56
+ number of images that should be generated per prompt
57
+ prompt_embeds (`torch.FloatTensor`, *optional*):
58
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
59
+ provided, text embeddings will be generated from `prompt` input argument.
60
+ """
61
+
62
+ if prompt is not None and isinstance(prompt, str):
63
+ batch_size = 1
64
+ elif prompt is not None and isinstance(prompt, list):
65
+ batch_size = len(prompt)
66
+ else:
67
+ batch_size = prompt_embeds.shape[0]
68
+
69
+ if prompt_embeds is None:
70
+
71
+ text_inputs = self.tokenizer(
72
+ prompt,
73
+ padding="max_length",
74
+ max_length=self.tokenizer.model_max_length,
75
+ truncation=True,
76
+ return_tensors="pt",
77
+ )
78
+ text_input_ids = text_inputs.input_ids
79
+ untruncated_ids = self.tokenizer(
80
+ prompt, padding="longest", return_tensors="pt").input_ids
81
+
82
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
83
+ text_input_ids, untruncated_ids
84
+ ):
85
+ removed_text = self.tokenizer.batch_decode(
86
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
87
+ )
88
+ logger.warning(
89
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
90
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
91
+ )
92
+
93
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
94
+ attention_mask = text_inputs.attention_mask.to(device)
95
+ else:
96
+ attention_mask = None
97
+
98
+ prompt_embeds = self.text_encoder(
99
+ text_input_ids.to(device),
100
+ attention_mask=attention_mask,
101
+ )
102
+ prompt_embeds = prompt_embeds[0]
103
+
104
+ if self.text_encoder is not None:
105
+ prompt_embeds_dtype = self.text_encoder.dtype
106
+ elif self.unet is not None:
107
+ prompt_embeds_dtype = self.unet.dtype
108
+ else:
109
+ prompt_embeds_dtype = prompt_embeds.dtype
110
+
111
+ prompt_embeds = prompt_embeds.to(
112
+ dtype=prompt_embeds_dtype, device=device)
113
+
114
+ bs_embed, seq_len, _ = prompt_embeds.shape
115
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
116
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
117
+ prompt_embeds = prompt_embeds.view(
118
+ bs_embed * num_images_per_prompt, seq_len, -1)
119
+
120
+ # Don't need to get uncond prompt embedding because of LCM Guided Distillation
121
+ return prompt_embeds
122
+
123
+ # ¯\_(ツ)_/¯
124
+ def run_safety_checker(self, image, device, dtype):
125
+ return image, None
126
+
127
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents=None):
128
+ shape = (batch_size, num_channels_latents, height //
129
+ self.vae_scale_factor, width // self.vae_scale_factor)
130
+ if latents is None:
131
+ latents = torch.randn(shape, dtype=dtype).to(device)
132
+ else:
133
+ latents = latents.to(device)
134
+ # scale the initial noise by the standard deviation required by the scheduler
135
+ latents = latents * self.scheduler.init_noise_sigma
136
+ return latents
137
+
138
+ def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
139
+ """
140
+ see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
141
+ Args:
142
+ timesteps: torch.Tensor: generate embedding vectors at these timesteps
143
+ embedding_dim: int: dimension of the embeddings to generate
144
+ dtype: data type of the generated embeddings
145
+
146
+ Returns:
147
+ embedding vectors with shape `(len(timesteps), embedding_dim)`
148
+ """
149
+ assert len(w.shape) == 1
150
+ w = w * 1000.
151
+
152
+ half_dim = embedding_dim // 2
153
+ emb = torch.log(torch.tensor(10000.)) / (half_dim - 1)
154
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
155
+ emb = w.to(dtype)[:, None] * emb[None, :]
156
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
157
+ if embedding_dim % 2 == 1: # zero pad
158
+ emb = torch.nn.functional.pad(emb, (0, 1))
159
+ assert emb.shape == (w.shape[0], embedding_dim)
160
+ return emb
161
+
162
+ @torch.no_grad()
163
+ def __call__(
164
+ self,
165
+ prompt: Union[str, List[str]] = None,
166
+ height: Optional[int] = 768,
167
+ width: Optional[int] = 768,
168
+ guidance_scale: float = 7.5,
169
+ num_images_per_prompt: Optional[int] = 1,
170
+ latents: Optional[torch.FloatTensor] = None,
171
+ num_inference_steps: int = 4,
172
+ original_inference_steps: int = 50,
173
+ prompt_embeds: Optional[torch.FloatTensor] = None,
174
+ output_type: Optional[str] = "pil",
175
+ return_dict: bool = True,
176
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
177
+ device: Optional[Union[str, torch.device]] = None,
178
+ ):
179
+
180
+ # 0. Default height and width to unet
181
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
182
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
183
+
184
+ # 2. Define call parameters
185
+ if prompt is not None and isinstance(prompt, str):
186
+ batch_size = 1
187
+ elif prompt is not None and isinstance(prompt, list):
188
+ batch_size = len(prompt)
189
+ else:
190
+ batch_size = prompt_embeds.shape[0]
191
+
192
+ # do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
193
+
194
+ # 3. Encode input prompt
195
+ prompt_embeds = self._encode_prompt(
196
+ prompt,
197
+ device,
198
+ num_images_per_prompt,
199
+ prompt_embeds=prompt_embeds,
200
+ )
201
+
202
+ # 4. Prepare timesteps
203
+ self.scheduler.set_timesteps(num_inference_steps, original_inference_steps)
204
+ timesteps = self.scheduler.timesteps
205
+
206
+ # 5. Prepare latent variable
207
+ num_channels_latents = self.unet.config.in_channels
208
+ latents = self.prepare_latents(
209
+ batch_size * num_images_per_prompt,
210
+ num_channels_latents,
211
+ height,
212
+ width,
213
+ prompt_embeds.dtype,
214
+ device,
215
+ latents,
216
+ )
217
+ bs = batch_size * num_images_per_prompt
218
+
219
+ # 6. Get Guidance Scale Embedding
220
+ w = torch.tensor(guidance_scale).repeat(bs)
221
+ w_embedding = self.get_w_embedding(w, embedding_dim=256).to(
222
+ device=device, dtype=latents.dtype)
223
+
224
+ # 7. LCM MultiStep Sampling Loop:
225
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
226
+ for i, t in enumerate(timesteps):
227
+
228
+ ts = torch.full((bs,), t, device=device, dtype=torch.long)
229
+ latents = latents.to(prompt_embeds.dtype)
230
+
231
+ # model prediction (v-prediction, eps, x)
232
+ model_pred = self.unet(
233
+ latents,
234
+ ts,
235
+ timestep_cond=w_embedding,
236
+ encoder_hidden_states=prompt_embeds,
237
+ cross_attention_kwargs=cross_attention_kwargs,
238
+ return_dict=False)[0]
239
+
240
+ # compute the previous noisy sample x_t -> x_t-1
241
+ latents, denoised = self.scheduler.step(
242
+ model_pred, i, t, latents, return_dict=False)
243
+
244
+ # # call the callback, if provided
245
+ # if i == len(timesteps) - 1:
246
+ progress_bar.update()
247
+
248
+ denoised = denoised.to(prompt_embeds.dtype)
249
+ if not output_type == "latent":
250
+ image = self.vae.decode(
251
+ denoised / self.vae.config.scaling_factor, return_dict=False)[0]
252
+ image, has_nsfw_concept = self.run_safety_checker(
253
+ image, device, prompt_embeds.dtype)
254
+ else:
255
+ image = denoised
256
+ has_nsfw_concept = None
257
+
258
+ if has_nsfw_concept is None:
259
+ do_denormalize = [True] * image.shape[0]
260
+ else:
261
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
262
+
263
+ image = self.image_processor.postprocess(
264
+ image, output_type=output_type, do_denormalize=do_denormalize)
265
+
266
+ if not return_dict:
267
+ return (image, has_nsfw_concept)
268
+
269
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
lcm/lcm_scheduler.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from diffusers import ConfigMixin, SchedulerMixin
26
+ from diffusers.configuration_utils import register_to_config
27
+ from diffusers.utils import BaseOutput
28
+
29
+
30
+ @dataclass
31
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
32
+ class LCMSchedulerOutput(BaseOutput):
33
+ """
34
+ Output class for the scheduler's `step` function output.
35
+
36
+ Args:
37
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
39
+ denoising loop.
40
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
42
+ `pred_original_sample` can be used to preview progress or for guidance.
43
+ """
44
+
45
+ prev_sample: torch.FloatTensor
46
+ denoised: Optional[torch.FloatTensor] = None
47
+
48
+
49
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
50
+ def betas_for_alpha_bar(
51
+ num_diffusion_timesteps,
52
+ max_beta=0.999,
53
+ alpha_transform_type="cosine",
54
+ ):
55
+ """
56
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
57
+ (1-beta) over time from t = [0,1].
58
+
59
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
60
+ to that part of the diffusion process.
61
+
62
+
63
+ Args:
64
+ num_diffusion_timesteps (`int`): the number of betas to produce.
65
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
66
+ prevent singularities.
67
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
68
+ Choose from `cosine` or `exp`
69
+
70
+ Returns:
71
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
72
+ """
73
+ if alpha_transform_type == "cosine":
74
+
75
+ def alpha_bar_fn(t):
76
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
77
+
78
+ elif alpha_transform_type == "exp":
79
+
80
+ def alpha_bar_fn(t):
81
+ return math.exp(t * -12.0)
82
+
83
+ else:
84
+ raise ValueError(
85
+ f"Unsupported alpha_tranform_type: {alpha_transform_type}")
86
+
87
+ betas = []
88
+ for i in range(num_diffusion_timesteps):
89
+ t1 = i / num_diffusion_timesteps
90
+ t2 = (i + 1) / num_diffusion_timesteps
91
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
92
+ return torch.tensor(betas, dtype=torch.float32)
93
+
94
+
95
+ def rescale_zero_terminal_snr(betas):
96
+ """
97
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
98
+
99
+
100
+ Args:
101
+ betas (`torch.FloatTensor`):
102
+ the betas that the scheduler is being initialized with.
103
+
104
+ Returns:
105
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
106
+ """
107
+ # Convert betas to alphas_bar_sqrt
108
+ alphas = 1.0 - betas
109
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
110
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
111
+
112
+ # Store old values.
113
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
114
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
115
+
116
+ # Shift so the last timestep is zero.
117
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
118
+
119
+ # Scale so the first timestep is back to the old value.
120
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / \
121
+ (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
122
+
123
+ # Convert alphas_bar_sqrt to betas
124
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
125
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
126
+ alphas = torch.cat([alphas_bar[0:1], alphas])
127
+ betas = 1 - alphas
128
+
129
+ return betas
130
+
131
+
132
+ class LCMScheduler(SchedulerMixin, ConfigMixin):
133
+ """
134
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
135
+ non-Markovian guidance.
136
+
137
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
138
+ methods the library implements for all schedulers such as loading and saving.
139
+
140
+ Args:
141
+ num_train_timesteps (`int`, defaults to 1000):
142
+ The number of diffusion steps to train the model.
143
+ beta_start (`float`, defaults to 0.0001):
144
+ The starting `beta` value of inference.
145
+ beta_end (`float`, defaults to 0.02):
146
+ The final `beta` value.
147
+ beta_schedule (`str`, defaults to `"linear"`):
148
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
149
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
150
+ trained_betas (`np.ndarray`, *optional*):
151
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
152
+ clip_sample (`bool`, defaults to `True`):
153
+ Clip the predicted sample for numerical stability.
154
+ clip_sample_range (`float`, defaults to 1.0):
155
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
156
+ set_alpha_to_one (`bool`, defaults to `True`):
157
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
158
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
159
+ otherwise it uses the alpha value at step 0.
160
+ steps_offset (`int`, defaults to 0):
161
+ An offset added to the inference steps. You can use a combination of `offset=1` and
162
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
163
+ Diffusion.
164
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
165
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
166
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
167
+ Video](https://imagen.research.google/video/paper.pdf) paper).
168
+ thresholding (`bool`, defaults to `False`):
169
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
170
+ as Stable Diffusion.
171
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
172
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
173
+ sample_max_value (`float`, defaults to 1.0):
174
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
175
+ timestep_spacing (`str`, defaults to `"leading"`):
176
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
177
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
178
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
179
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
180
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
181
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
182
+ """
183
+
184
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
185
+ order = 1
186
+
187
+ @register_to_config
188
+ def __init__(
189
+ self,
190
+ num_train_timesteps: int = 1000,
191
+ beta_start: float = 0.0001,
192
+ beta_end: float = 0.02,
193
+ beta_schedule: str = "linear",
194
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
195
+ clip_sample: bool = True,
196
+ set_alpha_to_one: bool = True,
197
+ steps_offset: int = 0,
198
+ prediction_type: str = "epsilon",
199
+ thresholding: bool = False,
200
+ dynamic_thresholding_ratio: float = 0.995,
201
+ clip_sample_range: float = 1.0,
202
+ sample_max_value: float = 1.0,
203
+ timestep_spacing: str = "leading",
204
+ rescale_betas_zero_snr: bool = False,
205
+ ):
206
+ if trained_betas is not None:
207
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
208
+ elif beta_schedule == "linear":
209
+ self.betas = torch.linspace(
210
+ beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
211
+ elif beta_schedule == "scaled_linear":
212
+ # this schedule is very specific to the latent diffusion model.
213
+ self.betas = (
214
+ torch.linspace(beta_start**0.5, beta_end**0.5,
215
+ num_train_timesteps, dtype=torch.float32) ** 2
216
+ )
217
+ elif beta_schedule == "squaredcos_cap_v2":
218
+ # Glide cosine schedule
219
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
220
+ else:
221
+ raise NotImplementedError(
222
+ f"{beta_schedule} does is not implemented for {self.__class__}")
223
+
224
+ # Rescale for zero SNR
225
+ if rescale_betas_zero_snr:
226
+ self.betas = rescale_zero_terminal_snr(self.betas)
227
+
228
+ self.alphas = 1.0 - self.betas
229
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
230
+
231
+ # At every step in ddim, we are looking into the previous alphas_cumprod
232
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
233
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
234
+ # whether we use the final alpha of the "non-previous" one.
235
+ self.final_alpha_cumprod = torch.tensor(
236
+ 1.0) if set_alpha_to_one else self.alphas_cumprod[0]
237
+
238
+ # standard deviation of the initial noise distribution
239
+ self.init_noise_sigma = 1.0
240
+
241
+ # setable values
242
+ self.num_inference_steps = None
243
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[
244
+ ::-1].copy().astype(np.int64))
245
+
246
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
247
+ """
248
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
249
+ current timestep.
250
+
251
+ Args:
252
+ sample (`torch.FloatTensor`):
253
+ The input sample.
254
+ timestep (`int`, *optional*):
255
+ The current timestep in the diffusion chain.
256
+
257
+ Returns:
258
+ `torch.FloatTensor`:
259
+ A scaled input sample.
260
+ """
261
+ return sample
262
+
263
+ def _get_variance(self, timestep, prev_timestep):
264
+ alpha_prod_t = self.alphas_cumprod[timestep]
265
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
266
+ beta_prod_t = 1 - alpha_prod_t
267
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
268
+
269
+ variance = (beta_prod_t_prev / beta_prod_t) * \
270
+ (1 - alpha_prod_t / alpha_prod_t_prev)
271
+
272
+ return variance
273
+
274
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
275
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
276
+ """
277
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
278
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
279
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
280
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
281
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
282
+
283
+ https://arxiv.org/abs/2205.11487
284
+ """
285
+ dtype = sample.dtype
286
+ batch_size, channels, height, width = sample.shape
287
+
288
+ if dtype not in (torch.float32, torch.float64):
289
+ # upcast for quantile calculation, and clamp not implemented for cpu half
290
+ sample = sample.float()
291
+
292
+ # Flatten sample for doing quantile calculation along each image
293
+ sample = sample.reshape(batch_size, channels * height * width)
294
+
295
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
296
+
297
+ s = torch.quantile(
298
+ abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
299
+ s = torch.clamp(
300
+ s, min=1, max=self.config.sample_max_value
301
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
302
+
303
+ # (batch_size, 1) because clamp will broadcast along dim=0
304
+ s = s.unsqueeze(1)
305
+ # "we threshold xt0 to the range [-s, s] and then divide by s"
306
+ sample = torch.clamp(sample, -s, s) / s
307
+
308
+ sample = sample.reshape(batch_size, channels, height, width)
309
+ sample = sample.to(dtype)
310
+
311
+ return sample
312
+
313
+ def set_timesteps(self, num_inference_steps: int, original_inference_steps: int, device: Union[str, torch.device] = None):
314
+ """
315
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
316
+
317
+ Args:
318
+ num_inference_steps (`int`):
319
+ The number of diffusion steps used when generating samples with a pre-trained model.
320
+ """
321
+
322
+ if num_inference_steps > self.config.num_train_timesteps:
323
+ raise ValueError(
324
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
325
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
326
+ f" maximal {self.config.num_train_timesteps} timesteps."
327
+ )
328
+
329
+ self.num_inference_steps = num_inference_steps
330
+
331
+ # LCM Timesteps Setting: # Linear Spacing
332
+ c = self.config.num_train_timesteps // original_inference_steps
333
+ lcm_origin_timesteps = np.asarray(
334
+ list(range(1, original_inference_steps + 1))) * c - 1 # LCM Training Steps Schedule
335
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
336
+ # LCM Inference Steps Schedule
337
+ timesteps = lcm_origin_timesteps[::-
338
+ skipping_step][:num_inference_steps]
339
+
340
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
341
+
342
+ def get_scalings_for_boundary_condition_discrete(self, t):
343
+ self.sigma_data = 0.5 # Default: 0.5
344
+
345
+ # By dividing 0.1: This is almost a delta function at t=0.
346
+ c_skip = self.sigma_data**2 / (
347
+ (t / 0.1) ** 2 + self.sigma_data**2
348
+ )
349
+ c_out = ((t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5)
350
+ return c_skip, c_out
351
+
352
+ def step(
353
+ self,
354
+ model_output: torch.FloatTensor,
355
+ timeindex: int,
356
+ timestep: int,
357
+ sample: torch.FloatTensor,
358
+ eta: float = 0.0,
359
+ use_clipped_model_output: bool = False,
360
+ generator=None,
361
+ variance_noise: Optional[torch.FloatTensor] = None,
362
+ return_dict: bool = True,
363
+ ) -> Union[LCMSchedulerOutput, Tuple]:
364
+ """
365
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
366
+ process from the learned model outputs (most often the predicted noise).
367
+
368
+ Args:
369
+ model_output (`torch.FloatTensor`):
370
+ The direct output from learned diffusion model.
371
+ timestep (`float`):
372
+ The current discrete timestep in the diffusion chain.
373
+ sample (`torch.FloatTensor`):
374
+ A current instance of a sample created by the diffusion process.
375
+ eta (`float`):
376
+ The weight of noise for added noise in diffusion step.
377
+ use_clipped_model_output (`bool`, defaults to `False`):
378
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
379
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
380
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
381
+ `use_clipped_model_output` has no effect.
382
+ generator (`torch.Generator`, *optional*):
383
+ A random number generator.
384
+ variance_noise (`torch.FloatTensor`):
385
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
386
+ itself. Useful for methods such as [`CycleDiffusion`].
387
+ return_dict (`bool`, *optional*, defaults to `True`):
388
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
389
+
390
+ Returns:
391
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
392
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
393
+ tuple is returned where the first element is the sample tensor.
394
+
395
+ """
396
+ if self.num_inference_steps is None:
397
+ raise ValueError(
398
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
399
+ )
400
+
401
+ # 1. get previous step value
402
+ prev_timeindex = timeindex + 1
403
+ if prev_timeindex < len(self.timesteps):
404
+ prev_timestep = self.timesteps[prev_timeindex]
405
+ else:
406
+ prev_timestep = timestep
407
+
408
+ # 2. compute alphas, betas
409
+ alpha_prod_t = self.alphas_cumprod[timestep]
410
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
411
+
412
+ beta_prod_t = 1 - alpha_prod_t
413
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
414
+
415
+ # 3. Get scalings for boundary conditions
416
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(
417
+ timestep)
418
+
419
+ # 4. Different Parameterization:
420
+ parameterization = self.config.prediction_type
421
+
422
+ if parameterization == "epsilon": # noise-prediction
423
+ pred_x0 = (sample - beta_prod_t.sqrt() *
424
+ model_output) / alpha_prod_t.sqrt()
425
+
426
+ elif parameterization == "sample": # x-prediction
427
+ pred_x0 = model_output
428
+
429
+ elif parameterization == "v_prediction": # v-prediction
430
+ pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
431
+
432
+ # 4. Denoise model output using boundary conditions
433
+ denoised = c_out * pred_x0 + c_skip * sample
434
+
435
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
436
+ # Noise is not used for one-step sampling.
437
+ if len(self.timesteps) > 1:
438
+ noise = torch.randn(model_output.shape).to(model_output.device)
439
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
440
+ else:
441
+ prev_sample = denoised
442
+
443
+ if not return_dict:
444
+ return (prev_sample, denoised)
445
+
446
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
447
+
448
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
449
+
450
+ def add_noise(
451
+ self,
452
+ original_samples: torch.FloatTensor,
453
+ noise: torch.FloatTensor,
454
+ timesteps: torch.IntTensor,
455
+ ) -> torch.FloatTensor:
456
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
457
+ alphas_cumprod = self.alphas_cumprod.to(
458
+ device=original_samples.device, dtype=original_samples.dtype)
459
+ timesteps = timesteps.to(original_samples.device)
460
+
461
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
462
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
463
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
464
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
465
+
466
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
467
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
468
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
469
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
470
+
471
+ noisy_samples = sqrt_alpha_prod * original_samples + \
472
+ sqrt_one_minus_alpha_prod * noise
473
+ return noisy_samples
474
+
475
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
476
+ def get_velocity(
477
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
478
+ ) -> torch.FloatTensor:
479
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
480
+ alphas_cumprod = self.alphas_cumprod.to(
481
+ device=sample.device, dtype=sample.dtype)
482
+ timesteps = timesteps.to(sample.device)
483
+
484
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
485
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
486
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
487
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
488
+
489
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
490
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
491
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
492
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
493
+
494
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
495
+ return velocity
496
+
497
+ def __len__(self):
498
+ return self.config.num_train_timesteps
scripts/main.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ from pathlib import Path
3
+ from typing import Optional
4
+ import uuid
5
+ from lcm.lcm_scheduler import LCMScheduler
6
+ from lcm.lcm_pipeline import LatentConsistencyModelPipeline
7
+ from lcm.lcm_i2i_pipeline import LatentConsistencyModelImg2ImgPipeline, LCMSchedulerWithTimestamp
8
+ from diffusers.image_processor import PipelineImageInput
9
+ # import modules.scripts as scripts
10
+ # import modules.shared
11
+ # from modules import script_callbacks
12
+ import os
13
+ import random
14
+ import time
15
+ import numpy as np
16
+ import gradio as gr
17
+ from PIL import Image, PngImagePlugin
18
+ import torch
19
+
20
+ scheduler = LCMScheduler.from_pretrained(
21
+ "SimianLuo/LCM_Dreamshaper_v7", subfolder="scheduler")
22
+
23
+ pipe = LatentConsistencyModelPipeline.from_pretrained(
24
+ "SimianLuo/LCM_Dreamshaper_v7", scheduler = scheduler, safety_checker = None)
25
+
26
+
27
+
28
+ DESCRIPTION = '''# Latent Consistency Model
29
+ Running [LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) | [Project Page](https://latent-consistency-models.github.io) | [Extension Page](https://github.com/0xbitches/sd-webui-lcm)
30
+ '''
31
+
32
+ MAX_SEED = np.iinfo(np.int32).max
33
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "768"))
34
+
35
+
36
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
37
+ if randomize_seed:
38
+ seed = random.randint(0, MAX_SEED)
39
+ return seed
40
+
41
+
42
+ def save_image(img, metadata: dict):
43
+ save_dir = './outputs/LCM-txt2img/'
44
+ Path(save_dir).mkdir(exist_ok=True, parents=True)
45
+ seed = metadata["seed"]
46
+ unique_id = uuid.uuid4()
47
+ filename = save_dir + f"{unique_id}-{seed}" + ".png"
48
+
49
+ meta_tuples = [(k, str(v)) for k, v in metadata.items()]
50
+ png_info = PngImagePlugin.PngInfo()
51
+ for k, v in meta_tuples:
52
+ png_info.add_text(k, v)
53
+ img.save(filename, pnginfo=png_info)
54
+
55
+ return filename
56
+
57
+
58
+ def save_images(image_array, metadata: dict):
59
+ paths = []
60
+ with ThreadPoolExecutor() as executor:
61
+ paths = list(executor.map(save_image, image_array,
62
+ [metadata]*len(image_array)))
63
+ return paths
64
+
65
+
66
+ def generate(
67
+ prompt: str,
68
+ seed: int = 0,
69
+ width: int = 512,
70
+ height: int = 512,
71
+ guidance_scale: float = 8.0,
72
+ num_inference_steps: int = 4,
73
+ num_images: int = 4,
74
+ randomize_seed: bool = False,
75
+ use_fp16: bool = True,
76
+ use_torch_compile: bool = False,
77
+ use_cpu: bool = False,
78
+ progress=gr.Progress(track_tqdm=True)
79
+ ) -> Image.Image:
80
+ seed = randomize_seed_fn(seed, randomize_seed)
81
+ torch.manual_seed(seed)
82
+
83
+ selected_device = 'cuda'
84
+ if use_cpu:
85
+ selected_device = "cpu"
86
+ if use_fp16:
87
+ use_fp16 = False
88
+ print("LCM warning: running on CPU, overrode FP16 with FP32")
89
+ global pipe, scheduler
90
+ pipe = LatentConsistencyModelPipeline(
91
+ vae= pipe.vae,
92
+ text_encoder = pipe.text_encoder,
93
+ tokenizer = pipe.tokenizer,
94
+ unet = pipe.unet,
95
+ scheduler = scheduler,
96
+ safety_checker = pipe.safety_checker,
97
+ feature_extractor = pipe.feature_extractor,
98
+ )
99
+ # pipe = LatentConsistencyModelPipeline.from_pretrained(
100
+ # "SimianLuo/LCM_Dreamshaper_v7", scheduler = scheduler, safety_checker = None)
101
+
102
+ if use_fp16:
103
+ pipe.to(torch_device=selected_device, torch_dtype=torch.float16)
104
+ else:
105
+ pipe.to(torch_device=selected_device, torch_dtype=torch.float32)
106
+
107
+ # Windows does not support torch.compile for now
108
+ if os.name != 'nt' and use_torch_compile:
109
+ pipe.unet = torch.compile(pipe.unet, mode='max-autotune')
110
+
111
+ start_time = time.time()
112
+ result = pipe(
113
+ prompt=prompt,
114
+ width=width,
115
+ height=height,
116
+ guidance_scale=guidance_scale,
117
+ num_inference_steps=num_inference_steps,
118
+ num_images_per_prompt=num_images,
119
+ original_inference_steps=50,
120
+ output_type="pil",
121
+ device = selected_device
122
+ ).images
123
+ paths = save_images(result, metadata={"prompt": prompt, "seed": seed, "width": width,
124
+ "height": height, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps})
125
+
126
+ elapsed_time = time.time() - start_time
127
+ print("LCM inference time: ", elapsed_time, "seconds")
128
+ return paths, seed
129
+
130
+
131
+ def generate_i2i(
132
+ prompt: str,
133
+ image: PipelineImageInput = None,
134
+ strength: float = 0.8,
135
+ seed: int = 0,
136
+ guidance_scale: float = 8.0,
137
+ num_inference_steps: int = 4,
138
+ num_images: int = 4,
139
+ randomize_seed: bool = False,
140
+ use_fp16: bool = True,
141
+ use_torch_compile: bool = False,
142
+ use_cpu: bool = False,
143
+ progress=gr.Progress(track_tqdm=True),
144
+ width: Optional[int] = 512,
145
+ height: Optional[int] = 512,
146
+ ) -> Image.Image:
147
+ seed = randomize_seed_fn(seed, randomize_seed)
148
+ torch.manual_seed(seed)
149
+
150
+ selected_device = 'cuda'
151
+ if use_cpu:
152
+ selected_device = "cpu"
153
+ if use_fp16:
154
+ use_fp16 = False
155
+ print("LCM warning: running on CPU, overrode FP16 with FP32")
156
+ global pipe, scheduler
157
+ pipe = LatentConsistencyModelImg2ImgPipeline(
158
+ vae= pipe.vae,
159
+ text_encoder = pipe.text_encoder,
160
+ tokenizer = pipe.tokenizer,
161
+ unet = pipe.unet,
162
+ scheduler = None, #scheduler,
163
+ safety_checker = pipe.safety_checker,
164
+ feature_extractor = pipe.feature_extractor,
165
+ requires_safety_checker = False,
166
+ )
167
+ # pipe = LatentConsistencyModelImg2ImgPipeline.from_pretrained(
168
+ # "SimianLuo/LCM_Dreamshaper_v7", safety_checker = None)
169
+
170
+ if use_fp16:
171
+ pipe.to(torch_device=selected_device, torch_dtype=torch.float16)
172
+ else:
173
+ pipe.to(torch_device=selected_device, torch_dtype=torch.float32)
174
+
175
+ # Windows does not support torch.compile for now
176
+ if os.name != 'nt' and use_torch_compile:
177
+ pipe.unet = torch.compile(pipe.unet, mode='max-autotune')
178
+
179
+ width, height = image.size
180
+
181
+ start_time = time.time()
182
+ result = pipe(
183
+ prompt=prompt,
184
+ image=image,
185
+ strength=strength,
186
+ width=width,
187
+ height=height,
188
+ guidance_scale=guidance_scale,
189
+ num_inference_steps=num_inference_steps,
190
+ num_images_per_prompt=num_images,
191
+ original_inference_steps=50,
192
+ output_type="pil",
193
+ device = selected_device
194
+ ).images
195
+ paths = save_images(result, metadata={"prompt": prompt, "seed": seed, "width": width,
196
+ "height": height, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps})
197
+
198
+ elapsed_time = time.time() - start_time
199
+ print("LCM inference time: ", elapsed_time, "seconds")
200
+ return paths, seed
201
+
202
+ import cv2
203
+
204
+ def video_to_frames(video_path):
205
+ # Open the video file
206
+ cap = cv2.VideoCapture(video_path)
207
+
208
+ # Check if the video opened successfully
209
+ if not cap.isOpened():
210
+ print("Error: LCM Could not open video.")
211
+ return
212
+
213
+ # Read frames from the video
214
+ pil_images = []
215
+ while True:
216
+ ret, frame = cap.read()
217
+ if not ret:
218
+ break
219
+
220
+ # Convert BGR to RGB (OpenCV uses BGR by default)
221
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
222
+
223
+ # Convert numpy array to PIL Image
224
+ pil_image = Image.fromarray(rgb_frame)
225
+
226
+ # Append the PIL Image to the list
227
+ pil_images.append(pil_image)
228
+
229
+ # Release the video capture object
230
+ cap.release()
231
+
232
+ return pil_images
233
+
234
+ def frames_to_video(pil_images, output_path, fps):
235
+ if not pil_images:
236
+ print("Error: No images to convert.")
237
+ return
238
+
239
+ img_array = []
240
+ for pil_image in pil_images:
241
+ img_array.append(np.array(pil_image))
242
+
243
+ height, width, layers = img_array[0].shape
244
+ size = (width, height)
245
+
246
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
247
+ for i in range(len(img_array)):
248
+ out.write(cv2.cvtColor(img_array[i], cv2.COLOR_RGB2BGR))
249
+ out.release()
250
+
251
+ def generate_v2v(
252
+ prompt: str,
253
+ video: any = None,
254
+ strength: float = 0.8,
255
+ seed: int = 0,
256
+ guidance_scale: float = 8.0,
257
+ num_inference_steps: int = 4,
258
+ randomize_seed: bool = False,
259
+ use_fp16: bool = True,
260
+ use_torch_compile: bool = False,
261
+ use_cpu: bool = False,
262
+ fps: int = 10,
263
+ save_frames: bool = False,
264
+ # progress=gr.Progress(track_tqdm=True),
265
+ width: Optional[int] = 512,
266
+ height: Optional[int] = 512,
267
+ num_images: Optional[int] = 1,
268
+ ) -> Image.Image:
269
+ seed = randomize_seed_fn(seed, randomize_seed)
270
+ torch.manual_seed(seed)
271
+
272
+ selected_device = 'cuda'
273
+ if use_cpu:
274
+ selected_device = "cpu"
275
+ if use_fp16:
276
+ use_fp16 = False
277
+ print("LCM warning: running on CPU, overrode FP16 with FP32")
278
+ global pipe, scheduler
279
+ pipe = LatentConsistencyModelImg2ImgPipeline(
280
+ vae= pipe.vae,
281
+ text_encoder = pipe.text_encoder,
282
+ tokenizer = pipe.tokenizer,
283
+ unet = pipe.unet,
284
+ scheduler = None,
285
+ safety_checker = pipe.safety_checker,
286
+ feature_extractor = pipe.feature_extractor,
287
+ requires_safety_checker = False,
288
+ )
289
+ # pipe = LatentConsistencyModelImg2ImgPipeline.from_pretrained(
290
+ # "SimianLuo/LCM_Dreamshaper_v7", safety_checker = None)
291
+
292
+ if use_fp16:
293
+ pipe.to(torch_device=selected_device, torch_dtype=torch.float16)
294
+ else:
295
+ pipe.to(torch_device=selected_device, torch_dtype=torch.float32)
296
+
297
+ # Windows does not support torch.compile for now
298
+ if os.name != 'nt' and use_torch_compile:
299
+ pipe.unet = torch.compile(pipe.unet, mode='max-autotune')
300
+
301
+ frames = video_to_frames(video)
302
+ if frames is None:
303
+ print("Error: LCM could not convert video.")
304
+ return
305
+ width, height = frames[0].size
306
+
307
+ start_time = time.time()
308
+
309
+ results = []
310
+ for frame in frames:
311
+ result = pipe(
312
+ prompt=prompt,
313
+ image=frame,
314
+ strength=strength,
315
+ width=width,
316
+ height=height,
317
+ guidance_scale=guidance_scale,
318
+ num_inference_steps=num_inference_steps,
319
+ num_images_per_prompt=1,
320
+ original_inference_steps=50,
321
+ output_type="pil",
322
+ device = selected_device
323
+ ).images
324
+ if save_frames:
325
+ paths = save_images(result, metadata={"prompt": prompt, "seed": seed, "width": width,
326
+ "height": height, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps})
327
+ results.extend(result)
328
+
329
+ elapsed_time = time.time() - start_time
330
+ print("LCM vid2vid inference complete! Processing", len(frames), "frames took", elapsed_time, "seconds")
331
+
332
+ save_dir = './outputs/LCM-vid2vid/'
333
+ Path(save_dir).mkdir(exist_ok=True, parents=True)
334
+ unique_id = uuid.uuid4()
335
+ _, input_ext = os.path.splitext(video)
336
+ output_path = save_dir + f"{unique_id}-{seed}" + f"{input_ext}"
337
+ frames_to_video(results, output_path, fps)
338
+ return output_path
339
+
340
+
341
+
342
+ examples = [
343
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
344
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
345
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
346
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
347
+ ]
348
+
349
+ with gr.Blocks() as lcm:
350
+ with gr.Tab("LCM txt2img"):
351
+ #gr.Markdown(DESCRIPTION)
352
+ with gr.Row():
353
+ prompt = gr.Textbox(label="Prompt",
354
+ show_label=False,
355
+ lines=3,
356
+ placeholder="Prompt",
357
+ elem_classes=["prompt"])
358
+ run_button = gr.Button("Run", scale=0)
359
+ with gr.Row():
360
+ result = gr.Gallery(
361
+ label="Generated images", show_label=False, elem_id="gallery", grid=[2], preview=True
362
+ )
363
+
364
+ with gr.Accordion("Advanced options", open=False):
365
+ seed = gr.Slider(
366
+ label="Seed",
367
+ minimum=0,
368
+ maximum=MAX_SEED,
369
+ step=1,
370
+ value=0,
371
+ randomize=True
372
+ )
373
+ randomize_seed = gr.Checkbox(
374
+ label="Randomize seed across runs", value=True)
375
+ use_fp16 = gr.Checkbox(
376
+ label="Run LCM in fp16 (for lower VRAM)", value=False)
377
+ use_torch_compile = gr.Checkbox(
378
+ label="Run LCM with torch.compile (currently not supported on Windows)", value=False)
379
+ use_cpu = gr.Checkbox(label="Run LCM on CPU", value=True)
380
+ with gr.Row():
381
+ width = gr.Slider(
382
+ label="Width",
383
+ minimum=256,
384
+ maximum=MAX_IMAGE_SIZE,
385
+ step=32,
386
+ value=512,
387
+ )
388
+ height = gr.Slider(
389
+ label="Height",
390
+ minimum=256,
391
+ maximum=MAX_IMAGE_SIZE,
392
+ step=32,
393
+ value=512,
394
+ )
395
+ with gr.Row():
396
+ guidance_scale = gr.Slider(
397
+ label="Guidance scale for base",
398
+ minimum=2,
399
+ maximum=14,
400
+ step=0.1,
401
+ value=8.0,
402
+ )
403
+ num_inference_steps = gr.Slider(
404
+ label="Number of inference steps for base",
405
+ minimum=1,
406
+ maximum=8,
407
+ step=1,
408
+ value=4,
409
+ )
410
+ with gr.Row():
411
+ num_images = gr.Slider(
412
+ label="Number of images (batch count)",
413
+ minimum=1,
414
+ maximum=int(os.getenv("MAX_NUM_IMAGES")),
415
+ step=1,
416
+ value=1,
417
+ )
418
+
419
+ gr.Examples(
420
+ examples=examples,
421
+ inputs=prompt,
422
+ outputs=result,
423
+ fn=generate
424
+ )
425
+
426
+ run_button.click(
427
+ fn=generate,
428
+ inputs=[
429
+ prompt,
430
+ seed,
431
+ width,
432
+ height,
433
+ guidance_scale,
434
+ num_inference_steps,
435
+ num_images,
436
+ randomize_seed,
437
+ use_fp16,
438
+ use_torch_compile,
439
+ use_cpu
440
+ ],
441
+ outputs=[result, seed],
442
+ )
443
+
444
+ with gr.Tab("LCM img2img"):
445
+ with gr.Row():
446
+ prompt = gr.Textbox(label="Prompt",
447
+ show_label=False,
448
+ lines=3,
449
+ placeholder="Prompt",
450
+ elem_classes=["prompt"])
451
+ run_i2i_button = gr.Button("Run", scale=0)
452
+ with gr.Row():
453
+ image_input = gr.Image(label="Upload your Image", type="pil")
454
+ result = gr.Gallery(
455
+ label="Generated images",
456
+ show_label=False,
457
+ elem_id="gallery",
458
+ preview=True
459
+ )
460
+
461
+ with gr.Accordion("Advanced options", open=False):
462
+ seed = gr.Slider(
463
+ label="Seed",
464
+ minimum=0,
465
+ maximum=MAX_SEED,
466
+ step=1,
467
+ value=0,
468
+ randomize=True
469
+ )
470
+ randomize_seed = gr.Checkbox(
471
+ label="Randomize seed across runs", value=True)
472
+ use_fp16 = gr.Checkbox(
473
+ label="Run LCM in fp16 (for lower VRAM)", value=False)
474
+ use_torch_compile = gr.Checkbox(
475
+ label="Run LCM with torch.compile (currently not supported on Windows)", value=False)
476
+ use_cpu = gr.Checkbox(label="Run LCM on CPU", value=True)
477
+ with gr.Row():
478
+ guidance_scale = gr.Slider(
479
+ label="Guidance scale for base",
480
+ minimum=2,
481
+ maximum=14,
482
+ step=0.1,
483
+ value=8.0,
484
+ )
485
+ num_inference_steps = gr.Slider(
486
+ label="Number of inference steps for base",
487
+ minimum=1,
488
+ maximum=8,
489
+ step=1,
490
+ value=4,
491
+ )
492
+ with gr.Row():
493
+ num_images = gr.Slider(
494
+ label="Number of images (batch count)",
495
+ minimum=1,
496
+ maximum=int(os.getenv("MAX_NUM_IMAGES")),
497
+ step=1,
498
+ value=1,
499
+ )
500
+ strength = gr.Slider(
501
+ label="Prompt Strength",
502
+ minimum=0.1,
503
+ maximum=1.0,
504
+ step=0.1,
505
+ value=0.5,
506
+ )
507
+
508
+ run_i2i_button.click(
509
+ fn=generate_i2i,
510
+ inputs=[
511
+ prompt,
512
+ image_input,
513
+ strength,
514
+ seed,
515
+ guidance_scale,
516
+ num_inference_steps,
517
+ num_images,
518
+ randomize_seed,
519
+ use_fp16,
520
+ use_torch_compile,
521
+ use_cpu
522
+ ],
523
+ outputs=[result, seed],
524
+ )
525
+
526
+
527
+ with gr.Tab("LCM vid2vid"):
528
+
529
+ show_v2v = False if os.getenv("SHOW_VID2VID") == "NO" else True
530
+ gr.Markdown("Not recommended for use with CPU. Duplicate the space and modify SHOW_VID2VID to enable it. 🚫💻")
531
+ with gr.Tabs(visible=show_v2v) as tabs:
532
+ #with gr.Tab("", visible=show_v2v):
533
+
534
+ with gr.Row():
535
+ prompt = gr.Textbox(label="Prompt",
536
+ show_label=False,
537
+ lines=3,
538
+ placeholder="Prompt",
539
+ elem_classes=["prompt"])
540
+ run_v2v_button = gr.Button("Run", scale=0)
541
+ with gr.Row():
542
+ video_input = gr.Video(label="Source Video")
543
+ video_output = gr.Video(label="Generated Video")
544
+
545
+ with gr.Accordion("Advanced options", open=False):
546
+ seed = gr.Slider(
547
+ label="Seed",
548
+ minimum=0,
549
+ maximum=MAX_SEED,
550
+ step=1,
551
+ value=0,
552
+ randomize=True
553
+ )
554
+ randomize_seed = gr.Checkbox(
555
+ label="Randomize seed across runs", value=True)
556
+ use_fp16 = gr.Checkbox(
557
+ label="Run LCM in fp16 (for lower VRAM)", value=False)
558
+ use_torch_compile = gr.Checkbox(
559
+ label="Run LCM with torch.compile (currently not supported on Windows)", value=False)
560
+ use_cpu = gr.Checkbox(label="Run LCM on CPU", value=True)
561
+ save_frames = gr.Checkbox(label="Save intermediate frames", value=False)
562
+ with gr.Row():
563
+ guidance_scale = gr.Slider(
564
+ label="Guidance scale for base",
565
+ minimum=2,
566
+ maximum=14,
567
+ step=0.1,
568
+ value=8.0,
569
+ )
570
+ num_inference_steps = gr.Slider(
571
+ label="Number of inference steps for base",
572
+ minimum=1,
573
+ maximum=8,
574
+ step=1,
575
+ value=4,
576
+ )
577
+ with gr.Row():
578
+ fps = gr.Slider(
579
+ label="Output FPS",
580
+ minimum=1,
581
+ maximum=200,
582
+ step=1,
583
+ value=10,
584
+ )
585
+ strength = gr.Slider(
586
+ label="Prompt Strength",
587
+ minimum=0.1,
588
+ maximum=1.0,
589
+ step=0.05,
590
+ value=0.5,
591
+ )
592
+
593
+ run_v2v_button.click(
594
+ fn=generate_v2v,
595
+ inputs=[
596
+ prompt,
597
+ video_input,
598
+ strength,
599
+ seed,
600
+ guidance_scale,
601
+ num_inference_steps,
602
+ randomize_seed,
603
+ use_fp16,
604
+ use_torch_compile,
605
+ use_cpu,
606
+ fps,
607
+ save_frames
608
+ ],
609
+ outputs=video_output,
610
+ )
611
+
612
+ if __name__ == "__main__":
613
+ lcm.queue().launch()