Martijn van Beers commited on
Commit
9f74b46
1 Parent(s): 8821877

Initial implementation

Browse files
Files changed (5) hide show
  1. app.py +154 -0
  2. description.md +4 -0
  3. examples.csv +3 -0
  4. notice.md +2 -0
  5. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import pandas
3
+ import seaborn
4
+ import gradio
5
+ import pathlib
6
+
7
+ import torch
8
+ import matplotlib
9
+ import matplotlib.pyplot as plt
10
+ import numpy
11
+ from sklearn.metrics.pairwise import cosine_distances
12
+
13
+ from transformers import (
14
+ AutoConfig,
15
+ AutoTokenizer,
16
+ AutoModelForSequenceClassification, AutoModelForMaskedLM
17
+ )
18
+
19
+ ## Rollout Helper Function
20
+ def compute_joint_attention(att_mat, res=True):
21
+ if res:
22
+ residual_att = numpy.eye(att_mat.shape[1])[None,...]
23
+ att_mat = att_mat + residual_att
24
+ att_mat = att_mat / att_mat.sum(axis=-1)[...,None]
25
+
26
+ joint_attentions = numpy.zeros(att_mat.shape)
27
+ layers = joint_attentions.shape[0]
28
+ joint_attentions[0] = att_mat[0]
29
+ for i in numpy.arange(1,layers):
30
+ joint_attentions[i] = att_mat[i].dot(joint_attentions[i-1])
31
+
32
+ return joint_attentions
33
+
34
+ def create_plot(all_tokens, score_data):
35
+ LAYERS = list(range(12))
36
+ fig, axs = plt.subplots(6, 2, figsize=(8, 24))
37
+ plt.subplots_adjust(top=0.98, bottom=0.05, hspace=0.5, wspace=0.5)
38
+ for layer in LAYERS:
39
+ a = (layer)//2
40
+ b = layer%2
41
+ seaborn.heatmap(
42
+ ax=axs[a, b],
43
+ data=pandas.DataFrame(score_data[layer], index= all_tokens, columns=all_tokens),
44
+ cmap="Blues",
45
+ annot=False,
46
+ cbar=False
47
+ )
48
+ axs[a, b].set_title(f"Layer: {layer+1}")
49
+ return fig
50
+
51
+ matplotlib.use('agg')
52
+
53
+ DISTANCE_FUNC = {
54
+ 'cosine': cosine_distances
55
+ }
56
+ MODEL_PATH = {
57
+ 'bert': 'bert-base-uncased',
58
+ 'roberta': 'roberta-base',
59
+ }
60
+
61
+ MODEL_NAME = 'bert'
62
+ #MODEL_NAME = 'roberta'
63
+ METRIC = 'cosine'
64
+
65
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
66
+ config = AutoConfig.from_pretrained(MODEL_PATH[MODEL_NAME])
67
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH[MODEL_NAME])
68
+ model = AutoModelForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config).to(device)
69
+
70
+
71
+ def run(mname, sent):
72
+ global MODEL_NAME, config, model, tokenizer
73
+ if mname != MODEL_NAME:
74
+ MODEL_NAME = mname
75
+ config = AutoConfig.from_pretrained(MODEL_PATH[MODEL_NAME])
76
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH[MODEL_NAME])
77
+ model = AutoModelForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config).to(device)
78
+ sent = re.sub(r".MASK.", tokenizer.mask_token, sent)
79
+ inputs = tokenizer(sent, return_token_type_ids=True, return_tensors="pt")
80
+
81
+ ## Cpmpute: layerwise value zeroing
82
+ inputs = {k: v.to(device) for k, v in inputs.items()}
83
+ with torch.no_grad():
84
+ outputs = model(inputs['input_ids'],
85
+ attention_mask=inputs['attention_mask'],
86
+ token_type_ids=inputs['token_type_ids'],
87
+ output_hidden_states=True, output_attentions=False)
88
+
89
+ org_hidden_states = torch.stack(outputs['hidden_states']).squeeze(1)
90
+ input_shape = inputs['input_ids'].size()
91
+ batch_size, seq_length = input_shape
92
+
93
+ score_matrix = numpy.zeros((config.num_hidden_layers, seq_length, seq_length))
94
+ for l, layer_module in enumerate(getattr(model, MODEL_NAME).encoder.layer):
95
+ for t in range(seq_length):
96
+ extended_blanking_attention_mask: torch.Tensor = getattr(model, MODEL_NAME).get_extended_attention_mask(inputs['attention_mask'], input_shape, device)
97
+ with torch.no_grad():
98
+ layer_outputs = layer_module(org_hidden_states[l].unsqueeze(0), # previous layer's original output
99
+ attention_mask=extended_blanking_attention_mask,
100
+ output_attentions=False,
101
+ zero_value_index=t,
102
+ )
103
+ hidden_states = layer_outputs[0].squeeze().detach().cpu().numpy()
104
+ # compute similarity between original and new outputs
105
+ # cosine
106
+ x = hidden_states
107
+ y = org_hidden_states[l+1].detach().cpu().numpy()
108
+
109
+ distances = DISTANCE_FUNC[METRIC](x, y).diagonal()
110
+ score_matrix[l, :, t] = distances
111
+
112
+ valuezeroing_scores = score_matrix / numpy.sum(score_matrix, axis=-1, keepdims=True)
113
+ rollout_valuezeroing_scores = compute_joint_attention(valuezeroing_scores, res=False)
114
+
115
+
116
+ # Plot:
117
+ cmap = "Blues"
118
+ all_tokens = [tokenizer.convert_ids_to_tokens(t) for t in inputs['input_ids']]
119
+ rollout_fig = create_plot(all_tokens, rollout_valuezeroing_scores)
120
+ value_fig = create_plot(all_tokens, valuezeroing_scores)
121
+
122
+ return rollout_fig, value_fig
123
+
124
+ examples = pandas.read_csv("examples.csv").to_numpy().tolist()
125
+
126
+ with gradio.Blocks(
127
+ title="Differences with/without zero-valuing",
128
+ css= ".output-image > img {height: 2000px !important; max-height: none !important;} "
129
+ ) as iface:
130
+ gradio.Markdown(pathlib.Path("description.md").read_text)
131
+ with gradio.Row(equal_height=True):
132
+ with gradio.Column(scale=4):
133
+ sent = gradio.Textbox(label="Input sentence")
134
+ with gradio.Column(scale=1):
135
+ model_choice = gradio.Dropdown(choices=['bert', 'roberta'], value="bert")
136
+ but = gradio.Button("Submit")
137
+ gradio.Examples(examples, [sent])
138
+ with gradio.Row(equal_height=True):
139
+ with gradio.Column():
140
+ gradio.Markdown("### With Rollout")
141
+ rollout_result = gradio.Plot()
142
+ with gradio.Column():
143
+ gradio.Markdown("### Without Rollout")
144
+ value_result = gradio.Plot()
145
+ with gradio.Accordion("Some more details"):
146
+ gradio.Markdown(pathlib.Path("notice.md").read_text)
147
+
148
+ but.click(run,
149
+ inputs=[model_choice, sent],
150
+ outputs=[rollout_result, value_result]
151
+ )
152
+
153
+
154
+ iface.launch()
description.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Value Zeroing
2
+
3
+ Demo of the effect of value-zeroing (Hosein, 2022) both with Attention Rollout (Abnar & Zuidema, 2020)
4
+ and without.
examples.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ sentence
2
+ "You either win the game or you [MASK] the game."
3
+ "The author talked to Sarah about [MASK] book."
notice.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ * Shown on the left are the results after applying attention rollout, as defined by Abnar & Zuidema (2020)
2
+ * On the left the results before.
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ git+https://github.com/martijnvanbeers/transformers@feature/transformer-explainability
3
+ pandas
4
+ seaborn
5
+ matplotlib
6
+ numpy
7
+ scikit-learn