dar-tau commited on
Commit
0577153
1 Parent(s): 392fdcd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -9,8 +9,8 @@ dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
9
  def analyze_sentence(index):
10
  row = dataset[index]
11
  text = row['text']
12
- _, seq_len, _ = row['attention_maps_shape']
13
-
14
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape)
15
  plot = sns.heatmap(attn_maps.sum(1).sum(0))
16
  plt.xticks(np.arange(seq_len - 1) + 0.5,
 
9
  def analyze_sentence(index):
10
  row = dataset[index]
11
  text = row['text']
12
+ attn_map_shape = row['attention_maps_shape'][1:]
13
+ seq_len = attn_map_shape[1]
14
  attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape)
15
  plot = sns.heatmap(attn_maps.sum(1).sum(0))
16
  plt.xticks(np.arange(seq_len - 1) + 0.5,