|
|
|
|
|
|
|
|
|
|
|
import re |
|
import urllib |
|
import warnings |
|
from argparse import Namespace |
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
import esm |
|
from esm.model.esm2 import ESM2 |
|
|
|
|
|
def _has_regression_weights(model_name): |
|
"""Return whether we expect / require regression weights; |
|
Right now that is all models except ESM-1v, ESM-IF, and partially trained ESM2 models""" |
|
return not ("esm1v" in model_name or "esm_if" in model_name or "270K" in model_name or "500K" in model_name) |
|
|
|
|
|
def load_model_and_alphabet(model_name): |
|
if model_name.endswith(".pt"): |
|
return load_model_and_alphabet_local(model_name) |
|
else: |
|
return load_model_and_alphabet_hub(model_name) |
|
|
|
|
|
def load_hub_workaround(url): |
|
try: |
|
data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu") |
|
except RuntimeError: |
|
|
|
fn = Path(url).name |
|
data = torch.load( |
|
f"{torch.hub.get_dir()}/checkpoints/{fn}", |
|
map_location="cpu", |
|
) |
|
except urllib.error.HTTPError as e: |
|
raise Exception(f"Could not load {url}, check if you specified a correct model name?") |
|
return data |
|
|
|
|
|
def load_regression_hub(model_name): |
|
url = f"https://dl.fbaipublicfiles.com/fair-esm/regression/{model_name}-contact-regression.pt" |
|
regression_data = load_hub_workaround(url) |
|
return regression_data |
|
|
|
|
|
def _download_model_and_regression_data(model_name): |
|
url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt" |
|
model_data = load_hub_workaround(url) |
|
if _has_regression_weights(model_name): |
|
regression_data = load_regression_hub(model_name) |
|
else: |
|
regression_data = None |
|
return model_data, regression_data |
|
|
|
|
|
def load_model_and_alphabet_hub(model_name): |
|
model_data, regression_data = _download_model_and_regression_data(model_name) |
|
return load_model_and_alphabet_core(model_name, model_data, regression_data) |
|
|
|
|
|
def load_model_and_alphabet_local(model_location): |
|
"""Load from local path. The regression weights need to be co-located""" |
|
model_location = Path(model_location) |
|
model_data = torch.load(str(model_location), map_location="cpu") |
|
model_name = model_location.stem |
|
if _has_regression_weights(model_name): |
|
regression_location = str(model_location.with_suffix("")) + "-contact-regression.pt" |
|
regression_data = torch.load(regression_location, map_location="cpu") |
|
else: |
|
regression_data = None |
|
return load_model_and_alphabet_core(model_name, model_data, regression_data) |
|
|
|
|
|
def has_emb_layer_norm_before(model_state): |
|
"""Determine whether layer norm needs to be applied before the encoder""" |
|
return any(k.startswith("emb_layer_norm_before") for k, param in model_state.items()) |
|
|
|
|
|
def _load_model_and_alphabet_core_v1(model_data): |
|
import esm |
|
|
|
alphabet = esm.Alphabet.from_architecture(model_data["args"].arch) |
|
|
|
if model_data["args"].arch == "roberta_large": |
|
|
|
pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s) |
|
prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s) |
|
prs2 = lambda s: "".join( |
|
s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s |
|
) |
|
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()} |
|
model_state = {prs1(prs2(arg[0])): arg[1] for arg in model_data["model"].items()} |
|
model_state["embed_tokens.weight"][alphabet.mask_idx].zero_() |
|
model_args["emb_layer_norm_before"] = has_emb_layer_norm_before(model_state) |
|
model_type = esm.ProteinBertModel |
|
|
|
elif model_data["args"].arch == "protein_bert_base": |
|
|
|
|
|
pra = lambda s: "".join(s.split("decoder_")[1:] if "decoder" in s else s) |
|
prs = lambda s: "".join(s.split("decoder.")[1:] if "decoder" in s else s) |
|
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()} |
|
model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()} |
|
model_type = esm.ProteinBertModel |
|
elif model_data["args"].arch == "msa_transformer": |
|
|
|
|
|
pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s) |
|
prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s) |
|
prs2 = lambda s: "".join( |
|
s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s |
|
) |
|
prs3 = lambda s: s.replace("row", "column") if "row" in s else s.replace("column", "row") |
|
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()} |
|
model_state = {prs1(prs2(prs3(arg[0]))): arg[1] for arg in model_data["model"].items()} |
|
if model_args.get("embed_positions_msa", False): |
|
emb_dim = model_state["msa_position_embedding"].size(-1) |
|
model_args["embed_positions_msa_dim"] = emb_dim |
|
|
|
model_type = esm.MSATransformer |
|
|
|
elif "invariant_gvp" in model_data["args"].arch: |
|
import esm.inverse_folding |
|
|
|
model_type = esm.inverse_folding.gvp_transformer.GVPTransformerModel |
|
model_args = vars(model_data["args"]) |
|
|
|
def update_name(s): |
|
|
|
|
|
s = s.replace("W_v", "embed_graph.embed_node") |
|
s = s.replace("W_e", "embed_graph.embed_edge") |
|
s = s.replace("embed_scores.0", "embed_confidence") |
|
s = s.replace("embed_score.", "embed_graph.embed_confidence.") |
|
s = s.replace("seq_logits_projection.", "") |
|
s = s.replace("embed_ingraham_features", "embed_dihedrals") |
|
s = s.replace("embed_gvp_in_local_frame.0", "embed_gvp_output") |
|
s = s.replace("embed_features_in_local_frame.0", "embed_gvp_input_features") |
|
return s |
|
|
|
model_state = { |
|
update_name(sname): svalue |
|
for sname, svalue in model_data["model"].items() |
|
if "version" not in sname |
|
} |
|
|
|
else: |
|
raise ValueError("Unknown architecture selected") |
|
|
|
model = model_type( |
|
Namespace(**model_args), |
|
alphabet, |
|
) |
|
|
|
return model, alphabet, model_state |
|
|
|
|
|
def _load_model_and_alphabet_core_v2(model_data): |
|
def upgrade_state_dict(state_dict): |
|
"""Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'.""" |
|
prefixes = ["encoder.sentence_encoder.", "encoder."] |
|
pattern = re.compile("^" + "|".join(prefixes)) |
|
state_dict = {pattern.sub("", name): param for name, param in state_dict.items()} |
|
return state_dict |
|
|
|
cfg = model_data["cfg"]["model"] |
|
state_dict = model_data["model"] |
|
state_dict = upgrade_state_dict(state_dict) |
|
alphabet = esm.data.Alphabet.from_architecture("ESM-1b") |
|
model = ESM2( |
|
num_layers=cfg.encoder_layers, |
|
embed_dim=cfg.encoder_embed_dim, |
|
attention_heads=cfg.encoder_attention_heads, |
|
alphabet=alphabet, |
|
token_dropout=cfg.token_dropout, |
|
) |
|
return model, alphabet, state_dict |
|
|
|
|
|
def load_model_and_alphabet_core(model_name, model_data, regression_data=None): |
|
if regression_data is not None: |
|
model_data["model"].update(regression_data["model"]) |
|
|
|
if model_name.startswith("esm2"): |
|
model, alphabet, model_state = _load_model_and_alphabet_core_v2(model_data) |
|
else: |
|
model, alphabet, model_state = _load_model_and_alphabet_core_v1(model_data) |
|
|
|
expected_keys = set(model.state_dict().keys()) |
|
found_keys = set(model_state.keys()) |
|
|
|
if regression_data is None: |
|
expected_missing = {"contact_head.regression.weight", "contact_head.regression.bias"} |
|
error_msgs = [] |
|
missing = (expected_keys - found_keys) - expected_missing |
|
if missing: |
|
error_msgs.append(f"Missing key(s) in state_dict: {missing}.") |
|
unexpected = found_keys - expected_keys |
|
if unexpected: |
|
error_msgs.append(f"Unexpected key(s) in state_dict: {unexpected}.") |
|
|
|
if error_msgs: |
|
raise RuntimeError( |
|
"Error(s) in loading state_dict for {}:\n\t{}".format( |
|
model.__class__.__name__, "\n\t".join(error_msgs) |
|
) |
|
) |
|
if expected_missing - found_keys: |
|
warnings.warn( |
|
"Regression weights not found, predicting contacts will not produce correct results." |
|
) |
|
|
|
model.load_state_dict(model_state, strict=regression_data is not None) |
|
|
|
return model, alphabet |
|
|
|
|
|
def esm1_t34_670M_UR50S(): |
|
"""34 layer transformer model with 670M params, trained on Uniref50 Sparse. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm1_t34_670M_UR50S") |
|
|
|
|
|
def esm1_t34_670M_UR50D(): |
|
"""34 layer transformer model with 670M params, trained on Uniref50 Dense. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm1_t34_670M_UR50D") |
|
|
|
|
|
def esm1_t34_670M_UR100(): |
|
"""34 layer transformer model with 670M params, trained on Uniref100. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm1_t34_670M_UR100") |
|
|
|
|
|
def esm1_t12_85M_UR50S(): |
|
"""12 layer transformer model with 85M params, trained on Uniref50 Sparse. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm1_t12_85M_UR50S") |
|
|
|
|
|
def esm1_t6_43M_UR50S(): |
|
"""6 layer transformer model with 43M params, trained on Uniref50 Sparse. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm1_t6_43M_UR50S") |
|
|
|
|
|
def esm1b_t33_650M_UR50S(): |
|
"""33 layer transformer model with 650M params, trained on Uniref50 Sparse. |
|
This is our best performing model, which will be described in a future publication. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm1b_t33_650M_UR50S") |
|
|
|
|
|
def esm_msa1_t12_100M_UR50S(): |
|
warnings.warn( |
|
"This model had a minor bug in the positional embeddings, " |
|
"please use ESM-MSA-1b: esm.pretrained.esm_msa1b_t12_100M_UR50S()", |
|
) |
|
return load_model_and_alphabet_hub("esm_msa1_t12_100M_UR50S") |
|
|
|
|
|
def esm_msa1b_t12_100M_UR50S(): |
|
return load_model_and_alphabet_hub("esm_msa1b_t12_100M_UR50S") |
|
|
|
|
|
def esm1v_t33_650M_UR90S(): |
|
"""33 layer transformer model with 650M params, trained on Uniref90. |
|
This is model 1 of a 5 model ensemble. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1") |
|
|
|
|
|
def esm1v_t33_650M_UR90S_1(): |
|
"""33 layer transformer model with 650M params, trained on Uniref90. |
|
This is model 1 of a 5 model ensemble. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1") |
|
|
|
|
|
def esm1v_t33_650M_UR90S_2(): |
|
"""33 layer transformer model with 650M params, trained on Uniref90. |
|
This is model 2 of a 5 model ensemble. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_2") |
|
|
|
|
|
def esm1v_t33_650M_UR90S_3(): |
|
"""33 layer transformer model with 650M params, trained on Uniref90. |
|
This is model 3 of a 5 model ensemble. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_3") |
|
|
|
|
|
def esm1v_t33_650M_UR90S_4(): |
|
"""33 layer transformer model with 650M params, trained on Uniref90. |
|
This is model 4 of a 5 model ensemble. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_4") |
|
|
|
|
|
def esm1v_t33_650M_UR90S_5(): |
|
"""33 layer transformer model with 650M params, trained on Uniref90. |
|
This is model 5 of a 5 model ensemble. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_5") |
|
|
|
|
|
def esm_if1_gvp4_t16_142M_UR50(): |
|
"""Inverse folding model with 142M params, with 4 GVP-GNN layers, 8 |
|
Transformer encoder layers, and 8 Transformer decoder layers, trained on |
|
CATH structures and 12 million alphafold2 predicted structures from UniRef50 |
|
sequences. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm_if1_gvp4_t16_142M_UR50") |
|
|
|
|
|
def esm2_t6_8M_UR50D(): |
|
"""6 layer ESM-2 model with 8M params, trained on UniRef50. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm2_t6_8M_UR50D") |
|
|
|
|
|
def esm2_t12_35M_UR50D(): |
|
"""12 layer ESM-2 model with 35M params, trained on UniRef50. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm2_t12_35M_UR50D") |
|
|
|
|
|
def esm2_t30_150M_UR50D(): |
|
"""30 layer ESM-2 model with 150M params, trained on UniRef50. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm2_t30_150M_UR50D") |
|
|
|
|
|
def esm2_t33_650M_UR50D(): |
|
"""33 layer ESM-2 model with 650M params, trained on UniRef50. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm2_t33_650M_UR50D") |
|
|
|
|
|
def esm2_t36_3B_UR50D(): |
|
"""36 layer ESM-2 model with 3B params, trained on UniRef50. |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm2_t36_3B_UR50D") |
|
|
|
|
|
def esm2_t48_15B_UR50D(): |
|
"""48 layer ESM-2 model with 15B params, trained on UniRef50. |
|
If you have OOM while loading this model, please refer to README |
|
on how to employ FSDP and ZeRO CPU offloading |
|
|
|
Returns a tuple of (Model, Alphabet). |
|
""" |
|
return load_model_and_alphabet_hub("esm2_t48_15B_UR50D") |
|
|
|
|
|
def esmfold_v0(): |
|
""" |
|
ESMFold v0 model with 3B ESM-2, 48 folding blocks. |
|
This version was used for the paper (Lin et al, 2022). It was trained |
|
on all PDB chains until 2020-05, to ensure temporal holdout with CASP14 |
|
and the CAMEO validation and test set reported there. |
|
""" |
|
import esm.esmfold.v1.pretrained |
|
return esm.esmfold.v1.pretrained.esmfold_v0() |
|
|
|
|
|
def esmfold_v1(): |
|
""" |
|
ESMFold v1 model using 3B ESM-2, 48 folding blocks. |
|
ESMFold provides fast high accuracy atomic level structure prediction |
|
directly from the individual sequence of a protein. ESMFold uses the ESM2 |
|
protein language model to extract meaningful representations from the |
|
protein sequence. |
|
""" |
|
import esm.esmfold.v1.pretrained |
|
return esm.esmfold.v1.pretrained.esmfold_v1() |
|
|
|
def esmfold_structure_module_only_8M(): |
|
""" |
|
ESMFold baseline model using 8M ESM-2, 0 folding blocks. |
|
ESM-2 here is trained out to 500K updates. |
|
This is a model designed to test the capabilities of the language model |
|
when ablated for number of parameters in the language model. |
|
See table S1 in (Lin et al, 2022). |
|
""" |
|
import esm.esmfold.v1.pretrained |
|
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_8M() |
|
|
|
|
|
def esmfold_structure_module_only_8M_270K(): |
|
""" |
|
ESMFold baseline model using 8M ESM-2, 0 folding blocks. |
|
ESM-2 here is trained out to 270K updates. |
|
This is a model designed to test the capabilities of the language model |
|
when ablated for number of parameters in the language model. |
|
See table S1 in (Lin et al, 2022). |
|
""" |
|
import esm.esmfold.v1.pretrained |
|
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_8M_270K() |
|
|
|
|
|
def esmfold_structure_module_only_35M(): |
|
""" |
|
ESMFold baseline model using 35M ESM-2, 0 folding blocks. |
|
ESM-2 here is trained out to 500K updates. |
|
This is a model designed to test the capabilities of the language model |
|
when ablated for number of parameters in the language model. |
|
See table S1 in (Lin et al, 2022). |
|
""" |
|
import esm.esmfold.v1.pretrained |
|
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_35M() |
|
|
|
|
|
def esmfold_structure_module_only_35M_270K(): |
|
""" |
|
ESMFold baseline model using 35M ESM-2, 0 folding blocks. |
|
ESM-2 here is trained out to 270K updates. |
|
This is a model designed to test the capabilities of the language model |
|
when ablated for number of parameters in the language model. |
|
See table S1 in (Lin et al, 2022). |
|
""" |
|
import esm.esmfold.v1.pretrained |
|
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_35M_270K() |
|
|
|
|
|
def esmfold_structure_module_only_150M(): |
|
""" |
|
ESMFold baseline model using 150M ESM-2, 0 folding blocks. |
|
ESM-2 here is trained out to 500K updates. |
|
This is a model designed to test the capabilities of the language model |
|
when ablated for number of parameters in the language model. |
|
See table S1 in (Lin et al, 2022). |
|
""" |
|
import esm.esmfold.v1.pretrained |
|
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_150M() |
|
|
|
|
|
def esmfold_structure_module_only_150M_270K(): |
|
""" |
|
ESMFold baseline model using 150M ESM-2, 0 folding blocks. |
|
ESM-2 here is trained out to 270K updates. |
|
This is a model designed to test the capabilities of the language model |
|
when ablated for number of parameters in the language model. |
|
See table S1 in (Lin et al, 2022). |
|
""" |
|
import esm.esmfold.v1.pretrained |
|
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_150M_270K() |
|
|
|
|
|
def esmfold_structure_module_only_650M(): |
|
""" |
|
ESMFold baseline model using 650M ESM-2, 0 folding blocks. |
|
ESM-2 here is trained out to 500K updates. |
|
This is a model designed to test the capabilities of the language model |
|
when ablated for number of parameters in the language model. |
|
See table S1 in (Lin et al, 2022). |
|
""" |
|
import esm.esmfold.v1.pretrained |
|
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_650M() |
|
|
|
|
|
def esmfold_structure_module_only_650M_270K(): |
|
""" |
|
ESMFold baseline model using 650M ESM-2, 0 folding blocks. |
|
ESM-2 here is trained out to 270K updates. |
|
This is a model designed to test the capabilities of the language model |
|
when ablated for number of parameters in the language model. |
|
See table S1 in (Lin et al, 2022). |
|
""" |
|
import esm.esmfold.v1.pretrained |
|
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_650M_270K() |
|
|
|
|
|
def esmfold_structure_module_only_3B(): |
|
""" |
|
ESMFold baseline model using 3B ESM-2, 0 folding blocks. |
|
ESM-2 here is trained out to 500K updates. |
|
This is a model designed to test the capabilities of the language model |
|
when ablated for number of parameters in the language model. |
|
See table S1 in (Lin et al, 2022). |
|
""" |
|
import esm.esmfold.v1.pretrained |
|
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_3B() |
|
|
|
|
|
def esmfold_structure_module_only_3B_270K(): |
|
""" |
|
ESMFold baseline model using 3B ESM-2, 0 folding blocks. |
|
ESM-2 here is trained out to 270K updates. |
|
This is a model designed to test the capabilities of the language model |
|
when ablated for number of parameters in the language model. |
|
See table S1 in (Lin et al, 2022). |
|
""" |
|
import esm.esmfold.v1.pretrained |
|
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_3B_270K() |
|
|
|
|
|
def esmfold_structure_module_only_15B(): |
|
""" |
|
ESMFold baseline model using 15B ESM-2, 0 folding blocks. |
|
ESM-2 here is trained out to 270K updates. |
|
The 15B parameter ESM-2 was not trained out to 500K updates |
|
This is a model designed to test the capabilities of the language model |
|
when ablated for number of parameters in the language model. |
|
See table S1 in (Lin et al, 2022). |
|
""" |
|
import esm.esmfold.v1.pretrained |
|
return esm.esmfold.v1.pretrained.esmfold_structure_module_only_15B() |
|
|