|
|
|
|
|
|
|
|
|
@@ -46,6 +46,18 @@ class CFGDenoiserParams: |
|
"""Total number of sampling steps planned""" |
|
|
|
|
|
+class CFGDenoisedParams: |
|
+ def __init__(self, x, sampling_step, total_sampling_steps): |
|
+ self.x = x |
|
+ """Latent image representation in the process of being denoised""" |
|
+ |
|
+ self.sampling_step = sampling_step |
|
+ """Current Sampling step number""" |
|
+ |
|
+ self.total_sampling_steps = total_sampling_steps |
|
+ """Total number of sampling steps planned""" |
|
+ |
|
+ |
|
class UiTrainTabParams: |
|
def __init__(self, txt2img_preview_params): |
|
self.txt2img_preview_params = txt2img_preview_params |
|
@@ -68,6 +80,7 @@ callback_map = dict( |
|
callbacks_before_image_saved=[], |
|
callbacks_image_saved=[], |
|
callbacks_cfg_denoiser=[], |
|
+ callbacks_cfg_denoised=[], |
|
callbacks_before_component=[], |
|
callbacks_after_component=[], |
|
callbacks_image_grid=[], |
|
@@ -150,6 +163,14 @@ def cfg_denoiser_callback(params: CFGDenoiserParams): |
|
report_exception(c, 'cfg_denoiser_callback') |
|
|
|
|
|
+def cfg_denoised_callback(params: CFGDenoisedParams): |
|
+ for c in callback_map['callbacks_cfg_denoised']: |
|
+ try: |
|
+ c.callback(params) |
|
+ except Exception: |
|
+ report_exception(c, 'cfg_denoised_callback') |
|
+ |
|
+ |
|
def before_component_callback(component, **kwargs): |
|
for c in callback_map['callbacks_before_component']: |
|
try: |
|
@@ -283,6 +304,14 @@ def on_cfg_denoiser(callback): |
|
add_callback(callback_map['callbacks_cfg_denoiser'], callback) |
|
|
|
|
|
+def on_cfg_denoised(callback): |
|
+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. |
|
+ The callback is called with one argument: |
|
+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. |
|
+ """ |
|
+ add_callback(callback_map['callbacks_cfg_denoised'], callback) |
|
+ |
|
+ |
|
def on_before_component(callback): |
|
"""register a function to be called before a component is created. |
|
The callback is called with arguments: |
|
|
|
|
|
|
|
|
|
@@ -8,6 +8,7 @@ from modules import prompt_parser, devices, sd_samplers_common |
|
from modules.shared import opts, state |
|
import modules.shared as shared |
|
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback |
|
+from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback |
|
|
|
samplers_k_diffusion = [ |
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), |
|
@@ -136,6 +137,9 @@ class CFGDenoiser(torch.nn.Module): |
|
|
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) |
|
|
|
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) |
|
+ cfg_denoised_callback(denoised_params) |
|
+ |
|
devices.test_for_nans(x_out, "unet") |
|
|
|
if opts.live_preview_content == "Prompt": |
|
|