Spaces:
Runtime error
Runtime error
from transformers_interpret import SequenceClassificationExplainer | |
from captum.attr import visualization as viz | |
class CustomExplainer(SequenceClassificationExplainer): | |
def __init__(self, model, tokenizer): | |
super().__init__(model, tokenizer) | |
def visualize(self, html_filepath: str = None, true_class: str = None): | |
""" | |
Visualizes word attributions. If in a notebook table will be displayed inline. | |
Otherwise pass a valid path to `html_filepath` and the visualization will be saved | |
as a html file. | |
If the true class is known for the text that can be passed to `true_class` | |
""" | |
tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] | |
attr_class = self.id2label[self.selected_index] | |
if self._single_node_output: | |
if true_class is None: | |
true_class = round(float(self.pred_probs)) | |
predicted_class = round(float(self.pred_probs)) | |
attr_class = round(float(self.pred_probs)) | |
else: | |
if true_class is None: | |
true_class = self.selected_index | |
predicted_class = self.predicted_class_name | |
score_viz = self.attributions.visualize_attributions( # type: ignore | |
self.pred_probs, | |
predicted_class, | |
true_class, | |
attr_class, | |
tokens, | |
) | |
print(score_viz) | |
html = viz.visualize_text([score_viz]) | |
if html_filepath: | |
if not html_filepath.endswith(".html"): | |
html_filepath = html_filepath + ".html" | |
with open(html_filepath, "w") as html_file: | |
html_file.write("<meta charset='UTF-8'>" + html.data) | |
return html | |
def merge_attributions(self, token_level_attributions): | |
final = [] | |
scores = [] | |
for i, elem in enumerate(token_level_attributions): | |
token = elem[0] | |
score = elem[1] | |
if token.startswith("##"): | |
final[-1] = final[-1] + token.replace("##", "") | |
scores[-1] = scores[-1] + score | |
else: | |
final.append(token) | |
scores.append(score) | |
attr = [(final[i], scores[i]) for i in range(len(final))] | |
return attr | |
def visualize_wordwise(self, sentence: str, path: str, true_class: str): | |
pred_class = self.predicted_class_name | |
if pred_class == true_class: | |
legend_sent = f"against {pred_class}" | |
else: | |
legend_sent = f"against {pred_class} and towards {true_class}" | |
attribution_weights = self.merge_attributions(self(sentence)) | |
min_weight = min([float(abs(w)) for _, w in attribution_weights]) | |
max_weight = max([float(abs(w)) for _, w in attribution_weights]) | |
attention_html = [] | |
for word, weight in attribution_weights: | |
hue = 5 if weight < 0 else 147 | |
sat = "100%" if weight < 0 else "50%" | |
# Logarithmic mapping to scale weight values | |
scaled_weight = (min_weight + abs(weight)) / (max_weight - min_weight) | |
# Adjust brightness and saturation for better contrast | |
lightness = f"{100 - 50 * scaled_weight}%" | |
color = f"hsl({hue},{sat},{lightness})" | |
attention_html.append( | |
f"<span class='word-box' style='background-color: {color};''>{word}</span><span> </span>") | |
#attention_html = html.unescape("".join(attention_html)) | |
final_html = f""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Attention Visualization</title> | |
<style> | |
span {{ | |
font-family: sans-serif; | |
font-size: 16px; | |
}} | |
</style> | |
<style> | |
/* Color legend */ | |
.color-legend {{ | |
display: inline-block; | |
margin: 10px 0; | |
padding: 10px 15px; | |
border: 1px solid #ccc; | |
border-radius: 5px; | |
}} | |
.word-box {{ | |
display: inline-block; | |
border-radius: 5px; | |
padding: 0.2em; | |
}} | |
.color-legend span {{ | |
display: inline-block; | |
margin: 0 5px; | |
}} | |
.positive-weight {{ | |
color: green; | |
}} | |
.negative-weight {{ | |
color: red; | |
}} | |
.color-legend span:first-child {{ | |
margin-left: 0; | |
}} | |
</style> | |
<meta charset="utf-8" /> | |
</head> | |
<body> | |
<div class="color-legend"> | |
<p>PREDICTED LABEL: <b>{pred_class}</b><br>TRUE LABEL: <b>{true_class}</b></p> | |
<p><span class='word-box' style='background-color: hsl(5,100%,50%)';>Disagreement</span> ({legend_sent})</p> | |
<p><span class='word-box' style='background-color: hsl(147,50%,50%)';>Agreement</span> (towards {pred_class})</p> | |
</div> | |
<div>{attention_html}</div> | |
</body> | |
</html> | |
""" | |
with open(path, "w", encoding="utf-8") as f: | |
f.write(final_html) | |