Aksel Lenes commited on
Commit
909a82d
1 Parent(s): 77c9ae7

Add fixed_scale parameter for grid plot view

Browse files

To see patterns in longer sequences.
Added a warning that this means each subplot in the grid has an
individual scale, so it means you canot compare
attention attensities between grid cells.

hexviz/pages/1_🗺️Identify_Interesting_Heads.py CHANGED
@@ -85,6 +85,9 @@ truncated_sequence = sequence[slice_start - 1 : slice_end]
85
  remove_special_tokens = st.sidebar.checkbox(
86
  "Hide attention to special tokens", key="remove_special_tokens"
87
  )
 
 
 
88
 
89
 
90
  layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
@@ -104,7 +107,7 @@ attention, tokens = get_attention(
104
  ec_number=ec_number,
105
  )
106
 
107
- fig = plot_tiled_heatmap(attention, layer_sequence=layer_sequence, head_sequence=head_sequence)
108
 
109
 
110
  st.pyplot(fig)
@@ -143,5 +146,5 @@ if len(tokens_to_label) > 0:
143
  tokens = [token if token in tokens_to_label else "" for token in tokens]
144
 
145
 
146
- single_head_fig = plot_single_heatmap(attention, layer, head, tokens=tokens)
147
  st.pyplot(single_head_fig)
 
85
  remove_special_tokens = st.sidebar.checkbox(
86
  "Hide attention to special tokens", key="remove_special_tokens"
87
  )
88
+ if "fixed_scale" not in st.session_state:
89
+ st.session_state.fixed_scale = True
90
+ fixed_scale = st.sidebar.checkbox("Fixed scale", help="For long sequences the default fixed 0 to 1 scale can have very low contrast heatmaps, consider using a relative scale to increase the contrast between high attention and low attention areas. Note that each subplot will have separate color scales so don't compare colors between attention heads if using a non-fixed scale.", key="fixed_scale")
91
 
92
 
93
  layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
 
107
  ec_number=ec_number,
108
  )
109
 
110
+ fig = plot_tiled_heatmap(attention, layer_sequence=layer_sequence, head_sequence=head_sequence, fixed_scale=fixed_scale)
111
 
112
 
113
  st.pyplot(fig)
 
146
  tokens = [token if token in tokens_to_label else "" for token in tokens]
147
 
148
 
149
+ single_head_fig = plot_single_heatmap(attention, layer, head, tokens=tokens, fixed_scale=fixed_scale)
150
  st.pyplot(single_head_fig)
hexviz/plot.py CHANGED
@@ -6,7 +6,7 @@ from matplotlib.ticker import FixedLocator
6
  from mpl_toolkits.axes_grid1 import make_axes_locatable
7
 
8
 
9
- def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[int]):
10
  tensor = tensor[layer_sequence, :][
11
  :, head_sequence, :, :
12
  ] # Slice the tensor according to the provided sequences and sequence_count
@@ -18,9 +18,14 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
18
  fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
19
  for i in range(num_layers):
20
  for j in range(num_heads):
21
- axes[i, j].imshow(
22
- tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal", vmin=0, vmax=1
23
- )
 
 
 
 
 
24
  axes[i, j].axis("off")
25
 
26
  # Enumerate the axes
@@ -33,7 +38,7 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
33
  row_label = f"{layer_sequence[i]+1}"
34
  row_pos = ax_row[num_heads - 1].get_position()
35
  fig.text(row_pos.x1 + offset, (row_pos.y1 + row_pos.y0) / 2, row_label, va="center")
36
-
37
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
38
  return fig
39
 
@@ -43,11 +48,15 @@ def plot_single_heatmap(
43
  layer: int,
44
  head: int,
45
  tokens: list[str],
 
46
  ):
47
  single_heatmap = tensor[layer, head, :, :].detach().numpy()
48
 
49
  fig, ax = plt.subplots(figsize=(10, 10))
50
- heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal", vmin=0, vmax=1)
 
 
 
51
 
52
  # Function to adjust font size based on the number of labels
53
  def get_font_size(labels):
 
6
  from mpl_toolkits.axes_grid1 import make_axes_locatable
7
 
8
 
9
+ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[int], fixed_scale: bool = True):
10
  tensor = tensor[layer_sequence, :][
11
  :, head_sequence, :, :
12
  ] # Slice the tensor according to the provided sequences and sequence_count
 
18
  fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
19
  for i in range(num_layers):
20
  for j in range(num_heads):
21
+ if fixed_scale:
22
+ im = axes[i, j].imshow(
23
+ tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal", vmin=0, vmax=1
24
+ )
25
+ else:
26
+ im = axes[i, j].imshow(
27
+ tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal"
28
+ )
29
  axes[i, j].axis("off")
30
 
31
  # Enumerate the axes
 
38
  row_label = f"{layer_sequence[i]+1}"
39
  row_pos = ax_row[num_heads - 1].get_position()
40
  fig.text(row_pos.x1 + offset, (row_pos.y1 + row_pos.y0) / 2, row_label, va="center")
41
+
42
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
43
  return fig
44
 
 
48
  layer: int,
49
  head: int,
50
  tokens: list[str],
51
+ fixed_scale : bool = True
52
  ):
53
  single_heatmap = tensor[layer, head, :, :].detach().numpy()
54
 
55
  fig, ax = plt.subplots(figsize=(10, 10))
56
+ if fixed_scale:
57
+ heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal", vmin=0, vmax=1)
58
+ else:
59
+ heatmap = ax.imshow(single_heatmap, cmap="viridis", aspect="equal")
60
 
61
  # Function to adjust font size based on the number of labels
62
  def get_font_size(labels):