|
""" |
|
Token position definitions (copied from MCQA task) |
|
""" |
|
|
|
import re |
|
from CausalAbstraction.neural.LM_units import TokenPosition, get_last_token_index |
|
|
|
|
|
def get_token_positions(pipeline, causal_model): |
|
""" |
|
Get token positions for the simple MCQA task. |
|
|
|
Args: |
|
pipeline: The language model pipeline with tokenizer |
|
causal_model: The causal model for the task |
|
|
|
Returns: |
|
list[TokenPosition]: List of TokenPosition objects for intervention experiments |
|
""" |
|
def get_correct_symbol_index(input, pipeline, causal_model): |
|
""" |
|
Find the index of the correct answer symbol in the prompt. |
|
|
|
Args: |
|
input (Dict): The input dictionary to a causal model |
|
pipeline: The tokenizer pipeline |
|
causal_model: The causal model |
|
|
|
Returns: |
|
list[int]: List containing the index of the correct answer symbol token |
|
""" |
|
|
|
output = causal_model.run_forward(input) |
|
pointer = output["answer_pointer"] |
|
correct_symbol = output[f"symbol{pointer}"] |
|
prompt = input["raw_input"] |
|
|
|
|
|
matches = list(re.finditer(r"\b[A-Z]\b", prompt)) |
|
|
|
|
|
symbol_match = None |
|
for match in matches: |
|
if prompt[match.start():match.end()] == correct_symbol: |
|
symbol_match = match |
|
break |
|
|
|
if not symbol_match: |
|
raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}") |
|
|
|
|
|
substring = prompt[:symbol_match.end()] |
|
tokenized_substring = list(pipeline.load(substring)["input_ids"][0]) |
|
|
|
|
|
return [len(tokenized_substring) - 1] |
|
|
|
|
|
token_positions = [ |
|
TokenPosition(lambda x: get_correct_symbol_index(x, pipeline, causal_model), pipeline, id="correct_symbol"), |
|
TokenPosition(lambda x: [get_correct_symbol_index(x, pipeline, causal_model)[0]+1], pipeline, id="correct_symbol_period"), |
|
TokenPosition(lambda x: get_last_token_index(x, pipeline), pipeline, id="last_token") |
|
] |
|
return token_positions |