Spaces:
Runtime error
Runtime error
| # Copyright 2024 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import re | |
| from typing import Dict, List, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from ...models.attention_processor import ( | |
| Attention, | |
| AttentionProcessor, | |
| PAGCFGIdentitySelfAttnProcessor2_0, | |
| PAGIdentitySelfAttnProcessor2_0, | |
| ) | |
| from ...utils import logging | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class PAGMixin: | |
| r"""Mixin class for [Pertubed Attention Guidance](https://huggingface.co/papers/2403.17377v1).""" | |
| def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): | |
| r""" | |
| Set the attention processor for the PAG layers. | |
| """ | |
| pag_attn_processors = self._pag_attn_processors | |
| if pag_attn_processors is None: | |
| raise ValueError( | |
| "No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters." | |
| ) | |
| pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] | |
| if hasattr(self, "unet"): | |
| model: nn.Module = self.unet | |
| else: | |
| model: nn.Module = self.transformer | |
| def is_self_attn(module: nn.Module) -> bool: | |
| r""" | |
| Check if the module is self-attention module based on its name. | |
| """ | |
| return isinstance(module, Attention) and not module.is_cross_attention | |
| def is_fake_integral_match(layer_id, name): | |
| layer_id = layer_id.split(".")[-1] | |
| name = name.split(".")[-1] | |
| return layer_id.isnumeric() and name.isnumeric() and layer_id == name | |
| for layer_id in pag_applied_layers: | |
| # for each PAG layer input, we find corresponding self-attention layers in the unet model | |
| target_modules = [] | |
| for name, module in model.named_modules(): | |
| # Identify the following simple cases: | |
| # (1) Self Attention layer existing | |
| # (2) Whether the module name matches pag layer id even partially | |
| # (3) Make sure it's not a fake integral match if the layer_id ends with a number | |
| # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" | |
| if ( | |
| is_self_attn(module) | |
| and re.search(layer_id, name) is not None | |
| and not is_fake_integral_match(layer_id, name) | |
| ): | |
| logger.debug(f"Applying PAG to layer: {name}") | |
| target_modules.append(module) | |
| if len(target_modules) == 0: | |
| raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") | |
| for module in target_modules: | |
| module.processor = pag_attn_proc | |
| def _get_pag_scale(self, t): | |
| r""" | |
| Get the scale factor for the perturbed attention guidance at timestep `t`. | |
| """ | |
| if self.do_pag_adaptive_scaling: | |
| signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t) | |
| if signal_scale < 0: | |
| signal_scale = 0 | |
| return signal_scale | |
| else: | |
| return self.pag_scale | |
| def _apply_perturbed_attention_guidance( | |
| self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False | |
| ): | |
| r""" | |
| Apply perturbed attention guidance to the noise prediction. | |
| Args: | |
| noise_pred (torch.Tensor): The noise prediction tensor. | |
| do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. | |
| guidance_scale (float): The scale factor for the guidance term. | |
| t (int): The current time step. | |
| return_pred_text (bool): Whether to return the text noise prediction. | |
| Returns: | |
| Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying | |
| perturbed attention guidance and the text noise prediction. | |
| """ | |
| pag_scale = self._get_pag_scale(t) | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) | |
| noise_pred = ( | |
| noise_pred_uncond | |
| + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| + pag_scale * (noise_pred_text - noise_pred_perturb) | |
| ) | |
| else: | |
| noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) | |
| noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) | |
| if return_pred_text: | |
| return noise_pred, noise_pred_text | |
| return noise_pred | |
| def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): | |
| """ | |
| Prepares the perturbed attention guidance for the PAG model. | |
| Args: | |
| cond (torch.Tensor): The conditional input tensor. | |
| uncond (torch.Tensor): The unconditional input tensor. | |
| do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance. | |
| Returns: | |
| torch.Tensor: The prepared perturbed attention guidance tensor. | |
| """ | |
| cond = torch.cat([cond] * 2, dim=0) | |
| if do_classifier_free_guidance: | |
| cond = torch.cat([uncond, cond], dim=0) | |
| return cond | |
| def set_pag_applied_layers( | |
| self, | |
| pag_applied_layers: Union[str, List[str]], | |
| pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( | |
| PAGCFGIdentitySelfAttnProcessor2_0(), | |
| PAGIdentitySelfAttnProcessor2_0(), | |
| ), | |
| ): | |
| r""" | |
| Set the self-attention layers to apply PAG. Raise ValueError if the input is invalid. | |
| Args: | |
| pag_applied_layers (`str` or `List[str]`): | |
| One or more strings identifying the layer names, or a simple regex for matching multiple layers, where | |
| PAG is to be applied. A few ways of expected usage are as follows: | |
| - Single layers specified as - "blocks.{layer_index}" | |
| - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] | |
| - Multiple layers as a block name - "mid" | |
| - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" | |
| pag_attn_processors: | |
| (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), | |
| PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention | |
| processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second | |
| attention processor is for PAG with CFG disabled (unconditional only). | |
| """ | |
| if not hasattr(self, "_pag_attn_processors"): | |
| self._pag_attn_processors = None | |
| if not isinstance(pag_applied_layers, list): | |
| pag_applied_layers = [pag_applied_layers] | |
| if pag_attn_processors is not None: | |
| if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: | |
| raise ValueError("Expected a tuple of two attention processors") | |
| for i in range(len(pag_applied_layers)): | |
| if not isinstance(pag_applied_layers[i], str): | |
| raise ValueError( | |
| f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" | |
| ) | |
| self.pag_applied_layers = pag_applied_layers | |
| self._pag_attn_processors = pag_attn_processors | |
| def pag_scale(self) -> float: | |
| r"""Get the scale factor for the perturbed attention guidance.""" | |
| return self._pag_scale | |
| def pag_adaptive_scale(self) -> float: | |
| r"""Get the adaptive scale factor for the perturbed attention guidance.""" | |
| return self._pag_adaptive_scale | |
| def do_pag_adaptive_scaling(self) -> bool: | |
| r"""Check if the adaptive scaling is enabled for the perturbed attention guidance.""" | |
| return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 | |
| def do_perturbed_attention_guidance(self) -> bool: | |
| r"""Check if the perturbed attention guidance is enabled.""" | |
| return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 | |
| def pag_attn_processors(self) -> Dict[str, AttentionProcessor]: | |
| r""" | |
| Returns: | |
| `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model | |
| with the key as the name of the layer. | |
| """ | |
| if self._pag_attn_processors is None: | |
| return {} | |
| valid_attn_processors = {x.__class__ for x in self._pag_attn_processors} | |
| processors = {} | |
| # We could have iterated through the self.components.items() and checked if a component is | |
| # `ModelMixin` subclassed but that can include a VAE too. | |
| if hasattr(self, "unet"): | |
| denoiser_module = self.unet | |
| elif hasattr(self, "transformer"): | |
| denoiser_module = self.transformer | |
| else: | |
| raise ValueError("No denoiser module found.") | |
| for name, proc in denoiser_module.attn_processors.items(): | |
| if proc.__class__ in valid_attn_processors: | |
| processors[name] = proc | |
| return processors | |