Spaces:
Runtime error
Runtime error
import transformers | |
import torch | |
import torch.nn as nn | |
import re | |
import logging | |
from nn import FixableDropout | |
from utils import scr | |
LOG = logging.getLogger(__name__) | |
class CastModule(nn.Module): | |
def __init__(self, module: nn.Module, in_cast: torch.dtype = torch.float32, out_cast: torch.dtype = None): | |
super().__init__() | |
self.underlying = module | |
self.in_cast = in_cast | |
self.out_cast = out_cast | |
def cast(self, obj, dtype): | |
if dtype is None: | |
return obj | |
if isinstance(obj, torch.Tensor): | |
return obj.to(dtype) | |
else: | |
return obj | |
def forward(self, *args, **kwargs): | |
args = tuple(self.cast(a, self.in_cast) for a in args) | |
kwargs = {k: self.cast(v, self.in_cast) for k, v in kwargs.items()} | |
outputs = self.underlying(*args, **kwargs) | |
if isinstance(outputs, torch.Tensor): | |
outputs = self.cast(outputs, self.out_cast) | |
elif isinstance(outputs, tuple): | |
outputs = tuple(self.cast(o, self.out_cast) for o in outputs) | |
else: | |
raise RuntimeError(f"Not sure how to cast type {type(outputs)}") | |
return outputs | |
def extra_repr(self): | |
return f"in_cast: {self.in_cast}\nout_cast: {self.out_cast}" | |
class BertClassifier(torch.nn.Module): | |
def __init__(self, model_name, hidden_dim=768): | |
super().__init__() | |
if model_name.startswith("bert"): | |
self.model = transformers.BertModel.from_pretrained(model_name, cache_dir=scr()) | |
else: | |
self.model = transformers.AutoModel.from_pretrained(model_name, cache_dir=scr()) | |
self.classifier = torch.nn.Linear(hidden_dim, 1) | |
def config(self): | |
return self.model.config | |
def forward(self, *args, **kwargs): | |
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "labels"} | |
model_output = self.model(*args, **filtered_kwargs) | |
if "pooler_output" in model_output.keys(): | |
pred = self.classifier(model_output.pooler_output) | |
else: | |
pred = self.classifier(model_output.last_hidden_state[:, 0]) | |
if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]: | |
last_hidden_state = model_output.last_hidden_state | |
return pred, last_hidden_state | |
else: | |
return pred | |
def replace_dropout(model): | |
for m in model.modules(): | |
for n, c in m.named_children(): | |
if isinstance(c, nn.Dropout): | |
setattr(m, n, FixableDropout(c.p)) | |
def resample(m, seed=None): | |
for c in m.children(): | |
if hasattr(c, "resample"): | |
c.resample(seed) | |
else: | |
resample(c, seed) | |
model.resample_dropout = resample.__get__(model) | |
def get_model(config): | |
if config.model.class_name == "BertClassifier": | |
model = BertClassifier(config.model.name) | |
else: | |
ModelClass = getattr(transformers, config.model.class_name) | |
LOG.info(f"Loading model class {ModelClass} with name {config.model.name} from cache dir {scr()}") | |
model = ModelClass.from_pretrained(config.model.name, cache_dir=scr()) | |
if config.model.pt is not None: | |
LOG.info(f"Loading model initialization from {config.model.pt}") | |
state_dict = torch.load(config.model.pt, map_location="cpu") | |
try: | |
model.load_state_dict(state_dict) | |
except RuntimeError: | |
LOG.info("Default load failed; stripping prefix and trying again.") | |
state_dict = {re.sub("^model.", "", k): v for k, v in state_dict.items()} | |
model.load_state_dict(state_dict) | |
LOG.info("Loaded model initialization") | |
if config.dropout is not None: | |
n_reset = 0 | |
for m in model.modules(): | |
if isinstance(m, nn.Dropout): | |
m.p = config.dropout | |
n_reset += 1 | |
if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout | |
if isinstance(m.dropout, float): | |
m.dropout = config.dropout | |
n_reset += 1 | |
if hasattr(m, "activation_dropout"): # Requires for BART, which uses F.dropout | |
if isinstance(m.activation_dropout, float): | |
m.activation_dropout = config.dropout | |
n_reset += 1 | |
LOG.info(f"Set {n_reset} dropout modules to p={config.dropout}") | |
param_names = [n for n, _ in model.named_parameters()] | |
bad_inner_params = [p for p in config.model.inner_params if p not in param_names] | |
if len(bad_inner_params) != 0: | |
raise ValueError(f"Params {bad_inner_params} do not exist in model of type {type(model)}.") | |
if config.no_grad_layers is not None: | |
if config.half: | |
model.bfloat16() | |
def upcast(mod): | |
modlist = None | |
for child in mod.children(): | |
if isinstance(child, nn.ModuleList): | |
assert modlist is None, f"Found multiple modlists for {mod}" | |
modlist = child | |
if modlist is None: | |
raise RuntimeError("Couldn't find a ModuleList child") | |
LOG.info(f"Setting {len(modlist) - config.no_grad_layers} modules to full precision, with autocasting") | |
modlist[config.no_grad_layers:].to(torch.float32) | |
modlist[config.no_grad_layers] = CastModule(modlist[config.no_grad_layers]) | |
modlist[-1] = CastModule(modlist[-1], in_cast=torch.float32, out_cast=torch.bfloat16) | |
parents = [] | |
if hasattr(model, "transformer"): | |
parents.append(model.transformer) | |
if hasattr(model, "encoder"): | |
parents.append(model.encoder) | |
if hasattr(model, "decoder"): | |
parents.append(model.decoder) | |
if hasattr(model, "model"): | |
parents.extend([model.model.encoder, model.model.decoder]) | |
for t in parents: | |
t.no_grad_layers = config.no_grad_layers | |
if config.half and config.alg != "rep": | |
upcast(t) | |
if config.half and config.alg != "rep": | |
idxs = [] | |
for p in config.model.inner_params: | |
for comp in p.split('.'): | |
if comp.isdigit(): | |
idxs.append(int(comp)) | |
max_idx, min_idx = str(max(idxs)), str(config.no_grad_layers) | |
for pidx, p in enumerate(config.model.inner_params): | |
comps = p.split('.') | |
if max_idx in comps or min_idx in comps: | |
index = comps.index(max_idx) if max_idx in comps else comps.index(min_idx) | |
comps.insert(index + 1, 'underlying') | |
new_p = '.'.join(comps) | |
LOG.info(f"Replacing config.model.inner_params[{pidx}] '{p}' -> '{new_p}'") | |
config.model.inner_params[pidx] = new_p | |
return model | |
def get_tokenizer(config): | |
tok_name = config.model.tokenizer_name if config.model.tokenizer_name is not None else config.model.name | |
return getattr(transformers, config.model.tokenizer_class).from_pretrained(tok_name, cache_dir=scr()) | |
if __name__ == '__main__': | |
m = BertClassifier("bert-base-uncased") | |
m(torch.arange(5)[None, :]) | |
import pdb; pdb.set_trace() | |