File size: 2,210 Bytes
d4dd3c5
 
 
 
 
 
67a34bd
d4dd3c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67a34bd
d4dd3c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7e16d0
a597c76
226ad46
67a34bd
d4dd3c5
 
67a34bd
 
 
d4dd3c5
67a34bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# external imports
from captum.attr import LLMAttribution, TextTokenInput, KernelShap
import torch

# internal imports
from utils import formatting as fmt
from .plotting import plot_seq
from .markup import markup_text


# function to extract sequence attribution
def cpt_extract_seq_att(attr):

    # getting values from captum
    values = attr.seq_attr.to(torch.device("cpu")).numpy()

    # format the input tokens nicely and check for mismatch
    input_tokens = fmt.format_tokens(attr.input_tokens)
    if len(attr.input_tokens) != len(values):
        raise RuntimeError("values and input len mismatch")

    # return a list of tuples with token and value
    return list(zip(input_tokens, values))


# main explain function that returns a chat with explanations
def chat_explained(model, prompt):
    model.set_config({})

    # creating llm attribution class with KernelSHAP and Mistral Model, Tokenizer
    llm_attribution = LLMAttribution(KernelShap(model.MODEL), model.TOKENIZER)

    # generation attribution
    attribution_input = TextTokenInput(prompt, model.TOKENIZER)
    attribution_result = llm_attribution.attribute(
        attribution_input, gen_args=model.CONFIG.to_dict()
    )

    # extracting values and input tokens
    values = attribution_result.seq_attr.to(torch.device("cpu")).numpy()
    input_tokens = fmt.format_tokens(attribution_result.input_tokens)

    # raising error if mismatch occurs
    if len(attribution_result.input_tokens) != len(values):
        raise RuntimeError("values and input len mismatch")

    # getting response text, graphic placeholder and marked text object
    response_text = fmt.format_output_text(attribution_result.output_tokens)
    graphic = """<div style='text-align: center; font-family:arial;'><h4>
        Interpretation with Captum doesn't support an interactive graphic.</h4></div>
        """
    # create the explanation marked text array
    marked_text = markup_text(input_tokens, values, variant="captum")

    # creating sequence attribution plot
    plot = plot_seq(cpt_extract_seq_att(attribution_result), "KernelSHAP")

    # return response, graphic and marked_text array
    return response_text, graphic, marked_text, plot