| |
| import pathlib |
| import numpy as np |
| import torch |
| import streamlit as st |
| from torch import nn |
| from transformers import BertModel, BertTokenizer |
|
|
| MODEL_INPUT_DIM = 1024 |
| MODEL_ARCH = "FastMLP" |
| PROTBERT_MODEL_NAME = "Rostlab/prot_bert" |
|
|
| class FastMLP(nn.Module): |
| |
| def __init__(self, input_dim=MODEL_INPUT_DIM): |
| super(FastMLP, self).__init__() |
| self.layers = nn.Sequential( |
| nn.Linear(input_dim, 512), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(512, 128), |
| nn.ReLU(), |
| nn.Linear(128, 1), |
| ) |
|
|
| def forward(self, x): |
| return self.layers(x) |
|
|
|
|
| def _load_checkpoint(path: pathlib.Path): |
| |
| obj = torch.load(str(path), map_location="cpu") |
| if isinstance(obj, dict) and "state_dict" in obj: |
| return obj["state_dict"], obj.get("meta", {}) |
| if isinstance(obj, dict): |
| return obj, {} |
| raise ValueError( |
| f"Unsupported model checkpoint format at '{path}'. " |
| "Expected a PyTorch state_dict or {'state_dict': ..., 'meta': ...}." |
| ) |
|
|
|
|
| def _infer_first_layer_input_dim(state_dict: dict) -> int | None: |
| |
| w = state_dict.get("layers.0.weight") |
| if w is None: |
| return None |
| if hasattr(w, "shape") and len(w.shape) == 2: |
| return int(w.shape[1]) |
| return None |
|
|
|
|
| def _normalize_sequence(sequence: str) -> str: |
| |
| return "".join(c for c in str(sequence).upper() if not c.isspace()) |
|
|
|
|
| @st.cache_resource |
| def load_model(): |
| |
| streamlitapp_dir = pathlib.Path(__file__).resolve().parent.parent |
| repo_root = streamlitapp_dir.parent |
|
|
| candidates = [ |
| repo_root / "MLModels" / "ampMLModel.pt", |
| repo_root / "MLModels" / "fast_mlp_amp.pt", |
| repo_root / "models" / "ampMLModel.pt", |
| streamlitapp_dir / "models" / "ampMLModel.pt", |
| ] |
| |
| model_path = next((p for p in candidates if p.exists()), candidates[0]) |
|
|
| if not model_path.exists(): |
| raise FileNotFoundError( |
| "Classifier checkpoint not found in any of:\n" |
| f"- {repo_root / 'MLModels' / 'ampMLModel.pt'}\n" |
| f"- {repo_root / 'MLModels' / 'fast_mlp_amp.pt'}\n" |
| f"- {repo_root / 'models' / 'ampMLModel.pt'}\n" |
| f"- {streamlitapp_dir / 'models' / 'ampMLModel.pt'}\n" |
| ) |
|
|
| state_dict, _meta = _load_checkpoint(model_path) |
| inferred_input_dim = _infer_first_layer_input_dim(state_dict) |
| if inferred_input_dim != MODEL_INPUT_DIM: |
| raise ValueError( |
| "Model/input mismatch. Loaded classifier expects " |
| f"{inferred_input_dim} input features; ProtBERT pooled embeddings are {MODEL_INPUT_DIM}-dim." |
| ) |
|
|
| classifier = FastMLP(input_dim=MODEL_INPUT_DIM) |
| classifier.load_state_dict(state_dict) |
| classifier.eval() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| tokenizer = BertTokenizer.from_pretrained(PROTBERT_MODEL_NAME, do_lower_case=False) |
|
|
| |
| encoder = BertModel.from_pretrained(PROTBERT_MODEL_NAME).to(device) |
| encoder.eval() |
|
|
| return { |
| "classifier": classifier, |
| "tokenizer": tokenizer, |
| "encoder": encoder, |
| "device": device, |
| "classifier_path": str(model_path), |
| } |
|
|
|
|
| def encode_sequence(seq, model_bundle): |
| |
| clean = _normalize_sequence(seq) |
| spaced = " ".join(list(clean)) |
| tokenizer = model_bundle["tokenizer"] |
| encoder = model_bundle["encoder"] |
| device = model_bundle["device"] |
|
|
| tokens = tokenizer( |
| spaced, |
| return_tensors="pt", |
| truncation=True, |
| padding=True, |
| ).to(device) |
| with torch.no_grad(): |
| outputs = encoder(**tokens) |
| emb = outputs.last_hidden_state.mean(dim=1).squeeze(0).detach().cpu().numpy() |
| return emb.astype(np.float32) |
|
|
|
|
| def get_embedding_extractor(model_bundle): |
| |
| classifier = model_bundle["classifier"] |
| extractor = torch.nn.Sequential(*list(classifier.layers)[:-1]) |
| extractor.eval() |
| return extractor |
|
|
|
|
| def predict_amp(sequence, model_bundle): |
| |
| x = torch.tensor(encode_sequence(sequence, model_bundle), dtype=torch.float32).unsqueeze(0) |
| classifier = model_bundle["classifier"] |
|
|
| with torch.no_grad(): |
| logits = classifier(x) |
| prob = torch.sigmoid(logits).item() |
|
|
| label = "AMP" if prob >= 0.5 else "Non-AMP" |
| return label, round(prob, 3) |