Spaces:
Paused
Paused
File size: 1,978 Bytes
d950775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import torch
from diffusers.models.attention import CrossAttention
class MyCrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
# new bookkeeping to save the attn probs
attn.attn_probs = attention_probs
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
"""
A function that prepares a U-Net model for training by enabling gradient computation
for a specified set of parameters and setting the forward pass to be performed by a
custom cross attention processor.
Parameters:
unet: A U-Net model.
Returns:
unet: The prepared U-Net model.
"""
def prep_unet(unet):
# set the gradients for XA maps to be true
for name, params in unet.named_parameters():
if 'attn2' in name:
params.requires_grad = True
else:
params.requires_grad = False
# replace the fwd function
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_processor(MyCrossAttnProcessor())
return unet
|