Gabriela Nicole Gonzalez Saez commited on
Commit
e4bccbf
1 Parent(s): 9e85aff
Files changed (2) hide show
  1. app.py +35 -7
  2. plotsjs.js +140 -4
app.py CHANGED
@@ -16,8 +16,6 @@ from functools import partial
16
 
17
  from transformers import AutoTokenizer, MarianTokenizer, AutoModel, AutoModelForSeq2SeqLM, MarianMTModel
18
 
19
-
20
-
21
  model_es = "Helsinki-NLP/opus-mt-en-es"
22
  model_fr = "Helsinki-NLP/opus-mt-en-fr"
23
  model_zh = "Helsinki-NLP/opus-mt-en-zh"
@@ -75,6 +73,28 @@ contrastive_examples = [
75
  ]
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def split_token_from_sequences(sequences, model) -> dict :
79
  n_sentences = len(sequences)
80
 
@@ -138,7 +158,8 @@ def split_token_from_sequences(sequences, model) -> dict :
138
  return dict_parent
139
 
140
 
141
- import gradio as gr
 
142
 
143
  html = """
144
  <html>
@@ -149,9 +170,13 @@ html = """
149
  <p id="viz"></p>
150
 
151
  <p id="demo2"></p>
 
 
 
 
 
152
 
153
 
154
- <div id="d3_beam_search"></div>
155
 
156
  </body>
157
  </html>
@@ -175,16 +200,19 @@ def sentence_maker(w1, model, var2={}):
175
  beam_dict = split_token_from_sequences(translated.sequences,model )
176
 
177
  tgt_text = dict_tokenizer_tr[model].decode(translated.sequences[0], skip_special_tokens=True)
 
 
 
 
178
 
179
- return [tgt_text,beam_dict]
180
 
181
  def sentence_maker2(w1,j2):
182
- # json_value = {'one':1}
183
- # return f"{w1['two']} in sentence22..."
184
  print(w1,j2)
185
  return "in sentence22..."
186
 
187
 
 
188
  with gr.Blocks(js="plotsjs.js") as demo:
