dar-tau commited on
Commit
392fdcd
1 Parent(s): fac5648

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -8,11 +8,14 @@ dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
8
 
9
  def analyze_sentence(index):
10
  row = dataset[index]
11
- attn_maps = np.array(row['attention_maps']).reshape(*row['attention_maps_shape'])
 
 
 
12
  plot = sns.heatmap(attn_maps.sum(1).sum(0))
13
- plt.xticks(np.arange(len(tokenized)-1) + 0.5,
14
  tokenizer.tokenize(text, add_special_tokens=False), rotation=90);
15
- plt.yticks(np.arange(len(tokenized)-1) + 0.5,
16
  tokenizer.tokenize(text, add_special_tokens=False), rotation=0);
17
  plt.ylabel('TARGET')
18
  plt.xlabel('SOURCE')
 
8
 
9
  def analyze_sentence(index):
10
  row = dataset[index]
11
+ text = row['text']
12
+ _, seq_len, _ = row['attention_maps_shape']
13
+
14
+ attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape)
15
  plot = sns.heatmap(attn_maps.sum(1).sum(0))
16
+ plt.xticks(np.arange(seq_len - 1) + 0.5,
17
  tokenizer.tokenize(text, add_special_tokens=False), rotation=90);
18
+ plt.yticks(np.arange(seq_len - 1) + 0.5,
19
  tokenizer.tokenize(text, add_special_tokens=False), rotation=0);
20
  plt.ylabel('TARGET')
21
  plt.xlabel('SOURCE')