# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch from typing import Optional, Union, Tuple, Dict from PIL import Image def save_images(images,dest, num_rows=1, offset_ratio=0.02): if type(images) is list: num_empty = len(images) % num_rows elif images.ndim == 4: num_empty = images.shape[0] % num_rows else: images = [images] num_empty = 0 pil_img = Image.fromarray(images[-1]) pil_img.save(dest) # display(pil_img) def save_image(images,dest, num_rows=1, offset_ratio=0.02): print(images.shape) pil_img = Image.fromarray(images[0]) pil_img.save(dest) def register_attention_control(model, controller): class AttnProcessor(): def __init__(self,place_in_unet): self.place_in_unet = place_in_unet def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, scale=1.0,): # The `Attention` class can call different attention processors / attention functions residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) h = attn.heads is_cross = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) q = attn.to_q(hidden_states) k = attn.to_k(encoder_hidden_states) v = attn.to_v(encoder_hidden_states) q = attn.head_to_batch_dim(q) k = attn.head_to_batch_dim(k) v = attn.head_to_batch_dim(v) if not is_cross: q,k,v = controller.self_attn_forward(q, k, v, attn.heads) attention_probs = attn.get_attention_scores(q, k, attention_mask) if is_cross: attention_probs = controller(attention_probs , is_cross, self.place_in_unet) # else: # out = controller.self_attn_forward(q, k, v, sim, attention_probs , is_cross, self.place_in_unet, attn.heads, scale=attn.scale) hidden_states = torch.bmm(attention_probs, v) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states, scale=scale) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def register_recr(net_, count, place_in_unet): for idx, m in enumerate(net_.modules()): # print(m.__class__.__name__) if m.__class__.__name__ == "Attention": count+=1 m.processor = AttnProcessor( place_in_unet) return count cross_att_count = 0 sub_nets = model.unet.named_children() for net in sub_nets: if "down" in net[0]: cross_att_count += register_recr(net[1], 0, "down") elif "up" in net[0]: cross_att_count += register_recr(net[1], 0, "up") elif "mid" in net[0]: cross_att_count += register_recr(net[1], 0, "mid") controller.num_att_layers = cross_att_count def get_word_inds(text: str, word_place: int, tokenizer): split_text = text.split(" ") if type(word_place) is str: word_place = [i for i, word in enumerate(split_text) if word_place == word] elif type(word_place) is int: word_place = [word_place] out = [] if len(word_place) > 0: words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] cur_len, ptr = 0, 0 for i in range(len(words_encode)): cur_len += len(words_encode[i]) if ptr in word_place: out.append(i + 1) if cur_len >= len(split_text[ptr]): ptr += 1 cur_len = 0 return np.array(out) def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor]=None): if type(bounds) is float: bounds = 0, bounds start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) if word_inds is None: word_inds = torch.arange(alpha.shape[2]) alpha[: start, prompt_ind, word_inds] = 0 alpha[start: end, prompt_ind, word_inds] = 1 alpha[end:, prompt_ind, word_inds] = 0 return alpha def get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77): if type(cross_replace_steps) is not dict: cross_replace_steps = {"default_": cross_replace_steps} if "default_" not in cross_replace_steps: cross_replace_steps["default_"] = (0., 1.) alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) for i in range(len(prompts) - 1): alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i) for key, item in cross_replace_steps.items(): if key != "default_": inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] for i, ind in enumerate(inds): if len(ind) > 0: alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) # time, batch, heads, pixels, words return alpha_time_words