Spaces:
Running
Running
| """Functions to help with searching codes using regex.""" | |
| import pickle | |
| import re | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| def load_dataset_cache(cache_base_path): | |
| """Load cache files required for dataset from `cache_base_path`.""" | |
| tokens_str = np.load(cache_base_path + "tokens_str.npy") | |
| tokens_text = np.load(cache_base_path + "tokens_text.npy") | |
| token_byte_pos = np.load(cache_base_path + "token_byte_pos.npy") | |
| return tokens_str, tokens_text, token_byte_pos | |
| def load_code_search_cache(cache_base_path): | |
| """Load cache files required for code search from `cache_base_path`.""" | |
| metrics = np.load(cache_base_path + "metrics.npy", allow_pickle=True).item() | |
| with open(cache_base_path + "cb_acts.pkl", "rb") as f: | |
| cb_acts = pickle.load(f) | |
| with open(cache_base_path + "act_count_ft_tkns.pkl", "rb") as f: | |
| act_count_ft_tkns = pickle.load(f) | |
| return cb_acts, act_count_ft_tkns, metrics | |
| def search_re(re_pattern, tokens_text, at_odd_even=-1): | |
| """Get list of (example_id, token_pos) where re_pattern matches in tokens_text. | |
| Args: | |
| re_pattern: regex pattern to search for. | |
| tokens_text: list of example texts. | |
| at_odd_even: to limit matches to odd or even positions only. | |
| -1 (default): to not limit matches. | |
| 0: to limit matches to odd positions only. | |
| 1: to limit matches to even positions only. | |
| This is useful for the TokFSM dataset when searching for states | |
| since the first token of states are always at even positions. | |
| """ | |
| # TODO: ensure that parentheses are not escaped | |
| assert at_odd_even in [-1, 0, 1], f"Invalid at_odd_even: {at_odd_even}" | |
| if re_pattern.find("(") == -1: | |
| re_pattern = f"({re_pattern})" | |
| res = [ | |
| (i, finditer.span(1)[0]) | |
| for i, text in enumerate(tokens_text) | |
| for finditer in re.finditer(re_pattern, text) | |
| if finditer.span(1)[0] != finditer.span(1)[1] | |
| ] | |
| if at_odd_even != -1: | |
| res = [r for r in res if r[1] % 2 == at_odd_even] | |
| return res | |
| def byte_id_to_token_pos_id(example_byte_id, token_byte_pos): | |
| """Convert byte position (or character position in a text) to its token position. | |
| Used to convert the searched regex span to its token position. | |
| Args: | |
| example_byte_id: tuple of (example_id, byte_id) where byte_id is a | |
| character's position in the text. | |
| token_byte_pos: numpy array of shape (num_examples, seq_len) where | |
| `token_byte_pos[example_id][token_pos]` is the byte position of | |
| the token at `token_pos` in the example with `example_id`. | |
| Returns: | |
| (example_id, token_pos_id) tuple. | |
| """ | |
| example_id, byte_id = example_byte_id | |
| index = np.searchsorted(token_byte_pos[example_id], byte_id, side="right") | |
| return (example_id, index) | |
| def get_code_precision_and_recall(token_pos_ids, codebook_acts, cb_act_counts=None): | |
| """Search for the codes that activate on the given `token_pos_ids`. | |
| Args: | |
| token_pos_ids: list of (example_id, token_pos_id) tuples. | |
| codebook_acts: numpy array of activations of a codebook on a dataset with | |
| shape (num_examples, seq_len, k_codebook). | |
| cb_act_counts: array of shape (num_codes,) where `cb_act_counts[cb_name][code]` | |
| is the number of times the code `code` is activated in the dataset. | |
| Returns: | |
| codes: numpy array of code ids sorted by their precision on the given `token_pos_ids`. | |
| prec: numpy array where `prec[i]` is the precision of the code | |
| `codes[i]` for the given `token_pos_ids`. | |
| recall: numpy array where `recall[i]` is the recall of the code | |
| `codes[i]` for the given `token_pos_ids`. | |
| code_acts: numpy array where `code_acts[i]` is the number of times | |
| the code `codes[i]` is activated in the dataset. | |
| """ | |
| codes = np.array( | |
| [ | |
| codebook_acts[example_id][token_pos_id] | |
| for example_id, token_pos_id in token_pos_ids | |
| ] | |
| ) | |
| codes, counts = np.unique(codes, return_counts=True) | |
| recall = counts / len(token_pos_ids) | |
| idx = recall > 0.01 | |
| codes, counts, recall = codes[idx], counts[idx], recall[idx] | |
| if cb_act_counts is not None: | |
| code_acts = np.array([cb_act_counts[code] for code in codes]) | |
| prec = counts / code_acts | |
| sort_idx = np.argsort(prec)[::-1] | |
| else: | |
| code_acts = np.zeros_like(codes) | |
| prec = np.zeros_like(codes) | |
| sort_idx = np.argsort(recall)[::-1] | |
| codes, prec, recall = codes[sort_idx], prec[sort_idx], recall[sort_idx] | |
| code_acts = code_acts[sort_idx] | |
| return codes, prec, recall, code_acts | |
| def get_neuron_precision_and_recall( | |
| token_pos_ids, recall, neuron_acts_by_ex, neuron_sorted_acts | |
| ): | |
| """Get the neurons with the highest precision and recall for the given `token_pos_ids`. | |
| Args: | |
| token_pos_ids: list of token (example_id, token_pos_id) tuples from a dataset over which | |
| the neurons with the highest precision and recall are to be found. | |
| recall: recall threshold for the neurons (this determines their activation threshold). | |
| neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons | |
| on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size). | |
| The third dimension is 2 because we consider neurons from both: attention and mlp. | |
| neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons | |
| on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len). | |
| This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two | |
| dimensions to the last dimensions and then sorting the last dimension. | |
| Returns: | |
| best_prec: highest precision amongst all the neurons for the given `token_pos_ids`. | |
| best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids` | |
| based on the threshold determined by the `recall` argument. | |
| best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number, | |
| `is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp, | |
| and `neuron_id` is the neuron's index in the layer. | |
| """ | |
| if isinstance(neuron_acts_by_ex, torch.Tensor): | |
| neuron_acts_on_pattern = torch.stack( | |
| [ | |
| neuron_acts_by_ex[example_id, token_pos_id] | |
| for example_id, token_pos_id in token_pos_ids | |
| ], | |
| dim=-1, | |
| ) # (layers, 2, dim_size, matches) | |
| neuron_acts_on_pattern = torch.sort(neuron_acts_on_pattern, dim=-1).values | |
| else: | |
| neuron_acts_on_pattern = np.stack( | |
| [ | |
| neuron_acts_by_ex[example_id, token_pos_id] | |
| for example_id, token_pos_id in token_pos_ids | |
| ], | |
| axis=-1, | |
| ) # (layers, 2, dim_size, matches) | |
| neuron_acts_on_pattern.sort(axis=-1) | |
| neuron_acts_on_pattern = torch.from_numpy(neuron_acts_on_pattern) | |
| act_thresh = neuron_acts_on_pattern[ | |
| :, :, :, -int(recall * neuron_acts_on_pattern.shape[-1]) | |
| ] | |
| assert neuron_sorted_acts.shape[:-1] == act_thresh.shape | |
| prec_den = torch.searchsorted(neuron_sorted_acts, act_thresh.unsqueeze(-1)) | |
| prec_den = prec_den.squeeze(-1) | |
| prec_den = neuron_sorted_acts.shape[-1] - prec_den | |
| prec = int(recall * neuron_acts_on_pattern.shape[-1]) / prec_den | |
| assert ( | |
| prec.shape == neuron_acts_on_pattern.shape[:-1] | |
| ), f"{prec.shape} != {neuron_acts_on_pattern.shape[:-1]}" | |
| best_neuron_idx = np.unravel_index(prec.argmax(), prec.shape) | |
| best_prec = prec[best_neuron_idx] | |
| best_neuron_act_thresh = act_thresh[best_neuron_idx].item() | |
| best_neuron_acts = neuron_acts_by_ex[ | |
| :, :, best_neuron_idx[0], best_neuron_idx[1], best_neuron_idx[2] | |
| ] | |
| best_neuron_acts = best_neuron_acts >= best_neuron_act_thresh | |
| best_neuron_acts = np.stack(np.where(best_neuron_acts), axis=-1) | |
| return best_prec, best_neuron_acts, best_neuron_idx | |
| def convert_to_adv_name(name, cb_at, gcb=""): | |
| """Convert layer0_head0 to layer0_attn_preproj_gcb0.""" | |
| if gcb: | |
| layer, head = name.split("_") | |
| return layer + f"_{cb_at}_gcb" + head[4:] | |
| else: | |
| return layer + "_" + cb_at | |
| def convert_to_base_name(name, gcb=""): | |
| """Convert layer0_attn_preproj_gcb0 to layer0_head0.""" | |
| split_name = name.split("_") | |
| layer, head = split_name[0], split_name[-1][3:] | |
| if "gcb" in name: | |
| return layer + "_head" + head | |
| else: | |
| return layer | |
| def get_layer_head_from_base_name(name): | |
| """Convert layer0_head0 to 0, 0.""" | |
| split_name = name.split("_") | |
| layer = int(split_name[0][5:]) | |
| head = None | |
| if len(split_name) > 1: | |
| head = int(split_name[-1][4:]) | |
| return layer, head | |
| def get_layer_head_from_adv_name(name): | |
| """Convert layer0_attn_preproj_gcb0 to 0, 0.""" | |
| base_name = convert_to_base_name(name) | |
| layer, head = get_layer_head_from_base_name(base_name) | |
| return layer, head | |
| def get_codes_from_pattern( | |
| re_pattern, | |
| tokens_text, | |
| token_byte_pos, | |
| cb_acts, | |
| act_count_ft_tkns, | |
| gcb="", | |
| topk=5, | |
| prec_threshold=0.5, | |
| at_odd_even=-1, | |
| ): | |
| """Fetch codes that activate on a given regex pattern. | |
| Retrieves at most `top_k` codes that activate with precision above `prec_threshold`. | |
| Args: | |
| re_pattern: regex pattern to search for. | |
| tokens_text: list of example texts of a dataset. | |
| token_byte_pos: numpy array of shape (num_examples, seq_len) where | |
| `token_byte_pos[example_id][token_pos]` is the byte position of | |
| the token at `token_pos` in the example with `example_id`. | |
| cb_acts: dict of codebook activations. | |
| act_count_ft_tkns: dict over all codebooks of number of token activations on the dataset | |
| gcb: "_gcb" for grouped codebooks and "" for non-grouped codebooks. | |
| topk: maximum number of codes to return per codebook. | |
| prec_threshold: minimum precision required for a code to be returned. | |
| at_odd_even: to limit matches to odd or even positions only. | |
| -1 (default): to not limit matches. | |
| 0: to limit matches to odd positions only. | |
| 1: to limit matches to even positions only. | |
| This is useful for the TokFSM dataset when searching for states | |
| since the first token of states are always at even positions. | |
| Returns: | |
| codebook_wise_codes: dict of codebook name to list of | |
| (code, prec, recall, code_acts) tuples. | |
| re_token_matches: number of tokens that match the regex pattern. | |
| """ | |
| byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even) | |
| token_pos_ids = [ | |
| byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids | |
| ] | |
| token_pos_ids = np.unique(token_pos_ids, axis=0) | |
| re_token_matches = len(token_pos_ids) | |
| codebook_wise_codes = {} | |
| for cb_name, cb in tqdm(cb_acts.items()): | |
| base_cb_name = convert_to_base_name(cb_name, gcb=gcb) | |
| codes, prec, recall, code_acts = get_code_precision_and_recall( | |
| token_pos_ids, | |
| cb, | |
| cb_act_counts=act_count_ft_tkns[base_cb_name], | |
| ) | |
| idx = np.arange(min(topk, len(codes))) | |
| idx = idx[prec[:topk] > prec_threshold] | |
| codes, prec, recall = codes[idx], prec[idx], recall[idx] | |
| code_acts = code_acts[idx] | |
| codes_pr = list(zip(codes, prec, recall, code_acts)) | |
| codebook_wise_codes[base_cb_name] = codes_pr | |
| return codebook_wise_codes, re_token_matches | |
| def get_neurons_from_pattern( | |
| re_pattern, | |
| tokens_text, | |
| token_byte_pos, | |
| neuron_acts_by_ex, | |
| neuron_sorted_acts, | |
| recall_threshold, | |
| at_odd_even=-1, | |
| ): | |
| """Fetch the highest precision neurons that activate on a given regex pattern. | |
| The activation threshold for the neurons is determined by the `recall_threshold`. | |
| Args: | |
| re_pattern: regex pattern to search for. | |
| tokens_text: list of example texts of a dataset. | |
| token_byte_pos: numpy array of shape (num_examples, seq_len) where | |
| `token_byte_pos[example_id][token_pos]` is the byte position of | |
| the token at `token_pos` in the example with `example_id`. | |
| neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons | |
| on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size). | |
| The third dimension is 2 because we consider neurons from both: attention and mlp. | |
| neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons | |
| on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len). | |
| This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two | |
| dimensions to the last dimensions and then sorting the last dimension. | |
| recall_threshold: recall threshold for the neurons (this determines their activation threshold). | |
| at_odd_even: to limit matches to odd or even positions only. | |
| -1 (default): to not limit matches. | |
| 0: to limit matches to odd positions only. | |
| 1: to limit matches to even positions only. | |
| This is useful for the TokFSM dataset when searching for states | |
| since the first token of states are always at even positions. | |
| Returns: | |
| best_prec: highest precision amongst all the neurons for the given `token_pos_ids`. | |
| best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids` | |
| based on the threshold determined by the `recall` argument. | |
| best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number, | |
| `is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp, | |
| and `neuron_id` is the neuron's index in the layer. | |
| re_token_matches: number of tokens that match the regex pattern. | |
| """ | |
| byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even) | |
| token_pos_ids = [ | |
| byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids | |
| ] | |
| token_pos_ids = np.unique(token_pos_ids, axis=0) | |
| re_token_matches = len(token_pos_ids) | |
| best_prec, best_neuron_acts, best_neuron_idx = get_neuron_precision_and_recall( | |
| token_pos_ids, | |
| recall_threshold, | |
| neuron_acts_by_ex, | |
| neuron_sorted_acts, | |
| ) | |
| return best_prec, best_neuron_acts, best_neuron_idx, re_token_matches | |
| def compare_codes_with_neurons( | |
| best_codes_info, | |
| tokens_text, | |
| token_byte_pos, | |
| neuron_acts_by_ex, | |
| neuron_sorted_acts, | |
| at_odd_even=-1, | |
| ): | |
| """Compare codes with the highest precision neurons on the regex pattern of the code. | |
| Args: | |
| best_codes_info: list of CodeInfo objects. | |
| tokens_text: list of example texts of a dataset. | |
| token_byte_pos: numpy array of shape (num_examples, seq_len) where | |
| `token_byte_pos[example_id][token_pos]` is the byte position of | |
| the token at `token_pos` in the example with `example_id`. | |
| neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons | |
| on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size). | |
| The third dimension is 2 because we consider neurons from both: attention and mlp. | |
| neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons | |
| on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len). | |
| This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two | |
| dimensions to the last dimensions and then sorting the last dimension. | |
| at_odd_even: to limit matches to odd or even positions only. | |
| -1 (default): to not limit matches. | |
| 0: to limit matches to odd positions only. | |
| 1: to limit matches to even positions only. | |
| This is useful for the TokFSM dataset when searching for states | |
| since the first token of states are always at even positions. | |
| Returns: | |
| codes_better_than_neurons: fraction of codes that have higher precision than the highest | |
| precision neuron on the regex pattern of the code. | |
| code_best_precs: is an array of the precision of each code in `best_codes_info`. | |
| all_best_prec: is an array of the highest precision neurons on the regex pattern. | |
| """ | |
| assert isinstance(neuron_acts_by_ex, np.ndarray) | |
| ( | |
| neuron_best_prec, | |
| all_best_neuron_acts, | |
| all_best_neuron_idxs, | |
| all_re_token_matches, | |
| ) = zip( | |
| *[ | |
| get_neurons_from_pattern( | |
| code_info.regex, | |
| tokens_text, | |
| token_byte_pos, | |
| neuron_acts_by_ex, | |
| neuron_sorted_acts, | |
| code_info.recall, | |
| at_odd_even=at_odd_even, | |
| ) | |
| for code_info in tqdm(best_codes_info) | |
| ], | |
| strict=True, | |
| ) | |
| neuron_best_prec = np.array(neuron_best_prec) | |
| code_best_precs = np.array([code_info.prec for code_info in best_codes_info]) | |
| codes_better_than_neurons = code_best_precs > neuron_best_prec | |
| return codes_better_than_neurons.mean(), code_best_precs, neuron_best_prec | |