File size: 3,268 Bytes
fe1089d
 
 
 
67a34bd
d2116db
 
fe1089d
 
67a34bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe1089d
2492536
 
fe1089d
 
 
 
 
 
 
 
 
 
 
bf15c20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe1089d
 
 
 
 
 
 
bf15c20
fe1089d
 
 
 
2492536
 
fe1089d
 
 
 
 
 
 
 
2492536
67a34bd
fe1089d
 
 
 
 
 
 
d2116db
 
2492536
f5ebee7
 
 
c28c597
2492536
f5ebee7
d2116db
f5ebee7
c28c597
2492536
67a34bd
f301e04
67a34bd
 
 
517fd4c
 
f301e04
67a34bd
f301e04
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# formatting util module providing formatting functions for the model input and output

# external imports
import re
import torch
import numpy as np
from numpy import ndarray


# globally defined tokens that are removed from the output
SPECIAL_TOKENS = [
    "[CLS]",
    "[SEP]",
    "[PAD]",
    "[UNK]",
    "[MASK]",
    "▁",
    "Ġ",
    "</w>",
    "<0x0A>",
    "<0x0D>",
    "<0x09>",
    "<s>",
    "</s>",
]


# function to format the model repose nicely
# takes a list of strings and returning a combined string
def format_output_text(output: list):

    # remove special tokens from list using other function
    formatted_output = format_tokens(output)

    # start string with first list item if it is not empty
    if formatted_output[0] != "":
        output_str = formatted_output[0]
    else:
        # alternatively start with second list item
        output_str = formatted_output[1]

    # add all other list items with a space in between
    for txt in formatted_output[1:]:
        # check if the token is a punctuation mark or other special character
        if txt in [
            ".",
            ",",
            "!",
            "?",
            ":",
            ";",
            ")",
            "]",
            "}",
            "'",
            '"',
            "[",
            "{",
            "(",
            "<",
        ]:
            # add punctuation mark without space
            output_str += txt
        # add token with space if not empty
        elif txt != "":
            output_str += " " + txt

    # return the combined string with multiple spaces removed
    return re.sub(r"\s+", " ", output_str)


# format the tokens by removing special tokens and special characters
def format_tokens(tokens: list):

    # initialize empty list
    updated_tokens = []

    # loop through tokens
    for t in tokens:
        # remove special token from start of token if found
        if t.startswith("▁"):
            t = t.lstrip("▁")

        # loop through special tokens list and remove from current token if matched
        for s in SPECIAL_TOKENS:
            t = t.replace(s, "")

        # add token to list
        updated_tokens.append(t)

    # return the list of tokens
    return updated_tokens


# function to flatten shap values into a 2d list by summing them up
def flatten_attribution(values: ndarray, axis: int = 0):
    return np.sum(values, axis=axis)


# function to flatten values into a 2d list by averaging the attention values
def flatten_attention(values: ndarray, axis: int = 0):
    return np.mean(values, axis=axis)


# function to get averaged decoder attention from attention values
def avg_attention(attention_values, model: str):

    # check if model is godel
    if model == "godel":
        # get attention values for the input and output vectors
        attention = attention_values.encoder_attentions[0][0].detach().numpy()
        return np.mean(attention, axis=1)

    # extracting attention values for mistral
    attention = attention_values.to(torch.device("cpu")).detach().numpy()

    # removing the last dimension and transposing to get the correct shape
    attention = attention[:, :, :, 0]

    # return the averaged attention values
    return np.mean(attention, axis=1)