Spaces:
Running
on
Zero
Running
on
Zero
File size: 15,760 Bytes
22a452a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple, Union
import torch
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from .hooks import HookRegistry, ModelHook
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
@dataclass
class PyramidAttentionBroadcastConfig:
r"""
Configuration for Pyramid Attention Broadcast.
Args:
spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific spatial attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
old attention states will be re-used) before computing the new attention states again.
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific temporal attention broadcast is skipped before computing the attention
states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
(i.e., old attention states will be re-used) before computing the new attention states again.
cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific cross-attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
old attention states will be re-used) before computing the new attention states again.
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the spatial attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the temporal attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the cross-attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
"""
spatial_attention_block_skip_range: Optional[int] = None
temporal_attention_block_skip_range: Optional[int] = None
cross_attention_block_skip_range: Optional[int] = None
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
current_timestep_callback: Callable[[], int] = None
# TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase
# so not added for now)
def __repr__(self) -> str:
return (
f"PyramidAttentionBroadcastConfig(\n"
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n"
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
f" cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n"
f" current_timestep_callback={self.current_timestep_callback}\n"
")"
)
class PyramidAttentionBroadcastState:
r"""
State for Pyramid Attention Broadcast.
Attributes:
iteration (`int`):
The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is
called before starting a new inference forward pass for PAB to work correctly.
cache (`Any`):
The cached output from the previous forward pass. This is used to re-use the attention states when the
attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module.
"""
def __init__(self) -> None:
self.iteration = 0
self.cache = None
def reset(self):
self.iteration = 0
self.cache = None
def __repr__(self):
cache_repr = ""
if self.cache is None:
cache_repr = "None"
else:
cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})"
return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})"
class PyramidAttentionBroadcastHook(ModelHook):
r"""A hook that applies Pyramid Attention Broadcast to a given module."""
_is_stateful = True
def __init__(
self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
) -> None:
super().__init__()
self.timestep_skip_range = timestep_skip_range
self.block_skip_range = block_skip_range
self.current_timestep_callback = current_timestep_callback
def initialize_hook(self, module):
self.state = PyramidAttentionBroadcastState()
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
is_within_timestep_range = (
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
)
should_compute_attention = (
self.state.cache is None
or self.state.iteration == 0
or not is_within_timestep_range
or self.state.iteration % self.block_skip_range == 0
)
if should_compute_attention:
output = self.fn_ref.original_forward(*args, **kwargs)
else:
output = self.state.cache
self.state.cache = output
self.state.iteration += 1
return output
def reset_state(self, module: torch.nn.Module) -> None:
self.state.reset()
return module
def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
r"""
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
PAB is an attention approximation method that leverages the similarity in attention states between timesteps to
reduce the computational cost of attention computation. The key takeaway from the paper is that the attention
similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and
spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently
than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.
Args:
module (`torch.nn.Module`):
The module to apply Pyramid Attention Broadcast to.
config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`):
The configuration to use for Pyramid Attention Broadcast.
Example:
```python
>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
>>> from diffusers.utils import export_to_video
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = PyramidAttentionBroadcastConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(100, 800),
... current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> apply_pyramid_attention_broadcast(pipe.transformer, config)
```
"""
if config.current_timestep_callback is None:
raise ValueError(
"The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast."
)
if (
config.spatial_attention_block_skip_range is None
and config.temporal_attention_block_skip_range is None
and config.cross_attention_block_skip_range is None
):
logger.warning(
"Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` "
"or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. "
"To avoid this warning, please set one of the above parameters."
)
config.spatial_attention_block_skip_range = 2
for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
continue
_apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config)
def _apply_pyramid_attention_broadcast_on_attention_class(
name: str, module: Attention, config: PyramidAttentionBroadcastConfig
) -> bool:
is_spatial_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
and config.spatial_attention_block_skip_range is not None
and not getattr(module, "is_cross_attention", False)
)
is_temporal_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
and config.temporal_attention_block_skip_range is not None
and not getattr(module, "is_cross_attention", False)
)
is_cross_attention = (
any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers)
and config.cross_attention_block_skip_range is not None
and getattr(module, "is_cross_attention", False)
)
block_skip_range, timestep_skip_range, block_type = None, None, None
if is_spatial_self_attention:
block_skip_range = config.spatial_attention_block_skip_range
timestep_skip_range = config.spatial_attention_timestep_skip_range
block_type = "spatial"
elif is_temporal_self_attention:
block_skip_range = config.temporal_attention_block_skip_range
timestep_skip_range = config.temporal_attention_timestep_skip_range
block_type = "temporal"
elif is_cross_attention:
block_skip_range = config.cross_attention_block_skip_range
timestep_skip_range = config.cross_attention_timestep_skip_range
block_type = "cross"
if block_skip_range is None or timestep_skip_range is None:
logger.info(
f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does '
f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, "
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
f"block identifiers in the configuration."
)
return False
logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}")
_apply_pyramid_attention_broadcast_hook(
module, timestep_skip_range, block_skip_range, config.current_timestep_callback
)
return True
def _apply_pyramid_attention_broadcast_hook(
module: Union[Attention, MochiAttention],
timestep_skip_range: Tuple[int, int],
block_skip_range: int,
current_timestep_callback: Callable[[], int],
):
r"""
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module.
Args:
module (`torch.nn.Module`):
The module to apply Pyramid Attention Broadcast to.
timestep_skip_range (`Tuple[int, int]`):
The range of timesteps to skip in the attention layer. The attention computations will be conditionally
skipped if the current timestep is within the specified range.
block_skip_range (`int`):
The number of times a specific attention broadcast is skipped before computing the attention states to
re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old
attention states will be re-used) before computing the new attention states again.
current_timestep_callback (`Callable[[], int]`):
A callback function that returns the current inference timestep.
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)
|