File size: 5,592 Bytes
f388ec1 9329025 f388ec1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from transformers_interpret import SequenceClassificationExplainer
from captum.attr import visualization as viz
class CustomExplainer(SequenceClassificationExplainer):
def __init__(self, model, tokenizer):
super().__init__(model, tokenizer)
def visualize(self, html_filepath: str = None, true_class: str = None):
"""
Visualizes word attributions. If in a notebook table will be displayed inline.
Otherwise pass a valid path to `html_filepath` and the visualization will be saved
as a html file.
If the true class is known for the text that can be passed to `true_class`
"""
tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
attr_class = self.id2label[self.selected_index]
if self._single_node_output:
if true_class is None:
true_class = round(float(self.pred_probs))
predicted_class = round(float(self.pred_probs))
attr_class = round(float(self.pred_probs))
else:
if true_class is None:
true_class = self.selected_index
predicted_class = self.predicted_class_name
score_viz = self.attributions.visualize_attributions( # type: ignore
self.pred_probs,
predicted_class,
true_class,
attr_class,
tokens,
)
print(score_viz)
html = viz.visualize_text([score_viz])
if html_filepath:
if not html_filepath.endswith(".html"):
html_filepath = html_filepath + ".html"
with open(html_filepath, "w") as html_file:
html_file.write("<meta charset='UTF-8'>" + html.data)
return html
def merge_attributions(self, token_level_attributions):
final = []
scores = []
for i, elem in enumerate(token_level_attributions):
token = elem[0]
score = elem[1]
if token.startswith("##"):
final[-1] = final[-1] + token.replace("##", "")
scores[-1] = scores[-1] + score
else:
final.append(token)
scores.append(score)
attr = [(final[i], scores[i]) for i in range(len(final))]
return attr
def visualize_wordwise(self, sentence: str, path: str, true_class: str):
pred_class = self.predicted_class_name
if pred_class == true_class:
legend_sent = f"against {pred_class}"
else:
legend_sent = f"against {pred_class} and towards {true_class}"
attribution_weights = self.merge_attributions(self(sentence))
min_weight = min([float(abs(w)) for _, w in attribution_weights])
max_weight = max([float(abs(w)) for _, w in attribution_weights])
attention_html = []
for word, weight in attribution_weights:
hue = 5 if weight < 0 else 147
sat = "100%" if weight < 0 else "50%"
# Logarithmic mapping to scale weight values
scaled_weight = (min_weight + abs(weight)) / (max_weight - min_weight)
# Adjust brightness and saturation for better contrast
lightness = f"{100 - 50 * scaled_weight}%"
color = f"hsl({hue},{sat},{lightness})"
attention_html.append(
f"<span class='word-box' style='background-color: {color};''>{word}</span><span> </span>")
#attention_html = html.unescape("".join(attention_html))
final_html = f"""
<!DOCTYPE html>
<html>
<head>
<title>Attention Visualization</title>
<style>
span {{
font-family: sans-serif;
font-size: 16px;
}}
</style>
<style>
/* Color legend */
.color-legend {{
display: inline-block;
margin: 10px 0;
padding: 10px 15px;
border: 1px solid #ccc;
border-radius: 5px;
}}
.word-box {{
display: inline-block;
border-radius: 5px;
padding: 0.2em;
}}
.color-legend span {{
display: inline-block;
margin: 0 5px;
}}
.positive-weight {{
color: green;
}}
.negative-weight {{
color: red;
}}
.color-legend span:first-child {{
margin-left: 0;
}}
</style>
<meta charset="utf-8" />
</head>
<body>
<div class="color-legend">
<p>PREDICTED LABEL: <b>{pred_class}</b><br>TRUE LABEL: <b>{true_class}</b></p>
<p><span class='word-box' style='background-color: hsl(5,100%,50%)';>Disagreement</span> ({legend_sent})</p>
<p><span class='word-box' style='background-color: hsl(147,50%,50%)';>Agreement</span> (towards {pred_class})</p>
</div>
<div>{attention_html}</div>
</body>
</html>
"""
with open(path, "w", encoding="utf-8") as f:
f.write(final_html)
|