Spaces:
Runtime error
Runtime error
| # Copyright 2024 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, List, Optional, Tuple | |
| import torch | |
| from ..models.attention_processor import Attention, MochiAttention | |
| from ..models.modeling_outputs import Transformer2DModelOutput | |
| from ..utils import logging | |
| from .hooks import HookRegistry, ModelHook | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| _FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser" | |
| _FASTER_CACHE_BLOCK_HOOK = "faster_cache_block" | |
| _ATTENTION_CLASSES = (Attention, MochiAttention) | |
| _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( | |
| "^blocks.*attn", | |
| "^transformer_blocks.*attn", | |
| "^single_transformer_blocks.*attn", | |
| ) | |
| _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",) | |
| _TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS | |
| _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = ( | |
| "hidden_states", | |
| "encoder_hidden_states", | |
| "timestep", | |
| "attention_mask", | |
| "encoder_attention_mask", | |
| ) | |
| class FasterCacheConfig: | |
| r""" | |
| Configuration for [FasterCache](https://huggingface.co/papers/2410.19355). | |
| Attributes: | |
| spatial_attention_block_skip_range (`int`, defaults to `2`): | |
| Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will | |
| be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention | |
| states again. | |
| temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`): | |
| Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will | |
| be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention | |
| states again. | |
| spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`): | |
| The timestep range within which the spatial attention computation can be skipped without a significant loss | |
| in quality. This is to be determined by the user based on the underlying model. The first value in the | |
| tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for | |
| denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at | |
| timestep 0). For the default values, this would mean that the spatial attention computation skipping will | |
| be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising | |
| process. | |
| temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`): | |
| The timestep range within which the temporal attention computation can be skipped without a significant | |
| loss in quality. This is to be determined by the user based on the underlying model. The first value in the | |
| tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for | |
| denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at | |
| timestep 0). | |
| low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`): | |
| The timestep range within which the low frequency weight scaling update is applied. The first value in the | |
| tuple is the lower bound and the second value is the upper bound of the timestep range. The callback | |
| function for the update is called only within this range. | |
| high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`): | |
| The timestep range within which the high frequency weight scaling update is applied. The first value in the | |
| tuple is the lower bound and the second value is the upper bound of the timestep range. The callback | |
| function for the update is called only within this range. | |
| alpha_low_frequency (`float`, defaults to `1.1`): | |
| The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from | |
| the conditional branch outputs. | |
| alpha_high_frequency (`float`, defaults to `1.1`): | |
| The weight to scale the high frequency updates by. This is used to approximate the unconditional branch | |
| from the conditional branch outputs. | |
| unconditional_batch_skip_range (`int`, defaults to `5`): | |
| Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch | |
| computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before | |
| computing the new unconditional branch states again. | |
| unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`): | |
| The timestep range within which the unconditional branch computation can be skipped without a significant | |
| loss in quality. This is to be determined by the user based on the underlying model. The first value in the | |
| tuple is the lower bound and the second value is the upper bound. | |
| spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`): | |
| The identifiers to match the spatial attention blocks in the model. If the name of the block contains any | |
| of these identifiers, FasterCache will be applied to that block. This can either be the full layer names, | |
| partial layer names, or regex patterns. Matching will always be done using a regex match. | |
| temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`): | |
| The identifiers to match the temporal attention blocks in the model. If the name of the block contains any | |
| of these identifiers, FasterCache will be applied to that block. This can either be the full layer names, | |
| partial layer names, or regex patterns. Matching will always be done using a regex match. | |
| attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`): | |
| The callback function to determine the weight to scale the attention outputs by. This function should take | |
| the attention module as input and return a float value. This is used to approximate the unconditional | |
| branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps. | |
| Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference | |
| progresses. Users are encouraged to experiment and provide custom weight schedules that take into account | |
| the number of inference steps and underlying model behaviour as denoising progresses. | |
| low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`): | |
| The callback function to determine the weight to scale the low frequency updates by. If not provided, the | |
| default weight is 1.1 for timesteps within the range specified (as described in the paper). | |
| high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`): | |
| The callback function to determine the weight to scale the high frequency updates by. If not provided, the | |
| default weight is 1.1 for timesteps within the range specified (as described in the paper). | |
| tensor_format (`str`, defaults to `"BCFHW"`): | |
| The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is | |
| used to split individual latent frames in order for low and high frequency components to be computed. | |
| is_guidance_distilled (`bool`, defaults to `False`): | |
| Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be | |
| applied at the denoiser-level to skip the unconditional branch computation (as there is none). | |
| _unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`): | |
| The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and | |
| conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will | |
| split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs | |
| names that contain the batchwise-concatenated unconditional and conditional inputs. | |
| """ | |
| # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable | |
| # after some testing. We default to 2 if these parameters are not provided. | |
| spatial_attention_block_skip_range: int = 2 | |
| temporal_attention_block_skip_range: Optional[int] = None | |
| spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) | |
| temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) | |
| # Indicator functions for low/high frequency as mentioned in Equation 11 of the paper | |
| low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901) | |
| high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301) | |
| # ⍺1 and ⍺2 as mentioned in Equation 11 of the paper | |
| alpha_low_frequency: float = 1.1 | |
| alpha_high_frequency: float = 1.1 | |
| # n as described in CFG-Cache explanation in the paper - dependent on the model | |
| unconditional_batch_skip_range: int = 5 | |
| unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641) | |
| spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS | |
| temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS | |
| attention_weight_callback: Callable[[torch.nn.Module], float] = None | |
| low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None | |
| high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None | |
| tensor_format: str = "BCFHW" | |
| is_guidance_distilled: bool = False | |
| current_timestep_callback: Callable[[], int] = None | |
| _unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS | |
| def __repr__(self) -> str: | |
| return ( | |
| f"FasterCacheConfig(\n" | |
| f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n" | |
| f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n" | |
| f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n" | |
| f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n" | |
| f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n" | |
| f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n" | |
| f" alpha_low_frequency={self.alpha_low_frequency},\n" | |
| f" alpha_high_frequency={self.alpha_high_frequency},\n" | |
| f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n" | |
| f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n" | |
| f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n" | |
| f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n" | |
| f" tensor_format={self.tensor_format},\n" | |
| f")" | |
| ) | |
| class FasterCacheDenoiserState: | |
| r""" | |
| State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module. | |
| """ | |
| def __init__(self) -> None: | |
| self.iteration: int = 0 | |
| self.low_frequency_delta: torch.Tensor = None | |
| self.high_frequency_delta: torch.Tensor = None | |
| def reset(self): | |
| self.iteration = 0 | |
| self.low_frequency_delta = None | |
| self.high_frequency_delta = None | |
| class FasterCacheBlockState: | |
| r""" | |
| State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is | |
| applied to will have an instance of this state. | |
| """ | |
| def __init__(self) -> None: | |
| self.iteration: int = 0 | |
| self.batch_size: int = None | |
| self.cache: Tuple[torch.Tensor, torch.Tensor] = None | |
| def reset(self): | |
| self.iteration = 0 | |
| self.batch_size = None | |
| self.cache = None | |
| class FasterCacheDenoiserHook(ModelHook): | |
| _is_stateful = True | |
| def __init__( | |
| self, | |
| unconditional_batch_skip_range: int, | |
| unconditional_batch_timestep_skip_range: Tuple[int, int], | |
| tensor_format: str, | |
| is_guidance_distilled: bool, | |
| uncond_cond_input_kwargs_identifiers: List[str], | |
| current_timestep_callback: Callable[[], int], | |
| low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor], | |
| high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor], | |
| ) -> None: | |
| super().__init__() | |
| self.unconditional_batch_skip_range = unconditional_batch_skip_range | |
| self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range | |
| # We can't easily detect what args are to be split in unconditional and conditional branches. We | |
| # can only do it for kwargs, hence they are the only ones we split. The args are passed as-is. | |
| # If a model is to be made compatible with FasterCache, the user must ensure that the inputs that | |
| # contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs. | |
| self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers | |
| self.tensor_format = tensor_format | |
| self.is_guidance_distilled = is_guidance_distilled | |
| self.current_timestep_callback = current_timestep_callback | |
| self.low_frequency_weight_callback = low_frequency_weight_callback | |
| self.high_frequency_weight_callback = high_frequency_weight_callback | |
| def initialize_hook(self, module): | |
| self.state = FasterCacheDenoiserState() | |
| return module | |
| def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| # Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs | |
| # followed by conditional inputs. | |
| _, cond = input.chunk(2, dim=0) | |
| return cond | |
| def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: | |
| # Split the unconditional and conditional inputs. We only want to infer the conditional branch if the | |
| # requirements for skipping the unconditional branch are met as described in the paper. | |
| # We skip the unconditional branch only if the following conditions are met: | |
| # 1. We have completed at least one iteration of the denoiser | |
| # 2. The current timestep is within the range specified by the user. This is the optimal timestep range | |
| # where approximating the unconditional branch from the computation of the conditional branch is possible | |
| # without a significant loss in quality. | |
| # 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that | |
| # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss. | |
| is_within_timestep_range = ( | |
| self.unconditional_batch_timestep_skip_range[0] | |
| < self.current_timestep_callback() | |
| < self.unconditional_batch_timestep_skip_range[1] | |
| ) | |
| should_skip_uncond = ( | |
| self.state.iteration > 0 | |
| and is_within_timestep_range | |
| and self.state.iteration % self.unconditional_batch_skip_range != 0 | |
| and not self.is_guidance_distilled | |
| ) | |
| if should_skip_uncond: | |
| is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys()) | |
| if is_any_kwarg_uncond: | |
| logger.debug("FasterCache - Skipping unconditional branch computation") | |
| args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args]) | |
| kwargs = { | |
| k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v) | |
| for k, v in kwargs.items() | |
| } | |
| output = self.fn_ref.original_forward(*args, **kwargs) | |
| if self.is_guidance_distilled: | |
| self.state.iteration += 1 | |
| return output | |
| if torch.is_tensor(output): | |
| hidden_states = output | |
| elif isinstance(output, (tuple, Transformer2DModelOutput)): | |
| hidden_states = output[0] | |
| batch_size = hidden_states.size(0) | |
| if should_skip_uncond: | |
| self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback( | |
| module | |
| ) | |
| self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback( | |
| module | |
| ) | |
| if self.tensor_format == "BCFHW": | |
| hidden_states = hidden_states.permute(0, 2, 1, 3, 4) | |
| if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW": | |
| hidden_states = hidden_states.flatten(0, 1) | |
| low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float()) | |
| # Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper | |
| low_freq_uncond = self.state.low_frequency_delta + low_freq_cond | |
| high_freq_uncond = self.state.high_frequency_delta + high_freq_cond | |
| uncond_freq = low_freq_uncond + high_freq_uncond | |
| uncond_states = torch.fft.ifftshift(uncond_freq) | |
| uncond_states = torch.fft.ifft2(uncond_states).real | |
| if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW": | |
| uncond_states = uncond_states.unflatten(0, (batch_size, -1)) | |
| hidden_states = hidden_states.unflatten(0, (batch_size, -1)) | |
| if self.tensor_format == "BCFHW": | |
| uncond_states = uncond_states.permute(0, 2, 1, 3, 4) | |
| hidden_states = hidden_states.permute(0, 2, 1, 3, 4) | |
| # Concatenate the approximated unconditional and predicted conditional branches | |
| uncond_states = uncond_states.to(hidden_states.dtype) | |
| hidden_states = torch.cat([uncond_states, hidden_states], dim=0) | |
| else: | |
| uncond_states, cond_states = hidden_states.chunk(2, dim=0) | |
| if self.tensor_format == "BCFHW": | |
| uncond_states = uncond_states.permute(0, 2, 1, 3, 4) | |
| cond_states = cond_states.permute(0, 2, 1, 3, 4) | |
| if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW": | |
| uncond_states = uncond_states.flatten(0, 1) | |
| cond_states = cond_states.flatten(0, 1) | |
| low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float()) | |
| low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float()) | |
| self.state.low_frequency_delta = low_freq_uncond - low_freq_cond | |
| self.state.high_frequency_delta = high_freq_uncond - high_freq_cond | |
| self.state.iteration += 1 | |
| if torch.is_tensor(output): | |
| output = hidden_states | |
| elif isinstance(output, tuple): | |
| output = (hidden_states, *output[1:]) | |
| else: | |
| output.sample = hidden_states | |
| return output | |
| def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: | |
| self.state.reset() | |
| return module | |
| class FasterCacheBlockHook(ModelHook): | |
| _is_stateful = True | |
| def __init__( | |
| self, | |
| block_skip_range: int, | |
| timestep_skip_range: Tuple[int, int], | |
| is_guidance_distilled: bool, | |
| weight_callback: Callable[[torch.nn.Module], float], | |
| current_timestep_callback: Callable[[], int], | |
| ) -> None: | |
| super().__init__() | |
| self.block_skip_range = block_skip_range | |
| self.timestep_skip_range = timestep_skip_range | |
| self.is_guidance_distilled = is_guidance_distilled | |
| self.weight_callback = weight_callback | |
| self.current_timestep_callback = current_timestep_callback | |
| def initialize_hook(self, module): | |
| self.state = FasterCacheBlockState() | |
| return module | |
| def _compute_approximated_attention_output( | |
| self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int | |
| ) -> torch.Tensor: | |
| if t_2_output.size(0) != batch_size: | |
| # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just | |
| # take the conditional branch outputs. | |
| assert t_2_output.size(0) == 2 * batch_size | |
| t_2_output = t_2_output[batch_size:] | |
| if t_output.size(0) != batch_size: | |
| # The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just | |
| # take the conditional branch outputs. | |
| assert t_output.size(0) == 2 * batch_size | |
| t_output = t_output[batch_size:] | |
| return t_output + (t_output - t_2_output) * weight | |
| def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: | |
| batch_size = [ | |
| *[arg.size(0) for arg in args if torch.is_tensor(arg)], | |
| *[v.size(0) for v in kwargs.values() if torch.is_tensor(v)], | |
| ][0] | |
| if self.state.batch_size is None: | |
| # Will be updated on first forward pass through the denoiser | |
| self.state.batch_size = batch_size | |
| # If we have to skip due to the skip conditions, then let's skip as expected. | |
| # But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This | |
| # is because the expected output shapes of attention layer will not match if we only return values from | |
| # the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true | |
| # unconditional-conditional batch size) is same as the current batch size, we don't perform the layer | |
| # skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns. | |
| is_within_timestep_range = ( | |
| self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1] | |
| ) | |
| if not is_within_timestep_range: | |
| should_skip_attention = False | |
| else: | |
| should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0 | |
| should_skip_attention = not should_compute_attention | |
| if should_skip_attention: | |
| should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size | |
| if should_skip_attention: | |
| logger.debug("FasterCache - Skipping attention and using approximation") | |
| if torch.is_tensor(self.state.cache[-1]): | |
| t_2_output, t_output = self.state.cache | |
| weight = self.weight_callback(module) | |
| output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size) | |
| else: | |
| # The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them. | |
| # Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity. | |
| # In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from | |
| # a forward pass of the block. We need to compute the approximated output for each of these tensors. | |
| # The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which | |
| # allows us to compute the approximated attention output for each tensor in the cache. | |
| output = () | |
| for t_2_output, t_output in zip(*self.state.cache): | |
| result = self._compute_approximated_attention_output( | |
| t_2_output, t_output, self.weight_callback(module), batch_size | |
| ) | |
| output += (result,) | |
| else: | |
| logger.debug("FasterCache - Computing attention") | |
| output = self.fn_ref.original_forward(*args, **kwargs) | |
| # Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return | |
| # a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle | |
| # both cases. | |
| if torch.is_tensor(output): | |
| cache_output = output | |
| if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size: | |
| # The output here can be both unconditional-conditional branch outputs or just conditional branch outputs. | |
| # This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs. | |
| cache_output = cache_output.chunk(2, dim=0)[1] | |
| else: | |
| # Cache all return values and perform the same operation as above | |
| cache_output = () | |
| for out in output: | |
| if not self.is_guidance_distilled and out.size(0) == self.state.batch_size: | |
| out = out.chunk(2, dim=0)[1] | |
| cache_output += (out,) | |
| if self.state.cache is None: | |
| self.state.cache = [cache_output, cache_output] | |
| else: | |
| self.state.cache = [self.state.cache[-1], cache_output] | |
| self.state.iteration += 1 | |
| return output | |
| def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: | |
| self.state.reset() | |
| return module | |
| def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None: | |
| r""" | |
| Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. | |
| Args: | |
| pipeline (`DiffusionPipeline`): | |
| The diffusion pipeline to apply FasterCache to. | |
| config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`): | |
| The configuration to use for FasterCache. | |
| Example: | |
| ```python | |
| >>> import torch | |
| >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache | |
| >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) | |
| >>> pipe.to("cuda") | |
| >>> config = FasterCacheConfig( | |
| ... spatial_attention_block_skip_range=2, | |
| ... spatial_attention_timestep_skip_range=(-1, 681), | |
| ... low_frequency_weight_update_timestep_range=(99, 641), | |
| ... high_frequency_weight_update_timestep_range=(-1, 301), | |
| ... spatial_attention_block_identifiers=["transformer_blocks"], | |
| ... attention_weight_callback=lambda _: 0.3, | |
| ... tensor_format="BFCHW", | |
| ... ) | |
| >>> apply_faster_cache(pipe.transformer, config) | |
| ``` | |
| """ | |
| logger.warning( | |
| "FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. " | |
| "The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at " | |
| "https://github.com/huggingface/diffusers/issues." | |
| ) | |
| if config.attention_weight_callback is None: | |
| # If the user has not provided a weight callback, we default to 0.5 for all timesteps. | |
| # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but | |
| # this depends from model-to-model. It is required by the user to provide a weight callback if they want to | |
| # use a different weight function. Defaulting to 0.5 works well in practice for most cases. | |
| logger.warning( | |
| "No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps." | |
| ) | |
| config.attention_weight_callback = lambda _: 0.5 | |
| if config.low_frequency_weight_callback is None: | |
| logger.debug( | |
| "Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." | |
| ) | |
| def low_frequency_weight_callback(module: torch.nn.Module) -> float: | |
| is_within_range = ( | |
| config.low_frequency_weight_update_timestep_range[0] | |
| < config.current_timestep_callback() | |
| < config.low_frequency_weight_update_timestep_range[1] | |
| ) | |
| return config.alpha_low_frequency if is_within_range else 1.0 | |
| config.low_frequency_weight_callback = low_frequency_weight_callback | |
| if config.high_frequency_weight_callback is None: | |
| logger.debug( | |
| "High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." | |
| ) | |
| def high_frequency_weight_callback(module: torch.nn.Module) -> float: | |
| is_within_range = ( | |
| config.high_frequency_weight_update_timestep_range[0] | |
| < config.current_timestep_callback() | |
| < config.high_frequency_weight_update_timestep_range[1] | |
| ) | |
| return config.alpha_high_frequency if is_within_range else 1.0 | |
| config.high_frequency_weight_callback = high_frequency_weight_callback | |
| supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video | |
| if config.tensor_format not in supported_tensor_formats: | |
| raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.") | |
| _apply_faster_cache_on_denoiser(module, config) | |
| for name, submodule in module.named_modules(): | |
| if not isinstance(submodule, _ATTENTION_CLASSES): | |
| continue | |
| if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS): | |
| _apply_faster_cache_on_attention_class(name, submodule, config) | |
| def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None: | |
| hook = FasterCacheDenoiserHook( | |
| config.unconditional_batch_skip_range, | |
| config.unconditional_batch_timestep_skip_range, | |
| config.tensor_format, | |
| config.is_guidance_distilled, | |
| config._unconditional_conditional_input_kwargs_identifiers, | |
| config.current_timestep_callback, | |
| config.low_frequency_weight_callback, | |
| config.high_frequency_weight_callback, | |
| ) | |
| registry = HookRegistry.check_if_exists_or_initialize(module) | |
| registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK) | |
| def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None: | |
| is_spatial_self_attention = ( | |
| any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) | |
| and config.spatial_attention_block_skip_range is not None | |
| and not getattr(module, "is_cross_attention", False) | |
| ) | |
| is_temporal_self_attention = ( | |
| any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers) | |
| and config.temporal_attention_block_skip_range is not None | |
| and not module.is_cross_attention | |
| ) | |
| block_skip_range, timestep_skip_range, block_type = None, None, None | |
| if is_spatial_self_attention: | |
| block_skip_range = config.spatial_attention_block_skip_range | |
| timestep_skip_range = config.spatial_attention_timestep_skip_range | |
| block_type = "spatial" | |
| elif is_temporal_self_attention: | |
| block_skip_range = config.temporal_attention_block_skip_range | |
| timestep_skip_range = config.temporal_attention_timestep_skip_range | |
| block_type = "temporal" | |
| if block_skip_range is None or timestep_skip_range is None: | |
| logger.debug( | |
| f'Unable to apply FasterCache to the selected layer: "{name}" because it does ' | |
| f"not match any of the required criteria for spatial or temporal attention layers. Note, " | |
| f"however, that this layer may still be valid for applying PAB. Please specify the correct " | |
| f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` " | |
| f"function to apply FasterCache to this layer." | |
| ) | |
| return | |
| logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}") | |
| hook = FasterCacheBlockHook( | |
| block_skip_range, | |
| timestep_skip_range, | |
| config.is_guidance_distilled, | |
| config.attention_weight_callback, | |
| config.current_timestep_callback, | |
| ) | |
| registry = HookRegistry.check_if_exists_or_initialize(module) | |
| registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK) | |
| # Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39 | |
| def _split_low_high_freq(x): | |
| fft = torch.fft.fft2(x) | |
| fft_shifted = torch.fft.fftshift(fft) | |
| height, width = x.shape[-2:] | |
| radius = min(height, width) // 5 | |
| y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width)) | |
| center_x, center_y = width // 2, height // 2 | |
| mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2 | |
| low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device) | |
| high_freq_mask = ~low_freq_mask | |
| low_freq_fft = fft_shifted * low_freq_mask | |
| high_freq_fft = fft_shifted * high_freq_mask | |
| return low_freq_fft, high_freq_fft | |