Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
from diffusers.models.attention_processor import AttnProcessor2_0 | |
from diffusers.models.attention import BasicTransformerBlock | |
from diffusers.utils.import_utils import is_xformers_available | |
from transformers.models.clip.modeling_clip import CLIPEncoder | |
GRADIENT_CHECKPOINTING = True | |
TEXT_ENCODER_GRADIENT_CHECKPOINTING = True | |
ENABLE_XFORMERS_MEMORY_EFFICIENT_ATTENTION = True | |
ENABLE_TORCH_2_ATTN = True | |
def is_attn(name): | |
return ('attn1' or 'attn2' == name.split('.')[-1]) | |
def unet_and_text_g_c(unet, text_encoder, unet_enable=GRADIENT_CHECKPOINTING, text_enable=TEXT_ENCODER_GRADIENT_CHECKPOINTING): | |
unet._set_gradient_checkpointing(value=unet_enable) | |
text_encoder._set_gradient_checkpointing(CLIPEncoder) | |
def set_processors(attentions): | |
for attn in attentions: attn.set_processor(AttnProcessor2_0()) | |
def set_torch_2_attn(unet): | |
optim_count = 0 | |
for name, module in unet.named_modules(): | |
if is_attn(name): | |
if isinstance(module, torch.nn.ModuleList): | |
for m in module: | |
if isinstance(m, BasicTransformerBlock): | |
set_processors([m.attn1, m.attn2]) | |
optim_count += 1 | |
if optim_count > 0: | |
print(f"{optim_count} Attention layers using Scaled Dot Product Attention.") | |
def handle_memory_attention( | |
unet, | |
enable_xformers_memory_efficient_attention=ENABLE_XFORMERS_MEMORY_EFFICIENT_ATTENTION, | |
enable_torch_2_attn=ENABLE_TORCH_2_ATTN | |
): | |
try: | |
is_torch_2 = hasattr(F, 'scaled_dot_product_attention') | |
enable_torch_2 = is_torch_2 and enable_torch_2_attn | |
if enable_xformers_memory_efficient_attention and not enable_torch_2: | |
if is_xformers_available(): | |
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp | |
unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) | |
else: | |
raise ValueError("xformers is not available. Make sure it is installed correctly") | |
if enable_torch_2: | |
set_torch_2_attn(unet) | |
except: | |
print("Could not enable memory efficient attention for xformers or Torch 2.0.") |