from transformers_interpret import SequenceClassificationExplainer from captum.attr import visualization as viz class CustomExplainer(SequenceClassificationExplainer): def __init__(self, model, tokenizer): super().__init__(model, tokenizer) def visualize(self, html_filepath: str = None, true_class: str = None): """ Visualizes word attributions. If in a notebook table will be displayed inline. Otherwise pass a valid path to `html_filepath` and the visualization will be saved as a html file. If the true class is known for the text that can be passed to `true_class` """ tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] attr_class = self.id2label[self.selected_index] if self._single_node_output: if true_class is None: true_class = round(float(self.pred_probs)) predicted_class = round(float(self.pred_probs)) attr_class = round(float(self.pred_probs)) else: if true_class is None: true_class = self.selected_index predicted_class = self.predicted_class_name score_viz = self.attributions.visualize_attributions( # type: ignore self.pred_probs, predicted_class, true_class, attr_class, tokens, ) print(score_viz) html = viz.visualize_text([score_viz]) if html_filepath: if not html_filepath.endswith(".html"): html_filepath = html_filepath + ".html" with open(html_filepath, "w") as html_file: html_file.write("" + html.data) return html def merge_attributions(self, token_level_attributions): final = [] scores = [] for i, elem in enumerate(token_level_attributions): token = elem[0] score = elem[1] if token.startswith("##"): final[-1] = final[-1] + token.replace("##", "") scores[-1] = scores[-1] + score else: final.append(token) scores.append(score) attr = [(final[i], scores[i]) for i in range(len(final))] return attr def visualize_wordwise(self, sentence: str, path: str, true_class: str): pred_class = self.predicted_class_name if pred_class == true_class: legend_sent = f"against {pred_class}" else: legend_sent = f"against {pred_class} and towards {true_class}" attribution_weights = self.merge_attributions(self(sentence)) min_weight = min([float(abs(w)) for _, w in attribution_weights]) max_weight = max([float(abs(w)) for _, w in attribution_weights]) attention_html = [] for word, weight in attribution_weights: hue = 5 if weight < 0 else 147 sat = "100%" if weight < 0 else "50%" # Logarithmic mapping to scale weight values scaled_weight = (min_weight + abs(weight)) / (max_weight - min_weight) # Adjust brightness and saturation for better contrast lightness = f"{100 - 50 * scaled_weight}%" color = f"hsl({hue},{sat},{lightness})" attention_html.append( f"{word} ") #attention_html = html.unescape("".join(attention_html)) final_html = f""" Attention Visualization

PREDICTED LABEL: {pred_class}
TRUE LABEL: {true_class}

Disagreement ({legend_sent})

Agreement (towards {pred_class})

{attention_html}
""" with open(path, "w", encoding="utf-8") as f: f.write(final_html)