|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
from typing import Optional, Union, Tuple, Dict |
|
from einops import rearrange, repeat |
|
|
|
|
|
def register_attention_control(model, controller): |
|
def block_forward(self, place_in_unet): |
|
|
|
def forward( |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
timestep=None, |
|
attention_mask=None, |
|
video_length=None, |
|
): |
|
|
|
norm_hidden_states = ( |
|
self.norm1(hidden_states, timestep) |
|
if self.use_ada_layer_norm |
|
else self.norm1(hidden_states) |
|
) |
|
|
|
norm_hidden_states, k_input, v_input = controller( |
|
norm_hidden_states, video_length, place_in_unet |
|
) |
|
|
|
if self.unet_use_cross_frame_attention: |
|
hidden_states = ( |
|
self.attn1( |
|
norm_hidden_states, |
|
k_input=k_input, |
|
v_input=v_input, |
|
attention_mask=attention_mask, |
|
video_length=video_length, |
|
) |
|
+ hidden_states |
|
) |
|
else: |
|
hidden_states = ( |
|
self.attn1( |
|
norm_hidden_states, |
|
k_input=k_input, |
|
v_input=v_input, |
|
attention_mask=attention_mask, |
|
) |
|
+ hidden_states |
|
) |
|
|
|
if self.attn2 is not None: |
|
|
|
norm_hidden_states = ( |
|
self.norm2(hidden_states, timestep) |
|
if self.use_ada_layer_norm |
|
else self.norm2(hidden_states) |
|
) |
|
norm_hidden_states, _, _ = controller( |
|
norm_hidden_states, video_length, place_in_unet |
|
) |
|
hidden_states = ( |
|
self.attn2( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
) |
|
+ hidden_states |
|
) |
|
|
|
|
|
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states |
|
|
|
|
|
if self.unet_use_temporal_attention: |
|
d = hidden_states.shape[1] |
|
hidden_states = rearrange( |
|
hidden_states, "(b f) d c -> (b d) f c", f=video_length |
|
) |
|
norm_hidden_states = ( |
|
self.norm_temp(hidden_states, timestep) |
|
if self.use_ada_layer_norm |
|
else self.norm_temp(hidden_states) |
|
) |
|
norm_hidden_states, _, _ = controller( |
|
norm_hidden_states, d, place_in_unet |
|
) |
|
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states |
|
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) |
|
|
|
return hidden_states |
|
|
|
return forward |
|
|
|
def register_recr(net_, count, place_in_unet): |
|
if net_.__class__.__name__ == "BasicTransformerBlock": |
|
net_.forward = block_forward(net_, place_in_unet) |
|
|
|
return count + 1 |
|
elif hasattr(net_, "children"): |
|
for net__ in net_.children(): |
|
count = register_recr(net__, count, place_in_unet) |
|
return count |
|
|
|
cross_att_count = 0 |
|
sub_nets = model.unet.named_children() |
|
for net in sub_nets: |
|
if "down" in net[0]: |
|
cross_att_count += register_recr(net[1], 0, "down") |
|
elif "up" in net[0]: |
|
cross_att_count += register_recr(net[1], 0, "up") |
|
elif "mid" in net[0]: |
|
cross_att_count += register_recr(net[1], 0, "mid") |
|
controller.num_att_layers = cross_att_count * 2 |
|
|
|
|
|
def get_word_inds(text: str, word_place: int, tokenizer): |
|
split_text = text.split(" ") |
|
if type(word_place) is str: |
|
word_place = [i for i, word in enumerate(split_text) if word_place == word] |
|
elif type(word_place) is int: |
|
word_place = [word_place] |
|
out = [] |
|
if len(word_place) > 0: |
|
words_encode = [ |
|
tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text) |
|
][1:-1] |
|
cur_len, ptr = 0, 0 |
|
|
|
for i in range(len(words_encode)): |
|
cur_len += len(words_encode[i]) |
|
if ptr in word_place: |
|
out.append(i + 1) |
|
if cur_len >= len(split_text[ptr]): |
|
ptr += 1 |
|
cur_len = 0 |
|
return np.array(out) |
|
|
|
|
|
def update_alpha_time_word( |
|
alpha, |
|
bounds: Union[float, Tuple[float, float]], |
|
prompt_ind: int, |
|
word_inds: Optional[torch.Tensor] = None, |
|
): |
|
if type(bounds) is float: |
|
bounds = 0, bounds |
|
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) |
|
if word_inds is None: |
|
word_inds = torch.arange(alpha.shape[2]) |
|
alpha[:start, prompt_ind, word_inds] = 0 |
|
alpha[start:end, prompt_ind, word_inds] = 1 |
|
alpha[end:, prompt_ind, word_inds] = 0 |
|
return alpha |
|
|
|
|
|
def get_time_words_attention_alpha( |
|
prompts, |
|
num_steps, |
|
cross_replace_steps: Union[ |
|
float, Tuple[float, float], Dict[str, Tuple[float, float]] |
|
], |
|
tokenizer, |
|
max_num_words=77, |
|
): |
|
if type(cross_replace_steps) is not dict: |
|
cross_replace_steps = {"default_": cross_replace_steps} |
|
if "default_" not in cross_replace_steps: |
|
cross_replace_steps["default_"] = (0.0, 1.0) |
|
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) |
|
for i in range(len(prompts) - 1): |
|
alpha_time_words = update_alpha_time_word( |
|
alpha_time_words, cross_replace_steps["default_"], i |
|
) |
|
for key, item in cross_replace_steps.items(): |
|
if key != "default_": |
|
inds = [ |
|
get_word_inds(prompts[i], key, tokenizer) |
|
for i in range(1, len(prompts)) |
|
] |
|
for i, ind in enumerate(inds): |
|
if len(ind) > 0: |
|
alpha_time_words = update_alpha_time_word( |
|
alpha_time_words, item, i, ind |
|
) |
|
alpha_time_words = alpha_time_words.reshape( |
|
num_steps + 1, len(prompts) - 1, 1, 1, max_num_words |
|
) |
|
return alpha_time_words |
|
|