skytnt commited on
Commit
070027b
1 Parent(s): 5c94345

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +62 -6
pipeline.py CHANGED
@@ -12,7 +12,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
12
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
  from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15
- from diffusers.utils import deprecate, logging
16
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
17
 
18
 
@@ -40,7 +40,7 @@ re_attention = re.compile(
40
 
41
  def parse_prompt_attention(text):
42
  """
43
- Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
44
  Accepted tokens are:
45
  (abc) - increases attention to abc by a multiplier of 1.1
46
  (abc:3.12) - increases attention to abc by a multiplier of 3.12
@@ -237,9 +237,9 @@ def get_weighted_text_embeddings(
237
  r"""
238
  Prompts can be assigned with local weights using brackets. For example,
239
  prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
240
- and the embedding tokens corresponding to the words get multipled by a constant, 1.1.
241
 
242
- Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the origional mean.
243
 
244
  Args:
245
  pipe (`DiffusionPipeline`):
@@ -431,6 +431,19 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
431
  new_config["steps_offset"] = 1
432
  scheduler._internal_dict = FrozenDict(new_config)
433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  if safety_checker is None:
435
  logger.warn(
436
  f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
@@ -451,6 +464,24 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
451
  feature_extractor=feature_extractor,
452
  )
453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
455
  r"""
456
  Enable sliced attention computation.
@@ -478,6 +509,23 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
478
  # set slice_size = `None` to disable `attention slicing`
479
  self.enable_attention_slicing(None)
480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  @torch.no_grad()
482
  def __call__(
483
  self,
@@ -498,6 +546,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
498
  output_type: Optional[str] = "pil",
499
  return_dict: bool = True,
500
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
501
  callback_steps: Optional[int] = 1,
502
  **kwargs,
503
  ):
@@ -560,11 +609,15 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
560
  callback (`Callable`, *optional*):
561
  A function that will be called every `callback_steps` steps during inference. The function will be
562
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
 
 
 
563
  callback_steps (`int`, *optional*, defaults to 1):
564
  The frequency at which the `callback` function will be called. If not specified, the callback will be
565
  called at every step.
566
 
567
  Returns:
 
568
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
569
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
570
  When returning a tuple, the first element is a list with the generated images, and the second element is a
@@ -757,8 +810,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
757
  latents = (init_latents_proper * mask) + (latents * (1 - mask))
758
 
759
  # call the callback, if provided
760
- if callback is not None and i % callback_steps == 0:
761
- callback(i, t, latents)
 
 
 
762
 
763
  latents = 1 / 0.18215 * latents
764
  image = self.vae.decode(latents).sample
 
12
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
  from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15
+ from diffusers.utils import deprecate, is_accelerate_available, logging
16
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
17
 
18
 
 
40
 
41
  def parse_prompt_attention(text):
42
  """
43
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
44
  Accepted tokens are:
45
  (abc) - increases attention to abc by a multiplier of 1.1
46
  (abc:3.12) - increases attention to abc by a multiplier of 3.12
 
237
  r"""
238
  Prompts can be assigned with local weights using brackets. For example,
239
  prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
240
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
241
 
242
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
243
 
244
  Args:
245
  pipe (`DiffusionPipeline`):
 
431
  new_config["steps_offset"] = 1
432
  scheduler._internal_dict = FrozenDict(new_config)
433
 
434
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
435
+ deprecation_message = (
436
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
437
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
438
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
439
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
440
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
441
+ )
442
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
443
+ new_config = dict(scheduler.config)
444
+ new_config["clip_sample"] = False
445
+ scheduler._internal_dict = FrozenDict(new_config)
446
+
447
  if safety_checker is None:
448
  logger.warn(
449
  f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
 
464
  feature_extractor=feature_extractor,
465
  )
466
 
467
+ def enable_xformers_memory_efficient_attention(self):
468
+ r"""
469
+ Enable memory efficient attention as implemented in xformers.
470
+
471
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
472
+ time. Speed up at training time is not guaranteed.
473
+
474
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
475
+ is used.
476
+ """
477
+ self.unet.set_use_memory_efficient_attention_xformers(True)
478
+
479
+ def disable_xformers_memory_efficient_attention(self):
480
+ r"""
481
+ Disable memory efficient attention as implemented in xformers.
482
+ """
483
+ self.unet.set_use_memory_efficient_attention_xformers(False)
484
+
485
  def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
486
  r"""
487
  Enable sliced attention computation.
 
509
  # set slice_size = `None` to disable `attention slicing`
510
  self.enable_attention_slicing(None)
511
 
512
+ def enable_sequential_cpu_offload(self):
513
+ r"""
514
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
515
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
516
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
517
+ """
518
+ if is_accelerate_available():
519
+ from accelerate import cpu_offload
520
+ else:
521
+ raise ImportError("Please install accelerate via `pip install accelerate`")
522
+
523
+ device = self.device
524
+
525
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
526
+ if cpu_offloaded_model is not None:
527
+ cpu_offload(cpu_offloaded_model, device)
528
+
529
  @torch.no_grad()
530
  def __call__(
531
  self,
 
546
  output_type: Optional[str] = "pil",
547
  return_dict: bool = True,
548
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
549
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
550
  callback_steps: Optional[int] = 1,
551
  **kwargs,
552
  ):
 
609
  callback (`Callable`, *optional*):
610
  A function that will be called every `callback_steps` steps during inference. The function will be
611
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
612
+ is_cancelled_callback (`Callable`, *optional*):
613
+ A function that will be called every `callback_steps` steps during inference. If the function returns
614
+ `True`, the inference will be cancelled.
615
  callback_steps (`int`, *optional*, defaults to 1):
616
  The frequency at which the `callback` function will be called. If not specified, the callback will be
617
  called at every step.
618
 
619
  Returns:
620
+ `None` if cancelled by `is_cancelled_callback`,
621
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
622
  [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
623
  When returning a tuple, the first element is a list with the generated images, and the second element is a
 
810
  latents = (init_latents_proper * mask) + (latents * (1 - mask))
811
 
812
  # call the callback, if provided
813
+ if i % callback_steps == 0:
814
+ if callback is not None:
815
+ callback(i, t, latents)
816
+ if is_cancelled_callback is not None and is_cancelled_callback():
817
+ return None
818
 
819
  latents = 1 / 0.18215 * latents
820
  image = self.vae.decode(latents).sample