Gabriela Nicole Gonzalez Saez commited on
Commit
056bbdc
1 Parent(s): fc37a00
Files changed (4) hide show
  1. app.py +114 -0
  2. bertviz_gradio.py +248 -0
  3. plotsjs_bertviz.js +430 -0
  4. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import inseq
3
+ import captum
4
+
5
+ import torch
6
+ import os
7
+ # import nltk
8
+ import argparse
9
+ import random
10
+ import numpy as np
11
+
12
+ from argparse import Namespace
13
+ from tqdm.notebook import tqdm
14
+ from torch.utils.data import DataLoader
15
+ from functools import partial
16
+
17
+ from transformers import AutoTokenizer, MarianTokenizer, AutoModel, AutoModelForSeq2SeqLM, MarianMTModel
18
+
19
+ from bertviz import model_view, head_view
20
+ from bertviz_gradio import head_view_mod
21
+
22
+
23
+ def get_bertvis_data(input_text, lg_model):
24
+ tokenizer_tr = dict_tokenizer_tr[lg_model]
25
+ model_tr = dict_models_tr[lg_model]
26
+
27
+ input_ids = tokenizer_tr(input_text, return_tensors="pt", padding=True)
28
+ result_att = model_tr.generate(**input_ids,
29
+ return_dict_in_generate=True,
30
+ output_attentions =True,
31
+ output_scores=True,
32
+ )
33
+
34
+ # tokenizer_tr.convert_ids_to_tokens(result_att.sequences[0])
35
+ # tokenizer_tr.convert_ids_to_tokens(input_ids.input_ids[0])
36
+
37
+ tgt_text = tokenizer_tr.decode(result_att.sequences[0], skip_special_tokens=True)
38
+
39
+ print(tgt_text)
40
+ outputs = model_tr(input_ids=input_ids.input_ids,
41
+ decoder_input_ids=result_att.sequences,
42
+ output_attentions =True,
43
+ )
44
+ print(tokenizer_tr.convert_ids_to_tokens(result_att.sequences[0]))
45
+ # print(tokenizer_tr.convert_ids_to_tokens(input_ids.input_ids[0]), tokenizer_tr.convert_ids_to_tokens(result_att.sequences[0]))
46
+ html_attentions = head_view_mod(
47
+ encoder_attention = outputs.encoder_attentions,
48
+ cross_attention = outputs.cross_attentions,
49
+ decoder_attention = outputs.decoder_attentions,
50
+ encoder_tokens = tokenizer_tr.convert_ids_to_tokens(input_ids.input_ids[0]),
51
+ decoder_tokens = tokenizer_tr.convert_ids_to_tokens(result_att.sequences[0]),
52
+ html_action='gradio'
53
+ )
54
+ return html_attentions, tgt_text
55
+
56
+
57
+
58
+ ## First create html and divs
59
+ html = """
60
+ <html>
61
+ <script async src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>
62
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/2.0.0/jquery.min"></script>
63
+ <script async data-require="d3@3.5.3" data-semver="3.5.3" src="//cdnjs.cloudflare.com/ajax/libs/d3/3.5.3/d3.js"></script>
64
+
65
+ <body>
66
+ <div id="bertviz"></div>
67
+ <div id="d3_beam_search"></div>
68
+ </body>
69
+ </html>
70
+ """
71
+
72
+ def sentence_maker(w1, model, var2={}):
73
+ #translate and get internal values
74
+ params,tgt = get_bertvis_data(w1, model)
75
+ ### get translation
76
+
77
+ return [tgt, params['params'],params['html2'].data]
78
+
79
+ def sentence_maker2(w1,j2):
80
+ # json_value = {'one':1}
81
+ # return f"{w1['two']} in sentence22..."
82
+ print(w1,j2)
83
+ return "in sentence22..."
84
+
85
+
86
+ with gr.Blocks(js="plotsjs_bertviz.js") as demo:
87
+ gr.Markdown("""
88
+ # MAKE NMT Workshop \t `BertViz` \n
89
+ https://github.com/jessevig/bertviz
90
+ """)
91
+ with gr.Row():
92
+ with gr.Column(scale=1):
93
+ in_text = gr.Textbox(label="Source Text")
94
+ out_text = gr.Textbox(label="Target Text")
95
+ out_text2 = gr.Textbox(visible=False)
96
+ var2 = gr.JSON(visible=False)
97
+ btn = gr.Button("Create sentence.")
98
+ radio_c = gr.Radio(choices=['en-zh', 'en-es', 'en-fr'], value="en-zh", label= '', container=False)
99
+
100
+
101
+ with gr.Column(scale=4):
102
+ gr.Markdown("Attentions: ")
103
+ input_mic = gr.HTML(html)
104
+ out_html = gr.HTML()
105
+ btn.click(sentence_maker, [in_text,radio_c], [out_text,var2,out_html], js="(in_text,radio_c) => testFn_out(in_text,radio_c)") #should return an output comp.
106
+ out_text.change(sentence_maker2, [out_text, var2], out_text2, js="(out_text,var2) => testFn_out_json(var2)") #
107
+ # out_text.change(sentence_maker2, [out_text, var2], out_text2, js="(out_text,var2) => testFn_out_json(var2)") #
108
+
109
+
110
+ # run script function on load,
111
+ # demo.load(None,None,None,js="plotsjs.js")
112
+
113
+ if __name__ == "__main__":
114
+ demo.launch()
bertviz_gradio.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import os
4
+ import uuid
5
+
6
+ from IPython.core.display import display, HTML, Javascript
7
+
8
+ from bertviz.util import format_special_chars, format_attention, num_layers
9
+
10
+
11
+ def head_view_mod(
12
+ attention=None,
13
+ tokens=None,
14
+ sentence_b_start=None,
15
+ prettify_tokens=True,
16
+ layer=None,
17
+ heads=None,
18
+ encoder_attention=None,
19
+ decoder_attention=None,
20
+ cross_attention=None,
21
+ encoder_tokens=None,
22
+ decoder_tokens=None,
23
+ include_layers=None,
24
+ html_action='view'
25
+ ):
26
+ """Render head view
27
+
28
+ Args:
29
+ For self-attention models:
30
+ attention: list of ``torch.FloatTensor``(one for each layer) of shape
31
+ ``(batch_size(must be 1), num_heads, sequence_length, sequence_length)``
32
+ tokens: list of tokens
33
+ sentence_b_start: index of first wordpiece in sentence B if input text is sentence pair (optional)
34
+ For encoder-decoder models:
35
+ encoder_attention: list of ``torch.FloatTensor``(one for each layer) of shape
36
+ ``(batch_size(must be 1), num_heads, encoder_sequence_length, encoder_sequence_length)``
37
+ decoder_attention: list of ``torch.FloatTensor``(one for each layer) of shape
38
+ ``(batch_size(must be 1), num_heads, decoder_sequence_length, decoder_sequence_length)``
39
+ cross_attention: list of ``torch.FloatTensor``(one for each layer) of shape
40
+ ``(batch_size(must be 1), num_heads, decoder_sequence_length, encoder_sequence_length)``
41
+ encoder_tokens: list of tokens for encoder input
42
+ decoder_tokens: list of tokens for decoder input
43
+ For all models:
44
+ prettify_tokens: indicates whether to remove special characters in wordpieces, e.g. Ġ
45
+ layer: index (zero-based) of initial selected layer in visualization. Defaults to layer 0.
46
+ heads: Indices (zero-based) of initial selected heads in visualization. Defaults to all heads.
47
+ include_layers: Indices (zero-based) of layers to include in visualization. Defaults to all layers.
48
+ Note: filtering layers may improve responsiveness of the visualization for long inputs.
49
+ html_action: Specifies the action to be performed with the generated HTML object
50
+ - 'view' (default): Displays the generated HTML representation as a notebook cell output
51
+ - 'return' : Returns an HTML object containing the generated view for further processing or custom visualization
52
+ """
53
+
54
+ attn_data = []
55
+ if attention is not None:
56
+ if tokens is None:
57
+ raise ValueError("'tokens' is required")
58
+ if encoder_attention is not None or decoder_attention is not None or cross_attention is not None \
59
+ or encoder_tokens is not None or decoder_tokens is not None:
60
+ raise ValueError("If you specify 'attention' you may not specify any encoder-decoder arguments. This"
61
+ " argument is only for self-attention models.")
62
+ if include_layers is None:
63
+ include_layers = list(range(num_layers(attention)))
64
+ attention = format_attention(attention, include_layers)
65
+ if sentence_b_start is None:
66
+ attn_data.append(
67
+ {
68
+ 'name': None,
69
+ 'attn': attention.tolist(),
70
+ 'left_text': tokens,
71
+ 'right_text': tokens
72
+ }
73
+ )
74
+ else:
75
+ slice_a = slice(0, sentence_b_start) # Positions corresponding to sentence A in input
76
+ slice_b = slice(sentence_b_start, len(tokens)) # Position corresponding to sentence B in input
77
+ attn_data.append(
78
+ {
79
+ 'name': 'All',
80
+ 'attn': attention.tolist(),
81
+ 'left_text': tokens,
82
+ 'right_text': tokens
83
+ }
84
+ )
85
+ attn_data.append(
86
+ {
87
+ 'name': 'Sentence A -> Sentence A',
88
+ 'attn': attention[:, :, slice_a, slice_a].tolist(),
89
+ 'left_text': tokens[slice_a],
90
+ 'right_text': tokens[slice_a]
91
+ }
92
+ )
93
+ attn_data.append(
94
+ {
95
+ 'name': 'Sentence B -> Sentence B',
96
+ 'attn': attention[:, :, slice_b, slice_b].tolist(),
97
+ 'left_text': tokens[slice_b],
98
+ 'right_text': tokens[slice_b]
99
+ }
100
+ )
101
+ attn_data.append(
102
+ {
103
+ 'name': 'Sentence A -> Sentence B',
104
+ 'attn': attention[:, :, slice_a, slice_b].tolist(),
105
+ 'left_text': tokens[slice_a],
106
+ 'right_text': tokens[slice_b]
107
+ }
108
+ )
109
+ attn_data.append(
110
+ {
111
+ 'name': 'Sentence B -> Sentence A',
112
+ 'attn': attention[:, :, slice_b, slice_a].tolist(),
113
+ 'left_text': tokens[slice_b],
114
+ 'right_text': tokens[slice_a]
115
+ }
116
+ )
117
+ elif encoder_attention is not None or decoder_attention is not None or cross_attention is not None:
118
+ if encoder_attention is not None:
119
+ if encoder_tokens is None:
120
+ raise ValueError("'encoder_tokens' required if 'encoder_attention' is not None")
121
+ if include_layers is None:
122
+ include_layers = list(range(num_layers(encoder_attention)))
123
+ encoder_attention = format_attention(encoder_attention, include_layers)
124
+ attn_data.append(
125
+ {
126
+ 'name': 'Encoder',
127
+ 'attn': encoder_attention.tolist(),
128
+ 'left_text': encoder_tokens,
129
+ 'right_text': encoder_tokens
130
+ }
131
+ )
132
+ if decoder_attention is not None:
133
+ if decoder_tokens is None:
134
+ raise ValueError("'decoder_tokens' required if 'decoder_attention' is not None")
135
+ if include_layers is None:
136
+ include_layers = list(range(num_layers(decoder_attention)))
137
+ decoder_attention = format_attention(decoder_attention, include_layers)
138
+ attn_data.append(
139
+ {
140
+ 'name': 'Decoder',
141
+ 'attn': decoder_attention.tolist(),
142
+ 'left_text': decoder_tokens,
143
+ 'right_text': decoder_tokens
144
+ }
145
+ )
146
+ if cross_attention is not None:
147
+ if encoder_tokens is None:
148
+ raise ValueError("'encoder_tokens' required if 'cross_attention' is not None")
149
+ if decoder_tokens is None:
150
+ raise ValueError("'decoder_tokens' required if 'cross_attention' is not None")
151
+ if include_layers is None:
152
+ include_layers = list(range(num_layers(cross_attention)))
153
+ cross_attention = format_attention(cross_attention, include_layers)
154
+ attn_data.append(
155
+ {
156
+ 'name': 'Cross',
157
+ 'attn': cross_attention.tolist(),
158
+ 'left_text': decoder_tokens,
159
+ 'right_text': encoder_tokens
160
+ }
161
+ )
162
+ else:
163
+ raise ValueError("You must specify at least one attention argument.")
164
+
165
+ if layer is not None and layer not in include_layers:
166
+ raise ValueError(f"Layer {layer} is not in include_layers: {include_layers}")
167
+
168
+ # Generate unique div id to enable multiple visualizations in one notebook
169
+ # vis_id = 'bertviz-%s'%(uuid.uuid4().hex)
170
+ vis_id = 'bertviz'#-%s'%(uuid.uuid4().hex)
171
+
172
+ # Compose html
173
+ if len(attn_data) > 1:
174
+ options = '\n'.join(
175
+ f'<option value="{i}">{attn_data[i]["name"]}</option>'
176
+ for i, d in enumerate(attn_data)
177
+ )
178
+ select_html = f'Attention: <select id="filter">{options}</select>'
179
+ else:
180
+ select_html = ""
181
+ vis_html = f"""
182
+ <div id="{vis_id}" style="font-family:'Helvetica Neue', Helvetica, Arial, sans-serif;">
183
+ <span style="user-select:none">
184
+ Layer: <select id="layer"></select>
185
+ {select_html}
186
+ </span>
187
+ <div id='vis'></div>
188
+ </div>
189
+ """
190
+
191
+ for d in attn_data:
192
+ attn_seq_len_left = len(d['attn'][0][0])
193
+ if attn_seq_len_left != len(d['left_text']):
194
+ raise ValueError(
195
+ f"Attention has {attn_seq_len_left} positions, while number of tokens is {len(d['left_text'])} "
196
+ f"for tokens: {' '.join(d['left_text'])}"
197
+ )
198
+ attn_seq_len_right = len(d['attn'][0][0][0])
199
+ if attn_seq_len_right != len(d['right_text']):
200
+ raise ValueError(
201
+ f"Attention has {attn_seq_len_right} positions, while number of tokens is {len(d['right_text'])} "
202
+ f"for tokens: {' '.join(d['right_text'])}"
203
+ )
204
+ if prettify_tokens:
205
+ d['left_text'] = format_special_chars(d['left_text'])
206
+ d['right_text'] = format_special_chars(d['right_text'])
207
+ params = {
208
+ 'attention': attn_data,
209
+ 'default_filter': "0",
210
+ 'root_div_id': vis_id,
211
+ 'layer': layer,
212
+ 'heads': heads,
213
+ 'include_layers': include_layers
214
+ }
215
+
216
+ # require.js must be imported for Colab or JupyterLab:
217
+
218
+ if html_action == 'gradio':
219
+ html1 = HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>')
220
+ html2 = HTML(vis_html)
221
+
222
+ return {'html1': html1, 'html2' : html2, 'params': params }
223
+
224
+
225
+ if html_action == 'view':
226
+ display(HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>'))
227
+ display(HTML(vis_html))
228
+ __location__ = os.path.realpath(
229
+ os.path.join(os.getcwd(), os.path.dirname(__file__)))
230
+ vis_js = open(os.path.join(__location__, 'head_view.js')).read().replace("PYTHON_PARAMS", json.dumps(params))
231
+ display(Javascript(vis_js))
232
+
233
+ elif html_action == 'return':
234
+ html1 = HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>')
235
+
236
+ html2 = HTML(vis_html)
237
+
238
+ __location__ = os.path.realpath(
239
+ os.path.join(os.getcwd(), os.path.dirname(__file__)))
240
+ vis_js = open(os.path.join(__location__, 'head_view.js')).read().replace("PYTHON_PARAMS", json.dumps(params))
241
+ html3 = Javascript(vis_js)
242
+ script = '\n<script type="text/javascript">\n' + html3.data + '\n</script>\n'
243
+
244
+ head_html = HTML(html1.data + html2.data + script)
245
+ return head_html
246
+
247
+ else:
248
+ raise ValueError("'html_action' parameter must be 'view' or 'return")
plotsjs_bertviz.js ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ async () => {
5
+ // set testFn() function on globalThis, so you html onlclick can access it
6
+
7
+
8
+ globalThis.testFn = () => {
9
+ document.getElementById('demo').innerHTML = "Hello-bertviz?"
10
+ };
11
+
12
+ // await import * as mod from "/my-module.js";
13
+
14
+ const d3 = await import("https://cdn.jsdelivr.net/npm/d3@5/+esm");
15
+ const $ = await import("https://cdn.jsdelivr.net/npm/jquery@3.7.1/dist/jquery.min.js");
16
+
17
+ globalThis.$ = $;
18
+
19
+ // const $ = await import("https://cdn.jsdelivr.net/npm/jquery@2/+esm");
20
+ // import $ from "jquery";
21
+ // import * as d3 from "https://cdn.jsdelivr.net/npm/d3@7/+esm";
22
+ // await import("https://cdn.jsdelivr.net/npm/jquery@2/+esm");
23
+
24
+ // export for others scripts to use
25
+ // window.$ = window.jQuery = jQuery;
26
+
27
+ // const d3 = await import("https://cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min");
28
+ // const $ = await import("https://cdnjs.cloudflare.com/ajax/libs/jquery/2.0.0/jquery.min");
29
+
30
+ globalThis.d3Fn = () => {
31
+ d3.select('#viz').append('svg')
32
+ .append('rect')
33
+ .attr('width', 50)
34
+ .attr('height', 50)
35
+ .attr('fill', 'black')
36
+ .on('mouseover', function(){d3.select(this).attr('fill', 'red')})
37
+ .on('mouseout', function(){d3.select(this).attr('fill', 'black')});
38
+
39
+ };
40
+
41
+
42
+
43
+ //
44
+
45
+ globalThis.testFn_out = (val,model) => {
46
+ // document.getElementById('demo').innerHTML = val
47
+ console.log(val);
48
+ // globalThis.d3Fn();
49
+ return([val,model]);
50
+ };
51
+
52
+ globalThis.testFn_out_json = (data) => {
53
+ console.log(data);
54
+ var $ = jQuery;
55
+ console.log($('#viz'));
56
+
57
+ attViz(data);
58
+
59
+ return(['string', {}])
60
+
61
+ };
62
+
63
+
64
+
65
+ function attViz(PYTHON_PARAMS) {
66
+ var $ = jQuery;
67
+ const params = PYTHON_PARAMS; // HACK: PYTHON_PARAMS is a template marker that is replaced by actual params.
68
+ const TEXT_SIZE = 15;
69
+ const BOXWIDTH = 110;
70
+ const BOXHEIGHT = 22.5;
71
+ const MATRIX_WIDTH = 115;
72
+ const CHECKBOX_SIZE = 20;
73
+ const TEXT_TOP = 30;
74
+
75
+ console.log("d3 version in ffuntions", d3.version)
76
+ let headColors;
77
+ try {
78
+ headColors = d3.scaleOrdinal(d3.schemeCategory10);
79
+ } catch (err) {
80
+ console.log('Older d3 version')
81
+ headColors = d3.scale.category10();
82
+ }
83
+ let config = {};
84
+ // globalThis.
85
+ initialize();
86
+ renderVis();
87
+
88
+ function initialize() {
89
+ // globalThis.initialize = () => {
90
+
91
+ console.log("init")
92
+ config.attention = params['attention'];
93
+ config.filter = params['default_filter'];
94
+ config.rootDivId = params['root_div_id'];
95
+ config.nLayers = config.attention[config.filter]['attn'].length;
96
+ config.nHeads = config.attention[config.filter]['attn'][0].length;
97
+ config.layers = params['include_layers']
98
+
99
+ if (params['heads']) {
100
+ config.headVis = new Array(config.nHeads).fill(false);
101
+ params['heads'].forEach(x => config.headVis[x] = true);
102
+ } else {
103
+ config.headVis = new Array(config.nHeads).fill(true);
104
+ }
105
+ config.initialTextLength = config.attention[config.filter].right_text.length;
106
+ config.layer_seq = (params['layer'] == null ? 0 : config.layers.findIndex(layer => params['layer'] === layer));
107
+ config.layer = config.layers[config.layer_seq]
108
+
109
+ // '#' + temp1.root_div_id+ ' #layer'
110
+ $('#' + config.rootDivId+ ' #layer').empty();
111
+
112
+ let layerEl = $('#' + config.rootDivId+ ' #layer');
113
+ console.log(layerEl)
114
+ for (const layer of config.layers) {
115
+ layerEl.append($("<option />").val(layer).text(layer));
116
+ }
117
+ layerEl.val(config.layer).change();
118
+ layerEl.on('change', function (e) {
119
+ config.layer = +e.currentTarget.value;
120
+ config.layer_seq = config.layers.findIndex(layer => config.layer === layer);
121
+ renderVis();
122
+ });
123
+
124
+ $('#'+config.rootDivId+' #filter').on('change', function (e) {
125
+ // $(`#${config.rootDivId} #filter`).on('change', function (e) {
126
+
127
+ config.filter = e.currentTarget.value;
128
+ renderVis();
129
+ });
130
+ }
131
+
132
+ function renderVis() {
133
+
134
+ // Load parameters
135
+ const attnData = config.attention[config.filter];
136
+ const leftText = attnData.left_text;
137
+ const rightText = attnData.right_text;
138
+
139
+ // Select attention for given layer
140
+ const layerAttention = attnData.attn[config.layer_seq];
141
+
142
+ // Clear vis
143
+ $('#'+config.rootDivId+' #vis').empty();
144
+
145
+ // Determine size of visualization
146
+ const height = Math.max(leftText.length, rightText.length) * BOXHEIGHT + TEXT_TOP;
147
+ const svg = d3.select('#'+ config.rootDivId +' #vis')
148
+ .append('svg')
149
+ .attr("width", "100%")
150
+ .attr("height", height + "px");
151
+
152
+ // Display tokens on left and right side of visualization
153
+ renderText(svg, leftText, true, layerAttention, 0);
154
+ renderText(svg, rightText, false, layerAttention, MATRIX_WIDTH + BOXWIDTH);
155
+
156
+ // Render attention arcs
157
+ renderAttention(svg, layerAttention);
158
+
159
+ // Draw squares at top of visualization, one for each head
160
+ drawCheckboxes(0, svg, layerAttention);
161
+ }
162
+
163
+ function renderText(svg, text, isLeft, attention, leftPos) {
164
+
165
+ const textContainer = svg.append("svg:g")
166
+ .attr("id", isLeft ? "left" : "right");
167
+
168
+ // Add attention highlights superimposed over words
169
+ textContainer.append("g")
170
+ .classed("attentionBoxes", true)
171
+ .selectAll("g")
172
+ .data(attention)
173
+ .enter()
174
+ .append("g")
175
+ .attr("head-index", (d, i) => i)
176
+ .selectAll("rect")
177
+ .data(d => isLeft ? d : transpose(d)) // if right text, transpose attention to get right-to-left weights
178
+ .enter()
179
+ .append("rect")
180
+ .attr("x", function () {
181
+ var headIndex = +this.parentNode.getAttribute("head-index");
182
+ return leftPos + boxOffsets(headIndex);
183
+ })
184
+ .attr("y", (+1) * BOXHEIGHT)
185
+ .attr("width", BOXWIDTH / activeHeads())
186
+ .attr("height", BOXHEIGHT)
187
+ .attr("fill", function () {
188
+ return headColors(+this.parentNode.getAttribute("head-index"))
189
+ })
190
+ .style("opacity", 0.0);
191
+
192
+ const tokenContainer = textContainer.append("g").selectAll("g")
193
+ .data(text)
194
+ .enter()
195
+ .append("g");
196
+
197
+ // Add gray background that appears when hovering over text
198
+ tokenContainer.append("rect")
199
+ .classed("background", true)
200
+ .style("opacity", 0.0)
201
+ .attr("fill", "lightgray")
202
+ .attr("x", leftPos)
203
+ .attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT)
204
+ .attr("width", BOXWIDTH)
205
+ .attr("height", BOXHEIGHT);
206
+
207
+ // Add token text
208
+ const textEl = tokenContainer.append("text")
209
+ .text(d => d)
210
+ .attr("font-size", TEXT_SIZE + "px")
211
+ .style("cursor", "default")
212
+ .style("-webkit-user-select", "none")
213
+ .attr("x", leftPos)
214
+ .attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT);
215
+
216
+ if (isLeft) {
217
+ textEl.style("text-anchor", "end")
218
+ .attr("dx", BOXWIDTH - 0.5 * TEXT_SIZE)
219
+ .attr("dy", TEXT_SIZE);
220
+ } else {
221
+ textEl.style("text-anchor", "start")
222
+ .attr("dx", +0.5 * TEXT_SIZE)
223
+ .attr("dy", TEXT_SIZE);
224
+ }
225
+
226
+ tokenContainer.on("mouseover", function (d, index) {
227
+
228
+ // Show gray background for moused-over token
229
+ textContainer.selectAll(".background")
230
+ .style("opacity", (d, i) => i === index ? 1.0 : 0.0)
231
+
232
+ // Reset visibility attribute for any previously highlighted attention arcs
233
+ svg.select("#attention")
234
+ .selectAll("line[visibility='visible']")
235
+ .attr("visibility", null)
236
+
237
+ // Hide group containing attention arcs
238
+ svg.select("#attention").attr("visibility", "hidden");
239
+
240
+ // Set to visible appropriate attention arcs to be highlighted
241
+ if (isLeft) {
242
+ svg.select("#attention").selectAll("line[left-token-index='" + index + "']").attr("visibility", "visible");
243
+ } else {
244
+ svg.select("#attention").selectAll("line[right-token-index='" + index + "']").attr("visibility", "visible");
245
+ }
246
+
247
+ // Update color boxes superimposed over tokens
248
+ const id = isLeft ? "right" : "left";
249
+ const leftPos = isLeft ? MATRIX_WIDTH + BOXWIDTH : 0;
250
+ svg.select("#" + id)
251
+ .selectAll(".attentionBoxes")
252
+ .selectAll("g")
253
+ .attr("head-index", (d, i) => i)
254
+ .selectAll("rect")
255
+ .attr("x", function () {
256
+ const headIndex = +this.parentNode.getAttribute("head-index");
257
+ return leftPos + boxOffsets(headIndex);
258
+ })
259
+ .attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT)
260
+ .attr("width", BOXWIDTH / activeHeads())
261
+ .attr("height", BOXHEIGHT)
262
+ .style("opacity", function (d) {
263
+ const headIndex = +this.parentNode.getAttribute("head-index");
264
+ if (config.headVis[headIndex])
265
+ if (d) {
266
+ return d[index];
267
+ } else {
268
+ return 0.0;
269
+ }
270
+ else
271
+ return 0.0;
272
+ });
273
+ });
274
+
275
+ textContainer.on("mouseleave", function () {
276
+
277
+ // Unhighlight selected token
278
+ d3.select(this).selectAll(".background")
279
+ .style("opacity", 0.0);
280
+
281
+ // Reset visibility attributes for previously selected lines
282
+ svg.select("#attention")
283
+ .selectAll("line[visibility='visible']")
284
+ .attr("visibility", null) ;
285
+ svg.select("#attention").attr("visibility", "visible");
286
+
287
+ // Reset highlights superimposed over tokens
288
+ svg.selectAll(".attentionBoxes")
289
+ .selectAll("g")
290
+ .selectAll("rect")
291
+ .style("opacity", 0.0);
292
+ });
293
+ }
294
+
295
+ function renderAttention(svg, attention) {
296
+
297
+ // Remove previous dom elements
298
+ svg.select("#attention").remove();
299
+
300
+ // Add new elements
301
+ svg.append("g")
302
+ .attr("id", "attention") // Container for all attention arcs
303
+ .selectAll(".headAttention")
304
+ .data(attention)
305
+ .enter()
306
+ .append("g")
307
+ .classed("headAttention", true) // Group attention arcs by head
308
+ .attr("head-index", (d, i) => i)
309
+ .selectAll(".tokenAttention")
310
+ .data(d => d)
311
+ .enter()
312
+ .append("g")
313
+ .classed("tokenAttention", true) // Group attention arcs by left token
314
+ .attr("left-token-index", (d, i) => i)
315
+ .selectAll("line")
316
+ .data(d => d)
317
+ .enter()
318
+ .append("line")
319
+ .attr("x1", BOXWIDTH)
320
+ .attr("y1", function () {
321
+ const leftTokenIndex = +this.parentNode.getAttribute("left-token-index")
322
+ return TEXT_TOP + leftTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2)
323
+ })
324
+ .attr("x2", BOXWIDTH + MATRIX_WIDTH)
325
+ .attr("y2", (d, rightTokenIndex) => TEXT_TOP + rightTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2))
326
+ .attr("stroke-width", 2)
327
+ .attr("stroke", function () {
328
+ const headIndex = +this.parentNode.parentNode.getAttribute("head-index");
329
+ return headColors(headIndex)
330
+ })
331
+ .attr("left-token-index", function () {
332
+ return +this.parentNode.getAttribute("left-token-index")
333
+ })
334
+ .attr("right-token-index", (d, i) => i)
335
+ ;
336
+ updateAttention(svg)
337
+ }
338
+
339
+ function updateAttention(svg) {
340
+ svg.select("#attention")
341
+ .selectAll("line")
342
+ .attr("stroke-opacity", function (d) {
343
+ const headIndex = +this.parentNode.parentNode.getAttribute("head-index");
344
+ // If head is selected
345
+ if (config.headVis[headIndex]) {
346
+ // Set opacity to attention weight divided by number of active heads
347
+ return d / activeHeads()
348
+ } else {
349
+ return 0.0;
350
+ }
351
+ })
352
+ }
353
+
354
+ function boxOffsets(i) {
355
+ const numHeadsAbove = config.headVis.reduce(
356
+ function (acc, val, cur) {
357
+ return val && cur < i ? acc + 1 : acc;
358
+ }, 0);
359
+ return numHeadsAbove * (BOXWIDTH / activeHeads());
360
+ }
361
+
362
+ function activeHeads() {
363
+ return config.headVis.reduce(function (acc, val) {
364
+ return val ? acc + 1 : acc;
365
+ }, 0);
366
+ }
367
+
368
+ function drawCheckboxes(top, svg) {
369
+ const checkboxContainer = svg.append("g");
370
+ const checkbox = checkboxContainer.selectAll("rect")
371
+ .data(config.headVis)
372
+ .enter()
373
+ .append("rect")
374
+ .attr("fill", (d, i) => headColors(i))
375
+ .attr("x", (d, i) => i * CHECKBOX_SIZE)
376
+ .attr("y", top)
377
+ .attr("width", CHECKBOX_SIZE)
378
+ .attr("height", CHECKBOX_SIZE);
379
+
380
+ function updateCheckboxes() {
381
+ checkboxContainer.selectAll("rect")
382
+ .data(config.headVis)
383
+ .attr("fill", (d, i) => d ? headColors(i): lighten(headColors(i)));
384
+ }
385
+
386
+ updateCheckboxes();
387
+
388
+ checkbox.on("click", function (d, i) {
389
+ if (config.headVis[i] && activeHeads() === 1) return;
390
+ config.headVis[i] = !config.headVis[i];
391
+ updateCheckboxes();
392
+ updateAttention(svg);
393
+ });
394
+
395
+ checkbox.on("dblclick", function (d, i) {
396
+ // If we double click on the only active head then reset
397
+ if (config.headVis[i] && activeHeads() === 1) {
398
+ config.headVis = new Array(config.nHeads).fill(true);
399
+ } else {
400
+ config.headVis = new Array(config.nHeads).fill(false);
401
+ config.headVis[i] = true;
402
+ }
403
+ updateCheckboxes();
404
+ updateAttention(svg);
405
+ });
406
+ }
407
+
408
+ function lighten(color) {
409
+ const c = d3.hsl(color);
410
+ const increment = (1 - c.l) * 0.6;
411
+ c.l += increment;
412
+ c.s -= increment;
413
+ return c;
414
+ }
415
+
416
+ function transpose(mat) {
417
+ return mat[0].map(function (col, i) {
418
+ return mat.map(function (row) {
419
+ return row[i];
420
+ });
421
+ });
422
+ }
423
+
424
+ }
425
+ // );
426
+
427
+
428
+
429
+ }
430
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ inseq
2
+ bertviz