Feature Extraction
Transformers
Safetensors
protenrich
proteins
bioinformatics
drug-discovery
custom_code
Instructions to use SaeedLab/ProtEnrich-ESM2-T36 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use SaeedLab/ProtEnrich-ESM2-T36 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="SaeedLab/ProtEnrich-ESM2-T36", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("SaeedLab/ProtEnrich-ESM2-T36", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| import torch | |
| from transformers.utils import ModelOutput | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput | |
| from transformers import AutoModel | |
| from .configuration_protenrich import ProtEnrichConfig | |
| class ProtEnrichModelOutput(ModelOutput): | |
| h_enrich: torch.FloatTensor = None | |
| h_anchor: Optional[torch.FloatTensor] = None | |
| h_algn: Optional[torch.FloatTensor] = None | |
| struct: Optional[torch.FloatTensor] = None | |
| dyn: Optional[torch.FloatTensor] = None | |
| class MLPEncoder(nn.Module): | |
| def __init__(self, in_dim, out_dim, hidden_dim=1024, n_layers=2, dropout=0.1): | |
| super().__init__() | |
| layers = [] | |
| d = in_dim | |
| for _ in range(n_layers - 1): | |
| layers += [ | |
| nn.Linear(d, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| ] | |
| d = hidden_dim | |
| layers.append(nn.Linear(d, out_dim)) | |
| self.net = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.net(x) | |
| class ProtEnrichModel(PreTrainedModel): | |
| config_class = ProtEnrichConfig | |
| base_model_prefix = "protenrich" | |
| def __init__(self, config: ProtEnrichConfig): | |
| super().__init__(config) | |
| self.seq_anchor = MLPEncoder(config.seq_dim, config.embed_dim) | |
| self.seq_algn = MLPEncoder(config.seq_dim, config.embed_dim) | |
| self.struct_encoder = MLPEncoder(config.struct_dim, config.embed_dim) | |
| self.dyn_encoder = MLPEncoder(config.dyn_dim, config.embed_dim) | |
| for p in self.struct_encoder.parameters(): | |
| p.requires_grad = False | |
| for p in self.dyn_encoder.parameters(): | |
| p.requires_grad = False | |
| self.seq_projector = nn.Linear(config.embed_dim, config.project_dim) | |
| self.struct_projector = nn.Linear(config.embed_dim, config.project_dim) | |
| self.dyn_projector = nn.Linear(config.embed_dim, config.project_dim) | |
| self.seq_decoder = MLPEncoder(config.embed_dim, config.seq_dim) | |
| self.struct_decoder = MLPEncoder(config.embed_dim, config.struct_dim) | |
| self.dyn_decoder = MLPEncoder(config.embed_dim, config.dyn_dim) | |
| self.alpha_logit = nn.Parameter(torch.tensor(-2.0)) | |
| self.alpha_max = config.alpha_max | |
| self.norm_anchor = nn.LayerNorm(config.embed_dim) | |
| self.norm_algn = nn.LayerNorm(config.embed_dim) | |
| self.post_init() | |
| def forward(self, seq: torch.Tensor, return_dict: Optional[bool] = None): | |
| h_anchor = self.norm_anchor(self.seq_anchor(seq)) | |
| h_algn = self.norm_algn(self.seq_algn(seq)) | |
| struct = self.struct_decoder(h_algn) | |
| dyn = self.dyn_decoder(h_algn) | |
| alpha = torch.sigmoid(self.alpha_logit) * self.alpha_max | |
| h_enrich = h_anchor + alpha * h_algn | |
| return ProtEnrichModelOutput( | |
| h_enrich=h_enrich, | |
| h_anchor=h_anchor, | |
| h_algn=h_algn, | |
| struct=struct, | |
| dyn=dyn, | |
| ) | |
| class ProtEnrichForSequenceClassification(PreTrainedModel): | |
| config_class = ProtEnrichConfig | |
| def __init__(self, config: ProtEnrichConfig): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.protenrich = ProtEnrichModel(config) | |
| self.classifier = nn.Linear(config.embed_dim, config.num_labels) | |
| self.post_init() | |
| def forward(self, seq: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None): | |
| outputs = self.protenrich(seq=seq, return_dict=return_dict) | |
| pooled = outputs.h_enrich | |
| logits = self.classifier(pooled) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=pooled, | |
| ) |