codebook-features / code_search_utils.py
taufeeque's picture
Update code
63b5bc1
"""Functions to help with searching codes using regex."""
import pickle
import re
import numpy as np
import torch
from tqdm import tqdm
def load_dataset_cache(cache_base_path):
"""Load cache files required for dataset from `cache_base_path`."""
tokens_str = np.load(cache_base_path + "tokens_str.npy")
tokens_text = np.load(cache_base_path + "tokens_text.npy")
token_byte_pos = np.load(cache_base_path + "token_byte_pos.npy")
return tokens_str, tokens_text, token_byte_pos
def load_code_search_cache(cache_base_path):
"""Load cache files required for code search from `cache_base_path`."""
metrics = np.load(cache_base_path + "metrics.npy", allow_pickle=True).item()
with open(cache_base_path + "cb_acts.pkl", "rb") as f:
cb_acts = pickle.load(f)
with open(cache_base_path + "act_count_ft_tkns.pkl", "rb") as f:
act_count_ft_tkns = pickle.load(f)
return cb_acts, act_count_ft_tkns, metrics
def search_re(re_pattern, tokens_text, at_odd_even=-1):
"""Get list of (example_id, token_pos) where re_pattern matches in tokens_text.
Args:
re_pattern: regex pattern to search for.
tokens_text: list of example texts.
at_odd_even: to limit matches to odd or even positions only.
-1 (default): to not limit matches.
0: to limit matches to odd positions only.
1: to limit matches to even positions only.
This is useful for the TokFSM dataset when searching for states
since the first token of states are always at even positions.
"""
# TODO: ensure that parentheses are not escaped
assert at_odd_even in [-1, 0, 1], f"Invalid at_odd_even: {at_odd_even}"
if re_pattern.find("(") == -1:
re_pattern = f"({re_pattern})"
res = [
(i, finditer.span(1)[0])
for i, text in enumerate(tokens_text)
for finditer in re.finditer(re_pattern, text)
if finditer.span(1)[0] != finditer.span(1)[1]
]
if at_odd_even != -1:
res = [r for r in res if r[1] % 2 == at_odd_even]
return res
def byte_id_to_token_pos_id(example_byte_id, token_byte_pos):
"""Convert byte position (or character position in a text) to its token position.
Used to convert the searched regex span to its token position.
Args:
example_byte_id: tuple of (example_id, byte_id) where byte_id is a
character's position in the text.
token_byte_pos: numpy array of shape (num_examples, seq_len) where
`token_byte_pos[example_id][token_pos]` is the byte position of
the token at `token_pos` in the example with `example_id`.
Returns:
(example_id, token_pos_id) tuple.
"""
example_id, byte_id = example_byte_id
index = np.searchsorted(token_byte_pos[example_id], byte_id, side="right")
return (example_id, index)
def get_code_precision_and_recall(token_pos_ids, codebook_acts, cb_act_counts=None):
"""Search for the codes that activate on the given `token_pos_ids`.
Args:
token_pos_ids: list of (example_id, token_pos_id) tuples.
codebook_acts: numpy array of activations of a codebook on a dataset with
shape (num_examples, seq_len, k_codebook).
cb_act_counts: array of shape (num_codes,) where `cb_act_counts[cb_name][code]`
is the number of times the code `code` is activated in the dataset.
Returns:
codes: numpy array of code ids sorted by their precision on the given `token_pos_ids`.
prec: numpy array where `prec[i]` is the precision of the code
`codes[i]` for the given `token_pos_ids`.
recall: numpy array where `recall[i]` is the recall of the code
`codes[i]` for the given `token_pos_ids`.
code_acts: numpy array where `code_acts[i]` is the number of times
the code `codes[i]` is activated in the dataset.
"""
codes = np.array(
[
codebook_acts[example_id][token_pos_id]
for example_id, token_pos_id in token_pos_ids
]
)
codes, counts = np.unique(codes, return_counts=True)
recall = counts / len(token_pos_ids)
idx = recall > 0.01
codes, counts, recall = codes[idx], counts[idx], recall[idx]
if cb_act_counts is not None:
code_acts = np.array([cb_act_counts[code] for code in codes])
prec = counts / code_acts
sort_idx = np.argsort(prec)[::-1]
else:
code_acts = np.zeros_like(codes)
prec = np.zeros_like(codes)
sort_idx = np.argsort(recall)[::-1]
codes, prec, recall = codes[sort_idx], prec[sort_idx], recall[sort_idx]
code_acts = code_acts[sort_idx]
return codes, prec, recall, code_acts
def get_neuron_precision_and_recall(
token_pos_ids, recall, neuron_acts_by_ex, neuron_sorted_acts
):
"""Get the neurons with the highest precision and recall for the given `token_pos_ids`.
Args:
token_pos_ids: list of token (example_id, token_pos_id) tuples from a dataset over which
the neurons with the highest precision and recall are to be found.
recall: recall threshold for the neurons (this determines their activation threshold).
neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons
on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).
The third dimension is 2 because we consider neurons from both: attention and mlp.
neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons
on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).
This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two
dimensions to the last dimensions and then sorting the last dimension.
Returns:
best_prec: highest precision amongst all the neurons for the given `token_pos_ids`.
best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids`
based on the threshold determined by the `recall` argument.
best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number,
`is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp,
and `neuron_id` is the neuron's index in the layer.
"""
if isinstance(neuron_acts_by_ex, torch.Tensor):
neuron_acts_on_pattern = torch.stack(
[
neuron_acts_by_ex[example_id, token_pos_id]
for example_id, token_pos_id in token_pos_ids
],
dim=-1,
) # (layers, 2, dim_size, matches)
neuron_acts_on_pattern = torch.sort(neuron_acts_on_pattern, dim=-1).values
else:
neuron_acts_on_pattern = np.stack(
[
neuron_acts_by_ex[example_id, token_pos_id]
for example_id, token_pos_id in token_pos_ids
],
axis=-1,
) # (layers, 2, dim_size, matches)
neuron_acts_on_pattern.sort(axis=-1)
neuron_acts_on_pattern = torch.from_numpy(neuron_acts_on_pattern)
act_thresh = neuron_acts_on_pattern[
:, :, :, -int(recall * neuron_acts_on_pattern.shape[-1])
]
assert neuron_sorted_acts.shape[:-1] == act_thresh.shape
prec_den = torch.searchsorted(neuron_sorted_acts, act_thresh.unsqueeze(-1))
prec_den = prec_den.squeeze(-1)
prec_den = neuron_sorted_acts.shape[-1] - prec_den
prec = int(recall * neuron_acts_on_pattern.shape[-1]) / prec_den
assert (
prec.shape == neuron_acts_on_pattern.shape[:-1]
), f"{prec.shape} != {neuron_acts_on_pattern.shape[:-1]}"
best_neuron_idx = np.unravel_index(prec.argmax(), prec.shape)
best_prec = prec[best_neuron_idx]
best_neuron_act_thresh = act_thresh[best_neuron_idx].item()
best_neuron_acts = neuron_acts_by_ex[
:, :, best_neuron_idx[0], best_neuron_idx[1], best_neuron_idx[2]
]
best_neuron_acts = best_neuron_acts >= best_neuron_act_thresh
best_neuron_acts = np.stack(np.where(best_neuron_acts), axis=-1)
return best_prec, best_neuron_acts, best_neuron_idx
def convert_to_adv_name(name, cb_at, gcb=""):
"""Convert layer0_head0 to layer0_attn_preproj_gcb0."""
if gcb:
layer, head = name.split("_")
return layer + f"_{cb_at}_gcb" + head[4:]
else:
return layer + "_" + cb_at
def convert_to_base_name(name, gcb=""):
"""Convert layer0_attn_preproj_gcb0 to layer0_head0."""
split_name = name.split("_")
layer, head = split_name[0], split_name[-1][3:]
if "gcb" in name:
return layer + "_head" + head
else:
return layer
def get_layer_head_from_base_name(name):
"""Convert layer0_head0 to 0, 0."""
split_name = name.split("_")
layer = int(split_name[0][5:])
head = None
if len(split_name) > 1:
head = int(split_name[-1][4:])
return layer, head
def get_layer_head_from_adv_name(name):
"""Convert layer0_attn_preproj_gcb0 to 0, 0."""
base_name = convert_to_base_name(name)
layer, head = get_layer_head_from_base_name(base_name)
return layer, head
def get_codes_from_pattern(
re_pattern,
tokens_text,
token_byte_pos,
cb_acts,
act_count_ft_tkns,
gcb="",
topk=5,
prec_threshold=0.5,
at_odd_even=-1,
):
"""Fetch codes that activate on a given regex pattern.
Retrieves at most `top_k` codes that activate with precision above `prec_threshold`.
Args:
re_pattern: regex pattern to search for.
tokens_text: list of example texts of a dataset.
token_byte_pos: numpy array of shape (num_examples, seq_len) where
`token_byte_pos[example_id][token_pos]` is the byte position of
the token at `token_pos` in the example with `example_id`.
cb_acts: dict of codebook activations.
act_count_ft_tkns: dict over all codebooks of number of token activations on the dataset
gcb: "_gcb" for grouped codebooks and "" for non-grouped codebooks.
topk: maximum number of codes to return per codebook.
prec_threshold: minimum precision required for a code to be returned.
at_odd_even: to limit matches to odd or even positions only.
-1 (default): to not limit matches.
0: to limit matches to odd positions only.
1: to limit matches to even positions only.
This is useful for the TokFSM dataset when searching for states
since the first token of states are always at even positions.
Returns:
codebook_wise_codes: dict of codebook name to list of
(code, prec, recall, code_acts) tuples.
re_token_matches: number of tokens that match the regex pattern.
"""
byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even)
token_pos_ids = [
byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
]
token_pos_ids = np.unique(token_pos_ids, axis=0)
re_token_matches = len(token_pos_ids)
codebook_wise_codes = {}
for cb_name, cb in tqdm(cb_acts.items()):
base_cb_name = convert_to_base_name(cb_name, gcb=gcb)
codes, prec, recall, code_acts = get_code_precision_and_recall(
token_pos_ids,
cb,
cb_act_counts=act_count_ft_tkns[base_cb_name],
)
idx = np.arange(min(topk, len(codes)))
idx = idx[prec[:topk] > prec_threshold]
codes, prec, recall = codes[idx], prec[idx], recall[idx]
code_acts = code_acts[idx]
codes_pr = list(zip(codes, prec, recall, code_acts))
codebook_wise_codes[base_cb_name] = codes_pr
return codebook_wise_codes, re_token_matches
def get_neurons_from_pattern(
re_pattern,
tokens_text,
token_byte_pos,
neuron_acts_by_ex,
neuron_sorted_acts,
recall_threshold,
at_odd_even=-1,
):
"""Fetch the highest precision neurons that activate on a given regex pattern.
The activation threshold for the neurons is determined by the `recall_threshold`.
Args:
re_pattern: regex pattern to search for.
tokens_text: list of example texts of a dataset.
token_byte_pos: numpy array of shape (num_examples, seq_len) where
`token_byte_pos[example_id][token_pos]` is the byte position of
the token at `token_pos` in the example with `example_id`.
neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons
on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).
The third dimension is 2 because we consider neurons from both: attention and mlp.
neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons
on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).
This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two
dimensions to the last dimensions and then sorting the last dimension.
recall_threshold: recall threshold for the neurons (this determines their activation threshold).
at_odd_even: to limit matches to odd or even positions only.
-1 (default): to not limit matches.
0: to limit matches to odd positions only.
1: to limit matches to even positions only.
This is useful for the TokFSM dataset when searching for states
since the first token of states are always at even positions.
Returns:
best_prec: highest precision amongst all the neurons for the given `token_pos_ids`.
best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids`
based on the threshold determined by the `recall` argument.
best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number,
`is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp,
and `neuron_id` is the neuron's index in the layer.
re_token_matches: number of tokens that match the regex pattern.
"""
byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even)
token_pos_ids = [
byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
]
token_pos_ids = np.unique(token_pos_ids, axis=0)
re_token_matches = len(token_pos_ids)
best_prec, best_neuron_acts, best_neuron_idx = get_neuron_precision_and_recall(
token_pos_ids,
recall_threshold,
neuron_acts_by_ex,
neuron_sorted_acts,
)
return best_prec, best_neuron_acts, best_neuron_idx, re_token_matches
def compare_codes_with_neurons(
best_codes_info,
tokens_text,
token_byte_pos,
neuron_acts_by_ex,
neuron_sorted_acts,
at_odd_even=-1,
):
"""Compare codes with the highest precision neurons on the regex pattern of the code.
Args:
best_codes_info: list of CodeInfo objects.
tokens_text: list of example texts of a dataset.
token_byte_pos: numpy array of shape (num_examples, seq_len) where
`token_byte_pos[example_id][token_pos]` is the byte position of
the token at `token_pos` in the example with `example_id`.
neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons
on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).
The third dimension is 2 because we consider neurons from both: attention and mlp.
neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons
on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).
This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two
dimensions to the last dimensions and then sorting the last dimension.
at_odd_even: to limit matches to odd or even positions only.
-1 (default): to not limit matches.
0: to limit matches to odd positions only.
1: to limit matches to even positions only.
This is useful for the TokFSM dataset when searching for states
since the first token of states are always at even positions.
Returns:
codes_better_than_neurons: fraction of codes that have higher precision than the highest
precision neuron on the regex pattern of the code.
code_best_precs: is an array of the precision of each code in `best_codes_info`.
all_best_prec: is an array of the highest precision neurons on the regex pattern.
"""
assert isinstance(neuron_acts_by_ex, np.ndarray)
(
neuron_best_prec,
all_best_neuron_acts,
all_best_neuron_idxs,
all_re_token_matches,
) = zip(
*[
get_neurons_from_pattern(
code_info.regex,
tokens_text,
token_byte_pos,
neuron_acts_by_ex,
neuron_sorted_acts,
code_info.recall,
at_odd_even=at_odd_even,
)
for code_info in tqdm(best_codes_info)
],
strict=True,
)
neuron_best_prec = np.array(neuron_best_prec)
code_best_precs = np.array([code_info.prec for code_info in best_codes_info])
codes_better_than_neurons = code_best_precs > neuron_best_prec
return codes_better_than_neurons.mean(), code_best_precs, neuron_best_prec