Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
def process_attn(attention, rng, attn_func): | |
heatmap = np.zeros((len(attention), attention[0].shape[1])) | |
for i, attn_layer in enumerate(attention): | |
attn_layer = attn_layer.to(torch.float32).numpy() | |
if "sum" in attn_func: | |
last_token_attn_to_inst = np.sum(attn_layer[0, :, -1, rng[0][0]:rng[0][1]], axis=1) | |
attn = last_token_attn_to_inst | |
elif "max" in attn_func: | |
last_token_attn_to_inst = np.max(attn_layer[0, :, -1, rng[0][0]:rng[0][1]], axis=1) | |
attn = last_token_attn_to_inst | |
else: raise NotImplementedError | |
last_token_attn_to_inst_sum = np.sum(attn_layer[0, :, -1, rng[0][0]:rng[0][1]], axis=1) | |
last_token_attn_to_data_sum = np.sum(attn_layer[0, :, -1, rng[1][0]:rng[1][1]], axis=1) | |
if "normalize" in attn_func: | |
epsilon = 1e-8 | |
heatmap[i, :] = attn / (last_token_attn_to_inst_sum + last_token_attn_to_data_sum + epsilon) | |
else: | |
heatmap[i, :] = attn | |
heatmap = np.nan_to_num(heatmap, nan=0.0) | |
return heatmap | |
def calc_attn_score(heatmap, heads): | |
score = np.mean([heatmap[l, h] for l, h in heads], axis=0) | |
return score | |