model-editing / models.py
Charles Lin
Add algorithms from efk codebase
e56055d
import transformers
import torch
import torch.nn as nn
import re
import logging
from nn import FixableDropout
from utils import scr
LOG = logging.getLogger(__name__)
class CastModule(nn.Module):
def __init__(self, module: nn.Module, in_cast: torch.dtype = torch.float32, out_cast: torch.dtype = None):
super().__init__()
self.underlying = module
self.in_cast = in_cast
self.out_cast = out_cast
def cast(self, obj, dtype):
if dtype is None:
return obj
if isinstance(obj, torch.Tensor):
return obj.to(dtype)
else:
return obj
def forward(self, *args, **kwargs):
args = tuple(self.cast(a, self.in_cast) for a in args)
kwargs = {k: self.cast(v, self.in_cast) for k, v in kwargs.items()}
outputs = self.underlying(*args, **kwargs)
if isinstance(outputs, torch.Tensor):
outputs = self.cast(outputs, self.out_cast)
elif isinstance(outputs, tuple):
outputs = tuple(self.cast(o, self.out_cast) for o in outputs)
else:
raise RuntimeError(f"Not sure how to cast type {type(outputs)}")
return outputs
def extra_repr(self):
return f"in_cast: {self.in_cast}\nout_cast: {self.out_cast}"
class BertClassifier(torch.nn.Module):
def __init__(self, model_name, hidden_dim=768):
super().__init__()
if model_name.startswith("bert"):
self.model = transformers.BertModel.from_pretrained(model_name, cache_dir=scr())
else:
self.model = transformers.AutoModel.from_pretrained(model_name, cache_dir=scr())
self.classifier = torch.nn.Linear(hidden_dim, 1)
@property
def config(self):
return self.model.config
def forward(self, *args, **kwargs):
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "labels"}
model_output = self.model(*args, **filtered_kwargs)
if "pooler_output" in model_output.keys():
pred = self.classifier(model_output.pooler_output)
else:
pred = self.classifier(model_output.last_hidden_state[:, 0])
if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]:
last_hidden_state = model_output.last_hidden_state
return pred, last_hidden_state
else:
return pred
def replace_dropout(model):
for m in model.modules():
for n, c in m.named_children():
if isinstance(c, nn.Dropout):
setattr(m, n, FixableDropout(c.p))
def resample(m, seed=None):
for c in m.children():
if hasattr(c, "resample"):
c.resample(seed)
else:
resample(c, seed)
model.resample_dropout = resample.__get__(model)
def get_model(config):
if config.model.class_name == "BertClassifier":
model = BertClassifier(config.model.name)
else:
ModelClass = getattr(transformers, config.model.class_name)
LOG.info(f"Loading model class {ModelClass} with name {config.model.name} from cache dir {scr()}")
model = ModelClass.from_pretrained(config.model.name, cache_dir=scr())
if config.model.pt is not None:
LOG.info(f"Loading model initialization from {config.model.pt}")
state_dict = torch.load(config.model.pt, map_location="cpu")
try:
model.load_state_dict(state_dict)
except RuntimeError:
LOG.info("Default load failed; stripping prefix and trying again.")
state_dict = {re.sub("^model.", "", k): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
LOG.info("Loaded model initialization")
if config.dropout is not None:
n_reset = 0
for m in model.modules():
if isinstance(m, nn.Dropout):
m.p = config.dropout
n_reset += 1
if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout
if isinstance(m.dropout, float):
m.dropout = config.dropout
n_reset += 1
if hasattr(m, "activation_dropout"): # Requires for BART, which uses F.dropout
if isinstance(m.activation_dropout, float):
m.activation_dropout = config.dropout
n_reset += 1
LOG.info(f"Set {n_reset} dropout modules to p={config.dropout}")
param_names = [n for n, _ in model.named_parameters()]
bad_inner_params = [p for p in config.model.inner_params if p not in param_names]
if len(bad_inner_params) != 0:
raise ValueError(f"Params {bad_inner_params} do not exist in model of type {type(model)}.")
if config.no_grad_layers is not None:
if config.half:
model.bfloat16()
def upcast(mod):
modlist = None
for child in mod.children():
if isinstance(child, nn.ModuleList):
assert modlist is None, f"Found multiple modlists for {mod}"
modlist = child
if modlist is None:
raise RuntimeError("Couldn't find a ModuleList child")
LOG.info(f"Setting {len(modlist) - config.no_grad_layers} modules to full precision, with autocasting")
modlist[config.no_grad_layers:].to(torch.float32)
modlist[config.no_grad_layers] = CastModule(modlist[config.no_grad_layers])
modlist[-1] = CastModule(modlist[-1], in_cast=torch.float32, out_cast=torch.bfloat16)
parents = []
if hasattr(model, "transformer"):
parents.append(model.transformer)
if hasattr(model, "encoder"):
parents.append(model.encoder)
if hasattr(model, "decoder"):
parents.append(model.decoder)
if hasattr(model, "model"):
parents.extend([model.model.encoder, model.model.decoder])
for t in parents:
t.no_grad_layers = config.no_grad_layers
if config.half and config.alg != "rep":
upcast(t)
if config.half and config.alg != "rep":
idxs = []
for p in config.model.inner_params:
for comp in p.split('.'):
if comp.isdigit():
idxs.append(int(comp))
max_idx, min_idx = str(max(idxs)), str(config.no_grad_layers)
for pidx, p in enumerate(config.model.inner_params):
comps = p.split('.')
if max_idx in comps or min_idx in comps:
index = comps.index(max_idx) if max_idx in comps else comps.index(min_idx)
comps.insert(index + 1, 'underlying')
new_p = '.'.join(comps)
LOG.info(f"Replacing config.model.inner_params[{pidx}] '{p}' -> '{new_p}'")
config.model.inner_params[pidx] = new_p
return model
def get_tokenizer(config):
tok_name = config.model.tokenizer_name if config.model.tokenizer_name is not None else config.model.name
return getattr(transformers, config.model.tokenizer_class).from_pretrained(tok_name, cache_dir=scr())
if __name__ == '__main__':
m = BertClassifier("bert-base-uncased")
m(torch.arange(5)[None, :])
import pdb; pdb.set_trace()