File size: 5,603 Bytes
f388ec1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from transformers_interpret import SequenceClassificationExplainer
from captum.attr import visualization as viz
import html


class CustomExplainer(SequenceClassificationExplainer):
    def __init__(self, model, tokenizer):
        super().__init__(model, tokenizer)

    def visualize(self, html_filepath: str = None, true_class: str = None):
        """
        Visualizes word attributions. If in a notebook table will be displayed inline.

        Otherwise pass a valid path to `html_filepath` and the visualization will be saved
        as a html file.

        If the true class is known for the text that can be passed to `true_class`

        """
        tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
        attr_class = self.id2label[self.selected_index]

        if self._single_node_output:
            if true_class is None:
                true_class = round(float(self.pred_probs))
            predicted_class = round(float(self.pred_probs))
            attr_class = round(float(self.pred_probs))

        else:
            if true_class is None:
                true_class = self.selected_index
            predicted_class = self.predicted_class_name

        score_viz = self.attributions.visualize_attributions(  # type: ignore
            self.pred_probs,
            predicted_class,
            true_class,
            attr_class,
            tokens,
        )

        print(score_viz)

        html = viz.visualize_text([score_viz])

        if html_filepath:
            if not html_filepath.endswith(".html"):
                html_filepath = html_filepath + ".html"
            with open(html_filepath, "w") as html_file:
                html_file.write("<meta charset='UTF-8'>" + html.data)
        return html

    def merge_attributions(self, token_level_attributions):
        final = []
        scores = []
        for i, elem in enumerate(token_level_attributions):
            token = elem[0]
            score = elem[1]
            if token.startswith("##"):
                final[-1] = final[-1] + token.replace("##", "")
                scores[-1] = scores[-1] + score
            else:
                final.append(token)
                scores.append(score)
        attr = [(final[i], scores[i]) for i in range(len(final))]
        return attr

    def visualize_wordwise(self, sentence: str, path: str,  true_class: str):
        pred_class = self.predicted_class_name
        if pred_class == true_class:
            legend_sent = f"against {pred_class}"
        else:
            legend_sent = f"against {pred_class} and towards {true_class}"
        attribution_weights = self.merge_attributions(self(sentence))
        min_weight = min([float(abs(w)) for _, w in attribution_weights])
        max_weight = max([float(abs(w)) for _, w in attribution_weights])
        attention_html = []
        for word, weight in attribution_weights:
            hue = 5 if weight < 0 else 147
            sat = "100%" if weight < 0 else "50%"

            # Logarithmic mapping to scale weight values
            scaled_weight = (min_weight + abs(weight)) / (max_weight - min_weight)

            # Adjust brightness and saturation for better contrast
            lightness = f"{100 - 50 * scaled_weight}%"

            color = f"hsl({hue},{sat},{lightness})"

            attention_html.append(
                f"<span class='word-box' style='background-color: {color};''>{word}</span><span>&nbsp;</span>")

        attention_html = html.unescape("".join(attention_html))

        final_html = f"""
            <!DOCTYPE html>
            <html>
            <head>
                <title>Attention Visualization</title>
                <style>
                    span {{
                        font-family: sans-serif;
                        font-size: 16px;
                    }}
                </style>
                <style>
                    /* Color legend */
                    .color-legend {{
                        display: inline-block;
                        margin: 10px 0;
                        padding: 10px 15px;
                        border: 1px solid #ccc;
                        border-radius: 5px;
                    }}
    
                    .word-box {{
                    display: inline-block;
                    border-radius: 5px;
                    padding: 0.2em;
                    }}
    
                    .color-legend span {{
                        display: inline-block;
                        margin: 0 5px;
                    }}
    
                    .positive-weight {{
                        color: green;
                    }}
    
                    .negative-weight {{
                        color: red;
                    }}
    
                    .color-legend span:first-child {{
                        margin-left: 0;
                    }}
                </style>
                <meta charset="utf-8" />
            </head>
            <body>
                <div class="color-legend">
                    <p>PREDICTED LABEL: <b>{pred_class}</b><br>TRUE LABEL: <b>{true_class}</b></p>
                    <p><span class='word-box' style='background-color: hsl(5,100%,50%)';>Disagreement</span> ({legend_sent})</p>
                    <p><span class='word-box' style='background-color: hsl(147,50%,50%)';>Agreement</span> (towards {pred_class})</p>
                </div>
                <div>{attention_html}</div>
            </body>
            </html>
            """

        with open(path, "w", encoding="utf-8") as f:
            f.write(final_html)