dar-tau commited on
Commit
17a2383
1 Parent(s): dae383f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -23,7 +23,7 @@ def analyze_sentence(index, vis_type):
23
  attn_map_shape = row['attention_maps_shape'][1:]
24
  seq_len = attn_map_shape[1]
25
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
26
- fig = plt.figure(figsize=(0.5 + 0.4 * len(tokenized), 0.35 * len(tokenized)))
27
  attn_maps = attn_maps[:, 1:, 1:]
28
  if vis_type == VisType.SUM.value:
29
  plot_data = attn_maps.sum(0)
@@ -39,7 +39,7 @@ def analyze_sentence(index, vis_type):
39
  plt.ylabel('TARGET')
40
  plt.xlabel('SOURCE')
41
  plt.grid()
42
- metrics = {'Metrics': 0}
43
  metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
44
  return fig, metrics
45
 
@@ -48,9 +48,9 @@ with demo:
48
  with gr.Row():
49
  sentence_dropdown = gr.Dropdown(label="Sentence",
50
  choices=[x.split('</s> ')[1] for x in dataset['text']],
51
- value=0, min_width=500, type='index')
52
  vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType] + [f'Layer #{i}' for i in range(num_layers)],
53
- min_width=150, value=VisType.SUM, type='value')
54
  btn = gr.Button("Run", min_width=30)
55
  output = gr.Plot(label="Plot", container=True)
56
  metrics = gr.Label("Metrics")
 
23
  attn_map_shape = row['attention_maps_shape'][1:]
24
  seq_len = attn_map_shape[1]
25
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
26
+ fig = plt.figure(figsize=(0.3 + 0.35 * len(tokenized), 0.3 * len(tokenized)))
27
  attn_maps = attn_maps[:, 1:, 1:]
28
  if vis_type == VisType.SUM.value:
29
  plot_data = attn_maps.sum(0)
 
39
  plt.ylabel('TARGET')
40
  plt.xlabel('SOURCE')
41
  plt.grid()
42
+ metrics = {'Metrics': 1}
43
  metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
44
  return fig, metrics
45
 
 
48
  with gr.Row():
49
  sentence_dropdown = gr.Dropdown(label="Sentence",
50
  choices=[x.split('</s> ')[1] for x in dataset['text']],
51
+ value=0, min_width=300, type='index')
52
  vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType] + [f'Layer #{i}' for i in range(num_layers)],
53
+ min_width=100, value=VisType.SUM, type='value')
54
  btn = gr.Button("Run", min_width=30)
55
  output = gr.Plot(label="Plot", container=True)
56
  metrics = gr.Label("Metrics")