Spaces:
Runtime error
Runtime error
| # ########################################################################### | |
| # | |
| # CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) | |
| # (C) Cloudera, Inc. 2022 | |
| # All rights reserved. | |
| # | |
| # Applicable Open Source License: Apache 2.0 | |
| # | |
| # NOTE: Cloudera open source products are modular software products | |
| # made up of hundreds of individual components, each of which was | |
| # individually copyrighted. Each Cloudera open source product is a | |
| # collective work under U.S. Copyright Law. Your license to use the | |
| # collective work is as provided in your written agreement with | |
| # Cloudera. Used apart from the collective work, this file is | |
| # licensed for your use pursuant to the open source license | |
| # identified above. | |
| # | |
| # This code is provided to you pursuant a written agreement with | |
| # (i) Cloudera, Inc. or (ii) a third-party authorized to distribute | |
| # this code. If you do not have a written agreement with Cloudera nor | |
| # with an authorized and properly licensed third party, you do not | |
| # have any rights to access nor to use this code. | |
| # | |
| # Absent a written agreement with Cloudera, Inc. (โClouderaโ) to the | |
| # contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY | |
| # KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED | |
| # WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO | |
| # IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND | |
| # FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, | |
| # AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS | |
| # ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE | |
| # OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY | |
| # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR | |
| # CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES | |
| # RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF | |
| # BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF | |
| # DATA. | |
| # | |
| # ########################################################################### | |
| import torch | |
| from transformers_interpret import SequenceClassificationExplainer | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| ) | |
| from apps.visualization_utils import visualize_text | |
| class CustomSequenceClassificationExplainer(SequenceClassificationExplainer): | |
| """ | |
| Subclassing to replace `visualize()` method with custom styling. | |
| Namely, removing a few columns, styling fonts, and re-arrangning legend position. | |
| """ | |
| def visualize(self, html_filepath: str = None, true_class: str = None): | |
| """ | |
| Visualizes word attributions. If in a notebook table will be displayed inline. | |
| Otherwise pass a valid path to `html_filepath` and the visualization will be saved | |
| as a html file. | |
| If the true class is known for the text that can be passed to `true_class` | |
| """ | |
| tokens = [token.replace("ฤ ", "") for token in self.decode(self.input_ids)] | |
| attr_class = self.id2label[self.selected_index] | |
| if self._single_node_output: | |
| if true_class is None: | |
| true_class = round(float(self.pred_probs)) | |
| predicted_class = round(float(self.pred_probs)) | |
| attr_class = round(float(self.pred_probs)) | |
| else: | |
| if true_class is None: | |
| true_class = self.selected_index | |
| predicted_class = self.predicted_class_name | |
| score_viz = self.attributions.visualize_attributions( # type: ignore | |
| self.pred_probs, | |
| predicted_class, | |
| true_class, | |
| attr_class, | |
| tokens, | |
| ) | |
| # NOTE: here is the overwritten function | |
| html = visualize_text([score_viz]) | |
| if html_filepath: | |
| if not html_filepath.endswith(".html"): | |
| html_filepath = html_filepath + ".html" | |
| with open(html_filepath, "w") as html_file: | |
| html_file.write(html.data) | |
| return html | |
| class InterpretTransformer: | |
| """ | |
| Utility for visualizing word attribution scores from Transformer models. | |
| This class utilizes the [Transformers Interpret](https://github.com/cdpierse/transformers-interpret) | |
| libary to calculate word attributions using a techinique called Integrated Gradients. | |
| Attributes: | |
| cls_model_identifier (str) | |
| """ | |
| def __init__(self, cls_model_identifier: str): | |
| self.cls_model_identifier = cls_model_identifier | |
| self.device = ( | |
| torch.cuda.current_device() if torch.cuda.is_available() else "cpu" | |
| ) | |
| self._initialize_hf_artifacts() | |
| def _initialize_hf_artifacts(self): | |
| """ | |
| Initialize a HuggingFace artifacts (tokenizer and model) according | |
| to the provided identifiers for both SBert and the classification model. | |
| Then initialize the word attribution explainer with the HF model+tokenizer. | |
| """ | |
| # classifer | |
| self.cls_tokenizer = AutoTokenizer.from_pretrained(self.cls_model_identifier) | |
| self.cls_model = AutoModelForSequenceClassification.from_pretrained( | |
| self.cls_model_identifier | |
| ) | |
| self.cls_model.to(self.device) | |
| # transformers interpret | |
| self.explainer = CustomSequenceClassificationExplainer( | |
| self.cls_model, self.cls_tokenizer | |
| ) | |
| def visualize_feature_attribution_scores(self, text: str, class_index: int = 0): | |
| """ | |
| Calculates and visualizes feature attributions using integrated gradients. | |
| Args: | |
| text (str) - text to get attributions for | |
| class_index (int) - Optional output index to provide attributions for | |
| """ | |
| self.explainer(text, index=class_index) | |
| return self.explainer.visualize() | |