Spaces:
Running
on
Zero
Running
on
Zero
| from diffusers.utils import logging | |
| logger = logging.get_logger(__name__) | |
| class TransformerDiffusionMixin: | |
| r""" | |
| Helper for DiffusionPipeline with vae and transformer.(mainly for DIT) | |
| """ | |
| def enable_vae_slicing(self): | |
| r""" | |
| Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to | |
| compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. | |
| """ | |
| self.vae.enable_slicing() | |
| def disable_vae_slicing(self): | |
| r""" | |
| Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to | |
| computing decoding in one step. | |
| """ | |
| self.vae.disable_slicing() | |
| def enable_vae_tiling(self): | |
| r""" | |
| Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to | |
| compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow | |
| processing larger images. | |
| """ | |
| self.vae.enable_tiling() | |
| def disable_vae_tiling(self): | |
| r""" | |
| Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to | |
| computing decoding in one step. | |
| """ | |
| self.vae.disable_tiling() | |
| def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True): | |
| """ | |
| Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) | |
| are fused. For cross-attention modules, key and value projection matrices are fused. | |
| <Tip warning={true}> | |
| This API is 🧪 experimental. | |
| </Tip> | |
| Args: | |
| transformer (`bool`, defaults to `True`): To apply fusion on the Transformer. | |
| vae (`bool`, defaults to `True`): To apply fusion on the VAE. | |
| """ | |
| self.fusing_transformer = False | |
| self.fusing_vae = False | |
| if transformer: | |
| self.fusing_transformer = True | |
| self.transformer.fuse_qkv_projections() | |
| if vae: | |
| self.fusing_vae = True | |
| self.vae.fuse_qkv_projections() | |
| def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True): | |
| """Disable QKV projection fusion if enabled. | |
| <Tip warning={true}> | |
| This API is 🧪 experimental. | |
| </Tip> | |
| Args: | |
| transformer (`bool`, defaults to `True`): To apply fusion on the Transformer. | |
| vae (`bool`, defaults to `True`): To apply fusion on the VAE. | |
| """ | |
| if transformer: | |
| if not self.fusing_transformer: | |
| logger.warning( | |
| "The UNet was not initially fused for QKV projections. Doing nothing." | |
| ) | |
| else: | |
| self.transformer.unfuse_qkv_projections() | |
| self.fusing_transformer = False | |
| if vae: | |
| if not self.fusing_vae: | |
| logger.warning( | |
| "The VAE was not initially fused for QKV projections. Doing nothing." | |
| ) | |
| else: | |
| self.vae.unfuse_qkv_projections() | |
| self.fusing_vae = False | |