ZePo / utils /ptp_utils.py
Jinl's picture
Add application file
a6cec16
raw
history blame
8.9 kB
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from typing import Optional, Union, Tuple, Dict
from PIL import Image
from . import merge
from .utils import isinstance_str, init_generator
def save_images(images,dest, num_rows=1, offset_ratio=0.02):
if type(images) is list:
num_empty = len(images) % num_rows
elif images.ndim == 4:
num_empty = images.shape[0] % num_rows
else:
images = [images]
num_empty = 0
pil_img = Image.fromarray(images[-1])
pil_img.save(dest)
# display(pil_img)
def save_image(images,dest, num_rows=1, offset_ratio=0.02):
print(images.shape)
pil_img = Image.fromarray(images[0])
pil_img.save(dest)
def register_attention_control(model, controller, tome, ratio, sx, sy, de_bug):
class AttnProcessor():
def __init__(self,place_in_unet,de_bug):
self.place_in_unet = place_in_unet
self.de_bug = de_bug
def __call__(self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=1.0,):
# The `Attention` class can call different attention processors / attention functions
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
h = attn.heads
is_cross = encoder_hidden_states is not None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
q = attn.to_q(hidden_states)
k = attn.to_k(encoder_hidden_states)
v = attn.to_v(encoder_hidden_states)
q = attn.head_to_batch_dim(q)
k = attn.head_to_batch_dim(k)
v = attn.head_to_batch_dim(v)
# print('unmerge:', q.shape)
#pass
attention_probs = attn.get_attention_scores(q, k, attention_mask) # bh,n,n
#
if is_cross:
pass
#attention_probs = controller(attention_probs , is_cross, self.place_in_unet)
x = hidden_states
hidden_states = torch.bmm(attention_probs, v)
if not is_cross:
if tome:
r = int(x.shape[1] * ratio)
H = W = int(np.sqrt(x.shape[1]))
generator = init_generator(x.device)
m, u = merge.bipartite_soft_matching_random2d(x, W, H, sx, sy, r,
no_rand=False, generator=generator)
x = m(x)
m_k = attn.to_k(x)
m_v = attn.to_v(x)
m_k = attn.head_to_batch_dim(m_k)
m_v = attn.head_to_batch_dim(m_v)
# print('merged:', m_q.shape)
# m_k = k
# m_v = v
#m_k, m_v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (m_k, m_v))
else:
m_k = k
m_v = v
# if self.de_bug:
# import pdb;pdb.set_trace()
h_s_re = controller(q, m_k, m_v, attn.heads, attention_probs, attn)
if h_s_re != None and hidden_states.shape[0]//attn.heads == 3:
hidden_states[2*attn.heads:]=h_s_re
if hidden_states.shape[0]//attn.heads != 3 and h_s_re != None:
(u_h_s_re, c_h_s_re) = h_s_re
if u_h_s_re != None:
hidden_states[2*attn.heads:3*attn.heads] = u_h_s_re
hidden_states[5*attn.heads:] = c_h_s_re
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def register_recr(net_, count, place_in_unet):
for idx, m in enumerate(net_.modules()):
# print(m.__class__.__name__)
if m.__class__.__name__ == "Attention":
count+=1
m.processor = AttnProcessor( place_in_unet, de_bug)
return count
cross_att_count = 0
sub_nets = model.unet.named_children()
for net in sub_nets:
if "down" in net[0]:
cross_att_count += register_recr(net[1], 0, "down")
elif "up" in net[0]:
cross_att_count += register_recr(net[1], 0, "up")
elif "mid" in net[0]:
cross_att_count += register_recr(net[1], 0, "mid")
controller.num_att_layers = cross_att_count
#print(f'this model have {cross_att_count} attn layer')
def get_word_inds(text: str, word_place: int, tokenizer):
split_text = text.split(" ")
if type(word_place) is str:
word_place = [i for i, word in enumerate(split_text) if word_place == word]
elif type(word_place) is int:
word_place = [word_place]
out = []
if len(word_place) > 0:
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
cur_len, ptr = 0, 0
for i in range(len(words_encode)):
cur_len += len(words_encode[i])
if ptr in word_place:
out.append(i + 1)
if cur_len >= len(split_text[ptr]):
ptr += 1
cur_len = 0
return np.array(out)
def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor]=None):
if type(bounds) is float:
bounds = 0, bounds
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
if word_inds is None:
word_inds = torch.arange(alpha.shape[2])
alpha[: start, prompt_ind, word_inds] = 0
alpha[start: end, prompt_ind, word_inds] = 1
alpha[end:, prompt_ind, word_inds] = 0
return alpha
def get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
tokenizer, max_num_words=77):
if type(cross_replace_steps) is not dict:
cross_replace_steps = {"default_": cross_replace_steps}
if "default_" not in cross_replace_steps:
cross_replace_steps["default_"] = (0., 1.)
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
for i in range(len(prompts) - 1):
alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
i)
for key, item in cross_replace_steps.items():
if key != "default_":
inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
for i, ind in enumerate(inds):
if len(ind) > 0:
alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) # time, batch, heads, pixels, words
return alpha_time_words