skytnt commited on
Commit
bbc7326
1 Parent(s): 625683c

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +120 -53
pipeline.py CHANGED
@@ -5,14 +5,37 @@ from typing import Callable, List, Optional, Union
5
  import numpy as np
6
  import torch
7
 
 
8
  import PIL
9
  from diffusers import SchedulerMixin, StableDiffusionPipeline
10
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
12
- from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
 
13
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
 
18
  re_attention = re.compile(
@@ -404,27 +427,75 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
404
  Model that extracts features from generated images to be used as inputs for the `safety_checker`.
405
  """
406
 
407
- def __init__(
408
- self,
409
- vae: AutoencoderKL,
410
- text_encoder: CLIPTextModel,
411
- tokenizer: CLIPTokenizer,
412
- unet: UNet2DConditionModel,
413
- scheduler: SchedulerMixin,
414
- safety_checker: StableDiffusionSafetyChecker,
415
- feature_extractor: CLIPFeatureExtractor,
416
- requires_safety_checker: bool = True,
417
- ):
418
- super().__init__(
419
- vae=vae,
420
- text_encoder=text_encoder,
421
- tokenizer=tokenizer,
422
- unet=unet,
423
- scheduler=scheduler,
424
- safety_checker=safety_checker,
425
- feature_extractor=feature_extractor,
426
- requires_safety_checker=requires_safety_checker,
427
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
  def _encode_prompt(
430
  self,
@@ -752,37 +823,33 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
752
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
753
 
754
  # 8. Denoising loop
755
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
756
- with self.progress_bar(total=num_inference_steps) as progress_bar:
757
- for i, t in enumerate(timesteps):
758
- # expand the latents if we are doing classifier free guidance
759
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
760
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
761
-
762
- # predict the noise residual
763
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
764
-
765
- # perform guidance
766
- if do_classifier_free_guidance:
767
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
768
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
769
-
770
- # compute the previous noisy sample x_t -> x_t-1
771
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
772
-
773
- if mask is not None:
774
- # masking
775
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
776
- latents = (init_latents_proper * mask) + (latents * (1 - mask))
777
-
778
- # call the callback, if provided
779
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
780
- progress_bar.update()
781
- if i % callback_steps == 0:
782
- if callback is not None:
783
- callback(i, t, latents)
784
- if is_cancelled_callback is not None and is_cancelled_callback():
785
- return None
786
 
787
  # 9. Post-processing
788
  image = self.decode_latents(latents)
 
5
  import numpy as np
6
  import torch
7
 
8
+ import diffusers
9
  import PIL
10
  from diffusers import SchedulerMixin, StableDiffusionPipeline
11
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
13
+ from diffusers.utils import deprecate, logging
14
+ from packaging import version
15
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
16
 
17
 
18
+ try:
19
+ from diffusers.utils import PIL_INTERPOLATION
20
+ except ImportError:
21
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
22
+ PIL_INTERPOLATION = {
23
+ "linear": PIL.Image.Resampling.BILINEAR,
24
+ "bilinear": PIL.Image.Resampling.BILINEAR,
25
+ "bicubic": PIL.Image.Resampling.BICUBIC,
26
+ "lanczos": PIL.Image.Resampling.LANCZOS,
27
+ "nearest": PIL.Image.Resampling.NEAREST,
28
+ }
29
+ else:
30
+ PIL_INTERPOLATION = {
31
+ "linear": PIL.Image.LINEAR,
32
+ "bilinear": PIL.Image.BILINEAR,
33
+ "bicubic": PIL.Image.BICUBIC,
34
+ "lanczos": PIL.Image.LANCZOS,
35
+ "nearest": PIL.Image.NEAREST,
36
+ }
37
+ # ------------------------------------------------------------------------------
38
+
39
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
 
41
  re_attention = re.compile(
 
427
  Model that extracts features from generated images to be used as inputs for the `safety_checker`.
428
  """
429
 
430
+ if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
431
+
432
+ def __init__(
433
+ self,
434
+ vae: AutoencoderKL,
435
+ text_encoder: CLIPTextModel,
436
+ tokenizer: CLIPTokenizer,
437
+ unet: UNet2DConditionModel,
438
+ scheduler: SchedulerMixin,
439
+ safety_checker: StableDiffusionSafetyChecker,
440
+ feature_extractor: CLIPFeatureExtractor,
441
+ requires_safety_checker: bool = True,
442
+ ):
443
+ super().__init__(
444
+ vae=vae,
445
+ text_encoder=text_encoder,
446
+ tokenizer=tokenizer,
447
+ unet=unet,
448
+ scheduler=scheduler,
449
+ safety_checker=safety_checker,
450
+ feature_extractor=feature_extractor,
451
+ requires_safety_checker=requires_safety_checker,
452
+ )
453
+ self.__init__additional__()
454
+
455
+ else:
456
+
457
+ def __init__(
458
+ self,
459
+ vae: AutoencoderKL,
460
+ text_encoder: CLIPTextModel,
461
+ tokenizer: CLIPTokenizer,
462
+ unet: UNet2DConditionModel,
463
+ scheduler: SchedulerMixin,
464
+ safety_checker: StableDiffusionSafetyChecker,
465
+ feature_extractor: CLIPFeatureExtractor,
466
+ ):
467
+ super().__init__(
468
+ vae=vae,
469
+ text_encoder=text_encoder,
470
+ tokenizer=tokenizer,
471
+ unet=unet,
472
+ scheduler=scheduler,
473
+ safety_checker=safety_checker,
474
+ feature_extractor=feature_extractor,
475
+ )
476
+ self.__init__additional__()
477
+
478
+ def __init__additional__(self):
479
+ if not hasattr(self, "vae_scale_factor"):
480
+ setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
481
+
482
+ @property
483
+ def _execution_device(self):
484
+ r"""
485
+ Returns the device on which the pipeline's models will be executed. After calling
486
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
487
+ hooks.
488
+ """
489
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
490
+ return self.device
491
+ for module in self.unet.modules():
492
+ if (
493
+ hasattr(module, "_hf_hook")
494
+ and hasattr(module._hf_hook, "execution_device")
495
+ and module._hf_hook.execution_device is not None
496
+ ):
497
+ return torch.device(module._hf_hook.execution_device)
498
+ return self.device
499
 
500
  def _encode_prompt(
501
  self,
 
823
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
824
 
825
  # 8. Denoising loop
826
+ for i, t in enumerate(self.progress_bar(timesteps)):
827
+ # expand the latents if we are doing classifier free guidance
828
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
829
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
830
+
831
+ # predict the noise residual
832
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
833
+
834
+ # perform guidance
835
+ if do_classifier_free_guidance:
836
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
837
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
838
+
839
+ # compute the previous noisy sample x_t -> x_t-1
840
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
841
+
842
+ if mask is not None:
843
+ # masking
844
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
845
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
846
+
847
+ # call the callback, if provided
848
+ if i % callback_steps == 0:
849
+ if callback is not None:
850
+ callback(i, t, latents)
851
+ if is_cancelled_callback is not None and is_cancelled_callback():
852
+ return None
 
 
 
 
853
 
854
  # 9. Post-processing
855
  image = self.decode_latents(latents)