Stefan Heimersheim commited on
Commit
ba97e60
1 Parent(s): eef2607

Fix x-axis bug

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -114,7 +114,8 @@ def imshow(tensor, xlabel="X", ylabel="Y", zlabel=None, xticks=None, yticks=None
114
 
115
  def plot_residual_stream_patch(clean_prompt=None, answer=None, corrupt_prompt=None, corrupt_answer=None):
116
  layers = ["blocks.0.hook_resid_pre", *[f"blocks.{i}.hook_resid_post" for i in range(model.cfg.n_layers)]]
117
- token_labels = model.to_str_tokens(clean_prompt)
 
118
  patching_effect = compute_residual_stream_patch(clean_prompt=clean_prompt, answer=answer, corrupt_prompt=corrupt_prompt, corrupt_answer=corrupt_answer, layers=layers)
119
  fig = imshow(patching_effect, xticks=token_labels, yticks=layers, xlabel="Position", ylabel="Layer",
120
  zlabel="Logit Difference", title="Patching residual stream at specific layer and position")
114
 
115
  def plot_residual_stream_patch(clean_prompt=None, answer=None, corrupt_prompt=None, corrupt_answer=None):
116
  layers = ["blocks.0.hook_resid_pre", *[f"blocks.{i}.hook_resid_post" for i in range(model.cfg.n_layers)]]
117
+ clean_tokens = model.to_str_tokens(clean_prompt)
118
+ token_labels = [f"(pos {i:2}) {t}" for i, t in enumerate(clean_tokens)]
119
  patching_effect = compute_residual_stream_patch(clean_prompt=clean_prompt, answer=answer, corrupt_prompt=corrupt_prompt, corrupt_answer=corrupt_answer, layers=layers)
120
  fig = imshow(patching_effect, xticks=token_labels, yticks=layers, xlabel="Position", ylabel="Layer",
121
  zlabel="Logit Difference", title="Patching residual stream at specific layer and position")