from __future__ import annotations import os import cv2 import abc from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from diffusers.models.attention import Attention from PIL import Image import random import matplotlib.pyplot as plt import pdb import math from PIL import Image class P2PCrossAttnProcessor: def __init__(self, controller, place_in_unet): super().__init__() self.controller = controller self.place_in_unet = place_in_unet def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) is_cross = encoder_hidden_states is not None encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) # one line change self.controller(attention_probs, is_cross, self.place_in_unet) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states def create_controller( prompts: List[str], cross_attention_kwargs: Dict, num_inference_steps: int, tokenizer, device, attn_res ) -> AttentionControl: edit_type = cross_attention_kwargs.get("edit_type", None) local_blend_words = cross_attention_kwargs.get("local_blend_words", None) equalizer_words = cross_attention_kwargs.get("equalizer_words", None) equalizer_strengths = cross_attention_kwargs.get("equalizer_strengths", None) n_cross_replace = cross_attention_kwargs.get("n_cross_replace", 0.4) n_self_replace = cross_attention_kwargs.get("n_self_replace", 0.4) if edit_type == 'visualize': return AttentionStore(device=device) # only replace if edit_type == "replace" and local_blend_words is None: return AttentionReplace( prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device, attn_res=attn_res ) # replace + localblend if edit_type == "replace" and local_blend_words is not None: lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res) return AttentionReplace( prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device, attn_res=attn_res ) # only refine if edit_type == "refine" and local_blend_words is None: return AttentionRefine( prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device, attn_res=attn_res ) # refine + localblend if edit_type == "refine" and local_blend_words is not None: lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res) return AttentionRefine( prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device, attn_res=attn_res ) # only reweight if edit_type == "reweight" and local_blend_words is None: assert ( equalizer_words is not None and equalizer_strengths is not None ), "To use reweight edit, please specify equalizer_words and equalizer_strengths." assert len(equalizer_words) == len( equalizer_strengths ), "equalizer_words and equalizer_strengths must be of same length." equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer) return AttentionReweight( prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device, equalizer=equalizer, attn_res=attn_res, ) # reweight and localblend if edit_type == "reweight" and local_blend_words: assert ( equalizer_words is not None and equalizer_strengths is not None ), "To use reweight edit, please specify equalizer_words and equalizer_strengths." assert len(equalizer_words) == len( equalizer_strengths ), "equalizer_words and equalizer_strengths must be of same length." equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer) lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res) return AttentionReweight( prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device, equalizer=equalizer, attn_res=attn_res, local_blend=lb, ) raise ValueError(f"Edit type {edit_type} not recognized. Use one of: replace, refine, reweight.") class AttentionControl(abc.ABC): def step_callback(self, x_t): return x_t def between_steps(self): return @property def num_uncond_att_layers(self): return 0 @abc.abstractmethod def forward(self, attn, is_cross: bool, place_in_unet: str): raise NotImplementedError def __call__(self, attn, is_cross: bool, place_in_unet: str): if self.cur_att_layer >= self.num_uncond_att_layers: h = attn.shape[0] attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: self.cur_att_layer = 0 self.cur_step += 1 self.between_steps() return attn def reset(self): self.cur_step = 0 self.cur_att_layer = 0 def __init__(self, attn_res=None): self.cur_step = 0 self.num_att_layers = -1 self.cur_att_layer = 0 self.attn_res = attn_res class EmptyControl(AttentionControl): def forward(self, attn, is_cross: bool, place_in_unet: str): return attn class AttentionStore(AttentionControl): @staticmethod def get_empty_store(): return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} def forward(self, attn, is_cross: bool, place_in_unet: str): key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" if attn.shape[1] <= 32**2: # avoid memory overhead if self.device.type != 'cuda': attn = attn.cpu() self.step_store[key].append(attn) return attn def between_steps(self): if len(self.attention_store) == 0: self.attention_store = self.step_store else: for key in self.attention_store: for i in range(len(self.attention_store[key])): self.attention_store[key][i] += self.step_store[key][i] self.step_store = self.get_empty_store() def get_average_attention(self): average_attention = { key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store } return average_attention def reset(self): super(AttentionStore, self).reset() self.step_store = self.get_empty_store() self.attention_store = {} def __init__(self, attn_res=None, device='cuda'): super(AttentionStore, self).__init__(attn_res) self.step_store = self.get_empty_store() self.attention_store = {} self.device = device class LocalBlend: def __call__(self, x_t, attention_store): # note that this code works on the latent level! k = 1 # maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] # These are the numbers because we want to take layers that are 256 x 256, I think this can be changed to something smarter...like, get all attentions where thesecond dim is self.attn_res[0] * self.attn_res[1] in up and down cross. maps = [m for m in attention_store["down_cross"] + attention_store["mid_cross"] + attention_store["up_cross"] if m.shape[1] == self.attn_res[0] * self.attn_res[1]] maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, self.attn_res[0], self.attn_res[1], self.max_num_words) for item in maps] maps = torch.cat(maps, dim=1) maps = (maps * self.alpha_layers).sum(-1).mean(1) # since alpha_layers is all 0s except where we edit, the product zeroes out all but what we change. Then, the sum adds the values of the original and what we edit. Then, we average across dim=1, which is the number of layers. mask = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) mask = F.interpolate(mask, size=(x_t.shape[2:])) mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] mask = mask.gt(self.threshold) mask = mask[:1] + mask[1:] mask = mask.to(torch.float16) x_t = x_t[:1] + mask * (x_t - x_t[:1]) # x_t[:1] is the original image. mask*(x_t - x_t[:1]) zeroes out the original image and removes the difference between the original and each image we are generating (mostly just one). Then, it applies the mask on the image. That is, it's only keeping the cells we want to generate. return x_t def __init__( self, prompts: List[str], words: [List[List[str]]], tokenizer, device, threshold=0.3, attn_res=None ): self.max_num_words = 77 self.attn_res = attn_res alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words) for i, (prompt, words_) in enumerate(zip(prompts, words)): if isinstance(words_, str): words_ = [words_] for word in words_: ind = get_word_inds(prompt, word, tokenizer) alpha_layers[i, :, :, :, :, ind] = 1 self.alpha_layers = alpha_layers.to(device) # a one-hot vector where the 1s are the words we modify (source and target) self.threshold = threshold class AttentionControlEdit(AttentionStore, abc.ABC): def step_callback(self, x_t): if self.local_blend is not None: x_t = self.local_blend(x_t, self.attention_store) return x_t def replace_self_attention(self, attn_base, att_replace): if att_replace.shape[2] <= self.attn_res[0]**2: return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) else: return att_replace @abc.abstractmethod def replace_cross_attention(self, attn_base, att_replace): raise NotImplementedError def forward(self, attn, is_cross: bool, place_in_unet: str): super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): h = attn.shape[0] // (self.batch_size) attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) attn_base, attn_replace = attn[0], attn[1:] if is_cross: alpha_words = self.cross_replace_alpha[self.cur_step] attn_replace_new = ( self.replace_cross_attention(attn_base, attn_replace) * alpha_words + (1 - alpha_words) * attn_replace ) attn[1:] = attn_replace_new else: attn[1:] = self.replace_self_attention(attn_base, attn_replace) attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) return attn def __init__( self, prompts, num_steps: int, cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], self_replace_steps: Union[float, Tuple[float, float]], local_blend: Optional[LocalBlend], tokenizer, device, attn_res=None, ): super(AttentionControlEdit, self).__init__(attn_res=attn_res) # add tokenizer and device here self.tokenizer = tokenizer self.device = device self.batch_size = len(prompts) self.cross_replace_alpha = get_time_words_attention_alpha( prompts, num_steps, cross_replace_steps, self.tokenizer ).to(self.device) if isinstance(self_replace_steps, float): self_replace_steps = 0, self_replace_steps self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) self.local_blend = local_blend class AttentionReplace(AttentionControlEdit): def replace_cross_attention(self, attn_base, att_replace): return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper) def __init__( self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, local_blend: Optional[LocalBlend] = None, tokenizer=None, device=None, attn_res=None, ): super(AttentionReplace, self).__init__( prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res ) self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device) class AttentionRefine(AttentionControlEdit): def replace_cross_attention(self, attn_base, att_replace): attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) return attn_replace def __init__( self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, local_blend: Optional[LocalBlend] = None, tokenizer=None, device=None, attn_res=None ): super(AttentionRefine, self).__init__( prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res ) self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer) self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device) self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) class AttentionReweight(AttentionControlEdit): def replace_cross_attention(self, attn_base, att_replace): if self.prev_controller is not None: attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] return attn_replace def __init__( self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None, tokenizer=None, device=None, attn_res=None, ): super(AttentionReweight, self).__init__( prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res ) self.equalizer = equalizer.to(self.device) self.prev_controller = controller ### util functions for all Edits def update_alpha_time_word( alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None ): if isinstance(bounds, float): bounds = 0, bounds start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) if word_inds is None: word_inds = torch.arange(alpha.shape[2]) alpha[:start, prompt_ind, word_inds] = 0 alpha[start:end, prompt_ind, word_inds] = 1 alpha[end:, prompt_ind, word_inds] = 0 return alpha def get_time_words_attention_alpha( prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77 ): if not isinstance(cross_replace_steps, dict): cross_replace_steps = {"default_": cross_replace_steps} if "default_" not in cross_replace_steps: cross_replace_steps["default_"] = (0.0, 1.0) alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) for i in range(len(prompts) - 1): alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i) for key, item in cross_replace_steps.items(): if key != "default_": inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] for i, ind in enumerate(inds): if len(ind) > 0: alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) return alpha_time_words ### util functions for LocalBlend and ReplacementEdit def get_word_inds(text: str, word_place: int, tokenizer): split_text = text.split(" ") if isinstance(word_place, str): word_place = [i for i, word in enumerate(split_text) if word_place == word] elif isinstance(word_place, int): word_place = [word_place] out = [] if len(word_place) > 0: words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] cur_len, ptr = 0, 0 for i in range(len(words_encode)): cur_len += len(words_encode[i]) if ptr in word_place: out.append(i + 1) if cur_len >= len(split_text[ptr]): ptr += 1 cur_len = 0 return np.array(out) ### util functions for ReplacementEdit def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): words_x = x.split(" ") words_y = y.split(" ") if len(words_x) != len(words_y): raise ValueError( f"attention replacement edit can only be applied on prompts with the same length" f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words." ) inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] mapper = np.zeros((max_len, max_len)) i = j = 0 cur_inds = 0 while i < max_len and j < max_len: if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] if len(inds_source_) == len(inds_target_): mapper[inds_source_, inds_target_] = 1 else: ratio = 1 / len(inds_target_) for i_t in inds_target_: mapper[inds_source_, i_t] = ratio cur_inds += 1 i += len(inds_source_) j += len(inds_target_) elif cur_inds < len(inds_source): mapper[i, j] = 1 i += 1 j += 1 else: mapper[j, j] = 1 i += 1 j += 1 # return torch.from_numpy(mapper).float() return torch.from_numpy(mapper).to(torch.float16) def get_replacement_mapper(prompts, tokenizer, max_len=77): x_seq = prompts[0] mappers = [] for i in range(1, len(prompts)): mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) mappers.append(mapper) return torch.stack(mappers) ### util functions for ReweightEdit def get_equalizer( text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer ): if isinstance(word_select, (int, str)): word_select = (word_select,) equalizer = torch.ones(len(values), 77) values = torch.tensor(values, dtype=torch.float32) for i, word in enumerate(word_select): inds = get_word_inds(text, word, tokenizer) equalizer[:, inds] = torch.FloatTensor(values[i]) return equalizer ### util functions for RefinementEdit class ScoreParams: def __init__(self, gap, match, mismatch): self.gap = gap self.match = match self.mismatch = mismatch def mis_match_char(self, x, y): if x != y: return self.mismatch else: return self.match def get_matrix(size_x, size_y, gap): matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) matrix[0, 1:] = (np.arange(size_y) + 1) * gap matrix[1:, 0] = (np.arange(size_x) + 1) * gap return matrix def get_traceback_matrix(size_x, size_y): matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) matrix[0, 1:] = 1 matrix[1:, 0] = 2 matrix[0, 0] = 4 return matrix def global_align(x, y, score): matrix = get_matrix(len(x), len(y), score.gap) trace_back = get_traceback_matrix(len(x), len(y)) for i in range(1, len(x) + 1): for j in range(1, len(y) + 1): left = matrix[i, j - 1] + score.gap up = matrix[i - 1, j] + score.gap diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) matrix[i, j] = max(left, up, diag) if matrix[i, j] == left: trace_back[i, j] = 1 elif matrix[i, j] == up: trace_back[i, j] = 2 else: trace_back[i, j] = 3 return matrix, trace_back def get_aligned_sequences(x, y, trace_back): x_seq = [] y_seq = [] i = len(x) j = len(y) mapper_y_to_x = [] while i > 0 or j > 0: if trace_back[i, j] == 3: x_seq.append(x[i - 1]) y_seq.append(y[j - 1]) i = i - 1 j = j - 1 mapper_y_to_x.append((j, i)) elif trace_back[i][j] == 1: x_seq.append("-") y_seq.append(y[j - 1]) j = j - 1 mapper_y_to_x.append((j, -1)) elif trace_back[i][j] == 2: x_seq.append(x[i - 1]) y_seq.append("-") i = i - 1 elif trace_back[i][j] == 4: break mapper_y_to_x.reverse() return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) def get_mapper(x: str, y: str, tokenizer, max_len=77): x_seq = tokenizer.encode(x) y_seq = tokenizer.encode(y) score = ScoreParams(0, 1, -1) matrix, trace_back = global_align(x_seq, y_seq, score) mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] alphas = torch.ones(max_len) alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() mapper = torch.zeros(max_len, dtype=torch.int64) mapper[: mapper_base.shape[0]] = mapper_base[:, 1] mapper[mapper_base.shape[0] :] = len(y_seq) + torch.arange(max_len - len(y_seq)) return mapper, alphas def get_refinement_mapper(prompts, tokenizer, max_len=77): x_seq = prompts[0] mappers, alphas = [], [] for i in range(1, len(prompts)): mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) mappers.append(mapper) alphas.append(alpha) return torch.stack(mappers), torch.stack(alphas) def aggregate_attention(prompts, attention_store: AttentionStore, height: int, width: int, from_where: List[str], is_cross: bool, select: int): out = [] attention_maps = attention_store.get_average_attention() attention_map_height = height // 32 attention_map_width = width // 32 num_pixels = attention_map_height * attention_map_width for location in from_where: for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: if item.shape[1] == num_pixels: cross_maps = item.reshape(len(prompts), -1, attention_map_width, attention_map_height, item.shape[-1])[select] out.append(cross_maps) out = torch.cat(out, dim=0) out = out.sum(0) / out.shape[0] return out.cpu() def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0, t=0): tokens = tokenizer.encode(prompts[select]) decoder = tokenizer.decode attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select) images = [] for i in range(len(tokens)): image = attention_maps[:, :, i] image = 255 * image / image.max() image = image.unsqueeze(-1).expand(*image.shape, 3) image = image.numpy().astype(np.uint8) image = np.array(Image.fromarray(image).resize((256, 256))) image = text_under_image(image, decoder(int(tokens[i]))) images.append(image) view_images(np.stack(images, axis=0), t=t, from_where=from_where) def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], max_com=10, select: int = 0): attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2)) u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) images = [] for i in range(max_com): image = vh[i].reshape(res, res) image = image - image.min() image = 255 * image / image.max() image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) image = Image.fromarray(image).resize((256, 256)) image = np.array(image) images.append(image) view_images(np.concatenate(images, axis=1),from_where=from_where) def view_images(images, num_rows=1, offset_ratio=0.02, t=0, from_where= List[str]): if type(images) is list: num_empty = len(images) % num_rows elif images.ndim == 4: num_empty = images.shape[0] % num_rows else: images = [images] num_empty = 0 empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty num_items = len(images) h, w, c = images[0].shape offset = int(h * offset_ratio) num_cols = num_items // num_rows image_ = np.ones((h * num_rows + offset * (num_rows - 1), w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 for i in range(num_rows): for j in range(num_cols): image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ i * num_cols + j] pil_img = Image.fromarray(image_) if len(from_where) > 1: from_where = '_'.join(from_where) save_path = f'./visualization/{from_where}' if not os.path.exists(save_path): os.mkdir(save_path) pil_img.save(f"{save_path}/{t}.png") def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): h, w, c = image.shape offset = int(h * .2) img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 font = cv2.FONT_HERSHEY_SIMPLEX img[:h] = image textsize = cv2.getTextSize(text, font, 1, 2)[0] text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) return img def get_views(height, width, window_size=32, stride=16, random_jitter=False): num_blocks_height = int((height - window_size) / stride - 1e-6) + 2 if height > window_size else 1 num_blocks_width = int((width - window_size) / stride - 1e-6) + 2 if width > window_size else 1 total_num_blocks = int(num_blocks_height * num_blocks_width) views = [] for i in range(total_num_blocks): h_start = int((i // num_blocks_width) * stride) h_end = h_start + window_size w_start = int((i % num_blocks_width) * stride) w_end = w_start + window_size if h_end > height: h_start = int(h_start + height - h_end) h_end = int(height) if w_end > width: w_start = int(w_start + width - w_end) w_end = int(width) if h_start < 0: h_end = int(h_end - h_start) h_start = 0 if w_start < 0: w_end = int(w_end - w_start) w_start = 0 if random_jitter: jitter_range = (window_size - stride) // 4 w_jitter = 0 h_jitter = 0 if (w_start != 0) and (w_end != width): w_jitter = random.randint(-jitter_range, jitter_range) elif (w_start == 0) and (w_end != width): w_jitter = random.randint(-jitter_range, 0) elif (w_start != 0) and (w_end == width): w_jitter = random.randint(0, jitter_range) if (h_start != 0) and (h_end != height): h_jitter = random.randint(-jitter_range, jitter_range) elif (h_start == 0) and (h_end != height): h_jitter = random.randint(-jitter_range, 0) elif (h_start != 0) and (h_end == height): h_jitter = random.randint(0, jitter_range) h_start += (h_jitter + jitter_range) h_end += (h_jitter + jitter_range) w_start += (w_jitter + jitter_range) w_end += (w_jitter + jitter_range) views.append((int(h_start), int(h_end), int(w_start), int(w_end))) return views def get_multidiffusion_prompts(tokenizer, prompts, threthod, attention_store:AttentionStore, height:int, width:int, from_where: List[str], scale_num=4, random_jitter=False): tokens = tokenizer.encode(prompts[0]) decoder = tokenizer.decode # get cross_attention_maps attention_maps = aggregate_attention(prompts, attention_store, height, width, from_where, True, 0) # view cross_attention_maps images = [] for i in range(len(tokens)): image = attention_maps[:, :, i] image = 255 * image / image.max() image = image.unsqueeze(-1).expand(*image.shape, 3).numpy().astype(np.uint8) image = np.array(Image.fromarray(image).resize((256, 256))) image = text_under_image(image, decoder(int(tokens[i]))) images.append(image) # get high attention regions masks = [] for i in range(len(tokens)): attention_map = attention_maps[:, :, i] attention_map = attention_map.to(torch.float32) words = decoder(int(tokens[i])) mask = torch.where(attention_map > attention_map.mean(), 1, 0).numpy().astype(np.uint8) mask = mask * 255 # process mask kernel = np.ones((3, 3), np.uint8) eroded_mask = cv2.erode(mask, kernel, iterations=mask.shape[0]//16) dilated_mask = cv2.dilate(eroded_mask, kernel, iterations=mask.shape[0]//16) masks.append(dilated_mask) # dict for prompts and views prompt_dict = {} view_dict = {} ori_w, ori_h = mask.shape window_size = max(ori_h, ori_w) for scale in range(2, scale_num+1): # current height and width cur_w = ori_w * scale cur_h = ori_h * scale views = get_views(height=cur_h, width=cur_w, window_size=window_size, stride=window_size/2, random_jitter=random_jitter) words_in_patch = [] for i, mask in enumerate(masks): # skip endoftext and beginof text masks if i == 0 or i == len(masks) - 1: continue # upscale masks mask = cv2.resize(mask, (cur_w, cur_h), interpolation=cv2.INTER_NEAREST) if random_jitter: jitter_range = int((ori_h - ori_h/2) // 4) mask = np.pad(mask, ((jitter_range, jitter_range), (jitter_range, jitter_range)), 'constant', constant_values=(0, 0)) word_in_patch =[] word = decoder(int(tokens[i])) for i, view in enumerate(views): h_start, h_end, w_start, w_end = view view_mask = mask[h_start:h_end, w_start:w_end] if (view_mask/255).sum() / (ori_h * ori_w) >= threthod: word_in_patch.append(word) # word in patch else: word_in_patch.append('') # word not in patch words_in_patch.append(word_in_patch) # get prompts for each view result = [] prompts_for_each_views = [' '.join(strings) for strings in zip(*words_in_patch)] for prompt in prompts_for_each_views: prompt = prompt.split() result.append(" ".join(prompt)) # save prompts and views in each scale prompt_dict[scale] = result view_dict[scale] = views return prompt_dict, view_dict class ScaledAttnProcessor: r""" Default processor for performing attention-related computations. """ def __init__(self, processor, test_res, train_res): self.processor = processor self.test_res = test_res self.train_res = train_res def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): input_ndim = hidden_states.ndim # print(f"cross attention: {not encoder_hidden_states is None}") # if encoder_hidden_states is None: if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape sequence_length = height * width else: batch_size, sequence_length, _ = hidden_states.shape test_train_ratio = (self.test_res ** 2.0) / (self.train_res ** 2.0) # test_train_ratio = float(self.test_res / self.train_res) # print(f"test_train_ratio: {test_train_ratio}") train_sequence_length = sequence_length / test_train_ratio scale_factor = math.log(sequence_length, train_sequence_length) ** 0.5 # else: # scale_factor = 1 # print(f"scale factor: {scale_factor}") original_scale = attn.scale attn.scale = attn.scale * scale_factor hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, temb, scale = attn.scale ) # hidden_states = super(ScaledAttnProcessor, self).__call__( # attn, hidden_states, encoder_hidden_states, attention_mask, temb) attn.scale = original_scale return hidden_states