UniCtrl / ptp_utils.py
Xuweiyi's picture
Upload 94 files
ee823b7 verified
raw
history blame
7.23 kB
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from typing import Optional, Union, Tuple, Dict
from einops import rearrange, repeat
def register_attention_control(model, controller):
def block_forward(self, place_in_unet):
def forward(
hidden_states,
encoder_hidden_states=None,
timestep=None,
attention_mask=None,
video_length=None,
):
# SparseCausal-Attention
norm_hidden_states = (
self.norm1(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm1(hidden_states)
)
norm_hidden_states, k_input, v_input = controller(
norm_hidden_states, video_length, place_in_unet
)
if self.unet_use_cross_frame_attention:
hidden_states = (
self.attn1(
norm_hidden_states,
k_input=k_input,
v_input=v_input,
attention_mask=attention_mask,
video_length=video_length,
)
+ hidden_states
)
else:
hidden_states = (
self.attn1(
norm_hidden_states,
k_input=k_input,
v_input=v_input,
attention_mask=attention_mask,
)
+ hidden_states
)
if self.attn2 is not None:
# Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
)
norm_hidden_states, _, _ = controller(
norm_hidden_states, video_length, place_in_unet
)
hidden_states = (
self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
+ hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
# Temporal-Attention
if self.unet_use_temporal_attention:
d = hidden_states.shape[1]
hidden_states = rearrange(
hidden_states, "(b f) d c -> (b d) f c", f=video_length
)
norm_hidden_states = (
self.norm_temp(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm_temp(hidden_states)
)
norm_hidden_states, _, _ = controller(
norm_hidden_states, d, place_in_unet
)
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
return forward
def register_recr(net_, count, place_in_unet):
if net_.__class__.__name__ == "BasicTransformerBlock":
net_.forward = block_forward(net_, place_in_unet)
# net_.__class__.__name__ = "BasicTransformerBlock_edit"
return count + 1
elif hasattr(net_, "children"):
for net__ in net_.children():
count = register_recr(net__, count, place_in_unet)
return count
cross_att_count = 0
sub_nets = model.unet.named_children()
for net in sub_nets:
if "down" in net[0]:
cross_att_count += register_recr(net[1], 0, "down")
elif "up" in net[0]:
cross_att_count += register_recr(net[1], 0, "up")
elif "mid" in net[0]:
cross_att_count += register_recr(net[1], 0, "mid")
controller.num_att_layers = cross_att_count * 2
def get_word_inds(text: str, word_place: int, tokenizer):
split_text = text.split(" ")
if type(word_place) is str:
word_place = [i for i, word in enumerate(split_text) if word_place == word]
elif type(word_place) is int:
word_place = [word_place]
out = []
if len(word_place) > 0:
words_encode = [
tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)
][1:-1]
cur_len, ptr = 0, 0
for i in range(len(words_encode)):
cur_len += len(words_encode[i])
if ptr in word_place:
out.append(i + 1)
if cur_len >= len(split_text[ptr]):
ptr += 1
cur_len = 0
return np.array(out)
def update_alpha_time_word(
alpha,
bounds: Union[float, Tuple[float, float]],
prompt_ind: int,
word_inds: Optional[torch.Tensor] = None,
):
if type(bounds) is float:
bounds = 0, bounds
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
if word_inds is None:
word_inds = torch.arange(alpha.shape[2])
alpha[:start, prompt_ind, word_inds] = 0
alpha[start:end, prompt_ind, word_inds] = 1
alpha[end:, prompt_ind, word_inds] = 0
return alpha
def get_time_words_attention_alpha(
prompts,
num_steps,
cross_replace_steps: Union[
float, Tuple[float, float], Dict[str, Tuple[float, float]]
],
tokenizer,
max_num_words=77,
):
if type(cross_replace_steps) is not dict:
cross_replace_steps = {"default_": cross_replace_steps}
if "default_" not in cross_replace_steps:
cross_replace_steps["default_"] = (0.0, 1.0)
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
for i in range(len(prompts) - 1):
alpha_time_words = update_alpha_time_word(
alpha_time_words, cross_replace_steps["default_"], i
)
for key, item in cross_replace_steps.items():
if key != "default_":
inds = [
get_word_inds(prompts[i], key, tokenizer)
for i in range(1, len(prompts))
]
for i, ind in enumerate(inds):
if len(ind) > 0:
alpha_time_words = update_alpha_time_word(
alpha_time_words, item, i, ind
)
alpha_time_words = alpha_time_words.reshape(
num_steps + 1, len(prompts) - 1, 1, 1, max_num_words
) # time, batch, heads, pixels, words
return alpha_time_words