dar-tau commited on
Commit
c75ae48
1 Parent(s): f7076c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -4,13 +4,17 @@ from datasets import load_dataset
4
  from transformers import AutoTokenizer
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
 
 
 
 
7
 
8
 
9
  dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
10
  tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
11
 
12
 
13
- def analyze_sentence(index):
14
  row = dataset[index]
15
  text = row['text']
16
  tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
@@ -18,21 +22,29 @@ def analyze_sentence(index):
18
  seq_len = attn_map_shape[1]
19
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
20
  fig = plt.figure(figsize=(0.5 + 0.5 * len(tokenized), 0.4 * len(tokenized)))
21
- sns.heatmap(attn_maps.sum(0)[1:, 1:])
 
 
 
22
  plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
23
  plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
24
  plt.ylabel('TARGET')
25
  plt.xlabel('SOURCE')
26
  plt.grid()
 
27
  return fig
28
 
29
  demo = gr.Blocks()
30
  with demo:
31
  with gr.Row():
32
- dropdown = gr.Dropdown(choices=dataset['text'], value=0, min_width=750, type='index')
 
 
 
33
  btn = gr.Button("Run")
34
  output = gr.Plot(label="Plot", container=True)
35
- btn.click(analyze_sentence, [dropdown], [output])
 
36
 
37
 
38
  if __name__ == "__main__":
 
4
  from transformers import AutoTokenizer
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
+ from enum import Enum
8
+
9
+ class VisType(Enum):
10
+ ALL = 'ALL'
11
 
12
 
13
  dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
14
  tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
15
 
16
 
17
+ def analyze_sentence(index, vis_type):
18
  row = dataset[index]
19
  text = row['text']
20
  tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
 
22
  seq_len = attn_map_shape[1]
23
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
24
  fig = plt.figure(figsize=(0.5 + 0.5 * len(tokenized), 0.4 * len(tokenized)))
25
+ plot_data = attn_maps[:, 1:, 1:]
26
+ if vis_type == VisType.ALL:
27
+ plot_data = attn_maps.sum(0)
28
+ sns.heatmap(plot_data)
29
  plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
30
  plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
31
  plt.ylabel('TARGET')
32
  plt.xlabel('SOURCE')
33
  plt.grid()
34
+ metrics = {k: v for k, v in record.items() if x not in ['text', 'attention_maps', 'attention_maps_shape']}
35
  return fig
36
 
37
  demo = gr.Blocks()
38
  with demo:
39
  with gr.Row():
40
+ sentence_dropdown = gr.Dropdown(label="Sentence",
41
+ choices=[x.split('</s> ')[1] for x in dataset['text']],
42
+ value=0, min_width=500, type='index')
43
+ vis_dropdown = gr.Dropdown(label="Visualization", choices=list(VisType), value=0, type='index')
44
  btn = gr.Button("Run")
45
  output = gr.Plot(label="Plot", container=True)
46
+ metrics = gr.Label("Metrics")
47
+ btn.click(analyze_sentence, [sentence_dropdown, vis_dropdown], [output, metrics])
48
 
49
 
50
  if __name__ == "__main__":