nyanko7 commited on
Commit
15e4f70
1 Parent(s): bbd02ca

Update modules/model.py

Browse files
Files changed (1) hide show
  1. modules/model.py +22 -0
modules/model.py CHANGED
@@ -39,6 +39,20 @@ exists = lambda val: val is not None
39
  default = lambda val, d: val if exists(val) else d
40
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def get_attention_scores(attn, query, key, attention_mask=None):
44
 
@@ -528,6 +542,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
528
  noise_pred = noise_pred_uncond + guidance_scale * (
529
  noise_pred_text - noise_pred_uncond
530
  )
 
 
 
 
531
  return noise_pred
532
 
533
  sampler_args = self.get_sampler_extra_args_i2i(sigma_sched, sampler)
@@ -696,6 +714,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
696
  noise_pred = noise_pred_uncond + guidance_scale * (
697
  noise_pred_text - noise_pred_uncond
698
  )
 
 
 
 
699
  return noise_pred
700
 
701
  extra_args = self.get_sampler_extra_args_t2i(
 
39
  default = lambda val, d: val if exists(val) else d
40
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
 
42
+ # from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
43
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
44
+ """
45
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
46
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
47
+ """
48
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
49
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
50
+ # rescale the results from guidance (fixes overexposure)
51
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
52
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
53
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
54
+ return noise_cfg
55
+
56
 
57
  def get_attention_scores(attn, query, key, attention_mask=None):
58
 
 
542
  noise_pred = noise_pred_uncond + guidance_scale * (
543
  noise_pred_text - noise_pred_uncond
544
  )
545
+
546
+ if guidance_rescale > 0.0:
547
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
548
+
549
  return noise_pred
550
 
551
  sampler_args = self.get_sampler_extra_args_i2i(sigma_sched, sampler)
 
714
  noise_pred = noise_pred_uncond + guidance_scale * (
715
  noise_pred_text - noise_pred_uncond
716
  )
717
+
718
+ if guidance_rescale > 0.0:
719
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
720
+
721
  return noise_pred
722
 
723
  extra_args = self.get_sampler_extra_args_t2i(