File size: 2,570 Bytes
9d1fa85
 
 
 
 
 
4f67e27
9d1fa85
 
 
4f67e27
 
 
 
 
9d1fa85
 
 
 
 
 
4f67e27
9d1fa85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from transformers import AutoTokenizer
from captum.attr import visualization

from roberta2 import RobertaForSequenceClassification
from util import visualize_text, PyTMinMaxScalerVectorized
from ExplanationGenerator import Generator

classifications = ["NEGATIVE", "POSITIVE"]

class GradientRolloutExplainer(Generator):
    def __init__(self, model, tokenizer):
        super().__init__(model, key="roberta.encoder.layer")
        self.device = model.device
        self.tokenizer = tokenizer

    def build_visualization(self, input_ids, attention_mask, index=None, start_layer=8):
        # generate an explanation for the input
        vis_data_records = []

        for index in range(2):
            output, expl = self.generate_rollout_attn_gradcam(
                input_ids, attention_mask, index=index, start_layer=start_layer
            )
            # normalize scores
            scaler = PyTMinMaxScalerVectorized()

            norm = scaler(expl)
            # get the model classification
            output = torch.nn.functional.softmax(output, dim=-1)

            for record in range(input_ids.size(0)):
                classification = output[record].argmax(dim=-1).item()
                class_name = classifications[classification]
                nrm = norm[record]

                # if the classification is negative, higher explanation scores are more negative
                # flip for visualization
                #if class_name == "NEGATIVE":
                if index == 0:
                    nrm *= -1
                tokens = self.tokens_from_ids(input_ids[record].flatten())[
                    1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
                ]
                vis_data_records.append(
                    visualization.VisualizationDataRecord(
                        nrm,
                        output[record][classification],
                        classification,
                        classification,
                        index,
                        1,
                        tokens,
                        1,
                    )
                )
        return visualize_text(vis_data_records)

    def __call__(self, input_text, start_layer=8):
        text_batch = [input_text]
        encoding = self.tokenizer(text_batch, return_tensors="pt")
        input_ids = encoding["input_ids"].to(self.device)
        attention_mask = encoding["attention_mask"].to(self.device)

        return self.build_visualization(input_ids, attention_mask, start_layer=int(start_layer))