"""Util functions for codebook features.""" import pathlib import re import typing from dataclasses import dataclass from functools import partial from typing import Optional import numpy as np import torch import torch.nn.functional as F from termcolor import colored from tqdm import tqdm @dataclass class CodeInfo: """Dataclass for codebook info.""" code: int layer: int head: Optional[int] cb_at: Optional[str] = None # for patching interventions pos: Optional[int] = None code_pos: Optional[int] = -1 # for description & regex-based interpretation description: Optional[str] = None regex: Optional[str] = None prec: Optional[float] = None recall: Optional[float] = None num_acts: Optional[int] = None def __post_init__(self): """Convert to appropriate types.""" self.code = int(self.code) self.layer = int(self.layer) if self.head: self.head = int(self.head) if self.pos: self.pos = int(self.pos) if self.code_pos: self.code_pos = int(self.code_pos) if self.prec: self.prec = float(self.prec) assert 0 <= self.prec <= 1 if self.recall: self.recall = float(self.recall) assert 0 <= self.recall <= 1 if self.num_acts: self.num_acts = int(self.num_acts) def check_description_info(self): """Check if the regex info is present.""" assert self.num_acts is not None and self.description is not None if self.regex is not None: assert self.prec is not None and self.recall is not None def __repr__(self): """Return the string representation.""" repr = f"CodeInfo(code={self.code}, layer={self.layer}, head={self.head}, cb_at={self.cb_at}" if self.pos is not None or self.code_pos is not None: repr += f", pos={self.pos}, code_pos={self.code_pos}" if self.description is not None: repr += f", description={self.description}" if self.regex is not None: repr += f", regex={self.regex}, prec={self.prec}, recall={self.recall}" if self.num_acts is not None: repr += f", num_acts={self.num_acts}" repr += ")" return repr @classmethod def from_str(cls, code_txt, *args, **kwargs): """Extract code info fields from string.""" code_txt = code_txt.strip().lower() code_txt = code_txt.split(", ") code_txt = dict(txt.split(": ") for txt in code_txt) return cls(*args, **code_txt, **kwargs) @dataclass class ModelInfoForWebapp: """Model info for webapp.""" model_name: str pretrained_path: str dataset_name: str num_codes: int cb_at: str gcb: str n_layers: int n_heads: Optional[int] = None seed: int = 42 max_samples: int = 2000 def __post_init__(self): """Convert to correct types.""" self.num_codes = int(self.num_codes) self.n_layers = int(self.n_layers) if self.n_heads == "None": self.n_heads = None elif self.n_heads is not None: self.n_heads = int(self.n_heads) self.seed = int(self.seed) self.max_samples = int(self.max_samples) @classmethod def load(cls, path): """Parse model info from path.""" path = pathlib.Path(path) with open(path / "info.txt", "r") as f: lines = f.readlines() lines = dict(line.strip().split(": ") for line in lines) return cls(**lines) def save(self, path): """Save model info to path.""" path = pathlib.Path(path) with open(path / "info.txt", "w") as f: for k, v in self.__dict__.items(): f.write(f"{k}: {v}\n") def logits_to_pred(logits, tokenizer, k=5): """Convert logits to top-k predictions.""" sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) probs = sorted_logits.softmax(dim=-1) topk_preds = [tokenizer.convert_ids_to_tokens(e) for e in sorted_indices[:, -1, :k]] topk_preds = [ tokenizer.convert_tokens_to_string([e]) for batch in topk_preds for e in batch ] return [(topk_preds[i], probs[:, -1, i].item()) for i in range(len(topk_preds))] def features_to_tokens(cb_key, cb_acts, num_codes, code=None): """Return the set of token ids each codebook feature activates on.""" codebook_ids = cb_acts[cb_key] if code is None: features_tokens = [[] for _ in range(num_codes)] for i in tqdm(range(codebook_ids.shape[0])): for j in range(codebook_ids.shape[1]): for k in range(codebook_ids.shape[2]): features_tokens[codebook_ids[i, j, k]].append((i, j)) else: idx0, idx1, _ = np.where(codebook_ids == code) features_tokens = list(zip(idx0, idx1)) return features_tokens def color_str(s: str, html: bool, color: Optional[str] = None): """Color the string for html or terminal.""" if html: color = "DeepSkyBlue" if color is None else color return f"{s}" else: color = "light_cyan" if color is None else color return colored(s, color) def color_tokens_tokfsm(tokens, color_idx, html=False): """Separate states with a dash and color red the tokens in color_idx.""" ret_string = "" itr_over_color_idx = 0 tokens_enumerate = enumerate(tokens) if tokens[0] == "<|endoftext|>": next(tokens_enumerate) if color_idx[0] == 0: itr_over_color_idx += 1 for i, c in tokens_enumerate: if i % 2 == 1: ret_string += "-" if itr_over_color_idx < len(color_idx) and i == color_idx[itr_over_color_idx]: ret_string += color_str(c, html) itr_over_color_idx += 1 else: ret_string += c return ret_string def color_tokens(tokens, color_idx, n=3, html=False): """Color the tokens in color_idx.""" ret_string = "" last_colored_token_idx = -1 for i in color_idx: c_str = tokens[i] if i <= last_colored_token_idx + 2 * n + 1: ret_string += "".join(tokens[last_colored_token_idx + 1 : i]) else: ret_string += "".join( tokens[last_colored_token_idx + 1 : last_colored_token_idx + n + 1] ) ret_string += " ... " ret_string += "".join(tokens[i - n : i]) ret_string += color_str(c_str, html) last_colored_token_idx = i ret_string += "".join( tokens[ last_colored_token_idx + 1 : min(last_colored_token_idx + n, len(tokens)) ] ) return ret_string def prepare_example_print( example_id, example_tokens, tokens_to_color, html, color_fn=color_tokens, ): """Format example to print.""" example_output = color_str(example_id, html, "green") example_output += ( ": " + color_fn(example_tokens, tokens_to_color, html=html) + ("
" if html else "\n") ) return example_output def print_token_activations_of_code( code_act_by_pos, tokens, is_fsm=False, n=3, max_examples=100, randomize=False, html=False, return_example_list=False, ): """Print the context with the tokens that a code activates on. Args: code_act_by_pos: list of (example_id, token_pos_id) tuples specifying the token positions that a code activates on in a dataset. tokens: list of tokens of a dataset. is_fsm: whether the dataset is the TokFSM dataset. n: context to print around each side of a token that the code activates on. max_examples: maximum number of examples to print. randomize: whether to randomize the order of examples. html: Format the printing style for html or terminal. return_example_list: whether to return the printed string by examples or as a single string. Returns: string of all examples formatted if `return_example_list` is False otherwise list of (example_string, num_tokens_colored) tuples for each example. """ if randomize: raise NotImplementedError("Randomize not yet implemented.") indices = range(len(code_act_by_pos)) print_output = [] if return_example_list else "" curr_ex = code_act_by_pos[0][0] total_examples = 0 tokens_to_color = [] color_fn = color_tokens_tokfsm if is_fsm else partial(color_tokens, n=n) for idx in indices: if total_examples > max_examples: break i, j = code_act_by_pos[idx] if i != curr_ex and curr_ex >= 0: # got new example so print the previous one curr_ex_output = prepare_example_print( curr_ex, tokens[curr_ex], tokens_to_color, html, color_fn, ) total_examples += 1 if return_example_list: print_output.append((curr_ex_output, len(tokens_to_color))) else: print_output += curr_ex_output curr_ex = i tokens_to_color = [] tokens_to_color.append(j) curr_ex_output = prepare_example_print( curr_ex, tokens[curr_ex], tokens_to_color, html, color_fn, ) if return_example_list: print_output.append((curr_ex_output, len(tokens_to_color))) else: print_output += curr_ex_output print_output += color_str("*" * 50, html, "green") total_examples += 1 return print_output def print_token_activations_of_codes( ft_tkns, tokens, is_fsm=False, n=3, start=0, stop=1000, indices=None, max_examples=100, freq_filter=None, randomize=False, html=False, return_example_list=False, ): """Print the tokens for the codebook features.""" indices = list(range(start, stop)) if indices is None else indices num_tokens = len(tokens) * len(tokens[0]) codes, token_act_freqs, token_acts = [], [], [] for i in indices: tkns_of_code = ft_tkns[i] freq = (len(tkns_of_code), 100 * len(tkns_of_code) / num_tokens) if freq_filter is not None and freq[1] > freq_filter: continue codes.append(i) token_act_freqs.append(freq) if len(tkns_of_code) > 0: tkn_acts = print_token_activations_of_code( tkns_of_code, tokens, is_fsm, n=n, max_examples=max_examples, randomize=randomize, html=html, return_example_list=return_example_list, ) token_acts.append(tkn_acts) else: token_acts.append("") return codes, token_act_freqs, token_acts def patch_in_codes(run_cb_ids, hook, pos, code, code_pos=None): """Patch in the `code` at `run_cb_ids`.""" pos = slice(None) if pos is None else pos code_pos = slice(None) if code_pos is None else code_pos if code_pos == "append": assert pos == slice(None) run_cb_ids = F.pad(run_cb_ids, (0, 1), mode="constant", value=code) if isinstance(pos, typing.Iterable) or isinstance(pos, typing.Iterable): for p in pos: run_cb_ids[:, p, code_pos] = code else: run_cb_ids[:, pos, code_pos] = code return run_cb_ids def get_cb_hook_key(cb_at: str, layer_idx: int, gcb_idx: Optional[int] = None): """Get the layer name used to store hooks/cache.""" comp_name = "attn" if "attn" in cb_at else "mlp" if gcb_idx is None: return f"blocks.{layer_idx}.{comp_name}.codebook_layer.hook_codebook_ids" else: return f"blocks.{layer_idx}.{comp_name}.codebook_layer.codebook.{gcb_idx}.hook_codebook_ids" def run_model_fn_with_codes( input, cb_model, fn_name, fn_kwargs=None, list_of_code_infos=(), ): """Run the `cb_model`'s `fn_name` method while activating the codes in `list_of_code_infos`. Common use case includes running the `run_with_cache` method while activating the codes. For running the `generate` method, use `generate_with_codes` instead. """ if fn_kwargs is None: fn_kwargs = {} hook_fns = [ partial(patch_in_codes, pos=tupl.pos, code=tupl.code, code_pos=tupl.code_pos) for tupl in list_of_code_infos ] fwd_hooks = [ (get_cb_hook_key(tupl.cb_at, tupl.layer, tupl.head), hook_fns[i]) for i, tupl in enumerate(list_of_code_infos) ] cb_model.reset_hook_kwargs() with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model: ret = hooked_model.__getattribute__(fn_name)(input, **fn_kwargs) return ret def generate_with_codes( input, cb_model, list_of_code_infos=(), tokfsm=None, generate_kwargs=None, ): """Sample from the language model while activating the codes in `list_of_code_infos`.""" gen = run_model_fn_with_codes( input, cb_model, "generate", generate_kwargs, list_of_code_infos, ) return tokfsm.seq_to_traj(gen) if tokfsm is not None else gen def JSD(logits1, logits2, pos=-1, reduction="batchmean"): """Compute the Jensen-Shannon divergence between two distributions.""" if len(logits1.shape) == 3: logits1, logits2 = logits1[:, pos, :], logits2[:, pos, :] probs1 = F.softmax(logits1, dim=-1) probs2 = F.softmax(logits2, dim=-1) total_m = (0.5 * (probs1 + probs2)).log() loss = 0.0 loss += F.kl_div( total_m, F.log_softmax(logits1, dim=-1), log_target=True, reduction=reduction, ) loss += F.kl_div( total_m, F.log_softmax(logits2, dim=-1), log_target=True, reduction=reduction, ) return 0.5 * loss def cb_hook_key_to_info(layer_hook_key: str): """Get the layer info from the codebook layer hook key. Args: layer_hook_key: the hook key of the codebook layer. E.g. `blocks.3.attn.codebook_layer.hook_codebook_ids` Returns: comp_name: the name of the component codebook is appied at. layer_idx: the layer index. gcb_idx: the codebook index if the codebook layer is grouped, otherwise None. """ layer_search = re.search(r"blocks\.(\d+)\.(\w+)\.", layer_hook_key) assert layer_search is not None layer_idx, comp_name = int(layer_search.group(1)), layer_search.group(2) gcb_idx_search = re.search(r"codebook\.(\d+)", layer_hook_key) if gcb_idx_search is not None: gcb_idx = int(gcb_idx_search.group(1)) else: gcb_idx = None return comp_name, layer_idx, gcb_idx def find_code_changes(cache1, cache2, pos=None): """Find the codebook codes that are different between the two caches.""" for k in cache1.keys(): if "codebook" in k: c1 = cache1[k][0, pos] c2 = cache2[k][0, pos] if not torch.all(c1 == c2): print(cb_hook_key_to_info(k), c1.tolist(), c2.tolist()) print(cb_hook_key_to_info(k), c1.tolist(), c2.tolist()) def common_codes_in_cache(cache_codes, threshold=0.0): """Get the common code in the cache.""" codes, counts = torch.unique(cache_codes, return_counts=True, sorted=True) counts = counts.float() * 100 counts /= cache_codes.shape[1] counts, indices = torch.sort(counts, descending=True) codes = codes[indices] indices = counts > threshold codes, counts = codes[indices], counts[indices] return codes, counts def parse_topic_codes_string( info_str: str, pos: Optional[int] = None, code_append: Optional[bool] = False, **code_info_kwargs, ): """Parse the topic codes string.""" code_info_strs = info_str.strip().split("\n") code_info_strs = [e.strip() for e in code_info_strs if e] topic_codes = [] layer, head = None, None if code_append is None: code_pos = None else: code_pos = "append" if code_append else -1 for code_info_str in code_info_strs: topic_codes.append( CodeInfo.from_str( code_info_str, pos=pos, code_pos=code_pos, **code_info_kwargs, ) ) if code_append is None or code_append: continue if layer == topic_codes[-1].layer and head == topic_codes[-1].head: code_pos -= 1 # type: ignore else: code_pos = -1 topic_codes[-1].code_pos = code_pos layer, head = topic_codes[-1].layer, topic_codes[-1].head return topic_codes def find_similar_codes(cb_model, code_info, n=8): """Find the `n` most similar codes to the given code using cosine similarity. Useful for finding related codes for interpretability. """ codebook = cb_model.get_codebook(code_info) device = codebook.weight.device code = codebook(torch.tensor(code_info.code).to(device)) code = code.to(device) logits = torch.matmul(code, codebook.weight.T) _, indices = torch.topk(logits, n) assert indices[0] == code_info.code assert torch.allclose(logits[indices[0]], torch.tensor(1.0)) return indices[1:], logits[indices[1:]].tolist()