Martijn van Beers commited on
Commit
9e7d7f8
1 Parent(s): 5b3ff3f

Add title and description

Browse files
Files changed (1) hide show
  1. app.py +65 -35
app.py CHANGED
@@ -9,9 +9,7 @@ from BERT_explainability.ExplanationGenerator import Generator
9
  from BERT_explainability.roberta2 import RobertaForSequenceClassification
10
  from transformers import AutoTokenizer
11
 
12
- from captum.attr import (
13
- visualization
14
- )
15
  import torch
16
 
17
  # from https://discuss.pytorch.org/t/using-scikit-learns-scalers-for-torchvision/53455
@@ -19,11 +17,15 @@ class PyTMinMaxScalerVectorized(object):
19
  """
20
  Transforms each channel to the range [0, 1].
21
  """
 
22
  def __init__(self, dimension=-1):
23
  self.d = dimension
 
24
  def __call__(self, tensor):
25
  d = self.d
26
- scale = 1.0 / (tensor.max(dim=d, keepdim=True)[0] - tensor.min(dim=d, keepdim=True)[0])
 
 
27
  tensor.mul_(scale).sub_(tensor.min(dim=d, keepdim=True)[0])
28
  return tensor
29
 
@@ -33,7 +35,9 @@ if torch.cuda.is_available():
33
  else:
34
  device = torch.device("cpu")
35
 
36
- model = RobertaForSequenceClassification.from_pretrained("textattack/roberta-base-SST-2").to(device)
 
 
37
  model.eval()
38
  tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2")
39
  # initialize the explanations generator
@@ -43,33 +47,33 @@ classifications = ["NEGATIVE", "POSITIVE"]
43
 
44
  # rule 5 from paper
45
  def avg_heads(cam, grad):
46
- cam = (
47
- (grad * cam)
48
- .clamp(min=0)
49
- .mean(dim=-3)
50
- )
51
  # set negative values to 0, then average
52
- # cam = cam.clamp(min=0).mean(dim=0)
53
  return cam
54
 
 
55
  # rule 6 from paper
56
  def apply_self_attention_rules(R_ss, cam_ss):
57
  R_ss_addition = torch.matmul(cam_ss, R_ss)
58
  return R_ss_addition
59
 
 
60
  def generate_relevance(model, input_ids, attention_mask, index=None, start_layer=0):
61
  output = model(input_ids=input_ids, attention_mask=attention_mask)[0]
62
  if index == None:
63
- #index = np.expand_dims(np.arange(input_ids.shape[1])
64
  # by default explain the class with the highest score
65
  index = output.argmax(axis=-1).detach().cpu().numpy()
66
 
67
  # create a one-hot vector selecting class we want explanations for
68
- one_hot = (torch.nn.functional
69
- .one_hot(torch.tensor(index, dtype=torch.int64), num_classes=output.size(-1))
70
- .to(torch.float)
71
- .requires_grad_(True)
72
- ).to(device)
 
 
73
  print("ONE_HOT", one_hot.size(), one_hot)
74
  one_hot = torch.sum(one_hot * output)
75
  model.zero_grad()
@@ -90,6 +94,7 @@ def generate_relevance(model, input_ids, attention_mask, index=None, start_layer
90
  R += joint
91
  return output, R[:, 0, 1:-1]
92
 
 
93
  def visualize_text(datarecords, legend=True):
94
  dom = ["<table width: 100%>"]
95
  rows = [
@@ -111,7 +116,9 @@ def visualize_text(datarecords, legend=True):
111
  )
112
  ),
113
  visualization.format_classname(datarecord.attr_class),
114
- visualization.format_classname("{0:.2f}".format(datarecord.attr_score)),
 
 
115
  visualization.format_word_importances(
116
  datarecord.raw_input_ids, datarecord.word_attributions
117
  ),
@@ -143,9 +150,12 @@ def visualize_text(datarecords, legend=True):
143
 
144
  return html
145
 
 
146
  def show_explanation(model, input_ids, attention_mask, index=None, start_layer=0):
