thesis / explanation /visualize.py
LennardZuendorf's picture
feat: implementing new explanation ui
d2116db unverified
raw
history blame
3.81 kB
# visualization module that creates an attention visualization using BERTViz
# external imports
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
# internal imports
from utils import formatting as fmt
from .markup import markup_text
# plotting function that plots the attention values in a heatmap
def chat_explained(model, prompt):
model.set_config()
# get encoded input and output vectors
encoder_input_ids = model.TOKENIZER(
prompt, return_tensors="pt", add_special_tokens=True
).input_ids
decoder_input_ids = model.MODEL.generate(encoder_input_ids, output_attentions=True)
encoder_text = fmt.format_tokens(
model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
)
decoder_text = fmt.format_tokens(
model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
)
# get attention values for the input and output vectors
attention_output = model.MODEL(
input_ids=encoder_input_ids,
decoder_input_ids=decoder_input_ids,
output_attentions=True,
)
averaged_attention = avg_attention(attention_output)
# create the response text, graphic and plot
response_text = fmt.format_output_text(decoder_text)
plot = create_plot(averaged_attention, (encoder_text, decoder_text))
marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
return response_text, "", plot, marked_text
# creating an attention heatmap plot using matplotlib/seaborn
# CREDIT: adopted from official Matplotlib documentation
## see https://matplotlib.org/stable/
def create_plot(averaged_attention_weights, enc_dec_texts: tuple):
# transpose the attention weights
averaged_attention_weights = np.transpose(averaged_attention_weights)
# get the encoder and decoder tokens in text form
encoder_tokens = enc_dec_texts[0]
decoder_tokens = enc_dec_texts[1]
# set seaborn style to dark and initialize figure and axis
sns.set(style="white")
fig, ax = plt.subplots()
# Setting figure size
fig.set_size_inches(
max(averaged_attention_weights.shape[1] * 2, 10),
max(averaged_attention_weights.shape[0] * 1, 5),
)
# Plotting the heatmap with seaborn's color palette
im = ax.imshow(
averaged_attention_weights,
vmax=averaged_attention_weights.max(),
vmin=-averaged_attention_weights.min(),
cmap=sns.color_palette("rocket", as_cmap=True),
aspect="auto",
)
# Creating colorbar
cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.set_ylabel("Attention Weight Scale", rotation=-90, va="bottom")
cbar.ax.yaxis.set_tick_params(color="black")
plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
# Setting ticks and labels with black color for visibility
ax.set_yticks(np.arange(len(encoder_tokens)), labels=encoder_tokens)
ax.set_xticks(np.arange(len(decoder_tokens)), labels=decoder_tokens)
ax.set_title("Attention Weights by Token")
plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
plt.setp(ax.get_yticklabels(), color="black")
# Adding text annotations with appropriate contrast
for i in range(averaged_attention_weights.shape[0]):
for j in range(averaged_attention_weights.shape[1]):
val = averaged_attention_weights[i, j]
color = (
"white"
if im.norm(averaged_attention_weights.max()) / 2 > im.norm(val)
else "black"
)
ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
# return the plot
return plt
def avg_attention(attention_values):
attention = attention_values.cross_attentions[0][0].detach().numpy()
return np.mean(attention, axis=0)