MitoInteract / model.py
ethanolivertroy's picture
MitoInteract v1 - Pearson R=-0.9107
98ed1b7 verified
"""
MitoInteract Model Class Definition
Copy this file to load the model for inference.
"""
import torch
import torch.nn as nn
from transformers import EsmModel, EsmTokenizer, AutoModel, AutoTokenizer
class MitoInteract(nn.Module):
def __init__(
self,
esm_model_name="facebook/esm2_t33_650M_UR50D",
mol_model_name="seyonec/ChemBERTa-zinc-base-v1",
protein_dim=1280,
mol_dim=768,
proj_dim=256,
n_heads=8,
dropout=0.1,
freeze_encoders=True,
):
super().__init__()
self.freeze_encoders = freeze_encoders
self.esm = EsmModel.from_pretrained(esm_model_name)
self.protein_dim = protein_dim
self.mol_encoder = AutoModel.from_pretrained(mol_model_name)
self.mol_dim = mol_dim
if freeze_encoders:
for p in self.esm.parameters(): p.requires_grad = False
for p in self.mol_encoder.parameters(): p.requires_grad = False
self.prot_proj = nn.Sequential(
nn.Linear(protein_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout))
self.mol_proj = nn.Sequential(
nn.Linear(mol_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout))
self.cross_attn_mol2prot = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True)
self.cross_attn_prot2mol = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True)
self.ln_mol2prot = nn.LayerNorm(proj_dim)
self.ln_prot2mol = nn.LayerNorm(proj_dim)
fused_dim = proj_dim * 2
self.mlp = nn.Sequential(
nn.Linear(fused_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(128, 1))
def encode_protein(self, input_ids, attention_mask):
ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad()
with ctx:
out = self.esm(input_ids=input_ids, attention_mask=attention_mask)
mask = attention_mask.unsqueeze(-1).float()
pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
return pooled, out.last_hidden_state
def encode_molecule(self, input_ids, attention_mask):
ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad()
with ctx:
out = self.mol_encoder(input_ids=input_ids, attention_mask=attention_mask)
return out.pooler_output, out.last_hidden_state
def forward(self, prot_input_ids, prot_attention_mask, mol_input_ids, mol_attention_mask):
prot_pooled, prot_seq = self.encode_protein(prot_input_ids, prot_attention_mask)
mol_pooled, mol_seq = self.encode_molecule(mol_input_ids, mol_attention_mask)
prot_seq_proj = self.prot_proj(prot_seq)
mol_seq_proj = self.mol_proj(mol_seq)
prot_q = self.prot_proj(prot_pooled).unsqueeze(1)
mol_q = self.mol_proj(mol_pooled).unsqueeze(1)
prot_pad_mask = (prot_attention_mask == 0)
mol_pad_mask = (mol_attention_mask == 0)
h_prot2mol, _ = self.cross_attn_prot2mol(prot_q, mol_seq_proj, mol_seq_proj, key_padding_mask=mol_pad_mask)
h_mol2prot, _ = self.cross_attn_mol2prot(mol_q, prot_seq_proj, prot_seq_proj, key_padding_mask=prot_pad_mask)
h_prot2mol = self.ln_prot2mol(h_prot2mol.squeeze(1))
h_mol2prot = self.ln_mol2prot(h_mol2prot.squeeze(1))
fused = torch.cat([h_prot2mol, h_mol2prot], dim=-1)
return self.mlp(fused).squeeze(-1)
def load_model(checkpoint_path, device="cpu"):
"""Load trained MitoInteract model."""
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
config = checkpoint["config"]
model = MitoInteract(
esm_model_name=config["esm_model"],
mol_model_name=config["mol_model"],
protein_dim=config["protein_dim"],
mol_dim=config["mol_dim"],
proj_dim=config["proj_dim"],
n_heads=config["n_heads"],
dropout=config["dropout"],
freeze_encoders=True,
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model, config
def predict_binding(model, protein_seq, smiles, device="cpu"):
"""Predict binding affinity (pKd) for a protein-molecule pair."""
prot_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
mol_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
prot_enc = prot_tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True, max_length=512)
mol_enc = mol_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True, max_length=200)
model = model.to(device)
with torch.no_grad():
pKd = model(
prot_enc["input_ids"].to(device), prot_enc["attention_mask"].to(device),
mol_enc["input_ids"].to(device), mol_enc["attention_mask"].to(device),
)
pKd_val = pKd.item()
Kd_uM = 10 ** (-pKd_val) * 1e6
return {"pKd": pKd_val, "Kd_uM": Kd_uM}