tuned-lens / lens_migration.py
levmckinney
skiping pythia 1.4b for now and decreaing atol
ab69474
raw
history blame
14.6 kB
#!/usr/bin/env python3
from huggingface_hub import model_info
import argparse
from copy import deepcopy
import inspect
from logging import warn
from pathlib import Path
from tqdm import tqdm
import json
from tuned_lens.model_surgery import get_final_norm, get_transformer_layers
from tuned_lens.load_artifacts import load_lens_artifacts
from tuned_lens.nn import TunedLens
from transformers.models.bloom.modeling_bloom import BloomBlock
from transformers import PreTrainedModel, AutoModelForCausalLM
from typing import Optional, Generator, Union
import torch as th
from tuned_lens.stats.distance import js_divergence
def instantiate_layer(model_config, layer_idx: int, model_type: str) -> th.nn.Module:
if model_type == "bloom":
from transformers.models.bloom.modeling_bloom import BloomBlock
return _BloomBlockWrapper(BloomBlock(model_config)) # type: ignore[arg-type]
if model_type == "gpt_neo":
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoBlock
return GPTNeoBlock(model_config, layer_idx)
if model_type == "gpt_neox":
from transformers.models.gpt_neox.modeling_gpt_neox import (
GPTNeoXLayer,
)
return GPTNeoXLayer(model_config) # type: ignore[arg-type]
if model_type == "gpt2":
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
return GPT2Block(model_config, layer_idx) # type: ignore[arg-type]
if model_type == "opt":
from transformers.models.opt.modeling_opt import OPTDecoderLayer
return OPTDecoderLayer(model_config) # type: ignore[arg-type]
else:
raise ValueError(f"Unknown model type '{model_type}'")
def maybe_wrap(layer: th.nn.Module) -> th.nn.Module:
return _BloomBlockWrapper(layer) if isinstance(layer, BloomBlock) else layer
# Very annoying that we have to do this. See https://bit.ly/3XSQ7W6 for context on
# what we're doing here.
class _BloomBlockWrapper(th.nn.Module):
def __init__(self, block: BloomBlock):
super().__init__()
self.block = block
def forward(self, x: th.Tensor) -> th.Tensor:
from transformers.models.bloom.modeling_bloom import (
BloomModel,
build_alibi_tensor,
)
batch_size, seq_len, _ = x.shape
dummy_mask = x.new_ones([batch_size, seq_len])
# Causal mask isn't created inside the block itself, so we have to do it here.
# Weirdly _prepare_attn_mask doesn't depend on `self` at all but is still an
# instance method for some reason, so we pass `None` as the first argument.
causal_mask = BloomModel._prepare_attn_mask(
None, dummy_mask, (batch_size, seq_len), 0 # type: ignore[arg-type]
)
alibi = build_alibi_tensor(dummy_mask, self.block.num_heads, x.dtype)
h, *_ = self.block(x, alibi, causal_mask)
return h
class TunedLensOld(th.nn.Module):
"""A tuned lens for decoding hidden states into logits."""
layer_norm: th.nn.LayerNorm
unembedding: th.nn.Linear
extra_layers: th.nn.Sequential
layer_translators: th.nn.ModuleList
def __init__(
self,
model: Optional[PreTrainedModel] = None,
*,
bias: bool = True,
extra_layers: int = 0,
include_input: bool = True,
reuse_unembedding: bool = True,
# Used when saving and loading the lens
model_config: Optional[dict] = None,
d_model: Optional[int] = None,
num_layers: Optional[int] = None,
vocab_size: Optional[int] = None,
):
"""Create a TunedLensOld.
Args:
model : A pertained model from the transformers library you wish to inspect.
bias : Whether to include a bias term in the translator layers.
extra_layers : The number of extra layers to apply to the hidden states
before decoding into logits.
include_input : Whether to include a lens that decodes the word embeddings.
reuse_unembedding : Weather to reuse the unembedding matrix from the model.
model_config : The config of the model. Used for saving and loading.
d_model : The models hidden size. Used for saving and loading.
num_layers : The number of layers in the model. Used for saving and loading.
vocab_size : The size of the vocabulary. Used for saving and loading.
Raises:
ValueError: if neither a model or d_model, num_layers, and vocab_size,
are provided.
"""
super().__init__()
self.extra_layers = th.nn.Sequential()
if (
model
is None
== (d_model is None or num_layers is None or vocab_size is None)
):
raise ValueError(
"Must provide either a model or d_model, num_layers, and vocab_size"
)
# Initializing from scratch without a model
if not model:
assert d_model and num_layers and vocab_size
self.layer_norm = th.nn.LayerNorm(d_model)
self.unembedding = th.nn.Linear(d_model, vocab_size, bias=False)
# Use HuggingFace methods to get decoder layers
else:
assert not (d_model or num_layers or vocab_size)
d_model = model.config.hidden_size
num_layers = model.config.num_hidden_layers
vocab_size = model.config.vocab_size
assert isinstance(d_model, int) and isinstance(vocab_size, int)
model_config = model.config.to_dict() # type: ignore[F841]
# Currently we convert the decoder to full precision
self.unembedding = deepcopy(model.get_output_embeddings()).float()
if ln := get_final_norm(model):
self.layer_norm = deepcopy(ln).float()
else:
self.layer_norm = th.nn.Identity()
if extra_layers:
_, layers = get_transformer_layers(model)
self.extra_layers.extend(
[maybe_wrap(layer) for layer in layers[-extra_layers:]]
)
# Save config for later
config_keys = set(inspect.getfullargspec(TunedLensOld).kwonlyargs)
self.config = {k: v for k, v in locals().items() if k in config_keys}
del model_config
# Try to prevent finetuning the decoder
assert d_model and num_layers
self.layer_norm.requires_grad_(False)
self.unembedding.requires_grad_(False)
out_features = d_model if reuse_unembedding else vocab_size
translator = th.nn.Linear(d_model, out_features, bias=bias)
if not reuse_unembedding:
translator.weight.data = self.unembedding.weight.data.clone()
translator.bias.data.zero_()
else:
translator.weight.data.zero_()
translator.bias.data.zero_()
self.add_module("input_translator", translator if include_input else None)
# Don't include the final layer
num_layers -= 1
self.layer_translators = th.nn.ModuleList(
[deepcopy(translator) for _ in range(num_layers)]
)
def __getitem__(self, item: int) -> th.nn.Module:
"""Get the probe module at the given index."""
if isinstance(self.input_translator, th.nn.Module):
if item == 0:
return self.input_translator
else:
item -= 1
return self.layer_translators[item]
def __iter__(self) -> Generator[th.nn.Module, None, None]:
"""Get iterator over the translators within the lens."""
if isinstance(self.input_translator, th.nn.Module):
yield self.input_translator
yield from self.layer_translators
@classmethod
def load(cls, resource_id: str, **kwargs) -> "TunedLensOld":
"""Load a tuned lens from a or hugging face hub.
Args:
resource_id : The path to the directory containing the config and checkpoint
or the name of the model on the hugging face hub.
**kwargs : Additional arguments to pass to torch.load.
Returns:
A TunedLensOld instance.
"""
config_path, ckpt_path = load_lens_artifacts(resource_id)
# Load config
with open(config_path, "r") as f:
config = json.load(f)
# Load parameters
state = th.load(ckpt_path, **kwargs)
# Backwards compatibility we really need to stop renaming things
keys = list(state.keys())
for key in keys:
for old_key in ["probe", "adapter"]:
if old_key in key:
warn(
f"Loading a checkpoint with a '{old_key}' key. "
"This is deprecated and may be removed in a future version. "
)
new_key = key.replace(old_key, "translator")
state[new_key] = state.pop(key)
# Drop unrecognized config keys
unrecognized = set(config) - set(inspect.getfullargspec(cls).kwonlyargs)
for key in unrecognized:
warn(f"Ignoring config key '{key}'")
del config[key]
lens = cls(**config)
if num_extras := config.get("extra_layers"):
# This is sort of a hack but AutoConfig doesn't appear to have a from_dict
# for some reason.
from transformers.models.auto import CONFIG_MAPPING
model_conf_dict = config.get("model_config")
del model_conf_dict["torch_dtype"]
assert model_conf_dict, "Need a 'model_config' entry to load extra layers"
model_type = model_conf_dict["model_type"]
config_cls = CONFIG_MAPPING[model_type]
model_config = config_cls.from_dict(model_conf_dict)
lens.extra_layers = th.nn.Sequential(
*[
instantiate_layer(
model_config, model_config.num_hidden_layers - i - 1, model_type
)
for i in range(num_extras)
]
)
lens.load_state_dict(state)
return lens
def save(
self,
path: Union[Path, str],
ckpt: str = "params.pt",
config: str = "config.json",
) -> None:
"""Save the lens to a directory.
Args:
path : The path to the directory to save the lens to.
ckpt : The name of the checkpoint file to save the parameters to.
config : The name of the config file to save the config to.
"""
path = Path(path)
path.mkdir(exist_ok=True, parents=True)
th.save(self.state_dict(), path / ckpt)
with open(path / config, "w") as f:
json.dump(self.config, f)
def normalize_(self):
"""Canonicalize the transforms by centering their weights and biases."""
for linear in self:
assert isinstance(linear, th.nn.Linear)
A, b = linear.weight.data, linear.bias.data
A -= A.mean(dim=0, keepdim=True)
b -= b.mean()
def transform_hidden(self, h: th.Tensor, idx: int) -> th.Tensor:
"""Transform hidden state from layer `idx`."""
if not self.config["reuse_unembedding"]:
raise RuntimeError("TunedLensOld.transform_hidden requires reuse_unembedding")
# Note that we add the translator output residually, in contrast to the formula
# in the paper. By parametrizing it this way we ensure that weight decay
# regularizes the transform toward the identity, not the zero transformation.
return h + self[idx](h)
def to_logits(self, h: th.Tensor) -> th.Tensor:
"""Decode a hidden state into logits."""
h = self.extra_layers(h)
while isinstance(h, tuple):
h, *_ = h
return self.unembedding(self.layer_norm(h))
def forward(self, h: th.Tensor, idx: int) -> th.Tensor:
"""Transform and then decode the hidden states into logits."""
# Sanity check to make sure we don't finetune the decoder
# if any(p.requires_grad for p in self.parameters(recurse=False)):
# raise RuntimeError("Make sure to freeze the decoder")
# We're learning a separate unembedding for each layer
if not self.config["reuse_unembedding"]:
h_ = self.layer_norm(h)
return self[idx](h_)
h = self.transform_hidden(h, idx)
return self.to_logits(h)
def __len__(self) -> int:
"""Return the number of layer translators in the lens."""
N = len(self.layer_translators)
if self.input_translator:
N += 1
return N
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="gpt2")
parser.add_argument("--resource-id", type=str, default="gpt2")
parser.add_argument("--output-dir", type=str, default="lens/gpt2")
args = parser.parse_args()
model = AutoModelForCausalLM.from_pretrained(args.model)
revision = model_info(args.model).sha
model.eval()
model.requires_grad_(False)
device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
print("Loading old lens")
tuned_lens_old = TunedLensOld.load(args.resource_id, map_location=device)
print("Initializing new lens")
tuned_lens = TunedLens.from_model(
model, bias=tuned_lens_old.config['bias'], revision=revision
)
for i in tqdm(range(len(tuned_lens_old)), desc="Copying parameters"):
tuned_lens[i].load_state_dict(tuned_lens_old[i].state_dict())
tuned_lens = tuned_lens.to(device)
tuned_lens_old = tuned_lens_old.to(device)
model = model.to(device)
# Fuzz the new lens against the old one's
with th.no_grad():
for i in tqdm(range(len(tuned_lens)), desc="Fuzzing layers"):
for _ in range(10):
a = th.randn(1, 1, tuned_lens.config.d_model, device=device)
logits_new = tuned_lens(a, i)
logits_old = tuned_lens_old(a, i)
log_ps_new = logits_new.log_softmax(-1)
log_ps_old = logits_old.log_softmax(-1)
print("js div", js_divergence(log_ps_new, log_ps_old))
assert (th.allclose(log_ps_new, log_ps_old, atol=1e-7)), (log_ps_new - log_ps_old).abs().max()
print("Saving new lens to", args.output_dir)
tuned_lens.to(th.device("cpu")).save(args.output_dir)