Gabriela Nicole Gonzalez Saez commited on
Commit
9e85aff
1 Parent(s): b114ef2
Files changed (2) hide show
  1. app.py +209 -0
  2. plotsjs.js +264 -0
app.py CHANGED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import inseq
3
+ import captum
4
+
5
+ import torch
6
+ import os
7
+ # import nltk
8
+ import argparse
9
+ import random
10
+ import numpy as np
11
+
12
+ from argparse import Namespace
13
+ from tqdm.notebook import tqdm
14
+ from torch.utils.data import DataLoader
15
+ from functools import partial
16
+
17
+ from transformers import AutoTokenizer, MarianTokenizer, AutoModel, AutoModelForSeq2SeqLM, MarianMTModel
18
+
19
+
20
+
21
+ model_es = "Helsinki-NLP/opus-mt-en-es"
22
+ model_fr = "Helsinki-NLP/opus-mt-en-fr"
23
+ model_zh = "Helsinki-NLP/opus-mt-en-zh"
24
+
25
+ tokenizer_es = AutoTokenizer.from_pretrained(model_es)
26
+ tokenizer_fr = AutoTokenizer.from_pretrained(model_fr)
27
+ tokenizer_zh = AutoTokenizer.from_pretrained(model_zh)
28
+
29
+ model_tr_es = MarianMTModel.from_pretrained(model_es)
30
+ model_tr_fr = MarianMTModel.from_pretrained(model_fr)
31
+ model_tr_zh = MarianMTModel.from_pretrained(model_zh)
32
+
33
+ model_es = inseq.load_model("Helsinki-NLP/opus-mt-en-es", "input_x_gradient")
34
+ model_fr = inseq.load_model("Helsinki-NLP/opus-mt-en-fr", "input_x_gradient")
35
+ model_zh = inseq.load_model("Helsinki-NLP/opus-mt-en-zh", "input_x_gradient")
36
+
37
+
38
+ dict_models = {
39
+ 'en-es': model_es,
40
+ 'en-fr': model_fr,
41
+ 'en-zh': model_zh,
42
+ }
43
+
44
+ dict_models_tr = {
45
+ 'en-es': model_tr_es,
46
+ 'en-fr': model_tr_fr,
47
+ 'en-zh': model_tr_zh,
48
+ }
49
+
50
+ dict_tokenizer_tr = {
51
+ 'en-es': tokenizer_es,
52
+ 'en-fr': tokenizer_fr,
53
+ 'en-zh': tokenizer_zh,
54
+ }
55
+
56
+ saliency_examples = [
57
+ "Peace of Mind: Protection for consumers.",
58
+ "The sustainable development goals report: towards a rescue plan for people and planet",
59
+ "We will leave no stone unturned to hold those responsible to account.",
60
+ "The clock is now ticking on our work to finalise the remaining key legislative proposals presented by this Commission to ensure that citizens and businesses can reap the benefits of our policy actions.",
61
+ "Pumpkins, squash and gourds, fresh or chilled, excluding courgettes",
62
+ "The labour market participation of mothers with infants has even deteriorated over the past two decades, often impacting their career and incomes for years.",
63
+ ]
64
+
65
+ contrastive_examples = [
66
+ ["Peace of Mind: Protection for consumers.",
67
+ "Paz mental: protección de los consumidores",
68
+ "Paz de la mente: protección de los consumidores"],
69
+ ["the slaughterer has finished his work.",
70
+ "l'abatteur a terminé son travail.",
71
+ "l'abatteuse a terminé son travail."],
72
+ ['A fundamental shift is needed - in commitment, solidarity, financing and action - to put the world on a better path.',
73
+ '需要在承诺、团结、筹资和行动方面进行根本转变,使世界走上更美好的道路。',
74
+ '我们需要从根本上转变承诺、团结、资助和行动,使世界走上更美好的道路。',]
75
+ ]
76
+
77
+
78
+ def split_token_from_sequences(sequences, model) -> dict :
79
+ n_sentences = len(sequences)
80
+
81
+ gen_sequences_texts = []
82
+ for bs in range(n_sentences):
83
+ #### decoder per token.
84
+ gen_sequences_texts.append(dict_tokenizer_tr[model].decode(sequences[:, 1:][bs], skip_special_tokens=True).split(' '))
85
+ print(gen_sequences_texts)
86
+ score = 0
87
+
88
+ #raw dict is bos
89
+ text = 'bos'
90
+ new_id = text +'--1'
91
+ dict_parent = [{'id': new_id, 'parentId': None , 'text': text, 'name': 'bos', 'prob':score }]
92
+ id_dict_pos = {}
93
+ step_i = 0
94
+ cont = True
95
+ words_by_step = [] #[['bos' for i in range(n_sentences)]]
96
+
97
+ while cont:
98
+ # append to dict_parent for all beams of step_i
99
+ cont = False
100
+ step_words = []
101
+ for beam in range(n_sentences):
102
+ app_text = ''
103
+ if step_i < len(gen_sequences_texts[beam]):
104
+ app_text = gen_sequences_texts[beam][step_i]
105
+ cont = True
106
+ step_words.append(app_text)
107
+ words_by_step.append(step_words)
108
+ print(words_by_step)
109
+
110
+ for i_bs, step_w in enumerate(step_words):
111
+ if step_w != '':
112
+ #new id if the same word is not in another beam (?) [beam[i] was a token id]
113
+ #parent id = previous word and previous step.
114
+
115
+
116
+ # new_parent_id = "-".join([str(beam[i]) for i in range(step_i)])
117
+
118
+ new_id = "-".join([str(words_by_step[i][i_bs])+ '-' + str(i) for i in range(step_i+1)])
119
+ parent_id = "-".join([words_by_step[i][i_bs] + '-' + str(i) for i in range(step_i) ])
120
+
121
+ # new_id = step_w +'-' + str(step_i)
122
+ # parent_id = words_by_step[step_i-1][i_bs] + '-' + str(step_i -1)
123
+
124
+ if step_i == 0 :
125
+ parent_id = 'bos--1'
126
+ ## if the dict already exists remove it, if it is not a root...
127
+ ## root?? then next is ''
128
+ next_word_flag = len(gen_sequences_texts[i_bs][step_i]) > step_i
129
+ if next_word_flag:
130
+ if not (new_id in id_dict_pos):
131
+ dict_parent.append({'id': new_id, 'parentId': parent_id , 'text': step_w, 'name': step_w, 'prob' : score })
132
+ id_dict_pos[new_id] = len(dict_parent) - 1
133
+ else:
134
+ dict_parent.append({'id': new_id, 'parentId': parent_id , 'text': step_w, 'name': step_w, 'prob' : score })
135
+ id_dict_pos[new_id] = len(dict_parent) - 1
136
+
137
+ step_i += 1
138
+ return dict_parent
139
+
140
+
141
+ import gradio as gr
142
+
143
+ html = """
144
+ <html>
145
+ <script async src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>
146
+ <body>
147
+
148
+ <p id="demo"></p>
149
+ <p id="viz"></p>
150
+
151
+ <p id="demo2"></p>
152
+
153
+
154
+ <div id="d3_beam_search"></div>
155
+
156
+ </body>
157
+ </html>
158
+ """
159
+
160
+
161
+ def sentence_maker(w1, model, var2={}):
162
+ #translate and get internal values
163
+ # src_text = saliency_examples[0]
164
+ inputs = dict_tokenizer_tr[model](w1, return_tensors="pt")
165
+
166
+ num_ret_seq = 4
167
+ translated = dict_models_tr[model].generate(**inputs,
168
+ num_beams=4,
169
+ num_return_sequences=num_ret_seq,
170
+ return_dict_in_generate=True,
171
+ output_attentions =True,
172
+ output_hidden_states = True,
173
+ output_scores=True,)
174
+
175
+ beam_dict = split_token_from_sequences(translated.sequences,model )
176
+
177
+ tgt_text = dict_tokenizer_tr[model].decode(translated.sequences[0], skip_special_tokens=True)
178
+
179
+ return [tgt_text,beam_dict]
180
+
181
+ def sentence_maker2(w1,j2):
182
+ # json_value = {'one':1}
183
+ # return f"{w1['two']} in sentence22..."
184
+ print(w1,j2)
185
+ return "in sentence22..."
186
+
187
+
188
+ with gr.Blocks(js="plotsjs.js") as demo:
189
+ gr.Markdown(
190
+ """
191
+ # MAKE NMT Workshop \t `BeamSearch`
192
+ """)
193
+ in_text = gr.Textbox(label="source text")
194
+ out_text = gr.Textbox(label="target text")
195
+ out_text2 = gr.Textbox(visible=False)
196
+ var2 = gr.JSON(visible=False)
197
+ radio_c = gr.Radio(choices=['en-zh', 'en-es', 'en-fr'], value="en-zh", label= '', container=False)
198
+ btn = gr.Button("Translate")
199
+ input_mic = gr.HTML(html)
200
+
201
+
202
+ btn.click(sentence_maker, [in_text, radio_c], [out_text,var2], js="(in_text,radio_c) => testFn_out(in_text,radio_c)") #should return an output comp.
203
+ out_text.change(sentence_maker2, [out_text, var2], out_text2, js="(out_text,var2) => testFn_out_json(var2)") #
204
+
205
+ # run script function on load,
206
+ # demo.load(None,None,None,js="plotsjs.js")
207
+
208
+ if __name__ == "__main__":
209
+ demo.launch()
plotsjs.js ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ async () => {
4
+ // set testFn() function on globalThis, so you html onlclick can access it
5
+
6
+
7
+ globalThis.testFn = () => {
8
+ document.getElementById('demo').innerHTML = "Hello?"
9
+ };
10
+
11
+ const d3 = await import("https://cdn.jsdelivr.net/npm/d3@7/+esm");
12
+
13
+ globalThis.d3 = d3;
14
+
15
+ globalThis.d3Fn = () => {
16
+ d3.select('#viz').append('svg')
17
+ .append('rect')
18
+ .attr('width', 50)
19
+ .attr('height', 50)
20
+ .attr('fill', 'black')
21
+ .on('mouseover', function(){d3.select(this).attr('fill', 'red')})
22
+ .on('mouseout', function(){d3.select(this).attr('fill', 'black')});
23
+
24
+ };
25
+
26
+ globalThis.testFn_out = (val,radio_c) => {
27
+ // document.getElementById('demo').innerHTML = val
28
+ console.log(val);
29
+ // globalThis.d3Fn();
30
+ return([val,radio_c]);
31
+ };
32
+
33
+ // Is this function well commented???
34
+ // globalThis.testFn_out_json = (val) => {
35
+ // document.getElementById('demo2').innerHTML = JSON.stringify(val);
36
+ // console.log(val);
37
+ // globalThis.d3Fn();
38
+ // return(['string', {}])
39
+ // // return(JSON.stringify(val), JSON.stringify(val) );
40
+ // };
41
+
42
+
43
+ globalThis.testFn_out_json = (data) => {
44
+ const idMapping = data.reduce((acc, el, i) => {
45
+ acc[el.id] = i;
46
+ return acc;
47
+ }, {});
48
+
49
+ let root;
50
+ data.forEach(el => {
51
+ // Handle the root element
52
+ if (el.parentId === null) {
53
+ root = el;
54
+ return;
55
+ }
56
+ // Use our mapping to locate the parent element in our data array
57
+ const parentEl = data[idMapping[el.parentId]];
58
+ // Add our current el to its parent's `children` array
59
+ parentEl.children = [...(parentEl.children || []), el];
60
+ });
61
+
62
+ // console.log(Tree(root));
63
+ // document.getElementById('d3_beam_search').innerHTML = Tree(root)
64
+ d3.select('#d3_beam_search').html("");
65
+ d3.select('#d3_beam_search').append(function(){return Tree(root);});
66
+ // $('#d3_beam_search').html(Tree(root)) ;
67
+
68
+ return(['string', {}])
69
+
70
+ }
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+ // Copyright 2021 Observable, Inc.
79
+ // Released under the ISC license.
80
+ // https://observablehq.com/@d3/tree
81
+ function Tree(data, { // data is either tabular (array of objects) or hierarchy (nested objects)
82
+ path, // as an alternative to id and parentId, returns an array identifier, imputing internal nodes
83
+ id = Array.isArray(data) ? d => d.id : null, // if tabular data, given a d in data, returns a unique identifier (string)
84
+ parentId = Array.isArray(data) ? d => d.parentId : null, // if tabular data, given a node d, returns its parent’s identifier
85
+ children, // if hierarchical data, given a d in data, returns its children
86
+ tree = d3.tree, // layout algorithm (typically d3.tree or d3.cluster)
87
+ sort, // how to sort nodes prior to layout (e.g., (a, b) => d3.descending(a.height, b.height))
88
+ label = d => d.name, // given a node d, returns the display name
89
+ title = d => d.name, // given a node d, returns its hover text
90
+ link , // given a node d, its link (if any)
91
+ linkTarget = "_blank", // the target attribute for links (if any)
92
+ width = 800, // outer width, in pixels
93
+ height, // outer height, in pixels
94
+ r = 3, // radius of nodes
95
+ padding = 1, // horizontal padding for first and last column
96
+ fill = "#999", // fill for nodes
97
+ fillOpacity, // fill opacity for nodes
98
+ stroke = "#555", // stroke for links
99
+ strokeWidth = 2, // stroke width for links
100
+ strokeOpacity = 0.4, // stroke opacity for links
101
+ strokeLinejoin, // stroke line join for links
102
+ strokeLinecap, // stroke line cap for links
103
+ halo = "#fff", // color of label halo
104
+ haloWidth = 3, // padding around the labels
105
+ curve = d3.curveBumpX, // curve for the link
106
+ } = {}) {
107
+
108
+ // If id and parentId options are specified, or the path option, use d3.stratify
109
+ // to convert tabular data to a hierarchy; otherwise we assume that the data is
110
+ // specified as an object {children} with nested objects (a.k.a. the “flare.json”
111
+ // format), and use d3.hierarchy.
112
+ const root = path != null ? d3.stratify().path(path)(data)
113
+ : id != null || parentId != null ? d3.stratify().id(id).parentId(parentId)(data)
114
+ : d3.hierarchy(data, children);
115
+
116
+ // Sort the nodes.
117
+ if (sort != null) root.sort(sort);
118
+
119
+ // Compute labels and titles.
120
+ const descendants = root.descendants();
121
+ const L = label == null ? null : descendants.map(d => label(d.data, d));
122
+
123
+ // Compute the layout.
124
+ const descWidth = 10;
125
+ // console.log('descendants', descendants);
126
+ const realWidth = descWidth * descendants.length
127
+ const totalWidth = (realWidth > width) ? realWidth : width;
128
+
129
+ const dx = 25;
130
+ const dy = totalWidth / (root.height + padding);
131
+ tree().nodeSize([dx, dy])(root);
132
+
133
+ // Center the tree.
134
+ let x0 = Infinity;
135
+ let x1 = -x0;
136
+ root.each(d => {
137
+ if (d.x > x1) x1 = d.x;
138
+ if (d.x < x0) x0 = d.x;
139
+ });
140
+
141
+ // Compute the default height.
142
+ if (height === undefined) height = x1 - x0 + dx * 2;
143
+
144
+
145
+
146
+ // Use the required curve
147
+ if (typeof curve !== "function") throw new Error(`Unsupported curve`);
148
+
149
+ const parent = d3.create("div");
150
+
151
+ const body = parent.append("div")
152
+ .style("overflow-x", "scroll")
153
+ .style("-webkit-overflow-scrolling", "touch");
154
+
155
+ const svg = body.append("svg")
156
+ .attr("viewBox", [-dy * padding / 2, x0 - dx, totalWidth, height])
157
+ .attr("width", totalWidth)
158
+ .attr("height", height)
159
+ .attr("style", "max-width: 100%; height: auto; height: intrinsic;")
160
+ .attr("font-family", "sans-serif")
161
+ .attr("font-size", 12);
162
+
163
+ svg.append("g")
164
+ .attr("fill", "none")
165
+ .attr("stroke", stroke)
166
+ .attr("stroke-opacity", strokeOpacity)
167
+ .attr("stroke-linecap", strokeLinecap)
168
+ .attr("stroke-linejoin", strokeLinejoin)
169
+ .attr("stroke-width", strokeWidth)
170
+ .selectAll("path")
171
+ .data(root.links())
172
+ .join("path")
173
+ // .attr("stroke", d => d.prob > 0.5 ? 'red' : 'blue' )
174
+ // .attr("fill", "red")
175
+ .attr("d", d3.link(curve)
176
+ .x(d => d.y)
177
+ .y(d => d.x));
178
+
179
+ const node = svg.append("g")
180
+ .selectAll("a")
181
+ .data(root.descendants())
182
+ .join("a")
183
+ .attr("xlink:href", link == null ? null : d => link(d.data, d))
184
+ .attr("target", link == null ? null : linkTarget)
185
+ .attr("transform", d => `translate(${d.y},${d.x})`);
186
+
187
+ node.append("circle")
188
+ .attr("fill", d => d.children ? stroke : fill)
189
+ .attr("r", r);
190
+
191
+ title = d => (d.name + ( d.prob));
192
+
193
+ if (title != null) node.append("title")
194
+ .text(d => title(d.data, d));
195
+
196
+ if (L) node.append("text")
197
+ .attr("dy", "0.32em")
198
+ .attr("x", d => d.children ? -6 : 6)
199
+ .attr("text-anchor", d => d.children ? "end" : "start")
200
+ .attr("paint-order", "stroke")
201
+ .attr("stroke", 'white')
202
+ .attr("fill", d => d.data.prob == 1 ? ('red') : ('black') )
203
+ .attr("stroke-width", haloWidth)
204
+ .text((d, i) => L[i]);
205
+ body.node().scrollBy(totalWidth, 0);
206
+ return svg.node();
207
+ }
208
+
209
+
210
+
211
+
212
+
213
+ }
214
+
215
+
216
+
217
+
218
+
219
+ // define('viz', ['d3'], function (d3) {
220
+
221
+ // function draw(container) {
222
+ // d3.select(container).append("svg").append('rect').attr('id', 'viz_rect').attr('width', 50).attr('height', 50);
223
+ // }
224
+ // return draw;
225
+ // });
226
+
227
+ // console.log("HERE!")
228
+ // element.append('Loaded 😄 ');
229
+ // variable2='hello';
230
+
231
+ // draw('.gradio-container')
232
+
233
+
234
+ // function transform_beamsearch(data){
235
+ // const idMapping = data.reduce((acc, el, i) => {
236
+ // acc[el.id] = i;
237
+ // return acc;
238
+ // }, {});
239
+
240
+ // let root;
241
+ // data.forEach(el => {
242
+ // // Handle the root element
243
+ // if (el.parentId === null) {
244
+ // root = el;
245
+ // return;
246
+ // }
247
+ // // Use our mapping to locate the parent element in our data array
248
+ // const parentEl = data[idMapping[el.parentId]];
249
+ // // Add our current el to its parent's `children` array
250
+ // parentEl.children = [...(parentEl.children || []), el];
251
+ // });
252
+ // // console.log(Tree(root, { label: d => d.name,}));
253
+
254
+ // console.log(root);
255
+ // // $('#d3_beam_search').html(Tree(root)) ;
256
+ // return root;
257
+
258
+ // }
259
+
260
+
261
+ // var gradioContainer = document.querySelector('.gradio-container');
262
+
263
+
264
+ // gradioContainer.insertBefore(container, gradioContainer.firstChild);