| """PETIMOT inference utilities for custom proteins.""" |
| import os, sys |
| import numpy as np |
| from pathlib import Path |
|
|
| |
| PETIMOT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| if PETIMOT_ROOT not in sys.path: |
| sys.path.insert(0, PETIMOT_ROOT) |
|
|
| EMBEDDING_DIM_MAP = {"prostt5": 1024, "esmc_300m": 960, "esmc_600m": 1152} |
|
|
|
|
| def run_inference(pdb_path: str, weights_path: str, config_path: str = None, |
| output_dir: str = "/tmp/petimot_pred") -> dict: |
| """Run PETIMOT inference on a single PDB file. |
| |
| Args: |
| pdb_path: Path to input PDB file |
| weights_path: Path to model weights .pt |
| config_path: Path to config YAML (default: configs/default.yaml) |
| output_dir: Where to save predictions |
| |
| Returns: |
| dict with modes, ca_coords, seq, etc. |
| """ |
| try: |
| import torch |
| except ImportError: |
| raise ImportError("PyTorch is required to run inference. Install it with: pip install torch") |
|
|
| from petimot.infer.infer import infer |
| from petimot.data.pdb_utils import load_backbone_coordinates |
|
|
|
|
| if config_path is None: |
| config_path = os.path.join(PETIMOT_ROOT, "configs", "default.yaml") |
|
|
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| infer(model_path=weights_path, config_file=config_path, |
| input_list=[pdb_path], output_path=output_dir) |
|
|
| |
| stem = os.path.splitext(os.path.basename(weights_path))[0] |
| pred_subdir = os.path.join(output_dir, stem) |
| basename = os.path.splitext(os.path.basename(pdb_path))[0] |
|
|
| |
| bb_data = load_backbone_coordinates(pdb_path, allow_hetatm=True) |
| ca = bb_data["bb"][:, 1].numpy() |
| seq = bb_data.get("seq", "X" * len(ca)) |
| if not isinstance(seq, str): |
| seq = "X" * len(ca) |
|
|
| |
| modes = {} |
| for k in range(10): |
| for pfx in [f"extracted_{basename}", basename]: |
| mf = os.path.join(pred_subdir, f"{pfx}_mode_{k}.txt") |
| if os.path.exists(mf): |
| modes[k] = np.loadtxt(mf) |
| break |
|
|
| with open(pdb_path) as f: |
| pdb_text = f.read() |
|
|
| return { |
| "name": basename, |
| "ca_coords": ca, |
| "seq": seq, |
| "modes": modes, |
| "pdb_text": pdb_text, |
| "pred_dir": pred_subdir, |
| "n_res": len(ca), |
| } |
|
|
|
|
| def download_pdb(pdb_id: str, output_dir: str = "/tmp/petimot_pdbs") -> str | None: |
| """Download PDB from RCSB.""" |
| import requests |
|
|
| os.makedirs(output_dir, exist_ok=True) |
| code4 = pdb_id[:4].lower() |
| chain = pdb_id[4:].upper() if len(pdb_id) > 4 else "" |
| out_path = os.path.join(output_dir, f"{pdb_id}.pdb") |
|
|
| if os.path.exists(out_path): |
| return out_path |
|
|
| r = requests.get(f"https://files.rcsb.org/download/{code4}.pdb", timeout=30) |
| if not r.ok: |
| return None |
|
|
| lines = r.text.split("\n") |
| if chain: |
| lines = [l for l in lines |
| if (l.startswith("ATOM") and len(l) > 21 and l[21] == chain) |
| or not l.startswith(("ATOM", "HETATM"))] |
|
|
| with open(out_path, "w") as f: |
| f.write("\n".join(lines)) |
| return out_path |
|
|