Spaces:
Runtime error
Runtime error
# visualization module that creates an attention visualization using BERTViz | |
# external imports | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
# internal imports | |
from utils import formatting as fmt | |
from .markup import markup_text | |
# plotting function that plots the attention values in a heatmap | |
def chat_explained(model, prompt): | |
model.set_config() | |
# get encoded input and output vectors | |
encoder_input_ids = model.TOKENIZER( | |
prompt, return_tensors="pt", add_special_tokens=True | |
).input_ids | |
decoder_input_ids = model.MODEL.generate(encoder_input_ids, output_attentions=True) | |
encoder_text = fmt.format_tokens( | |
model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0]) | |
) | |
decoder_text = fmt.format_tokens( | |
model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0]) | |
) | |
# get attention values for the input and output vectors | |
attention_output = model.MODEL( | |
input_ids=encoder_input_ids, | |
decoder_input_ids=decoder_input_ids, | |
output_attentions=True, | |
) | |
averaged_attention = avg_attention(attention_output) | |
# create the response text, graphic and plot | |
response_text = fmt.format_output_text(decoder_text) | |
plot = create_plot(averaged_attention, (encoder_text, decoder_text)) | |
marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer") | |
return response_text, "", plot, marked_text | |
# creating an attention heatmap plot using matplotlib/seaborn | |
# CREDIT: adopted from official Matplotlib documentation | |
## see https://matplotlib.org/stable/ | |
def create_plot(averaged_attention_weights, enc_dec_texts: tuple): | |
# transpose the attention weights | |
averaged_attention_weights = np.transpose(averaged_attention_weights) | |
# get the encoder and decoder tokens in text form | |
encoder_tokens = enc_dec_texts[0] | |
decoder_tokens = enc_dec_texts[1] | |
# set seaborn style to dark and initialize figure and axis | |
sns.set(style="white") | |
fig, ax = plt.subplots() | |
# Setting figure size | |
fig.set_size_inches( | |
max(averaged_attention_weights.shape[1] * 2, 10), | |
max(averaged_attention_weights.shape[0] * 1, 5), | |
) | |
# Plotting the heatmap with seaborn's color palette | |
im = ax.imshow( | |
averaged_attention_weights, | |
vmax=averaged_attention_weights.max(), | |
vmin=-averaged_attention_weights.min(), | |
cmap=sns.color_palette("rocket", as_cmap=True), | |
aspect="auto", | |
) | |
# Creating colorbar | |
cbar = ax.figure.colorbar(im, ax=ax) | |
cbar.ax.set_ylabel("Attention Weight Scale", rotation=-90, va="bottom") | |
cbar.ax.yaxis.set_tick_params(color="black") | |
plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black") | |
# Setting ticks and labels with black color for visibility | |
ax.set_yticks(np.arange(len(encoder_tokens)), labels=encoder_tokens) | |
ax.set_xticks(np.arange(len(decoder_tokens)), labels=decoder_tokens) | |
ax.set_title("Attention Weights by Token") | |
plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right") | |
plt.setp(ax.get_yticklabels(), color="black") | |
# Adding text annotations with appropriate contrast | |
for i in range(averaged_attention_weights.shape[0]): | |
for j in range(averaged_attention_weights.shape[1]): | |
val = averaged_attention_weights[i, j] | |
color = ( | |
"white" | |
if im.norm(averaged_attention_weights.max()) / 2 > im.norm(val) | |
else "black" | |
) | |
ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color) | |
# return the plot | |
return plt | |
def avg_attention(attention_values): | |
attention = attention_values.cross_attentions[0][0].detach().numpy() | |
return np.mean(attention, axis=0) | |