rich-text-to-image / utils /attention_utils.py
songweig's picture
minor
ff191d6
import numpy as np
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torchvision
from utils.richtext_utils import seed_everything
from sklearn.cluster import KMeans, SpectralClustering
# SelfAttentionLayers = [
# # 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
# # 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
# 'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
# # 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
# 'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
# 'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
# 'mid_block.attentions.0.transformer_blocks.0.attn1',
# 'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
# 'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
# 'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
# # 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
# 'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
# # 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
# # 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
# # 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
# # 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
# ]
SelfAttentionLayers = [
# 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
# 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
# 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
'mid_block.attentions.0.transformer_blocks.0.attn1',
'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
# 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
# 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
# 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
# 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
# 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
]
CrossAttentionLayers = [
# 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
# 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
# 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
'mid_block.attentions.0.transformer_blocks.0.attn2',
'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
# 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
# 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
# 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
# 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
# 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
]
# CrossAttentionLayers = [
# 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
# 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
# 'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
# 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
# 'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
# 'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
# 'mid_block.attentions.0.transformer_blocks.0.attn2',
# 'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
# 'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
# 'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
# 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
# 'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
# 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
# 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
# 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
# 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
# ]
# CrossAttentionLayers_XL = [
# 'up_blocks.0.attentions.0.transformer_blocks.1.attn2',
# 'up_blocks.0.attentions.0.transformer_blocks.2.attn2',
# 'up_blocks.0.attentions.0.transformer_blocks.3.attn2',
# 'up_blocks.0.attentions.0.transformer_blocks.4.attn2',
# 'up_blocks.0.attentions.0.transformer_blocks.5.attn2',
# 'up_blocks.0.attentions.0.transformer_blocks.6.attn2',
# 'up_blocks.0.attentions.0.transformer_blocks.7.attn2',
# ]
CrossAttentionLayers_XL = [
'down_blocks.2.attentions.1.transformer_blocks.3.attn2',
'down_blocks.2.attentions.1.transformer_blocks.4.attn2',
'mid_block.attentions.0.transformer_blocks.0.attn2',
'mid_block.attentions.0.transformer_blocks.1.attn2',
'mid_block.attentions.0.transformer_blocks.2.attn2',
'mid_block.attentions.0.transformer_blocks.3.attn2',
'up_blocks.0.attentions.0.transformer_blocks.1.attn2',
'up_blocks.0.attentions.0.transformer_blocks.2.attn2',
'up_blocks.0.attentions.0.transformer_blocks.3.attn2',
'up_blocks.0.attentions.0.transformer_blocks.4.attn2',
'up_blocks.0.attentions.0.transformer_blocks.5.attn2',
'up_blocks.0.attentions.0.transformer_blocks.6.attn2',
'up_blocks.0.attentions.0.transformer_blocks.7.attn2',
'up_blocks.1.attentions.0.transformer_blocks.0.attn2'
]
def split_attention_maps_over_steps(attention_maps):
r"""Function for splitting attention maps over steps.
Args:
attention_maps (dict): Dictionary of attention maps.
sampler_order (int): Order of the sampler.
"""
# This function splits attention maps into unconditional and conditional score and over steps
attention_maps_cond = dict() # Maps corresponding to conditional score
attention_maps_uncond = dict() # Maps corresponding to unconditional score
for layer in attention_maps.keys():
for step_num in range(len(attention_maps[layer])):
if step_num not in attention_maps_cond:
attention_maps_cond[step_num] = dict()
attention_maps_uncond[step_num] = dict()
attention_maps_uncond[step_num].update(
{layer: attention_maps[layer][step_num][:1]})
attention_maps_cond[step_num].update(
{layer: attention_maps[layer][step_num][1:2]})
return attention_maps_cond, attention_maps_uncond
def save_attention_heatmaps(attention_maps, tokens_vis, save_dir, prefix):
r"""Function to plot heatmaps for attention maps.
Args:
attention_maps (dict): Dictionary of attention maps per layer
save_dir (str): Directory to save attention maps
prefix (str): Filename prefix for html files
Returns:
Heatmaps, one per sample.
"""
html_names = []
idx = 0
html_list = []
for layer in attention_maps.keys():
if idx == 0:
# import ipdb;ipdb.set_trace()
# create a set of html files.
batch_size = attention_maps[layer].shape[0]
for sample_num in range(batch_size):
# html path
html_rel_path = os.path.join('sample_{}'.format(
sample_num), '{}.html'.format(prefix))
html_names.append(html_rel_path)
html_path = os.path.join(save_dir, html_rel_path)
os.makedirs(os.path.dirname(html_path), exist_ok=True)
html_list.append(open(html_path, 'wt'))
html_list[sample_num].write(
'<html><head></head><body><table>\n')
for sample_num in range(batch_size):
save_path = os.path.join(save_dir, 'sample_{}'.format(sample_num),
prefix, 'layer_{}'.format(layer)) + '.jpg'
Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)
layer_name = 'layer_{}'.format(layer)
html_list[sample_num].write(
f'<tr><td><h1>{layer_name}</h1></td></tr>\n')
prefix_stem = prefix.split('/')[-1]
relative_image_path = os.path.join(
prefix_stem, 'layer_{}'.format(layer)) + '.jpg'
html_list[sample_num].write(
f'<tr><td><img src=\"{relative_image_path}\"></td></tr>\n')
plt.figure()
plt.clf()
nrows = 2
ncols = 7
fig, axs = plt.subplots(nrows=nrows, ncols=ncols)
fig.set_figheight(8)
fig.set_figwidth(28.5)
# axs[0].set_aspect('equal')
# axs[1].set_aspect('equal')
# axs[2].set_aspect('equal')
# axs[3].set_aspect('equal')
# axs[4].set_aspect('equal')
# axs[5].set_aspect('equal')
cmap = plt.get_cmap('YlOrRd')
for rid in range(nrows):
for cid in range(ncols):
tid = rid*ncols + cid
# import ipdb;ipdb.set_trace()
attention_map_cur = attention_maps[layer][sample_num, :, :, tid].numpy(
)
vmax = float(attention_map_cur.max())
vmin = float(attention_map_cur.min())
sns.heatmap(
attention_map_cur, annot=False, cbar=False, ax=axs[rid, cid],
cmap=cmap, vmin=vmin, vmax=vmax
)
axs[rid, cid].set_xlabel(tokens_vis[tid])
# axs[0].set_xlabel('Self attention')
# axs[1].set_xlabel('Temporal attention')
# axs[2].set_xlabel('T5 text attention')
# axs[3].set_xlabel('CLIP text attention')
# axs[4].set_xlabel('CLIP image attention')
# axs[5].set_xlabel('Null text token')
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
# fig.colorbar(sm, cax=axs[6])
fig.tight_layout()
plt.savefig(save_path, dpi=64)
plt.close('all')
if idx == (len(attention_maps.keys()) - 1):
for sample_num in range(batch_size):
html_list[sample_num].write('</table></body></html>')
html_list[sample_num].close()
idx += 1
return html_names
def create_recursive_html_link(html_path, save_dir):
r"""Function for creating recursive html links.
If the path is dir1/dir2/dir3/*.html,
we create chained directories
-dir1
dir1.html (has links to all children)
-dir2
dir2.html (has links to all children)
-dir3
dir3.html
Args:
html_path (str): Path to html file.
save_dir (str): Save directory.
"""
html_path_split = os.path.splitext(html_path)[0].split('/')
if len(html_path_split) == 1:
return
# First create the root directory
root_dir = html_path_split[0]
child_dir = html_path_split[1]
cur_html_path = os.path.join(save_dir, '{}.html'.format(root_dir))
if os.path.exists(cur_html_path):
fp = open(cur_html_path, 'r')
lines_written = fp.readlines()
fp.close()
fp = open(cur_html_path, 'a+')
child_path = os.path.join(root_dir, f'{child_dir}.html')
line_to_write = f'<tr><td><a href=\"{child_path}\">{child_dir}</a></td></tr>\n'
if line_to_write not in lines_written:
fp.write('<html><head></head><body><table>\n')
fp.write(line_to_write)
fp.write('</table></body></html>')
fp.close()
else:
fp = open(cur_html_path, 'w')
child_path = os.path.join(root_dir, f'{child_dir}.html')
line_to_write = f'<tr><td><a href=\"{child_path}\">{child_dir}</a></td></tr>\n'
fp.write('<html><head></head><body><table>\n')
fp.write(line_to_write)
fp.write('</table></body></html>')
fp.close()
child_path = '/'.join(html_path.split('/')[1:])
save_dir = os.path.join(save_dir, root_dir)
create_recursive_html_link(child_path, save_dir)
def visualize_attention_maps(attention_maps_all, save_dir, width, height, tokens_vis):
r"""Function to visualize attention maps.
Args:
save_dir (str): Path to save attention maps
batch_size (int): Batch size
sampler_order (int): Sampler order
"""
rand_name = list(attention_maps_all.keys())[0]
nsteps = len(attention_maps_all[rand_name])
hw_ori = width * height
# html_path = save_dir + '.html'
text_input = save_dir.split('/')[-1]
# f = open(html_path, 'wt')
all_html_paths = []
for step_num in range(0, nsteps, 5):
# if cond_id == 'cond':
# attention_maps_cur = attention_maps_cond[step_num]
# else:
# attention_maps_cur = attention_maps_uncond[step_num]
attention_maps = dict()
for layer in attention_maps_all.keys():
attention_ind = attention_maps_all[layer][step_num].cpu()
# Attention maps are of shape [batch_size, nkeys, 77]
# since they are averaged out while collecting from hooks to save memory.
# Now split the heads from batch dimension
bs, hw, nclip = attention_ind.shape
down_ratio = np.sqrt(hw_ori // hw)
width_cur = int(width // down_ratio)
height_cur = int(height // down_ratio)
attention_ind = attention_ind.reshape(
bs, height_cur, width_cur, nclip)
attention_maps[layer] = attention_ind
# Obtain heatmaps corresponding to random heads and individual heads
html_names = save_attention_heatmaps(
attention_maps, tokens_vis, save_dir=save_dir, prefix='step_{}/attention_maps_cond'.format(
step_num)
)
# Write the logic for recursively creating pages
for html_name_cur in html_names:
all_html_paths.append(os.path.join(text_input, html_name_cur))
save_dir_root = '/'.join(save_dir.split('/')[0:-1])
for html_pth in all_html_paths:
create_recursive_html_link(html_pth, save_dir_root)
def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=None):
for i, attn_map in enumerate(atten_map_list):
n_obj = len(attn_map)
plt.figure()
plt.clf()
fig, axs = plt.subplots(
ncols=n_obj+1, gridspec_kw=dict(width_ratios=[1 for _ in range(n_obj)]+[0.1]))
fig.set_figheight(3)
fig.set_figwidth(3*n_obj+0.1)
cmap = plt.get_cmap('YlOrRd')
vmax = 0
vmin = 1
for tid in range(n_obj):
attention_map_cur = attn_map[tid]
vmax = max(vmax, float(attention_map_cur.max()))
vmin = min(vmin, float(attention_map_cur.min()))
for tid in range(n_obj):
sns.heatmap(
attn_map[tid][0], annot=False, cbar=False, ax=axs[tid],
cmap=cmap, vmin=vmin, vmax=vmax
)
axs[tid].set_axis_off()
if tokens_vis is not None:
if tid == n_obj-1:
axs_xlabel = 'other tokens'
else:
axs_xlabel = ''
for token_id in obj_tokens[tid]:
axs_xlabel += ' ' + tokens_vis[token_id.item() -
1][:-len('</w>')]
axs[tid].set_title(axs_xlabel)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
fig.colorbar(sm, cax=axs[-1])
fig.tight_layout()
canvas = fig.canvas
canvas.draw()
width, height = canvas.get_width_height()
img = np.frombuffer(canvas.tostring_rgb(),
dtype='uint8').reshape((height, width, 3))
plt.savefig(os.path.join(
save_dir, 'average_seed%d_attn%d.jpg' % (seed, i)), dpi=100)
plt.close('all')
return img
def get_average_attention_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None,
preprocess=False):
r"""Function to visualize attention maps.
Args:
save_dir (str): Path to save attention maps
batch_size (int): Batch size
sampler_order (int): Sampler order
"""
# Split attention maps over steps
attention_maps_cond, _ = split_attention_maps_over_steps(
attention_maps
)
nsteps = len(attention_maps_cond)
hw_ori = width * height
attention_maps = []
for obj_token in obj_tokens:
attention_maps.append([])
for step_num in range(nsteps):
attention_maps_cur = attention_maps_cond[step_num]
for layer in attention_maps_cur.keys():
if step_num < 10 or layer not in CrossAttentionLayers:
continue
attention_ind = attention_maps_cur[layer].cpu()
# Attention maps are of shape [batch_size, nkeys, 77]
# since they are averaged out while collecting from hooks to save memory.
# Now split the heads from batch dimension
bs, hw, nclip = attention_ind.shape
down_ratio = np.sqrt(hw_ori // hw)
width_cur = int(width // down_ratio)
height_cur = int(height // down_ratio)
attention_ind = attention_ind.reshape(
bs, height_cur, width_cur, nclip)
for obj_id, obj_token in enumerate(obj_tokens):
if obj_token[0] == -1:
attention_map_prev = torch.stack(
[attention_maps[i][-1] for i in range(obj_id)]).sum(0)
attention_maps[obj_id].append(
attention_map_prev.max()-attention_map_prev)
else:
obj_attention_map = attention_ind[:, :, :, obj_token].max(-1, True)[
0].permute([3, 0, 1, 2])
# obj_attention_map = attention_ind[:, :, :, obj_token].mean(-1, True).permute([3, 0, 1, 2])
obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
attention_maps[obj_id].append(obj_attention_map)
attention_maps_averaged = []
for obj_id, obj_token in enumerate(obj_tokens):
if obj_id == len(obj_tokens) - 1:
attention_maps_averaged.append(
torch.cat(attention_maps[obj_id]).mean(0))
else:
attention_maps_averaged.append(
torch.cat(attention_maps[obj_id]).mean(0))
attention_maps_averaged_normalized = []
attention_maps_averaged_sum = torch.cat(attention_maps_averaged).sum(0)
for obj_id, obj_token in enumerate(obj_tokens):
attention_maps_averaged_normalized.append(
attention_maps_averaged[obj_id]/attention_maps_averaged_sum)
if obj_tokens[-1][0] != -1:
attention_maps_averaged_normalized = (
torch.cat(attention_maps_averaged)/0.001).softmax(0)
attention_maps_averaged_normalized = [
attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]
if preprocess:
selem = square(5)
selem = square(3)
selem = square(1)
attention_maps_averaged_eroded = [erosion(skimage.img_as_float(
map[0].numpy()*255), selem) for map in attention_maps_averaged_normalized[:2]]
attention_maps_averaged_eroded = [(torch.from_numpy(map).unsqueeze(
0)/255. > 0.8).float() for map in attention_maps_averaged_eroded]
attention_maps_averaged_eroded.append(
1 - torch.cat(attention_maps_averaged_eroded).sum(0, True))
plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized,
attention_maps_averaged_eroded], obj_tokens, save_dir, seed, tokens_vis)
attention_maps_averaged_eroded = [attn_mask.unsqueeze(1).repeat(
[1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_eroded]
return attention_maps_averaged_eroded
else:
plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
obj_tokens, save_dir, seed, tokens_vis)
attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
[1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
return attention_maps_averaged_normalized
def get_average_attention_maps_threshold(attention_maps, save_dir, width, height, obj_tokens, seed=0, threshold=0.02):
r"""Function to visualize attention maps.
Args:
save_dir (str): Path to save attention maps
batch_size (int): Batch size
sampler_order (int): Sampler order
"""
_EPS = 1e-8
# Split attention maps over steps
attention_maps_cond, _ = split_attention_maps_over_steps(
attention_maps
)
nsteps = len(attention_maps_cond)
hw_ori = width * height
attention_maps = []
for obj_token in obj_tokens:
attention_maps.append([])
# for each side prompt, get attention maps for all steps and all layers
for step_num in range(nsteps):
attention_maps_cur = attention_maps_cond[step_num]
for layer in attention_maps_cur.keys():
attention_ind = attention_maps_cur[layer].cpu()
bs, hw, nclip = attention_ind.shape
down_ratio = np.sqrt(hw_ori // hw)
width_cur = int(width // down_ratio)
height_cur = int(height // down_ratio)
attention_ind = attention_ind.reshape(
bs, height_cur, width_cur, nclip)
for obj_id, obj_token in enumerate(obj_tokens):
if attention_ind.shape[1] > width//2:
continue
if obj_token[0] != -1:
obj_attention_map = attention_ind[:, :, :,
obj_token].mean(-1, True).permute([3, 0, 1, 2])
obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
attention_maps[obj_id].append(obj_attention_map)
# average of all steps and layers, thresholding
attention_maps_thres = []
attention_maps_averaged = []
for obj_id, obj_token in enumerate(obj_tokens):
if obj_token[0] != -1:
average_map = torch.cat(attention_maps[obj_id]).mean(0)
attention_maps_averaged.append(average_map)
attention_maps_thres.append((average_map > threshold).float())
# get the remaining region except for the original prompt
attention_maps_averaged_normalized = []
attention_maps_averaged_sum = torch.cat(attention_maps_thres).sum(0) + _EPS
for obj_id, obj_token in enumerate(obj_tokens):
if obj_token[0] != -1:
attention_maps_averaged_normalized.append(
attention_maps_thres[obj_id]/attention_maps_averaged_sum)
else:
attention_map_prev = torch.stack(
attention_maps_averaged_normalized).sum(0)
attention_maps_averaged_normalized.append(1.-attention_map_prev)
plot_attention_maps(
[attention_maps_averaged, attention_maps_averaged_normalized], save_dir, seed)
attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
[1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
# attention_maps_averaged_normalized = attention_maps_averaged_normalized.unsqueeze(1).repeat([1, 4, 1, 1]).cuda()
return attention_maps_averaged_normalized
def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, height, obj_tokens, kmeans_seed=0, tokens_vis=None,
preprocess=False, segment_threshold=0.3, num_segments=5, return_vis=False, save_attn=False):
r"""Function to visualize attention maps.
Args:
save_dir (str): Path to save attention maps
batch_size (int): Batch size
sampler_order (int): Sampler order
"""
resolution = 32
# attn_maps_1024 = [attn_map for attn_map in selfattn_maps.values(
# ) if attn_map.shape[1] == resolution**2]
# attn_maps_1024 = torch.cat(attn_maps_1024).mean(0).cpu().numpy()
attn_maps_1024 = {8: [], 16: [], 32: [], 64: []}
for attn_map in selfattn_maps.values():
resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
if resolution_map != resolution:
continue
# attn_map = torch.nn.functional.interpolate(rearrange(attn_map, '1 c (h w) -> 1 c h w', h=resolution_map), (resolution, resolution),
# mode='bicubic', antialias=True)
# attn_map = rearrange(attn_map, '1 (h w) a b -> 1 (a b) h w', h=resolution_map)
attn_map = attn_map.reshape(
1, resolution_map, resolution_map, resolution_map**2).permute([3, 0, 1, 2]).float()
attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
mode='bicubic', antialias=True)
attn_maps_1024[resolution_map].append(attn_map.permute([1, 2, 3, 0]).reshape(
1, resolution**2, resolution_map**2))
attn_maps_1024 = torch.cat([torch.cat(v).mean(0).cpu()
for v in attn_maps_1024.values() if len(v) > 0], -1).numpy()
if save_attn:
print('saving self-attention maps...', attn_maps_1024.shape)
torch.save(torch.from_numpy(attn_maps_1024),
'results/maps/selfattn_maps.pth')
seed_everything(kmeans_seed)
# import ipdb;ipdb.set_trace()
# kmeans = KMeans(n_clusters=num_segments,
# n_init=10).fit(attn_maps_1024)
# clusters = kmeans.labels_
# clusters = clusters.reshape(resolution, resolution)
# mesh = np.array(np.meshgrid(range(resolution), range(resolution), indexing='ij'), dtype=np.float32)/resolution
# dists = mesh.reshape(2, -1).T
# delta = 0.01
# spatial_sim = rbf_kernel(dists, dists)*delta
sc = SpectralClustering(num_segments, affinity='precomputed', n_init=100,
assign_labels='kmeans')
clusters = sc.fit_predict(attn_maps_1024)
clusters = clusters.reshape(resolution, resolution)
fig = plt.figure()
plt.imshow(clusters)
plt.axis('off')
plt.savefig(os.path.join(save_dir, 'segmentation_k%d_seed%d.jpg' % (num_segments, kmeans_seed)),
bbox_inches='tight', pad_inches=0)
if return_vis:
canvas = fig.canvas
canvas.draw()
cav_width, cav_height = canvas.get_width_height()
segments_vis = np.frombuffer(canvas.tostring_rgb(),
dtype='uint8').reshape((cav_height, cav_width, 3))
plt.close()
# label the segmentation mask using cross-attention maps
cross_attn_maps_1024 = []
for attn_map in crossattn_maps.values():
resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
# if resolution_map != 16:
# continue
attn_map = attn_map.reshape(
1, resolution_map, resolution_map, -1).permute([0, 3, 1, 2]).float()
attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
mode='bicubic', antialias=True)
cross_attn_maps_1024.append(attn_map.permute([0, 2, 3, 1]))
cross_attn_maps_1024 = torch.cat(
cross_attn_maps_1024).mean(0).cpu().numpy()
normalized_span_maps = []
for token_ids in obj_tokens:
token_ids = [token_id for token_id in token_ids if token_id < 77]
span_token_maps = cross_attn_maps_1024[:, :, token_ids.numpy()]
normalized_span_map = np.zeros_like(span_token_maps)
for i in range(span_token_maps.shape[-1]):
curr_noun_map = span_token_maps[:, :, i]
normalized_span_map[:, :, i] = (
# curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
curr_noun_map - np.abs(curr_noun_map.min())) / (curr_noun_map.max()-curr_noun_map.min())
normalized_span_maps.append(normalized_span_map)
foreground_token_maps = [np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze(
) for normalized_span_map in normalized_span_maps]
background_map = np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze()
for c in range(num_segments):
cluster_mask = np.zeros_like(clusters)
cluster_mask[clusters == c] = 1.
is_foreground = False
for normalized_span_map, foreground_nouns_map, token_ids in zip(normalized_span_maps, foreground_token_maps, obj_tokens):
score_maps = [cluster_mask * normalized_span_map[:, :, i]
for i in range(len(token_ids))]
scores = [score_map.sum() / cluster_mask.sum()
for score_map in score_maps]
if max(scores) > segment_threshold:
foreground_nouns_map += cluster_mask
is_foreground = True
if not is_foreground:
background_map += cluster_mask
foreground_token_maps.append(background_map)
# resize the token maps and visualization
resized_token_maps = torch.cat([torch.nn.functional.interpolate(torch.from_numpy(token_map).unsqueeze(0).unsqueeze(
0), (height, width), mode='bicubic', antialias=True)[0] for token_map in foreground_token_maps]).clamp(0, 1)
resized_token_maps = resized_token_maps / \
(resized_token_maps.sum(0, True)+1e-8)
resized_token_maps = [token_map.unsqueeze(
0) for token_map in resized_token_maps]
foreground_token_maps = [token_map[None, :, :]
for token_map in foreground_token_maps]
if preprocess:
selem = square(5)
eroded_token_maps = torch.stack([torch.from_numpy(erosion(skimage.img_as_float(
map[0].numpy()*255), selem))/255. for map in resized_token_maps[:-1]]).clamp(0, 1)
# import ipdb; ipdb.set_trace()
eroded_background_maps = (1-eroded_token_maps.sum(0, True)).clamp(0, 1)
eroded_token_maps = torch.cat([eroded_token_maps, eroded_background_maps])
eroded_token_maps = eroded_token_maps / (eroded_token_maps.sum(0, True)+1e-8)
resized_token_maps = [token_map.unsqueeze(
0) for token_map in eroded_token_maps]
token_maps_vis = plot_attention_maps([foreground_token_maps, resized_token_maps], obj_tokens,
save_dir, kmeans_seed, tokens_vis)
resized_token_maps = [token_map.unsqueeze(1).repeat(
[1, 4, 1, 1]).to(attn_map.dtype).cuda() for token_map in resized_token_maps]
if return_vis:
return resized_token_maps, segments_vis, token_maps_vis
else:
return resized_token_maps