Spaces:
Runtime error
Runtime error
File size: 3,892 Bytes
fe1089d d2116db fe1089d d2116db fe1089d d2116db fe1089d 69b34c4 d2116db fe1089d 69b34c4 fe1089d 69b34c4 fe1089d 69b34c4 fe1089d 69b34c4 fe1089d 69b34c4 fe1089d 69b34c4 fe1089d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
# interpret module that implements the interpretability method
# external imports
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from shap import models, maskers, plots, PartitionExplainer
import torch
# internal imports
from utils import formatting as fmt
from .markup import markup_text
# global variables
TEACHER_FORCING = None
TEXT_MASKER = None
# main explain function that returns a chat with explanations
def chat_explained(model, prompt):
model.set_config()
# create the shap explainer
shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
# get the shap values for the prompt
shap_values = shap_explainer([prompt])
# create the explanation graphic and plot
graphic = create_graphic(shap_values)
plot = create_plot(
values=shap_values.values[0],
output_names=shap_values.output_names,
input_names=shap_values.data[0],
)
marked_text = markup_text(
shap_values.data[0], shap_values.values[0], variant="shap"
)
# create the response text
response_text = fmt.format_output_text(shap_values.output_names)
return response_text, graphic, plot, marked_text
def wrap_shap(model):
global TEXT_MASKER, TEACHER_FORCING
# set the device to cuda if gpu is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# updating the model settings again
model.set_config()
# (re)initialize the shap models and masker
text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
TEACHER_FORCING = models.TeacherForcing(
text_generation,
model.TOKENIZER,
device=str(device),
similarity_model=model.MODEL,
similarity_tokenizer=model.TOKENIZER,
)
TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
# graphic plotting function that creates a html graphic (as string) for the explanation
def create_graphic(shap_values):
# create the html graphic using shap text plot function
graphic_html = plots.text(shap_values, display=False)
# return the html graphic as string
return str(graphic_html)
# creating an attention heatmap plot using matplotlib/seaborn
# CREDIT: adopted from official Matplotlib documentation
## see https://matplotlib.org/stable/
def create_plot(values, output_names, input_names):
# Set seaborn style to dark
sns.set(style="white")
fig, ax = plt.subplots()
# Setting figure size
fig.set_size_inches(
max(values.shape[1] * 2, 10),
max(values.shape[0] * 1, 5),
)
# Plotting the heatmap with Seaborn's color palette
im = ax.imshow(
values,
vmax=values.max(),
vmin=values.min(),
cmap=sns.color_palette("vlag_r", as_cmap=True),
aspect="auto",
)
# Creating colorbar
cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
cbar.ax.yaxis.set_tick_params(color="black")
plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
# Setting ticks and labels with white color for visibility
ax.set_yticks(np.arange(len(input_names)), labels=input_names)
ax.set_xticks(np.arange(len(output_names)), labels=output_names)
plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
plt.setp(ax.get_yticklabels(), color="black")
# Adjusting tick labels
ax.tick_params(
top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
)
# Adding text annotations with appropriate contrast
for i in range(values.shape[0]):
for j in range(values.shape[1]):
val = values[i, j]
color = "white" if im.norm(values.max()) / 2 > im.norm(val) else "black"
ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
return plt
|