Using callback
Most 🤗 Diffusers pipeline now accept a callback_on_step_end
argument that allows you to change the default behavior of denoising loop with custom defined functions. Here is an example of a callback function we can write to disable classifier free guidance after 40% of inference steps to save compute with minimum tradeoff in performance.
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
# adjust the batch_size of prompt_embeds according to guidance_scale
if step_index == int(pipe.num_timestep * 0.4):
prompt_embeds = callback_kwargs["prompt_embeds"]
prompt_embeds =prompt_embeds.chunk(2)[-1]
# update guidance_scale and prompt_embeds
pipe._guidance_scale = 0.0
callback_kwargs["prompt_embeds"] = prompt_embeds
return callback_kwargs
Your callback function has below arguments:
pipe
is the pipeline instance, which provides access to useful properties such asnum_timestep
andguidance_scale
. You can modify these properties by updating the underlying attributes. In this example, we disable CFG by settingpipe._guidance_scale
to be0
.step_index
andtimestep
tell you where you are in the denoising loop. In our example, we usestep_index
to decide when to turn off CFG.callback_kwargs
is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in thecallback_on_step_end_tensor_inputs
argument passed to the pipeline’s__call__
method. Different pipelines may use different sets of variables so please check the pipeline class’s_callback_tensor_inputs
attribute for the list of variables that you can modify. Common variables includelatents
andprompt_embeds
. In our example, we need to adjust the batch size ofprompt_embeds
after settingguidance_scale
to be0
in order for it to work properly.
You can pass the callback function as callback_on_step_end
argument to the pipeline along with callback_on_step_end_tensor_inputs
.
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
generator = torch.Generator(device="cuda").manual_seed(1)
out= pipe(prompt, generator=generator, callback_on_step_end = callback_custom_cfg, callback_on_step_end_tensor_inputs=['prompt_embeds'])
out.images[0].save("out_custom_cfg.png")
Your callback function will be executed at the end of each denoising step and modify pipeline attributes and tensor variables for the next denoising step. We successfully added the “dynamic CFG” feature to the stable diffusion pipeline without having to modify the code at all.
Currently we only support callback_on_step_end
. If you have a solid use case and require a callback function with a different execution point, please open an feature request so we can add it!