import torch import esm from argparse import Namespace import pathlib import urllib def length_to_mask(length, max_len=None, dtype=None): """length: B. return B x max_len. If max_len is None, then max of length will be used. """ assert len(length.shape) == 1, 'Length shape should be 1 dimensional.' max_len = max_len or length.max().item() mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1) if dtype is not None: mask = torch.as_tensor(mask, dtype=dtype, device=length.device) return mask def load_model_and_alphabet_core(args_dict, regression_data=None): args_dict = torch.load(args_dict) alphabet = esm.Alphabet.from_architecture(args_dict["args"].arch) # upgrade state dict pra = lambda s: "".join(s.split("decoder_")[1:] if "decoder" in s else s) prs = lambda s: "".join(s.split("decoder.")[1:] if "decoder" in s else s) model_args = {pra(arg[0]): arg[1] for arg in vars(args_dict["args"]).items()} model_type = esm.ProteinBertModel model = model_type( Namespace(**model_args), alphabet, ) return model, alphabet def load_hub_workaround(url): try: data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu") except RuntimeError: # Pytorch version issue - see https://github.com/pytorch/pytorch/issues/43106 fn = pathlib.Path(url).name data = torch.load( f"{torch.hub.get_dir()}/checkpoints/{fn}", map_location="cpu", ) except urllib.error.HTTPError as e: raise Exception(f"Could not load {url}, check your network!") return data