bounded-attention / bounded_attention.py
omer11a's picture
Upload app
de34da3
raw
history blame
22.4 kB
import nltk
import einops
import torch
import torch.nn.functional as F
import torchvision.utils
from torch_kmeans import KMeans
import os
import injection_utils
import utils
class BoundedAttention(injection_utils.AttentionBase):
EPSILON = 1e-5
FILTER_TAGS = {
'CC', 'CD', 'DT', 'EX', 'IN', 'LS', 'MD', 'PDT', 'POS', 'PRP', 'PRP$', 'RP', 'TO', 'UH', 'WDT', 'WP', 'WRB'}
TAG_RULES = {'left': 'IN', 'right': 'IN', 'top': 'IN', 'bottom': 'IN'}
def __init__(
self,
boxes,
prompts,
subject_token_indices,
cross_loss_layers,
self_loss_layers,
cross_mask_layers=None,
self_mask_layers=None,
eos_token_index=None,
filter_token_indices=None,
leading_token_indices=None,
mask_cross_during_guidance=True,
mask_eos=True,
cross_loss_coef=1,
self_loss_coef=1,
max_guidance_iter=15,
max_guidance_iter_per_step=5,
start_step_size=30,
end_step_size=10,
loss_stopping_value=0.2,
cross_mask_threshold=0.2,
self_mask_threshold=0.2,
delta_refine_mask_steps=5,
pca_rank=None,
num_clusters=None,
num_clusters_per_box=3,
map_dir=None,
debug=False,
delta_debug_attention_steps=20,
delta_debug_mask_steps=5,
debug_layers=None,
saved_resolution=64,
):
super().__init__()
self.boxes = boxes
self.prompts = prompts
self.subject_token_indices = subject_token_indices
self.cross_loss_layers = set(cross_loss_layers)
self.self_loss_layers = set(self_loss_layers)
self.cross_mask_layers = self.cross_loss_layers if cross_mask_layers is None else set(cross_mask_layers)
self.self_mask_layers = self.self_loss_layers if self_mask_layers is None else set(self_mask_layers)
self.eos_token_index = eos_token_index
self.filter_token_indices = filter_token_indices
self.leading_token_indices = leading_token_indices
self.mask_cross_during_guidance = mask_cross_during_guidance
self.mask_eos = mask_eos
self.cross_loss_coef = cross_loss_coef
self.self_loss_coef = self_loss_coef
self.max_guidance_iter = max_guidance_iter
self.max_guidance_iter_per_step = max_guidance_iter_per_step
self.start_step_size = start_step_size
self.step_size_coef = (end_step_size - start_step_size) / max_guidance_iter
self.loss_stopping_value = loss_stopping_value
self.cross_mask_threshold = cross_mask_threshold
self.self_mask_threshold = self_mask_threshold
self.delta_refine_mask_steps = delta_refine_mask_steps
self.pca_rank = pca_rank
num_clusters = len(boxes) * num_clusters_per_box if num_clusters is None else num_clusters
self.clustering = KMeans(n_clusters=num_clusters, num_init=100)
self.centers = None
self.map_dir = map_dir
self.debug = debug
self.delta_debug_attention_steps = delta_debug_attention_steps
self.delta_debug_mask_steps = delta_debug_mask_steps
self.debug_layers = self.cross_loss_layers | self.self_loss_layers if debug_layers is None else debug_layers
self.saved_resolution = saved_resolution
self.optimized = False
self.cross_foreground_values = []
self.self_foreground_values = []
self.cross_background_values = []
self.self_background_values = []
self.cross_maps = []
self.self_maps = []
self.self_masks = None
def clear_values(self, include_maps=False):
lists = (
self.cross_foreground_values,
self.self_foreground_values,
self.cross_background_values,
self.self_background_values,
)
if include_maps:
lists = (
*all_values,
self.cross_maps,
self.self_maps,
)
for values in lists:
values.clear()
def before_step(self):
self.clear_values()
if self.cur_step == 0:
self._determine_tokens()
def reset(self):
self.clear_values(include_maps=True)
super().reset()
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
self._display_attention_maps(attn, is_cross, num_heads)
_, n, d = sim.shape
sim_u, sim_c = sim.reshape(-1, num_heads, n, d).chunk(2) # b h n d
if is_cross:
sim_c = self._hide_other_subjects_from_tokens(sim_c)
else:
sim_u = self._hide_other_subjects_from_subjects(sim_u)
sim_c = self._hide_other_subjects_from_subjects(sim_c)
sim = torch.cat((sim_u, sim_c)).reshape(-1, n, d)
attn = sim.softmax(-1)
self._save(attn, is_cross, num_heads)
self._display_attention_maps(attn, is_cross, num_heads, prefix='masked')
self._debug_hook(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
out = torch.bmm(attn, v)
out = einops.rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
return out
def update_loss(self, forward_pass, latents, i):
if i >= self.max_guidance_iter:
return latents
step_size = self.start_step_size + self.step_size_coef * i
updated_latents = latents
self.optimized = True
normalized_loss = torch.tensor(10000)
with torch.enable_grad():
latents = updated_latents = updated_latents.clone().detach().requires_grad_(True)
for guidance_iter in range(self.max_guidance_iter_per_step):
if normalized_loss < self.loss_stopping_value:
break
latent_model_input = torch.cat([latents] * 2)
cur_step = self.cur_step
forward_pass(latent_model_input)
self.cur_step = cur_step
loss, normalized_loss = self._compute_loss()
grad_cond = torch.autograd.grad(loss, [updated_latents])[0]
latents = updated_latents = updated_latents - step_size * grad_cond
if self.debug:
print(f'Loss at step={i}, iter={guidance_iter}: {normalized_loss}')
grad_norms = grad_cond.flatten(start_dim=2).norm(dim=1)
grad_norms = grad_norms / grad_norms.max(dim=1, keepdim=True)[0]
self._save_maps(grad_norms, 'grad_norms')
self.optimized = False
return latents
def _tokenize(self):
ids = self.model.tokenizer.encode(self.prompts[0])
tokens = self.model.tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=True)
return [token[:-4] for token in tokens] # remove ending </w>
def _tag_tokens(self):
tagged_tokens = nltk.pos_tag(self._tokenize())
return [type(self).TAG_RULES.get(token, tag) for token, tag in tagged_tokens]
def _determine_eos_token(self):
tokens = self._tokenize()
eos_token_index = len(tokens) + 1
if self.eos_token_index is None:
self.eos_token_index = eos_token_index
elif eos_token_index != self.eos_token_index:
raise ValueError(f'Wrong EOS token index. Tokens are: {tokens}.')
def _determine_filter_tokens(self):
if self.filter_token_indices is not None:
return
tags = self._tag_tokens()
self.filter_token_indices = [i + 1 for i, tag in enumerate(tags) if tag in type(self).FILTER_TAGS]
def _determine_leading_tokens(self):
if self.leading_token_indices is not None:
return
tags = self._tag_tokens()
leading_token_indices = []
for indices in self.subject_token_indices:
subject_noun_indices = [i for i in indices if tags[i - 1].startswith('NN')]
leading_token_candidates = subject_noun_indices if len(subject_noun_indices) > 0 else indices
leading_token_indices.append(leading_token_candidates[-1])
self.leading_token_indices = leading_token_indices
def _determine_tokens(self):
self._determine_eos_token()
self._determine_filter_tokens()
self._determine_leading_tokens()
def _split_references(self, tensor, num_heads):
tensor = tensor.reshape(-1, num_heads, *tensor.shape[1:])
unconditional, conditional = tensor.chunk(2)
num_subjects = len(self.boxes)
batch_unconditional = unconditional[:-num_subjects]
references_unconditional = unconditional[-num_subjects:]
batch_conditional = conditional[:-num_subjects]
references_conditional = conditional[-num_subjects:]
batch = torch.cat((batch_unconditional, batch_conditional))
references = torch.cat((references_unconditional, references_conditional))
batch = batch.reshape(-1, *batch_unconditional.shape[2:])
references = references.reshape(-1, *references_unconditional.shape[2:])
return batch, references
def _hide_other_subjects_from_tokens(self, sim): # b h i j
dtype = sim.dtype
device = sim.device
batch_size = sim.size(0)
resolution = int(sim.size(2) ** 0.5)
subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
include_background = self.optimized or (not self.mask_cross_during_guidance and self.cur_step < self.max_guidance_iter_per_step)
subject_masks = torch.logical_or(subject_masks, background_masks.unsqueeze(1)) if include_background else subject_masks
min_value = torch.finfo(sim.dtype).min
sim_masks = torch.zeros_like(sim[:, 0, :, :]) # b i j
for token_indices in (*self.subject_token_indices, self.filter_token_indices):
sim_masks[:, :, token_indices] = min_value
for batch_index in range(batch_size):
for subject_mask, token_indices in zip(subject_masks[batch_index], self.subject_token_indices):
for token_index in token_indices:
sim_masks[batch_index, subject_mask, token_index] = 0
if self.mask_eos and not include_background:
for batch_index, background_mask in zip(range(batch_size), background_masks):
sim_masks[batch_index, background_mask, self.eos_token_index] = min_value
return sim + sim_masks.unsqueeze(1)
def _hide_other_subjects_from_subjects(self, sim): # b h i j
dtype = sim.dtype
device = sim.device
batch_size = sim.size(0)
resolution = int(sim.size(2) ** 0.5)
subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
min_value = torch.finfo(dtype).min
sim_masks = torch.zeros_like(sim[:, 0, :, :]) # b i j
for batch_index, background_mask in zip(range(batch_size), background_masks):
sim_masks[batch_index, ~background_mask, ~background_mask] = min_value
for batch_index in range(batch_size):
for subject_mask in subject_masks[batch_index]:
subject_sim_mask = sim_masks[batch_index, subject_mask]
condition = torch.logical_or(subject_sim_mask == 0, subject_mask.unsqueeze(0))
sim_masks[batch_index, subject_mask] = torch.where(condition, 0, min_value).to(dtype=dtype)
return sim + sim_masks.unsqueeze(1)
def _save(self, attn, is_cross, num_heads):
_, attn = attn.chunk(2)
attn = attn.reshape(-1, num_heads, *attn.shape[-2:]) # b h n k
self._save_mask_maps(attn, is_cross)
self._save_loss_values(attn, is_cross)
def _save_mask_maps(self, attn, is_cross):
if (
(self.optimized) or
(is_cross and self.cur_att_layer not in self.cross_mask_layers) or
((not is_cross) and (self.cur_att_layer not in self.self_mask_layers))
):
return
if is_cross:
attn = attn[..., self.leading_token_indices]
mask_maps = self.cross_maps
else:
mask_maps = self.self_maps
mask_maps.append(attn.mean(dim=1)) # mean over heads
if self.cur_step > 0:
mask_maps = mask_maps[1:] # throw away old maps
def _save_loss_values(self, attn, is_cross):
if (
(not self.optimized) or
(is_cross and (self.cur_att_layer not in self.cross_loss_layers)) or
((not is_cross) and (self.cur_att_layer not in self.self_loss_layers))
):
return
resolution = int(attn.size(2) ** 0.5)
boxes = self._convert_boxes_to_masks(resolution, device=attn.device) # s n
background_mask = boxes.sum(dim=0) == 0
if is_cross:
saved_foreground_values = self.cross_foreground_values
saved_background_values = self.cross_background_values
contexts = [indices + [self.eos_token_index] for indices in self.subject_token_indices] # TODO: fix EOS loss term
else:
saved_foreground_values = self.self_foreground_values
saved_background_values = self.self_background_values
contexts = boxes
foreground_values = []
background_values = []
for i, (box, context) in enumerate(zip(boxes, contexts)):
context_attn = attn[:, :, :, context]
# sum over heads, pixels and contexts
foreground_values.append(context_attn[:, :, box].sum(dim=(1, 2, 3)))
background_values.append(context_attn[:, :, background_mask].sum(dim=(1, 2, 3)))
saved_foreground_values.append(torch.stack(foreground_values, dim=1))
saved_background_values.append(torch.stack(background_values, dim=1))
def _compute_loss(self):
cross_losses = self._compute_loss_term(self.cross_foreground_values, self.cross_background_values)
self_losses = self._compute_loss_term(self.self_foreground_values, self.self_background_values)
b, s = cross_losses.shape
# sum over samples and subjects
total_cross_loss = cross_losses.sum()
total_self_loss = self_losses.sum()
loss = self.cross_loss_coef * total_cross_loss + self.self_loss_coef * total_self_loss
normalized_loss = loss / b / s
return loss, normalized_loss
def _compute_loss_term(self, foreground_values, background_values):
# mean over layers
mean_foreground = torch.stack(foreground_values).mean(dim=0)
mean_background = torch.stack(background_values).mean(dim=0)
iou = mean_foreground / (mean_foreground + len(self.boxes) * mean_background)
return (1 - iou) ** 2
def _obtain_masks(self, resolution, return_boxes=False, return_existing=False, batch_size=None, device=None):
return_boxes = return_boxes or (return_existing and self.self_masks is None)
if return_boxes or self.cur_step < self.max_guidance_iter:
masks = self._convert_boxes_to_masks(resolution, device=device).unsqueeze(0)
if batch_size is not None:
masks = masks.expand(batch_size, *masks.shape[1:])
else:
masks = self._obtain_self_masks(resolution, return_existing=return_existing)
if device is not None:
masks = masks.to(device=device)
background_mask = masks.sum(dim=1) == 0
return masks, background_mask
def _convert_boxes_to_masks(self, resolution, device=None): # s n
boxes = torch.zeros(len(self.boxes), resolution, resolution, dtype=bool, device=device)
for i, box in enumerate(self.boxes):
x0, x1 = box[0] * resolution, box[2] * resolution
y0, y1 = box[1] * resolution, box[3] * resolution
boxes[i, round(y0) : round(y1), round(x0) : round(x1)] = True
return boxes.flatten(start_dim=1)
def _obtain_self_masks(self, resolution, return_existing=False):
if (
(self.self_masks is None) or
(
(self.cur_step % self.delta_refine_mask_steps == 0) and
(self.cur_att_layer == 0) and
(not return_existing)
)
):
self.self_masks = self._fix_zero_masks(self._build_self_masks())
b, s, n = self.self_masks.shape
mask_resolution = int(n ** 0.5)
self_masks = self.self_masks.reshape(b, s, mask_resolution, mask_resolution).float()
self_masks = F.interpolate(self_masks, resolution, mode='nearest-exact')
return self_masks.flatten(start_dim=2).bool()
def _cluster_self_maps(self): # b s n
self_maps = self._aggregate_maps(self.self_maps) # b n m
if self.pca_rank is not None:
dtype = self_maps.dtype
_, _, eigen_vectors = torch.pca_lowrank(self_maps.float(), self.pca_rank)
self_maps = torch.matmul(self_maps, eigen_vectors.to(dtype=dtype))
clustering_results = self.clustering(self_maps, centers=self.centers)
self.clustering.num_init = 1 # clustering is deterministic after the first time
self.centers = clustering_results.centers
clusters = clustering_results.labels
num_clusters = self.clustering.n_clusters
self._save_maps(clusters / num_clusters, f'clusters')
return num_clusters, clusters
def _build_self_masks(self):
c, clusters = self._cluster_self_maps() # b n
cluster_masks = torch.stack([(clusters == cluster_index) for cluster_index in range(c)], dim=2) # b n c
cluster_area = cluster_masks.sum(dim=1, keepdim=True) # b 1 c
n = clusters.size(1)
resolution = int(n ** 0.5)
cross_masks = self._obtain_cross_masks(resolution) # b s n
cross_mask_area = cross_masks.sum(dim=2, keepdim=True) # b s 1
intersection = torch.bmm(cross_masks.float(), cluster_masks.float()) # b s c
min_area = torch.minimum(cross_mask_area, cluster_area) # b s c
score_per_cluster, subject_per_cluster = torch.max(intersection / min_area, dim=1) # b c
subjects = torch.gather(subject_per_cluster, 1, clusters) # b n
scores = torch.gather(score_per_cluster, 1, clusters) # b n
s = cross_masks.size(1)
self_masks = torch.stack([(subjects == subject_index) for subject_index in range(s)], dim=1) # b s n
scores = scores.unsqueeze(1).expand(-1 ,s, n) # b s n
self_masks[scores < self.self_mask_threshold] = False
self._save_maps(self_masks, 'self_masks')
return self_masks
def _obtain_cross_masks(self, resolution, scale=10):
maps = self._aggregate_maps(self.cross_maps, resolution=resolution) # b n k
maps = F.sigmoid(scale * (maps - self.cross_mask_threshold))
maps = self._normalize_maps(maps, reduce_min=True)
maps = maps.transpose(1, 2) # b k n
existing_masks, _ = self._obtain_masks(
resolution, return_existing=True, batch_size=maps.size(0), device=maps.device)
maps = maps * existing_masks.to(dtype=maps.dtype)
self._save_maps(maps, 'cross_masks')
return maps
def _fix_zero_masks(self, masks):
b, s, n = masks.shape
resolution = int(n ** 0.5)
boxes = self._convert_boxes_to_masks(resolution, device=masks.device) # s n
for i in range(b):
for j in range(s):
if masks[i, j].sum() == 0:
print('******Found a zero mask!******')
for k in range(s):
masks[i, k] = boxes[j] if (k == j) else masks[i, k].logical_and(~boxes[j])
return masks
def _aggregate_maps(self, maps, resolution=None): # b n k
maps = torch.stack(maps).mean(0) # mean over layers
if resolution is not None:
b, n, k = maps.shape
original_resolution = int(n ** 0.5)
maps = maps.transpose(1, 2).reshape(b, k, original_resolution, original_resolution)
maps = F.interpolate(maps, resolution, mode='bilinear', antialias=True)
maps = maps.reshape(b, k, -1).transpose(1, 2)
maps = self._normalize_maps(maps)
return maps
@classmethod
def _normalize_maps(cls, maps, reduce_min=False): # b n k
max_values = maps.max(dim=1, keepdim=True)[0]
min_values = maps.min(dim=1, keepdim=True)[0] if reduce_min else 0
numerator = maps - min_values
denominator = max_values - min_values + cls.EPSILON
return numerator / denominator
def _save_maps(self, maps, prefix):
if self.map_dir is None or self.cur_step % self.delta_debug_mask_steps != 0:
return
resolution = int(maps.size(-1) ** 0.5)
maps = maps.reshape(-1, 1, resolution, resolution).float()
maps = F.interpolate(maps, self.saved_resolution, mode='bilinear', antialias=True)
path = os.path.join(self.map_dir, f'map_{prefix}_{self.cur_step}_{self.cur_att_layer}.png')
torchvision.utils.save_image(maps, path)
def _display_attention_maps(self, attention_maps, is_cross, num_heads, prefix=None):
if (not self.debug) or (self.cur_step == 0) or (self.cur_step % self.delta_debug_attention_steps > 0) or (self.cur_att_layer not in self.debug_layers):
return
dir_name = self.map_dir
if prefix is not None:
splits = list(os.path.split(dir_name))
splits[-1] = '_'.join((prefix, splits[-1]))
dir_name = os.path.join(*splits)
resolution = int(attention_maps.size(-2) ** 0.5)
if is_cross:
attention_maps = einops.rearrange(attention_maps, 'b (r1 r2) k -> b k r1 r2', r1=resolution)
attention_maps = F.interpolate(attention_maps, self.saved_resolution, mode='bilinear', antialias=True)
attention_maps = einops.rearrange(attention_maps, 'b k r1 r2 -> b (r1 r2) k')
utils.display_attention_maps(
attention_maps,
is_cross,
num_heads,
self.model.tokenizer,
self.prompts,
dir_name,
self.cur_step,
self.cur_att_layer,
resolution,
)
def _debug_hook(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
pass