File size: 2,720 Bytes
d4dd3c5
 
 
 
 
 
 
 
67a34bd
d4dd3c5
 
 
 
 
 
 
 
67a34bd
d4dd3c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7e16d0
d4dd3c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30049a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67a34bd
 
 
30049a9
67a34bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# interpret module that implements the interpretability method

# external imports
from shap import models, maskers, plots, PartitionExplainer
import torch

# internal imports
from utils import formatting as fmt
from .plotting import plot_seq
from .markup import markup_text

# global variables
TEACHER_FORCING = None
TEXT_MASKER = None


# function to extract summarized sequence wise attribution
def shap_extract_seq_att(shap_values):

    # extracting summed up shap values
    values = fmt.flatten_attribution(shap_values.values[0], 1)

    # returning list of tuples of token and value
    return list(zip(shap_values.data[0], values))


# function used to wrap the model with a shap model
def wrap_shap(model):
    # calling global variants
    global TEXT_MASKER, TEACHER_FORCING

    # set the device to cuda if gpu is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # updating the model settings
    model.set_config({})

    # (re)initialize the shap models and masker
    # creating a shap text_generation model
    text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
    # wrapping the text generation model in a teacher forcing model
    TEACHER_FORCING = models.TeacherForcing(
        text_generation,
        model.TOKENIZER,
        device=str(device),
        similarity_model=model.MODEL,
        similarity_tokenizer=model.TOKENIZER,
    )
    # setting the text masker as an empty string
    TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)


# graphic plotting function that creates a html graphic (as string) for the explanation
def create_graphic(shap_values):

    # create the html graphic using shap text plot function
    graphic_html = plots.text(shap_values, display=False)

    # return the html graphic as string to display in iFrame
    return str(graphic_html)


# main explain function that returns a chat with explanations
def chat_explained(model, prompt):
    model.set_config({})

    # create the shap explainer
    shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)

    # get the shap values for the prompt
    shap_values = shap_explainer([prompt])

    # create the explanation graphic and marked text array
    graphic = create_graphic(shap_values)
    marked_text = markup_text(
        shap_values.data[0], shap_values.values[0], variant="shap"
    )

    # create the response text
    response_text = fmt.format_output_text(shap_values.output_names)

    # creating sequence attribution plot
    plot = plot_seq(shap_extract_seq_att(shap_values), "PartitionSHAP")

    # return response, graphic and marked_text array
    return response_text, graphic, marked_text, plot