RA-BART / attention_viz.py
MrVicente's picture
fixed gradio plot issue
c56dde4
raw
history blame contribute delete
No virus
9.27 kB
#############################
# Imports
#############################
# Python modules
# Remote modules
import matplotlib.pyplot as plt
import numpy as np
import torch
# Local modules
#############################
# Constants
#############################
class AttentionVisualizer:
def __init__(self, device):
self.device = device
def visualize_token2token_scores(self, all_tokens,
scores_mat,
useful_indeces,
x_label_name='Head',
apply_normalization=True):
fig = plt.figure(figsize=(20, 20))
all_tokens = np.array(all_tokens)[useful_indeces]
for idx, scores in enumerate(scores_mat):
if apply_normalization:
scores = torch.from_numpy(scores)
shape = scores.shape
scores = scores.reshape((shape[0],shape[1], 1))
scores = torch.linalg.norm(scores, dim=2)
scores_np = np.array(scores)
scores_np = scores_np[useful_indeces, :]
scores_np = scores_np[:, useful_indeces]
ax = fig.add_subplot(4, 4, idx + 1)
# append the attention weights
im = ax.imshow(scores_np, cmap='viridis')
fontdict = {'fontsize': 10}
ax.set_xticks(range(len(all_tokens)))
ax.set_yticks(range(len(all_tokens)))
ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
ax.set_yticklabels(all_tokens, fontdict=fontdict)
ax.set_xlabel('{} {}'.format(x_label_name, idx + 1))
fig.colorbar(im, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()
def visualize_matrix(self,
scores_mat,
label_name='heads_layers'):
_fig = plt.figure(figsize=(20, 20))
scores_np = np.array(scores_mat)
fig, ax = plt.subplots()
im = ax.imshow(scores_np, cmap='viridis')
fontdict = {'fontsize': 10}
ax.set_xticks(range(len(scores_mat[0])))
ax.set_yticks(range(len(scores_mat)))
x_labels = [f'head-{i}' for i in range(1, len(scores_mat[0])+1)]
y_labels = [f'layer-{i}' for i in range(1, len(scores_mat) + 1)]
ax.set_xticklabels(x_labels, fontdict=fontdict, rotation=90)
ax.set_yticklabels(y_labels, fontdict=fontdict)
ax.set_xlabel('{}'.format(label_name))
fig.colorbar(im, fraction=0.046, pad=0.04)
plt.tight_layout()
#plt.show()
plt.savefig(f'figs/{label_name}.png', dpi=fig.dpi)
def visualize_token2head_scores(self, all_tokens, scores_mat):
fig = plt.figure(figsize=(30, 50))
for idx, scores in enumerate(scores_mat):
scores_np = np.array(scores)
ax = fig.add_subplot(6, 3, idx + 1)
# append the attention weights
im = ax.matshow(scores_np, cmap='viridis')
fontdict = {'fontsize': 20}
ax.set_xticks(range(len(all_tokens)))
ax.set_yticks(range(len(scores)))
ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
ax.set_yticklabels(range(len(scores[0])), fontdict=fontdict)
ax.set_xlabel('Layer {}'.format(idx + 1))
fig.colorbar(im, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()
def plot_attn_lines(self, data, heads):
"""Plots attention maps for the given example and attention heads."""
width = 3
example_sep = 3
word_height = 1
pad = 0.1
for ei, (layer, head) in enumerate(heads):
yoffset = 1
xoffset = ei * width * example_sep
attn = data["attns"][layer][head]
attn = np.array(attn)
attn /= attn.sum(axis=-1, keepdims=True)
words = data["tokens"]
words[0] = "..."
n_words = len(words)
for position, word in enumerate(words):
plt.text(xoffset + 0, yoffset - position * word_height, word,
ha="right", va="center")
plt.text(xoffset + width, yoffset - position * word_height, word,
ha="left", va="center")
for i in range(1, n_words):
for j in range(1, n_words):
plt.plot([xoffset + pad, xoffset + width - pad],
[yoffset - word_height * i, yoffset - word_height * j],
color="blue", linewidth=1, alpha=attn[i, j])
def plot_attn_lines_concepts(self, title, examples, layer, head, color_words,
color_from=True, width=3, example_sep=3,
word_height=1, pad=0.1, hide_sep=False):
# examples -> {'words': tokens, 'attentions': [layer][head]}
plt.figure(figsize=(4, 4))
for i, example in enumerate(examples):
yoffset = 0
if i == 0:
yoffset += (len(examples[0]["words"]) -
len(examples[1]["words"])) * word_height / 2
xoffset = i * width * example_sep
attn = example["attentions"][layer][head]
if hide_sep:
attn = np.array(attn)
attn[:, 0] = 0
attn[:, -1] = 0
attn /= attn.sum(axis=-1, keepdims=True)
words = example["words"]
n_words = len(words)
for position, word in enumerate(words):
for x, from_word in [(xoffset, True), (xoffset + width, False)]:
color = "k"
if from_word == color_from and word in color_words:
color = "#cc0000"
plt.text(x, yoffset - (position * word_height), word,
ha="right" if from_word else "left", va="center",
color=color)
for i in range(n_words):
for j in range(n_words):
color = "b"
if words[i if color_from else j] in color_words:
color = "r"
print(attn[i, j])
plt.plot([xoffset + pad, xoffset + width - pad],
[yoffset - word_height * i, yoffset - word_height * j],
color=color, linewidth=1, alpha=attn[i, j])
plt.axis("off")
plt.title(title)
plt.show()
def plot_attn_lines_concepts_ids(self, title, examples, layer, head,
relations_total, width=3, example_sep=3,
word_height=1, pad=0.1, hide_sep=False):
# examples -> {'words': tokens, 'attentions': [layer][head]}
plt.clf()
fig = plt.figure(figsize=(10, 5))
# print('relations_total:', relations_total)
# print(examples[0])
for idx, example in enumerate(examples):
yoffset = 0
if idx == 0:
yoffset += (len(examples[0]["words"]) -
len(examples[0]["words"])) * word_height / 2
xoffset = idx * width * example_sep
attn = example["attentions"][layer][head]
if hide_sep:
attn = np.array(attn)
attn[:, 0] = 0
attn[:, -1] = 0
attn /= attn.sum(axis=-1, keepdims=True)
words = example["words"]
n_words = len(words)
example_rel = relations_total[idx]
for position, word in enumerate(words):
for x, from_word in [(xoffset, True), (xoffset + width, False)]:
color = "k"
for y_idx, y in enumerate(words):
if from_word and example_rel[position, y_idx] > 0:
# print('outgoing', position, y_idx)
color = "r"
if not from_word and example_rel[y_idx, position] > 0:
# print('coming', position, y_idx)
color = "g"
# if from_word == color_from and word in color_words:
# color = "#cc0000"
plt.text(x, yoffset - (position * word_height), word,
ha="right" if from_word else "left", va="center",
color=color)
for i in range(n_words):
for j in range(n_words):
color = "k"
# print(i,j, example_rel[i,j])
if example_rel[i, j].item() > 0 and i <= j:
color = "r"
if example_rel[i, j].item() > 0 and i >= j:
color = "g"
plt.plot([xoffset + pad, xoffset + width - pad],
[yoffset - word_height * i, yoffset - word_height * j],
color=color, linewidth=1, alpha=attn[i, j])
# color=color, linewidth=1, alpha=min(attn[i, j]*10,1))
plt.axis("off")
plt.title(title)
#plt.show()
return fig