Upload 11 files
Browse files- .gitattributes +6 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_featurizer +3 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_indices +1 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_inverse_featurizer +3 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_featurizer +3 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_indices +1 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_inverse_featurizer +3 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_featurizer +3 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_indices +1 -0
- 4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_inverse_featurizer +3 -0
- featurizer.py +52 -0
- token_position.py +65 -0
.gitattributes
CHANGED
@@ -53,3 +53,9 @@ ioi_submission/ioi_task_Gemma2ForCausalLM_output_token/AttentionHead(Layer-7,Hea
|
|
53 |
ioi_submission/ioi_task_Gemma2ForCausalLM_output_token/AttentionHead(Layer-7,Head-6,Token-all)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
|
54 |
ioi_submission/ioi_task_Gemma2ForCausalLM_output_token/AttentionHead(Layer-8,Head-1,Token-all)_featurizer filter=lfs diff=lfs merge=lfs -text
|
55 |
ioi_submission/ioi_task_Gemma2ForCausalLM_output_token/AttentionHead(Layer-8,Head-1,Token-all)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
ioi_submission/ioi_task_Gemma2ForCausalLM_output_token/AttentionHead(Layer-7,Head-6,Token-all)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
|
54 |
ioi_submission/ioi_task_Gemma2ForCausalLM_output_token/AttentionHead(Layer-8,Head-1,Token-all)_featurizer filter=lfs diff=lfs merge=lfs -text
|
55 |
ioi_submission/ioi_task_Gemma2ForCausalLM_output_token/AttentionHead(Layer-8,Head-1,Token-all)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
|
56 |
+
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_featurizer filter=lfs diff=lfs merge=lfs -text
|
57 |
+
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
|
58 |
+
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_featurizer filter=lfs diff=lfs merge=lfs -text
|
59 |
+
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
|
60 |
+
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_featurizer filter=lfs diff=lfs merge=lfs -text
|
61 |
+
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_featurizer
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dab9bccd2ea775eb56ad98fa9bac02c8d6a170d41d3a58b4b300c7e97eb80af8
|
3 |
+
size 21531300
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_indices
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
null
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_inverse_featurizer
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4a5f69e8af1b271494715d6d2cf3936a9f1897065b5cd7a1e35417c0eb19a665
|
3 |
+
size 21531356
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_featurizer
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d3ca3b99e9badc80119a4d711f60f35caf610ae7a8bcf08689385b490a197c0
|
3 |
+
size 21531349
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_indices
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
null
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_inverse_featurizer
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:630c183b185a53d826f9e17f6932dfa7e7d1011d8fb8435bc29b42fb1ac45189
|
3 |
+
size 21531533
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_featurizer
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8483d8ff87d3a188f542bcf17e545d63bb2039a644249982d82ef8c45a65964e
|
3 |
+
size 21531208
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_indices
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
null
|
4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_inverse_featurizer
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f02db1835842cbffc192c82973dc8d08dcc9b2f5667f57ac8f11c7af32c684b8
|
3 |
+
size 21531328
|
featurizer.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copy of the existing SubspaceFeaturizer implementation for submission.
|
3 |
+
This file provides the same SubspaceFeaturizer functionality in a self-contained format.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import pyvene as pv
|
9 |
+
from CausalAbstraction.neural.featurizers import Featurizer
|
10 |
+
|
11 |
+
|
12 |
+
class SubspaceFeaturizerModuleCopy(torch.nn.Module):
|
13 |
+
def __init__(self, rotate_layer):
|
14 |
+
super().__init__()
|
15 |
+
self.rotate = rotate_layer
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
r = self.rotate.weight.T
|
19 |
+
f = x.to(r.dtype) @ r.T
|
20 |
+
error = x - (f @ r).to(x.dtype)
|
21 |
+
return f, error
|
22 |
+
|
23 |
+
|
24 |
+
class SubspaceInverseFeaturizerModuleCopy(torch.nn.Module):
|
25 |
+
def __init__(self, rotate_layer):
|
26 |
+
super().__init__()
|
27 |
+
self.rotate = rotate_layer
|
28 |
+
|
29 |
+
def forward(self, f, error):
|
30 |
+
r = self.rotate.weight.T
|
31 |
+
return (f.to(r.dtype) @ r).to(f.dtype) + error.to(f.dtype)
|
32 |
+
|
33 |
+
|
34 |
+
class SubspaceFeaturizerCopy(Featurizer):
|
35 |
+
def __init__(self, shape=None, rotation_subspace=None, trainable=True, id="subspace"):
|
36 |
+
assert shape is not None or rotation_subspace is not None, "Either shape or rotation_subspace must be provided."
|
37 |
+
if shape is not None:
|
38 |
+
self.rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=True)
|
39 |
+
elif rotation_subspace is not None:
|
40 |
+
shape = rotation_subspace.shape
|
41 |
+
self.rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=False)
|
42 |
+
self.rotate.weight.data.copy_(rotation_subspace)
|
43 |
+
self.rotate = torch.nn.utils.parametrizations.orthogonal(self.rotate)
|
44 |
+
|
45 |
+
if not trainable:
|
46 |
+
self.rotate.requires_grad_(False)
|
47 |
+
|
48 |
+
# Create module-based featurizer and inverse_featurizer
|
49 |
+
featurizer = SubspaceFeaturizerModuleCopy(self.rotate)
|
50 |
+
inverse_featurizer = SubspaceInverseFeaturizerModuleCopy(self.rotate)
|
51 |
+
|
52 |
+
super().__init__(featurizer, inverse_featurizer, n_features=self.rotate.weight.shape[1], id=id)
|
token_position.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Token position definitions for MCQA task submission.
|
3 |
+
This file provides token position functions that identify key tokens in MCQA prompts.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import re
|
7 |
+
from CausalAbstraction.neural.LM_units import TokenPosition, get_last_token_index
|
8 |
+
|
9 |
+
|
10 |
+
def get_token_positions(pipeline, causal_model):
|
11 |
+
"""
|
12 |
+
Get token positions for the simple MCQA task.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
pipeline: The language model pipeline with tokenizer
|
16 |
+
causal_model: The causal model for the task
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
list[TokenPosition]: List of TokenPosition objects for intervention experiments
|
20 |
+
"""
|
21 |
+
def get_correct_symbol_index(input, pipeline, causal_model):
|
22 |
+
"""
|
23 |
+
Find the index of the correct answer symbol in the prompt.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
input (Dict): The input dictionary to a causal model
|
27 |
+
pipeline: The tokenizer pipeline
|
28 |
+
causal_model: The causal model
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
list[int]: List containing the index of the correct answer symbol token
|
32 |
+
"""
|
33 |
+
# Run the model to get the answer position
|
34 |
+
output = causal_model.run_forward(input)
|
35 |
+
pointer = output["answer_pointer"]
|
36 |
+
correct_symbol = output[f"symbol{pointer}"]
|
37 |
+
prompt = input["raw_input"]
|
38 |
+
|
39 |
+
# Find all single uppercase letters in the prompt
|
40 |
+
matches = list(re.finditer(r"\b[A-Z]\b", prompt))
|
41 |
+
|
42 |
+
# Find the match corresponding to our correct symbol
|
43 |
+
symbol_match = None
|
44 |
+
for match in matches:
|
45 |
+
if prompt[match.start():match.end()] == correct_symbol:
|
46 |
+
symbol_match = match
|
47 |
+
break
|
48 |
+
|
49 |
+
if not symbol_match:
|
50 |
+
raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}")
|
51 |
+
|
52 |
+
# Get the substring up to the symbol match end
|
53 |
+
substring = prompt[:symbol_match.end()]
|
54 |
+
tokenized_substring = list(pipeline.load(substring)["input_ids"][0])
|
55 |
+
|
56 |
+
# The symbol token will be at the end of the substring
|
57 |
+
return [len(tokenized_substring) - 1]
|
58 |
+
|
59 |
+
# Create TokenPosition objects
|
60 |
+
token_positions = [
|
61 |
+
TokenPosition(lambda x: get_correct_symbol_index(x, pipeline, causal_model), pipeline, id="correct_symbol"),
|
62 |
+
TokenPosition(lambda x: [get_correct_symbol_index(x, pipeline, causal_model)[0]+1], pipeline, id="correct_symbol_period"),
|
63 |
+
TokenPosition(lambda x: get_last_token_index(x, pipeline), pipeline, id="last_token")
|
64 |
+
]
|
65 |
+
return token_positions
|