Spaces:
Running
on
T4
Running
on
T4
Update modules/model.py
Browse files- modules/model.py +22 -0
modules/model.py
CHANGED
@@ -39,6 +39,20 @@ exists = lambda val: val is not None
|
|
39 |
default = lambda val, d: val if exists(val) else d
|
40 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
def get_attention_scores(attn, query, key, attention_mask=None):
|
44 |
|
@@ -528,6 +542,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|
528 |
noise_pred = noise_pred_uncond + guidance_scale * (
|
529 |
noise_pred_text - noise_pred_uncond
|
530 |
)
|
|
|
|
|
|
|
|
|
531 |
return noise_pred
|
532 |
|
533 |
sampler_args = self.get_sampler_extra_args_i2i(sigma_sched, sampler)
|
@@ -696,6 +714,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|
696 |
noise_pred = noise_pred_uncond + guidance_scale * (
|
697 |
noise_pred_text - noise_pred_uncond
|
698 |
)
|
|
|
|
|
|
|
|
|
699 |
return noise_pred
|
700 |
|
701 |
extra_args = self.get_sampler_extra_args_t2i(
|
|
|
39 |
default = lambda val, d: val if exists(val) else d
|
40 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41 |
|
42 |
+
# from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
43 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
44 |
+
"""
|
45 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
46 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
47 |
+
"""
|
48 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
49 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
50 |
+
# rescale the results from guidance (fixes overexposure)
|
51 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
52 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
53 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
54 |
+
return noise_cfg
|
55 |
+
|
56 |
|
57 |
def get_attention_scores(attn, query, key, attention_mask=None):
|
58 |
|
|
|
542 |
noise_pred = noise_pred_uncond + guidance_scale * (
|
543 |
noise_pred_text - noise_pred_uncond
|
544 |
)
|
545 |
+
|
546 |
+
if guidance_rescale > 0.0:
|
547 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
548 |
+
|
549 |
return noise_pred
|
550 |
|
551 |
sampler_args = self.get_sampler_extra_args_i2i(sigma_sched, sampler)
|
|
|
714 |
noise_pred = noise_pred_uncond + guidance_scale * (
|
715 |
noise_pred_text - noise_pred_uncond
|
716 |
)
|
717 |
+
|
718 |
+
if guidance_rescale > 0.0:
|
719 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
720 |
+
|
721 |
return noise_pred
|
722 |
|
723 |
extra_args = self.get_sampler_extra_args_t2i(
|