189
  gr.Markdown(
190
  """
 
16
 
17
  from transformers import AutoTokenizer, MarianTokenizer, AutoModel, AutoModelForSeq2SeqLM, MarianMTModel
18
 
 
 
19
  model_es = "Helsinki-NLP/opus-mt-en-es"
20
  model_fr = "Helsinki-NLP/opus-mt-en-fr"
21
  model_zh = "Helsinki-NLP/opus-mt-en-zh"
 
73
  ]
74
 
75
 
76
+ def get_k_prob_tokens(transition_scores, result, model, k_values=5):
77
+ tokenizer_tr = dict_tokenizer_tr[model]
78
+ gen_sequences = result.sequences[:, 1:]
79
+
80
+ result_output = []
81
+ # bs_alt = []
82
+ # bs_alt_scores = []
83
+
84
+ # First beam only...
85
+ bs = 0
86
+ text = ' '
87
+ for tok, score, i_step in zip(gen_sequences[bs], transition_scores[bs],range(len(gen_sequences[bs]))):
88
+ # bs_alt.append([tokenizer_tr.decode(tok) for tok in result.scores[i_step][bs].topk(k_values).indices ] )
89
+ # bs_alt_scores.append(np.exp(result.scores[i_step][bs].topk(k_values).values))
90
+
91
+ bs_alt = [tokenizer_tr.decode(tok) for tok in result.scores[i_step][bs].topk(k_values).indices ]
92
+ bs_alt_scores = np.exp(result.scores[i_step][bs].topk(k_values).values)
93
+ result_output.append([np.array(result.scores[i_step][bs].topk(k_values).indices), np.array(bs_alt_scores),bs_alt])
94
+
95
+ return result_output
96
+
97
+
98
  def split_token_from_sequences(sequences, model) -> dict :
99
  n_sentences = len(sequences)
100
 
 
158
  return dict_parent
159
 
160
 
161
+
162
+
163
 
164
  html = """
165
  <html>
 
170
  <p id="viz"></p>
171
 
172
  <p id="demo2"></p>
173
+ <h4> Exploring top-k probable tokens </h4>
174
+ <div id="d3_text_grid">... top 10 tokens generated at each step ...</div>
175
+
176
+ <h4> Exploring the Beam Search sequence generation</h4>
177
+ <div id="d3_beam_search">... top 4 generated sequences using Beam Search...</div>
178
 
179
 
 
180
 
181
  </body>
182
  </html>
 
200
  beam_dict = split_token_from_sequences(translated.sequences,model )
201
 
202
  tgt_text = dict_tokenizer_tr[model].decode(translated.sequences[0], skip_special_tokens=True)
203
+ transition_scores = dict_models_tr[model].compute_transition_scores(
204
+ translated.sequences, translated.scores, translated.beam_indices , normalize_logits=True
205
+ )
206
+ prob_tokens = get_k_prob_tokens(transition_scores, translated, model, k_values=10)
207
 
208
+ return [tgt_text,[beam_dict,prob_tokens]]
209
 
210
  def sentence_maker2(w1,j2):
 
 
211
  print(w1,j2)
212
  return "in sentence22..."
213
 
214
 
215
+
216
  with gr.Blocks(js="plotsjs.js") as demo:
217
  gr.Markdown(
218
  """
plotsjs.js CHANGED
@@ -41,20 +41,24 @@ async () => {
41
 
42
 
43
  globalThis.testFn_out_json = (data) => {
44
- const idMapping = data.reduce((acc, el, i) => {
 
 
 
 
45
  acc[el.id] = i;
46
  return acc;
47
  }, {});
48
 
49
  let root;
50
- data.forEach(el => {
51
  // Handle the root element
52
  if (el.parentId === null) {
53
  root = el;
54
  return;
55
  }
56
- // Use our mapping to locate the parent element in our data array
57
- const parentEl = data[idMapping[el.parentId]];
58
  // Add our current el to its parent's `children` array
59
  parentEl.children = [...(parentEl.children || []), el];
60
  });
@@ -63,6 +67,14 @@ async () => {
63
  // document.getElementById('d3_beam_search').innerHTML = Tree(root)
64
  d3.select('#d3_beam_search').html("");
65
  d3.select('#d3_beam_search').append(function(){return Tree(root);});
 
 
 
 
 
 
 
 
66
  // $('#d3_beam_search').html(Tree(root)) ;
67
 
68
  return(['string', {}])
@@ -206,6 +218,130 @@ function Tree(data, { // data is either tabular (array of objects) or hierarchy
206
  return svg.node();
207
  }
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
 
211
 
 
41
 
42
 
43
  globalThis.testFn_out_json = (data) => {
44
+ console.log(data);
45
+ data_beam = data[0];
46
+ data_probs = data[1];
47
+
48
+ const idMapping = data_beam.reduce((acc, el, i) => {
49
  acc[el.id] = i;
50
  return acc;
51
  }, {});
52
 
53
  let root;
54
+ data_beam.forEach(el => {
55
  // Handle the root element
56
  if (el.parentId === null) {
57
  root = el;
58
  return;
59
  }
60
+ // Use our mapping to locate the parent element in our data_beam array
61
+ const parentEl = data_beam[idMapping[el.parentId]];
62
  // Add our current el to its parent's `children` array
63
  parentEl.children = [...(parentEl.children || []), el];
64
  });
 
67
  // document.getElementById('d3_beam_search').innerHTML = Tree(root)
68
  d3.select('#d3_beam_search').html("");
69
  d3.select('#d3_beam_search').append(function(){return Tree(root);});
70
+
71
+ //probabilities;
72
+ //
73
+ d3.select('#d3_text_grid').html("");
74
+ d3.select('#d3_text_grid').append(function(){return TextGrid(data_probs);});
75
+ // $('#d3_text_grid').html(TextGrid(data)) ;
76
+
77
+
78
  // $('#d3_beam_search').html(Tree(root)) ;
79
 
80
  return(['string', {}])
 
218
  return svg.node();
219
  }
220
 
221
+ function TextGrid(data, div_name, {
222
+ width = 640, // outer width, in pixels
223
+ height , // outer height, in pixels
224
+ r = 3, // radius of nodes
225
+ padding = 1, // horizontal padding for first and last column
226
+ // text = d => d[2],
227
+ } = {}){
228
+ // console.log("TextGrid", data);
229
+
230
+ // Compute the layout.
231
+ const dx = 10;
232
+ const dy = 10; //width / (root.height + padding);
233
+
234
+ const marginTop = 20;
235
+ const marginRight = 20;
236
+ const marginBottom = 30;
237
+ const marginLeft = 30;
238
+
239
+ // Center the tree.
240
+ let x0 = Infinity;
241
+ let x1 = -x0;
242
+ topk = 10;
243
+ word_length = 20;
244
+ const rectWidth = 60;
245
+ const rectTotal = 70;
246
+
247
+ wval = 0
248
+
249
+ const realWidth = rectTotal * data.length
250
+ const totalWidth = (realWidth > width) ? realWidth : width;
251
+ // root.each(d => {
252
+ // if (d.x > x1) x1 = d.x;
253
+ // if (d.x < x0) x0 = d.x;
254
+ // });
255
+
256
+ // Compute the default height.
257
+ // if (height === undefined) height = x1 - x0 + dx * 2;
258
+ if (height === undefined) height = topk * word_length + 10;
259
+
260
+ const parent = d3.create("div");
261
+
262
+ // parent.append("svg")
263
+ // .attr("width", width)
264
+ // .attr("height", height)
265
+ // .style("position", "absolute")
266
+ // .style("pointer-events", "none")
267
+ // .style("z-index", 1);
268
+
269
+
270
+ // const svg = d3.create("svg")
271
+ // // svg = parent.append("svg")
272
+ // .attr("viewBox", [-dy * padding / 2, x0 - dx, width, height])
273
+ // .attr("width", width)
274
+ // .attr("height", height)
275
+ // .attr("style", "max-width: 100%; height: auto; height: intrinsic;")
276
+ // .attr("font-family", "sans-serif")
277
+ // .attr("font-size", 10);
278
+
279
+ // div.data([1, 2, 4, 8, 16, 32], d => d);
280
+ // div.enter().append("div").text(d => d);
281
+
282
+ const body = parent.append("div")
283
+ .style("overflow-x", "scroll")
284
+ .style("-webkit-overflow-scrolling", "touch");
285
+
286
+ const svg = body.append("svg")
287
+ .attr("width", totalWidth)
288
+ .attr("height", height)
289
+ .style("display", "block")
290
+ .attr("font-family", "sans-serif")
291
+ .attr("font-size", 10);
292
+
293
+
294
+ data.forEach(words_list => {
295
+ // console.log(wval, words_list);
296
+ words = words_list[2]; // {'t': words_list[2], 'p': words_list[1]};
297
+ scores = words_list[1];
298
+ words_score = words.map( (x,i) => {return {t: x, p: scores[i]}})
299
+ // console.log(words_score);
300
+ // svg.selectAll("text").enter()
301
+ // .data(words)
302
+ // .join("text")
303
+ // .text((d,i) => (d))
304
+ // .attr("x", wval)
305
+ // .attr("y", ((d,i) => (20 + i*20)))
306
+
307
+ var probs = svg.selectAll("text").enter()
308
+ .data(words_score).join('g');
309
+
310
+
311
+
312
+ probs.append("rect")
313
+ // .data(words)
314
+ .attr("x", wval)
315
+ .attr("y", ((d,i) => ( 10+ i*20)))
316
+ .attr('width', rectWidth)
317
+ .attr('height', 15)
318
+ .attr("color", 'gray')
319
+ .attr("fill", "gray")
320
+ // .attr("fill-opacity", "0.2")
321
+ .attr("fill-opacity", (d) => (d.p))
322
+ .attr("stroke-opacity", 0.8)
323
+ .append("svg:title")
324
+ .text(function(d){return d.t+":"+d.p;});
325
+
326
+
327
+ probs.append("text")
328
+ // .data(words)
329
+ .text((d,i) => (d.t))
330
+ .attr("x", wval)
331
+ .attr("y", ((d,i) => (20 + i*20)))
332
+ // .attr("fill", 'white')
333
+ .attr("font-weight", 700);
334
+
335
+ wval = wval + rectTotal;
336
+ });
337
+
338
+
339
+ body.node().scrollBy(totalWidth, 0);
340
+ // return svg.node();
341
+ return parent.node();
342
+ }
343
+
344
+
345
 
346
 
347