147
  # generate an explanation for the input
148
- output, expl = generate_relevance(model, input_ids, attention_mask, index=index, start_layer=start_layer)
 
 
149
  print(output.shape, expl.shape)
150
  # normalize scores
151
  scaler = PyTMinMaxScalerVectorized()
@@ -154,7 +164,6 @@ def show_explanation(model, input_ids, attention_mask, index=None, start_layer=0
154
  # get the model classification
155
  output = torch.nn.functional.softmax(output, dim=-1)
156
 
157
-
158
  vis_data_records = []
159
  for record in range(input_ids.size(0)):
160
  classification = output[record].argmax(dim=-1).item()
@@ -164,25 +173,31 @@ def show_explanation(model, input_ids, attention_mask, index=None, start_layer=0
164
  # if the classification is negative, higher explanation scores are more negative
165
  # flip for visualization
166
  if class_name == "NEGATIVE":
167
- nrm *= (-1)
168
- tokens = tokenizer.convert_ids_to_tokens(input_ids[record].flatten())[1:0 - ((attention_mask[record] == 0).sum().item() + 1)]
 
 
169
  print([(tokens[i], nrm[i].item()) for i in range(len(tokens))])
170
- vis_data_records.append(visualization.VisualizationDataRecord(
171
- nrm,
172
- output[record][classification],
173
- classification,
174
- classification,
175
- index,
176
- 1,
177
- tokens,
178
- 1))
 
 
 
179
  return visualize_text(vis_data_records)
180
 
 
181
  def run(input_text):
182
  text_batch = [input_text]
183
- encoding = tokenizer(text_batch, return_tensors='pt')
184
- input_ids = encoding['input_ids'].to(device)
185
- attention_mask = encoding['attention_mask'].to(device)
186
 
187
  # true class is positive - 1
188
  true_class = 1
@@ -190,5 +205,20 @@ def run(input_text):
190
  html = show_explanation(model, input_ids, attention_mask)
191
  return html
192
 
193
- iface = gradio.Interface(fn=run, inputs="text", outputs="html", examples=[["This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great"], ["I really didn't like this movie. Some of the actors were good, but overall the movie was boring"]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  iface.launch()
 
9
  from BERT_explainability.roberta2 import RobertaForSequenceClassification
10
  from transformers import AutoTokenizer
11
 
12
+ from captum.attr import visualization
 
 
13
  import torch
14
 
15
  # from https://discuss.pytorch.org/t/using-scikit-learns-scalers-for-torchvision/53455
 
17
  """
18
  Transforms each channel to the range [0, 1].
19
  """
20
+
21
  def __init__(self, dimension=-1):
22
  self.d = dimension
23
+
24
  def __call__(self, tensor):
25
  d = self.d
26
+ scale = 1.0 / (
27
+ tensor.max(dim=d, keepdim=True)[0] - tensor.min(dim=d, keepdim=True)[0]
28
+ )
29
  tensor.mul_(scale).sub_(tensor.min(dim=d, keepdim=True)[0])
30
  return tensor
31
 
 
35
  else:
36
  device = torch.device("cpu")
37
 
38
+ model = RobertaForSequenceClassification.from_pretrained(
39
+ "textattack/roberta-base-SST-2"
40
+ ).to(device)
41
  model.eval()
42
  tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2")
43
  # initialize the explanations generator
 
47
 
48
  # rule 5 from paper
49
  def avg_heads(cam, grad):
50
+ cam = (grad * cam).clamp(min=0).mean(dim=-3)
 
 
 
 
51
  # set negative values to 0, then average
52
+ # cam = cam.clamp(min=0).mean(dim=0)
53
  return cam
54
 
55
+
56
  # rule 6 from paper
57
  def apply_self_attention_rules(R_ss, cam_ss):
58
  R_ss_addition = torch.matmul(cam_ss, R_ss)
59
  return R_ss_addition
60
 
61
+
62
  def generate_relevance(model, input_ids, attention_mask, index=None, start_layer=0):
63
  output = model(input_ids=input_ids, attention_mask=attention_mask)[0]
64
  if index == None:
65
+ # index = np.expand_dims(np.arange(input_ids.shape[1])
66
  # by default explain the class with the highest score
67
  index = output.argmax(axis=-1).detach().cpu().numpy()
68
 
69
  # create a one-hot vector selecting class we want explanations for
70
+ one_hot = (
71
+ torch.nn.functional.one_hot(
72
+ torch.tensor(index, dtype=torch.int64), num_classes=output.size(-1)
73
+ )
74
+ .to(torch.float)
75
+ .requires_grad_(True)
76
+ ).to(device)
77
  print("ONE_HOT", one_hot.size(), one_hot)
78
  one_hot = torch.sum(one_hot * output)
79
  model.zero_grad()
 
94
  R += joint
95
  return output, R[:, 0, 1:-1]
96
 
97
+
98
  def visualize_text(datarecords, legend=True):
99
  dom = ["<table width: 100%>"]
100
  rows = [
 
116
  )
117
  ),
118
  visualization.format_classname(datarecord.attr_class),
119
+ visualization.format_classname(
120
+ "{0:.2f}".format(datarecord.attr_score)
121
+ ),
122
  visualization.format_word_importances(
123
  datarecord.raw_input_ids, datarecord.word_attributions
124
  ),
 
150
 
151
  return html
152
 
153
+
154
  def show_explanation(model, input_ids, attention_mask, index=None, start_layer=0):
155
  # generate an explanation for the input
156
+ output, expl = generate_relevance(
157
+ model, input_ids, attention_mask, index=index, start_layer=start_layer
158
+ )
159
  print(output.shape, expl.shape)
160
  # normalize scores
161
  scaler = PyTMinMaxScalerVectorized()
 
164
  # get the model classification
165
  output = torch.nn.functional.softmax(output, dim=-1)
166
 
 
167
  vis_data_records = []
168
  for record in range(input_ids.size(0)):
169
  classification = output[record].argmax(dim=-1).item()
 
173
  # if the classification is negative, higher explanation scores are more negative
174
  # flip for visualization
175
  if class_name == "NEGATIVE":
176
+ nrm *= -1
177
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[record].flatten())[
178
+ 1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
179
+ ]
180
  print([(tokens[i], nrm[i].item()) for i in range(len(tokens))])
181
+ vis_data_records.append(
182
+ visualization.VisualizationDataRecord(
183
+ nrm,
184
+ output[record][classification],
185
+ classification,
186
+ classification,
187
+ index,
188
+ 1,
189
+ tokens,
190
+ 1,
191
+ )
192
+ )
193
  return visualize_text(vis_data_records)
194
 
195
+
196
  def run(input_text):
197
  text_batch = [input_text]
198
+ encoding = tokenizer(text_batch, return_tensors="pt")
199
+ input_ids = encoding["input_ids"].to(device)
200
+ attention_mask = encoding["attention_mask"].to(device)
201
 
202
  # true class is positive - 1
203
  true_class = 1
 
205
  html = show_explanation(model, input_ids, attention_mask)
206
  return html
207
 
208
+
209
+ iface = gradio.Interface(
210
+ fn=run,
211
+ inputs="text",
212
+ outputs="html",
213
+ title="RoBERTa Explanability",
214
+ description="Quick demo of a version of [Hila Chefer's](https://github.com/hila-chefer) [Transformer-Explanability](https://github.com/hila-chefer/Transformer-Explainability/) but without the layerwise relevance propagation (as in [Transformer-MM_explainability](https://github.com/hila-chefer/Transformer-MM-Explainability/)) for a RoBERTa model.",
215
+ examples=[
216
+ [
217
+ "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great"
218
+ ],
219
+ [
220
+ "I really didn't like this movie. Some of the actors were good, but overall the movie was boring"
221
+ ],
222
+ ],
223
+ )
224
  iface.launch()