yiksiu's picture
upload submission files
f118c1e verified
"""
Token position definitions (copied from MCQA task)
"""
import re
from CausalAbstraction.neural.LM_units import TokenPosition, get_last_token_index
def get_token_positions(pipeline, causal_model):
"""
Get token positions for the simple MCQA task.
Args:
pipeline: The language model pipeline with tokenizer
causal_model: The causal model for the task
Returns:
list[TokenPosition]: List of TokenPosition objects for intervention experiments
"""
def get_correct_symbol_index(input, pipeline, causal_model):
"""
Find the index of the correct answer symbol in the prompt.
Args:
input (Dict): The input dictionary to a causal model
pipeline: The tokenizer pipeline
causal_model: The causal model
Returns:
list[int]: List containing the index of the correct answer symbol token
"""
# Run the model to get the answer position
output = causal_model.run_forward(input)
pointer = output["answer_pointer"]
correct_symbol = output[f"symbol{pointer}"]
prompt = input["raw_input"]
# Find all single uppercase letters in the prompt
matches = list(re.finditer(r"\b[A-Z]\b", prompt))
# Find the match corresponding to our correct symbol
symbol_match = None
for match in matches:
if prompt[match.start():match.end()] == correct_symbol:
symbol_match = match
break
if not symbol_match:
raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}")
# Get the substring up to the symbol match end
substring = prompt[:symbol_match.end()]
tokenized_substring = list(pipeline.load(substring)["input_ids"][0])
# The symbol token will be at the end of the substring
return [len(tokenized_substring) - 1]
# Create TokenPosition objects
token_positions = [
TokenPosition(lambda x: get_correct_symbol_index(x, pipeline, causal_model), pipeline, id="correct_symbol"),
TokenPosition(lambda x: [get_correct_symbol_index(x, pipeline, causal_model)[0]+1], pipeline, id="correct_symbol_period"),
TokenPosition(lambda x: get_last_token_index(x, pipeline), pipeline, id="last_token")
]
return token_positions