thesis / explanation /interpret_shap.py
LennardZuendorf's picture
feat/fix: fixing attention bug, fixing other mistral bugs
67a34bd unverified
# interpret module that implements the interpretability method
# external imports
from shap import models, maskers, plots, PartitionExplainer
import torch
# internal imports
from utils import formatting as fmt
from .plotting import plot_seq
from .markup import markup_text
# global variables
TEACHER_FORCING = None
TEXT_MASKER = None
# function to extract summarized sequence wise attribution
def shap_extract_seq_att(shap_values):
# extracting summed up shap values
values = fmt.flatten_attribution(shap_values.values[0], 1)
# returning list of tuples of token and value
return list(zip(shap_values.data[0], values))
# function used to wrap the model with a shap model
def wrap_shap(model):
# calling global variants
global TEXT_MASKER, TEACHER_FORCING
# set the device to cuda if gpu is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# updating the model settings
model.set_config({})
# (re)initialize the shap models and masker
# creating a shap text_generation model
text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
# wrapping the text generation model in a teacher forcing model
TEACHER_FORCING = models.TeacherForcing(
text_generation,
model.TOKENIZER,
device=str(device),
similarity_model=model.MODEL,
similarity_tokenizer=model.TOKENIZER,
)
# setting the text masker as an empty string
TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
# graphic plotting function that creates a html graphic (as string) for the explanation
def create_graphic(shap_values):
# create the html graphic using shap text plot function
graphic_html = plots.text(shap_values, display=False)
# return the html graphic as string to display in iFrame
return str(graphic_html)
# main explain function that returns a chat with explanations
def chat_explained(model, prompt):
model.set_config({})
# create the shap explainer
shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
# get the shap values for the prompt
shap_values = shap_explainer([prompt])
# create the explanation graphic and marked text array
graphic = create_graphic(shap_values)
marked_text = markup_text(
shap_values.data[0], shap_values.values[0], variant="shap"
)
# create the response text
response_text = fmt.format_output_text(shap_values.output_names)
# creating sequence attribution plot
plot = plot_seq(shap_extract_seq_att(shap_values), "PartitionSHAP")
# return response, graphic and marked_text array
return response_text, graphic, marked_text, plot