import os import sys sys.path.append('BERT') from transformers import BertTokenizer from BERT_explainability.modules.BERT.ExplanationGenerator import Generator from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification from transformers import AutoTokenizer from captum.attr import visualization import torch from sequenceoutput.modeling_output import SequenceClassifierOutput model = BertForSequenceClassification.from_pretrained("./BERT/BERT_weight") model.eval() tokenizer = AutoTokenizer.from_pretrained("./BERT/BERT_weight") # initialize the explanations generator explanations = Generator(model) classifications = ["NEGATIVE", "POSITIVE"] true_class = 1 def generate_visual(text_batch, target_class): encoding = tokenizer(text_batch, return_tensors='pt') input_ids = encoding['input_ids'] attention_mask = encoding['attention_mask'] expl = \ explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0] expl = (expl - expl.min()) / (expl.max() - expl.min()) output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1) classification = output.argmax(dim=-1).item() class_name = classifications[target_class] if class_name == "NEGATIVE": expl *= (-1) token_importance = {} tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten()) for i in range(len(tokens)): token_importance[tokens[i]] = round(expl[i].item(), 3) vis_data_records = [visualization.VisualizationDataRecord( expl, output[0][classification], classification, true_class, true_class, 1, tokens, 1)] html_page = visualization.visualize_text(vis_data_records) return token_importance, html_page.data