dar-tau commited on
Commit
60b5a72
·
verified ·
1 Parent(s): ee7b758

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -9
app.py CHANGED
@@ -24,7 +24,6 @@ def analyze_sentence(index, vis_type, vis_format):
24
  attn_map_shape = row['attention_maps_shape'][1:]
25
  seq_len = attn_map_shape[1]
26
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
27
- fig = plt.figure(figsize=(0.1 + 0.3 * len(tokenized), 0.25 * len(tokenized)))
28
  attn_maps = attn_maps[:, 1:, 1:]
29
  if vis_type == VisType.SUM.value:
30
  plot_data = attn_maps.sum(0)
@@ -34,15 +33,26 @@ def analyze_sentence(index, vis_type, vis_format):
34
  else:
35
  print(vis_type)
36
  0/0
37
- sns.heatmap(plot_data)
38
- plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
39
- plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
40
- plt.ylabel('TARGET')
41
- plt.xlabel('SOURCE')
42
- plt.grid()
 
 
 
 
 
 
 
 
 
 
43
  metrics = {'Metrics': 1}
 
44
  metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
45
- return fig, metrics
46
 
47
  demo = gr.Blocks()
48
  with demo:
@@ -57,8 +67,9 @@ with demo:
57
  vis_format_checkbox = gr.Radio(['Plot', 'Graph'])
58
 
59
  output = gr.Plot(label="Plot", container=True)
 
60
  metrics = gr.Label("Metrics")
61
- btn.click(analyze_sentence, [sentence_dropdown, vis_dropdown, vis_format_checkbox], [output, metrics])
62
 
63
 
64
  if __name__ == "__main__":
 
24
  attn_map_shape = row['attention_maps_shape'][1:]
25
  seq_len = attn_map_shape[1]
26
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
 
27
  attn_maps = attn_maps[:, 1:, 1:]
28
  if vis_type == VisType.SUM.value:
29
  plot_data = attn_maps.sum(0)
 
33
  else:
34
  print(vis_type)
35
  0/0
36
+ if vis_format == 'Plot':
37
+ fig = plt.figure(figsize=(0.1 + 0.3 * len(tokenized), 0.25 * len(tokenized)))
38
+ sns.heatmap(plot_data)
39
+ plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
40
+ plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
41
+ plt.ylabel('TARGET')
42
+ plt.xlabel('SOURCE')
43
+ plt.grid()
44
+ fig_output, graph_output = fig, ""
45
+ else:
46
+ ex = [{
47
+ "words": [{"text": x} for x in tokenized],
48
+ "arcs": [{"start": i, "end": j} for j in range(i) else range(len(tokenized))]
49
+ ]
50
+ graph_output = displacy.render(ex, style="dep", manual=True)
51
+ fig_output = None
52
  metrics = {'Metrics': 1}
53
+
54
  metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
55
+ return fig_output, graph_output, metrics
56
 
57
  demo = gr.Blocks()
58
  with demo:
 
67
  vis_format_checkbox = gr.Radio(['Plot', 'Graph'])
68
 
69
  output = gr.Plot(label="Plot", container=True)
70
+ graph_output = gr.HTML(label="Graph")
71
  metrics = gr.Label("Metrics")
72
+ btn.click(analyze_sentence, [sentence_dropdown, vis_dropdown, vis_format_checkbox], [output, graph_output, metrics])
73
 
74
 
75
  if __name__ == "__main__":