codebook-features / utils.py
taufeeque's picture
Update code
63b5bc1
"""Util functions for codebook features."""
import pathlib
import re
import typing
from dataclasses import dataclass
from functools import partial
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from termcolor import colored
from tqdm import tqdm
@dataclass
class CodeInfo:
"""Dataclass for codebook info."""
code: int
layer: int
head: Optional[int]
cb_at: Optional[str] = None
# for patching interventions
pos: Optional[int] = None
code_pos: Optional[int] = -1
# for description & regex-based interpretation
description: Optional[str] = None
regex: Optional[str] = None
prec: Optional[float] = None
recall: Optional[float] = None
num_acts: Optional[int] = None
def __post_init__(self):
"""Convert to appropriate types."""
self.code = int(self.code)
self.layer = int(self.layer)
if self.head:
self.head = int(self.head)
if self.pos:
self.pos = int(self.pos)
if self.code_pos:
self.code_pos = int(self.code_pos)
if self.prec:
self.prec = float(self.prec)
assert 0 <= self.prec <= 1
if self.recall:
self.recall = float(self.recall)
assert 0 <= self.recall <= 1
if self.num_acts:
self.num_acts = int(self.num_acts)
def check_description_info(self):
"""Check if the regex info is present."""
assert self.num_acts is not None and self.description is not None
if self.regex is not None:
assert self.prec is not None and self.recall is not None
def __repr__(self):
"""Return the string representation."""
repr = f"CodeInfo(code={self.code}, layer={self.layer}, head={self.head}, cb_at={self.cb_at}"
if self.pos is not None or self.code_pos is not None:
repr += f", pos={self.pos}, code_pos={self.code_pos}"
if self.description is not None:
repr += f", description={self.description}"
if self.regex is not None:
repr += f", regex={self.regex}, prec={self.prec}, recall={self.recall}"
if self.num_acts is not None:
repr += f", num_acts={self.num_acts}"
repr += ")"
return repr
@classmethod
def from_str(cls, code_txt, *args, **kwargs):
"""Extract code info fields from string."""
code_txt = code_txt.strip().lower()
code_txt = code_txt.split(", ")
code_txt = dict(txt.split(": ") for txt in code_txt)
return cls(*args, **code_txt, **kwargs)
@dataclass
class ModelInfoForWebapp:
"""Model info for webapp."""
model_name: str
pretrained_path: str
dataset_name: str
num_codes: int
cb_at: str
gcb: str
n_layers: int
n_heads: Optional[int] = None
seed: int = 42
max_samples: int = 2000
def __post_init__(self):
"""Convert to correct types."""
self.num_codes = int(self.num_codes)
self.n_layers = int(self.n_layers)
if self.n_heads == "None":
self.n_heads = None
elif self.n_heads is not None:
self.n_heads = int(self.n_heads)
self.seed = int(self.seed)
self.max_samples = int(self.max_samples)
@classmethod
def load(cls, path):
"""Parse model info from path."""
path = pathlib.Path(path)
with open(path / "info.txt", "r") as f:
lines = f.readlines()
lines = dict(line.strip().split(": ") for line in lines)
return cls(**lines)
def save(self, path):
"""Save model info to path."""
path = pathlib.Path(path)
with open(path / "info.txt", "w") as f:
for k, v in self.__dict__.items():
f.write(f"{k}: {v}\n")
def logits_to_pred(logits, tokenizer, k=5):
"""Convert logits to top-k predictions."""
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
probs = sorted_logits.softmax(dim=-1)
topk_preds = [tokenizer.convert_ids_to_tokens(e) for e in sorted_indices[:, -1, :k]]
topk_preds = [
tokenizer.convert_tokens_to_string([e]) for batch in topk_preds for e in batch
]
return [(topk_preds[i], probs[:, -1, i].item()) for i in range(len(topk_preds))]
def features_to_tokens(cb_key, cb_acts, num_codes, code=None):
"""Return the set of token ids each codebook feature activates on."""
codebook_ids = cb_acts[cb_key]
if code is None:
features_tokens = [[] for _ in range(num_codes)]
for i in tqdm(range(codebook_ids.shape[0])):
for j in range(codebook_ids.shape[1]):
for k in range(codebook_ids.shape[2]):
features_tokens[codebook_ids[i, j, k]].append((i, j))
else:
idx0, idx1, _ = np.where(codebook_ids == code)
features_tokens = list(zip(idx0, idx1))
return features_tokens
def color_str(s: str, html: bool, color: Optional[str] = None):
"""Color the string for html or terminal."""
if html:
color = "DeepSkyBlue" if color is None else color
return f"<span style='color:{color}'>{s}</span>"
else:
color = "light_cyan" if color is None else color
return colored(s, color)
def color_tokens_tokfsm(tokens, color_idx, html=False):
"""Separate states with a dash and color red the tokens in color_idx."""
ret_string = ""
itr_over_color_idx = 0
tokens_enumerate = enumerate(tokens)
if tokens[0] == "<|endoftext|>":
next(tokens_enumerate)
if color_idx[0] == 0:
itr_over_color_idx += 1
for i, c in tokens_enumerate:
if i % 2 == 1:
ret_string += "-"
if itr_over_color_idx < len(color_idx) and i == color_idx[itr_over_color_idx]:
ret_string += color_str(c, html)
itr_over_color_idx += 1
else:
ret_string += c
return ret_string
def color_tokens(tokens, color_idx, n=3, html=False):
"""Color the tokens in color_idx."""
ret_string = ""
last_colored_token_idx = -1
for i in color_idx:
c_str = tokens[i]
if i <= last_colored_token_idx + 2 * n + 1:
ret_string += "".join(tokens[last_colored_token_idx + 1 : i])
else:
ret_string += "".join(
tokens[last_colored_token_idx + 1 : last_colored_token_idx + n + 1]
)
ret_string += " ... "
ret_string += "".join(tokens[i - n : i])
ret_string += color_str(c_str, html)
last_colored_token_idx = i
ret_string += "".join(
tokens[
last_colored_token_idx + 1 : min(last_colored_token_idx + n, len(tokens))
]
)
return ret_string
def prepare_example_print(
example_id,
example_tokens,
tokens_to_color,
html,
color_fn=color_tokens,
):
"""Format example to print."""
example_output = color_str(example_id, html, "green")
example_output += (
": "
+ color_fn(example_tokens, tokens_to_color, html=html)
+ ("<br>" if html else "\n")
)
return example_output
def print_token_activations_of_code(
code_act_by_pos,
tokens,
is_fsm=False,
n=3,
max_examples=100,
randomize=False,
html=False,
return_example_list=False,
):
"""Print the context with the tokens that a code activates on.
Args:
code_act_by_pos: list of (example_id, token_pos_id) tuples specifying
the token positions that a code activates on in a dataset.
tokens: list of tokens of a dataset.
is_fsm: whether the dataset is the TokFSM dataset.
n: context to print around each side of a token that the code activates on.
max_examples: maximum number of examples to print.
randomize: whether to randomize the order of examples.
html: Format the printing style for html or terminal.
return_example_list: whether to return the printed string by examples or as a single string.
Returns:
string of all examples formatted if `return_example_list` is False otherwise
list of (example_string, num_tokens_colored) tuples for each example.
"""
if randomize:
raise NotImplementedError("Randomize not yet implemented.")
indices = range(len(code_act_by_pos))
print_output = [] if return_example_list else ""
curr_ex = code_act_by_pos[0][0]
total_examples = 0
tokens_to_color = []
color_fn = color_tokens_tokfsm if is_fsm else partial(color_tokens, n=n)
for idx in indices:
if total_examples > max_examples:
break
i, j = code_act_by_pos[idx]
if i != curr_ex and curr_ex >= 0:
# got new example so print the previous one
curr_ex_output = prepare_example_print(
curr_ex,
tokens[curr_ex],
tokens_to_color,
html,
color_fn,
)
total_examples += 1
if return_example_list:
print_output.append((curr_ex_output, len(tokens_to_color)))
else:
print_output += curr_ex_output
curr_ex = i
tokens_to_color = []
tokens_to_color.append(j)
curr_ex_output = prepare_example_print(
curr_ex,
tokens[curr_ex],
tokens_to_color,
html,
color_fn,
)
if return_example_list:
print_output.append((curr_ex_output, len(tokens_to_color)))
else:
print_output += curr_ex_output
print_output += color_str("*" * 50, html, "green")
total_examples += 1
return print_output
def print_token_activations_of_codes(
ft_tkns,
tokens,
is_fsm=False,
n=3,
start=0,
stop=1000,
indices=None,
max_examples=100,
freq_filter=None,
randomize=False,
html=False,
return_example_list=False,
):
"""Print the tokens for the codebook features."""
indices = list(range(start, stop)) if indices is None else indices
num_tokens = len(tokens) * len(tokens[0])
codes, token_act_freqs, token_acts = [], [], []
for i in indices:
tkns_of_code = ft_tkns[i]
freq = (len(tkns_of_code), 100 * len(tkns_of_code) / num_tokens)
if freq_filter is not None and freq[1] > freq_filter:
continue
codes.append(i)
token_act_freqs.append(freq)
if len(tkns_of_code) > 0:
tkn_acts = print_token_activations_of_code(
tkns_of_code,
tokens,
is_fsm,
n=n,
max_examples=max_examples,
randomize=randomize,
html=html,
return_example_list=return_example_list,
)
token_acts.append(tkn_acts)
else:
token_acts.append("")
return codes, token_act_freqs, token_acts
def patch_in_codes(run_cb_ids, hook, pos, code, code_pos=None):
"""Patch in the `code` at `run_cb_ids`."""
pos = slice(None) if pos is None else pos
code_pos = slice(None) if code_pos is None else code_pos
if code_pos == "append":
assert pos == slice(None)
run_cb_ids = F.pad(run_cb_ids, (0, 1), mode="constant", value=code)
if isinstance(pos, typing.Iterable) or isinstance(pos, typing.Iterable):
for p in pos:
run_cb_ids[:, p, code_pos] = code
else:
run_cb_ids[:, pos, code_pos] = code
return run_cb_ids
def get_cb_hook_key(cb_at: str, layer_idx: int, gcb_idx: Optional[int] = None):
"""Get the layer name used to store hooks/cache."""
comp_name = "attn" if "attn" in cb_at else "mlp"
if gcb_idx is None:
return f"blocks.{layer_idx}.{comp_name}.codebook_layer.hook_codebook_ids"
else:
return f"blocks.{layer_idx}.{comp_name}.codebook_layer.codebook.{gcb_idx}.hook_codebook_ids"
def run_model_fn_with_codes(
input,
cb_model,
fn_name,
fn_kwargs=None,
list_of_code_infos=(),
):
"""Run the `cb_model`'s `fn_name` method while activating the codes in `list_of_code_infos`.
Common use case includes running the `run_with_cache` method while activating the codes.
For running the `generate` method, use `generate_with_codes` instead.
"""
if fn_kwargs is None:
fn_kwargs = {}
hook_fns = [
partial(patch_in_codes, pos=tupl.pos, code=tupl.code, code_pos=tupl.code_pos)
for tupl in list_of_code_infos
]
fwd_hooks = [
(get_cb_hook_key(tupl.cb_at, tupl.layer, tupl.head), hook_fns[i])
for i, tupl in enumerate(list_of_code_infos)
]
cb_model.reset_hook_kwargs()
with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model:
ret = hooked_model.__getattribute__(fn_name)(input, **fn_kwargs)
return ret
def generate_with_codes(
input,
cb_model,
list_of_code_infos=(),
tokfsm=None,
generate_kwargs=None,
):
"""Sample from the language model while activating the codes in `list_of_code_infos`."""
gen = run_model_fn_with_codes(
input,
cb_model,
"generate",
generate_kwargs,
list_of_code_infos,
)
return tokfsm.seq_to_traj(gen) if tokfsm is not None else gen
def JSD(logits1, logits2, pos=-1, reduction="batchmean"):
"""Compute the Jensen-Shannon divergence between two distributions."""
if len(logits1.shape) == 3:
logits1, logits2 = logits1[:, pos, :], logits2[:, pos, :]
probs1 = F.softmax(logits1, dim=-1)
probs2 = F.softmax(logits2, dim=-1)
total_m = (0.5 * (probs1 + probs2)).log()
loss = 0.0
loss += F.kl_div(
total_m,
F.log_softmax(logits1, dim=-1),
log_target=True,
reduction=reduction,
)
loss += F.kl_div(
total_m,
F.log_softmax(logits2, dim=-1),
log_target=True,
reduction=reduction,
)
return 0.5 * loss
def cb_hook_key_to_info(layer_hook_key: str):
"""Get the layer info from the codebook layer hook key.
Args:
layer_hook_key: the hook key of the codebook layer.
E.g. `blocks.3.attn.codebook_layer.hook_codebook_ids`
Returns:
comp_name: the name of the component codebook is appied at.
layer_idx: the layer index.
gcb_idx: the codebook index if the codebook layer is grouped, otherwise None.
"""
layer_search = re.search(r"blocks\.(\d+)\.(\w+)\.", layer_hook_key)
assert layer_search is not None
layer_idx, comp_name = int(layer_search.group(1)), layer_search.group(2)
gcb_idx_search = re.search(r"codebook\.(\d+)", layer_hook_key)
if gcb_idx_search is not None:
gcb_idx = int(gcb_idx_search.group(1))
else:
gcb_idx = None
return comp_name, layer_idx, gcb_idx
def find_code_changes(cache1, cache2, pos=None):
"""Find the codebook codes that are different between the two caches."""
for k in cache1.keys():
if "codebook" in k:
c1 = cache1[k][0, pos]
c2 = cache2[k][0, pos]
if not torch.all(c1 == c2):
print(cb_hook_key_to_info(k), c1.tolist(), c2.tolist())
print(cb_hook_key_to_info(k), c1.tolist(), c2.tolist())
def common_codes_in_cache(cache_codes, threshold=0.0):
"""Get the common code in the cache."""
codes, counts = torch.unique(cache_codes, return_counts=True, sorted=True)
counts = counts.float() * 100
counts /= cache_codes.shape[1]
counts, indices = torch.sort(counts, descending=True)
codes = codes[indices]
indices = counts > threshold
codes, counts = codes[indices], counts[indices]
return codes, counts
def parse_topic_codes_string(
info_str: str,
pos: Optional[int] = None,
code_append: Optional[bool] = False,
**code_info_kwargs,
):
"""Parse the topic codes string."""
code_info_strs = info_str.strip().split("\n")
code_info_strs = [e.strip() for e in code_info_strs if e]
topic_codes = []
layer, head = None, None
if code_append is None:
code_pos = None
else:
code_pos = "append" if code_append else -1
for code_info_str in code_info_strs:
topic_codes.append(
CodeInfo.from_str(
code_info_str,
pos=pos,
code_pos=code_pos,
**code_info_kwargs,
)
)
if code_append is None or code_append:
continue
if layer == topic_codes[-1].layer and head == topic_codes[-1].head:
code_pos -= 1 # type: ignore
else:
code_pos = -1
topic_codes[-1].code_pos = code_pos
layer, head = topic_codes[-1].layer, topic_codes[-1].head
return topic_codes
def find_similar_codes(cb_model, code_info, n=8):
"""Find the `n` most similar codes to the given code using cosine similarity.
Useful for finding related codes for interpretability.
"""
codebook = cb_model.get_codebook(code_info)
device = codebook.weight.device
code = codebook(torch.tensor(code_info.code).to(device))
code = code.to(device)
logits = torch.matmul(code, codebook.weight.T)
_, indices = torch.topk(logits, n)
assert indices[0] == code_info.code
assert torch.allclose(logits[indices[0]], torch.tensor(1.0))
return indices[1:], logits[indices[1:]].tolist()