|
""" |
|
Token position definitions for MCQA task submission. |
|
This file provides token position functions that identify key tokens in MCQA prompts. |
|
""" |
|
|
|
import re |
|
from CausalAbstraction.model_units.LM_units import TokenPosition |
|
|
|
|
|
def get_last_token_index(prompt, pipeline): |
|
""" |
|
Get the index of the last token in the prompt. |
|
|
|
Args: |
|
prompt (str): The input prompt |
|
pipeline: The tokenizer pipeline |
|
|
|
Returns: |
|
list[int]: List containing the index of the last token |
|
""" |
|
input_ids = list(pipeline.load(prompt)["input_ids"][0]) |
|
return [len(input_ids) - 1] |
|
|
|
|
|
def get_correct_symbol_index(prompt, pipeline, task): |
|
""" |
|
Find the index of the correct answer symbol in the prompt. |
|
|
|
Args: |
|
prompt (str): The prompt text |
|
pipeline: The tokenizer pipeline |
|
task: The task object containing causal model |
|
|
|
Returns: |
|
list[int]: List containing the index of the correct answer symbol token |
|
""" |
|
|
|
output = task.causal_model.run_forward(task.input_loader(prompt)) |
|
pointer = output["answer_pointer"] |
|
correct_symbol = output[f"symbol{pointer}"] |
|
|
|
|
|
matches = list(re.finditer(r"\b[A-Z]\b", prompt)) |
|
|
|
|
|
symbol_match = None |
|
for match in matches: |
|
if prompt[match.start():match.end()] == correct_symbol: |
|
symbol_match = match |
|
break |
|
|
|
if not symbol_match: |
|
raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}") |
|
|
|
|
|
substring = prompt[:symbol_match.end()] |
|
tokenized_substring = list(pipeline.load(substring)["input_ids"][0]) |
|
|
|
|
|
return [len(tokenized_substring) - 1] |
|
|
|
|
|
def get_token_positions(pipeline, task): |
|
""" |
|
Get token positions for the MCQA task. |
|
|
|
This function identifies key token positions in MCQA prompts: |
|
- correct_symbol: The position of the correct answer symbol (A, B, C, or D) |
|
- last_token: The position of the last token in the prompt |
|
|
|
Args: |
|
pipeline: The language model pipeline with tokenizer |
|
task: The MCQA task object |
|
|
|
Returns: |
|
list[TokenPosition]: List of TokenPosition objects for intervention experiments |
|
""" |
|
|
|
token_positions = [ |
|
TokenPosition( |
|
lambda x: get_correct_symbol_index(x, pipeline, task), |
|
pipeline, |
|
id="correct_symbol" |
|
), |
|
TokenPosition( |
|
lambda x: get_last_token_index(x, pipeline), |
|
pipeline, |
|
id="last_token" |
|
) |
|
] |
|
return token_positions |