| | import torch |
| | from modules import sd_samplers_kdiffusion, sd_samplers_common |
| |
|
| | from ldm_patched.k_diffusion import sampling as k_diffusion_sampling |
| | from ldm_patched.modules.samplers import calculate_sigmas_scheduler |
| | from modules import shared |
| |
|
| |
|
| | ADAPTIVE_SOLVERS = {"dopri8", "dopri5", "bosh3", "fehlberg2", "adaptive_heun"} |
| | FIXED_SOLVERS = {"euler", "midpoint", "rk4", "heun3", "explicit_adams", "implicit_adams"} |
| | ALL_SOLVERS = list(ADAPTIVE_SOLVERS | FIXED_SOLVERS) |
| | ALL_SOLVERS.sort() |
| |
|
| | class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler): |
| | def __init__(self, sd_model, sampler_name, solver=None, rtol=None, atol=None): |
| | self.sampler_name = sampler_name |
| | self.scheduler_name = None |
| | self.unet = sd_model.forge_objects.unet |
| | self.model = sd_model |
| | self.solver = solver |
| | self.rtol = rtol |
| | self.atol = atol |
| | |
| | sampler_functions = { |
| | 'euler_comfy': k_diffusion_sampling.sample_euler, |
| | 'euler_ancestral_comfy': k_diffusion_sampling.sample_euler_ancestral, |
| | 'heun_comfy': k_diffusion_sampling.sample_heun, |
| | 'dpmpp_2s_ancestral_comfy': k_diffusion_sampling.sample_dpmpp_2s_ancestral, |
| | 'dpmpp_sde_comfy': k_diffusion_sampling.sample_dpmpp_sde, |
| | 'dpmpp_2m_comfy': k_diffusion_sampling.sample_dpmpp_2m, |
| | 'dpmpp_2m_sde_comfy': k_diffusion_sampling.sample_dpmpp_2m_sde, |
| | 'dpmpp_3m_sde_comfy': k_diffusion_sampling.sample_dpmpp_3m_sde, |
| | 'euler_ancestral_turbo': k_diffusion_sampling.sample_euler_ancestral, |
| | 'dpmpp_2m_turbo': k_diffusion_sampling.sample_dpmpp_2m, |
| | 'dpmpp_2m_sde_turbo': k_diffusion_sampling.sample_dpmpp_2m_sde, |
| | 'ddpm': k_diffusion_sampling.sample_ddpm, |
| | 'heunpp2': k_diffusion_sampling.sample_heunpp2, |
| | 'ipndm': k_diffusion_sampling.sample_ipndm, |
| | 'ipndm_v': k_diffusion_sampling.sample_ipndm_v, |
| | 'deis': k_diffusion_sampling.sample_deis, |
| | 'euler_cfg_pp': k_diffusion_sampling.sample_euler_cfg_pp, |
| | 'euler_ancestral_cfg_pp': k_diffusion_sampling.sample_euler_ancestral_cfg_pp, |
| | 'dpmpp_2s_ancestral_cfg_pp': k_diffusion_sampling.sample_dpmpp_2s_ancestral_cfg_pp, |
| | 'dpmpp_2s_ancestral_cfg_pp_dyn': k_diffusion_sampling.sample_dpmpp_2s_ancestral_cfg_pp_dyn, |
| | 'dpmpp_2s_ancestral_cfg_pp_intern': k_diffusion_sampling.sample_dpmpp_2s_ancestral_cfg_pp_intern, |
| | 'dpmpp_sde_cfg_pp': k_diffusion_sampling.sample_dpmpp_sde_cfg_pp, |
| | 'dpmpp_2m_cfg_pp': k_diffusion_sampling.sample_dpmpp_2m_cfg_pp, |
| | 'ode_bosh3': self.sample_ode_bosh3, |
| | 'ode_fehlberg2': self.sample_ode_fehlberg2, |
| | 'ode_adaptive_heun': self.sample_ode_adaptive_heun, |
| | 'ode_dopri5': self.sample_ode_dopri5, |
| | 'ode_custom':self.sample_ode_custom, |
| | } |
| | |
| | sampler_function = sampler_functions.get(sampler_name) |
| | if sampler_function is None: |
| | raise ValueError(f"Unknown sampler: {sampler_name}") |
| | |
| | super().__init__(sampler_function, sd_model, None) |
| |
|
| | def sample_func(self, model, x, sigmas, extra_args=None, callback=None, disable=None): |
| | if self.sampler_name == 'ode_bosh3': |
| | return self.sample_ode_bosh3(model, x, sigmas, extra_args, callback, disable) |
| | elif self.sampler_name == 'ode_fehlberg2': |
| | return self.sample_ode_fehlberg2(model, x, sigmas, extra_args, callback, disable) |
| | elif self.sampler_name == 'ode_adaptive_heun': |
| | return self.sample_ode_adaptive_heun(model, x, sigmas, extra_args, callback, disable) |
| | elif self.sampler_name == 'ode_dopri5': |
| | return self.sample_ode_dopri5(model, x, sigmas, extra_args, callback, disable) |
| | elif self.sampler_name == 'ode_custom': |
| | return self.sample_ode_custom(model, x, sigmas, extra_args, callback, disable) |
| | else: |
| | |
| | return super().sample_func(model, x, sigmas, extra_args, callback, disable) |
| |
|
| | def sample_ode_bosh3(self, model, x, sigmas, extra_args=None, callback=None, disable=None): |
| | return k_diffusion_sampling.sample_ode(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| | solver="bosh3", |
| | rtol=10**shared.opts.ode_bosh3_rtol, |
| | atol=10**shared.opts.ode_bosh3_atol, |
| | max_steps=shared.opts.ode_bosh3_max_steps) |
| |
|
| | def sample_ode_fehlberg2(self, model, x, sigmas, extra_args=None, callback=None, disable=None): |
| | return k_diffusion_sampling.sample_ode(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| | solver="fehlberg2", |
| | rtol=10**shared.opts.ode_fehlberg2_rtol, |
| | atol=10**shared.opts.ode_fehlberg2_atol, |
| | max_steps=shared.opts.ode_fehlberg2_max_steps) |
| |
|
| | def sample_ode_adaptive_heun(self, model, x, sigmas, extra_args=None, callback=None, disable=None): |
| | return k_diffusion_sampling.sample_ode(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| | solver="adaptive_heun", |
| | rtol=10**shared.opts.ode_adaptive_heun_rtol, |
| | atol=10**shared.opts.ode_adaptive_heun_atol, |
| | max_steps=shared.opts.ode_adaptive_heun_max_steps) |
| |
|
| | def sample_ode_dopri5(self, model, x, sigmas, extra_args=None, callback=None, disable=None): |
| | return k_diffusion_sampling.sample_ode(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| | solver="dopri5", |
| | rtol=10**shared.opts.ode_dopri5_rtol, |
| | atol=10**shared.opts.ode_dopri5_atol, |
| | max_steps=shared.opts.ode_dopri5_max_steps) |
| |
|
| | def sample_ode_custom(self, model, x, sigmas, extra_args=None, callback=None, disable=None): |
| | solver = shared.opts.ode_custom_solver |
| | rtol = 10**shared.opts.ode_custom_rtol if solver in ADAPTIVE_SOLVERS else None |
| | atol = 10**shared.opts.ode_custom_atol if solver in ADAPTIVE_SOLVERS else None |
| | max_steps = shared.opts.ode_custom_max_steps |
| | |
| | return k_diffusion_sampling.sample_ode(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| | solver=solver, rtol=rtol, atol=atol, max_steps=max_steps) |
| | |
| | def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): |
| | self.scheduler_name = p.scheduler |
| | return super().sample(p, x, conditioning, unconditional_conditioning, steps, image_conditioning) |
| |
|
| | def get_sigmas(self, p, steps): |
| | |
| | if self.scheduler_name is None: |
| | self.scheduler_name = 'React Cosinusoidal DynSF' |
| |
|
| | forge_schedulers = { |
| | "Normal": "normal", |
| | "Karras": "karras", |
| | "Exponential": "exponential", |
| | "SGM Uniform": "sgm_uniform", |
| | "Simple": "simple", |
| | "DDIM": "ddim_uniform", |
| | "Align Your Steps": "ays", |
| | "Align Your Steps GITS": "ays_gits", |
| | "Align Your Steps 11": "ays_11steps", |
| | "Align Your Steps 32": "ays_32steps", |
| | "KL Optimal": "kl_optimal", |
| | "Beta": "beta", |
| | "Sinusoidal SF":"sinusoidal_sf", |
| | "Invcosinusoidal SF":"invcosinusoidal_sf", |
| | "React Cosinusoidal DynSF":"react_cosinusoidal_dynsf" |
| | } |
| | |
| | if self.scheduler_name in forge_schedulers: |
| | matched_scheduler = forge_schedulers[self.scheduler_name] |
| | else: |
| | |
| | matched_scheduler = 'React Cosinusoidal DynSF' |
| |
|
| | if self.sampler_name.endswith('_turbo'): |
| | |
| | timesteps = torch.flip(torch.arange(1, steps + 1) * float(1000.0 / steps) - 1, (0,)).round().long().clip(0, 999) |
| | sigmas = self.unet.model.model_sampling.sigma(timesteps) |
| | sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) |
| | else: |
| | sigmas = calculate_sigmas_scheduler(self.unet.model, matched_scheduler, steps, is_sdxl=getattr(self.model, "is_sdxl", False)) |
| | |
| | return sigmas.to(self.unet.load_device) |
| |
|
| |
|
| | def build_constructor(sampler_name): |
| | def constructor(model): |
| | return AlterSampler(model, sampler_name) |
| | return constructor |
| |
|
| | samplers_data_alter = [ |
| | sd_samplers_common.SamplerData('Euler Comfy', build_constructor(sampler_name='euler_comfy'), ['euler_comfy'], {}), |
| | sd_samplers_common.SamplerData('Euler A Comfy', build_constructor(sampler_name='euler_ancestral_comfy'), ['euler_ancestral_comfy'], {}), |
| | sd_samplers_common.SamplerData('Heun Comfy', build_constructor(sampler_name='heun_comfy'), ['heun_comfy'], {}), |
| | sd_samplers_common.SamplerData('DPM++ 2S Ancestral Comfy', build_constructor(sampler_name='dpmpp_2s_ancestral_comfy'), ['dpmpp_2s_ancestral_comfy'], {}), |
| | sd_samplers_common.SamplerData('DPM++ SDE Comfy', build_constructor(sampler_name='dpmpp_sde_comfy'), ['dpmpp_sde_comfy'], {}), |
| | sd_samplers_common.SamplerData('DPM++ 2M Comfy', build_constructor(sampler_name='dpmpp_2m_comfy'), ['dpmpp_2m_comfy'], {}), |
| | sd_samplers_common.SamplerData('DPM++ 2M SDE Comfy', build_constructor(sampler_name='dpmpp_2m_sde_comfy'), ['dpmpp_2m_sde_comfy'], {}), |
| | sd_samplers_common.SamplerData('DPM++ 3M SDE Comfy', build_constructor(sampler_name='dpmpp_3m_sde_comfy'), ['dpmpp_3m_sde_comfy'], {}), |
| | sd_samplers_common.SamplerData('Euler A Turbo', build_constructor(sampler_name='euler_ancestral_turbo'), ['euler_ancestral_turbo'], {}), |
| | sd_samplers_common.SamplerData('DPM++ 2M Turbo', build_constructor(sampler_name='dpmpp_2m_turbo'), ['dpmpp_2m_turbo'], {}), |
| | sd_samplers_common.SamplerData('DPM++ 2M SDE Turbo', build_constructor(sampler_name='dpmpp_2m_sde_turbo'), ['dpmpp_2m_sde_turbo'], {}), |
| | sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}), |
| | sd_samplers_common.SamplerData('HeunPP2', build_constructor(sampler_name='heunpp2'), ['heunpp2'], {}), |
| | sd_samplers_common.SamplerData('IPNDM', build_constructor(sampler_name='ipndm'), ['ipndm'], {}), |
| | sd_samplers_common.SamplerData('IPNDM_V', build_constructor(sampler_name='ipndm_v'), ['ipndm_v'], {}), |
| | sd_samplers_common.SamplerData('DEIS', build_constructor(sampler_name='deis'), ['deis'], {}), |
| | sd_samplers_common.SamplerData('Euler CFG++', build_constructor(sampler_name='euler_cfg_pp'), ['euler_cfg_pp'], {}), |
| | sd_samplers_common.SamplerData('Euler Ancestral CFG++', build_constructor(sampler_name='euler_ancestral_cfg_pp'), ['euler_ancestral_cfg_pp'], {}), |
| | sd_samplers_common.SamplerData('DPM++ 2S Ancestral CFG++', build_constructor(sampler_name='dpmpp_2s_ancestral_cfg_pp'), ['dpmpp_2s_ancestral_cfg_pp'], {}), |
| | sd_samplers_common.SamplerData('DPM++ 2S Ancestral CFG++ Dyn', build_constructor(sampler_name='dpmpp_2s_ancestral_cfg_pp_dyn'), ['dpmpp_2s_ancestral_cfg_pp_dyn'], {}), |
| | sd_samplers_common.SamplerData('DPM++ 2S Ancestral CFG++ Intern', build_constructor(sampler_name='dpmpp_2s_ancestral_cfg_pp_intern'), ['dpmpp_2s_ancestral_cfg_pp_intern'], {}), |
| | sd_samplers_common.SamplerData('DPM++ SDE CFG++', build_constructor(sampler_name='dpmpp_sde_cfg_pp'), ['dpmpp_sde_cfg_pp'], {}), |
| | sd_samplers_common.SamplerData('DPM++ 2M CFG++', build_constructor(sampler_name='dpmpp_2m_cfg_pp'), ['dpmpp_2m_cfg_pp'], {}), |
| | sd_samplers_common.SamplerData('ODE (Bosh3)', build_constructor(sampler_name='ode_bosh3'), ['ode_bosh3'], {}), |
| | sd_samplers_common.SamplerData('ODE (Fehlberg2)', build_constructor(sampler_name='ode_fehlberg2'), ['ode_fehlberg2'], {}), |
| | sd_samplers_common.SamplerData('ODE (Adaptive Heun)', build_constructor(sampler_name='ode_adaptive_heun'), ['ode_adaptive_heun'], {}), |
| | sd_samplers_common.SamplerData('ODE (Dopri5)', build_constructor(sampler_name='ode_dopri5'), ['ode_dopri5'], {}), |
| | sd_samplers_common.SamplerData('ODE Custom', build_constructor(sampler_name='ode_custom'), ['ode_custom'], {}), |
| | ] |
| |
|