| |
| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| from huggingface_hub import snapshot_download |
| from inference import ( |
| PeptiVersePredictor, |
| read_best_manifest_csv, |
| canon_model, |
| ) |
|
|
| |
| |
| |
| MODEL_REPO = "ChatterjeeLab/PeptiVerse" |
| DEFAULT_ASSETS_DIR = Path("./") |
| DEFAULT_MANIFEST = Path("./basic_models.txt") |
|
|
| BANNED_MODELS = {"svm", "enet", "svm_gpu", "enet_gpu"} |
|
|
|
|
| def _norm_prop_disk(prop_key: str) -> str: |
| return "half_life" if prop_key == "halflife" else prop_key |
|
|
| def _resolve_expected_model_dir(prop_key: str, model_name: str, mode: str) -> str: |
| disk_prop = _norm_prop_disk(prop_key) |
| base = f"training_classifiers/{disk_prop}" |
|
|
| |
| if prop_key == "binding_affinity": |
| pooled_or_unpooled = model_name |
| return f"{base}/wt_{mode}_{pooled_or_unpooled}" |
|
|
| |
| if prop_key == "halflife": |
| if model_name in {"xgb_wt_log", "xgb_smiles"}: |
| return f"{base}/{model_name}" |
| if mode == "wt" and model_name == "transformer": |
| return f"{base}/transformer_wt_log" |
| if model_name == "xgb": |
| return f"{base}/{'xgb_wt_log' if mode == 'wt' else 'xgb_smiles'}" |
|
|
| return f"{base}/{model_name}_{mode}" |
|
|
|
|
| def build_allow_patterns_from_manifest(manifest_path: Path) -> List[str]: |
| best = read_best_manifest_csv(manifest_path) |
|
|
| allow: List[str] = [] |
|
|
| |
| for prop_key, row in best.items(): |
| for mode, label in [("wt", row.best_wt), ("smiles", row.best_smiles)]: |
| m = canon_model(label) |
| if m is None: |
| continue |
|
|
| if m in BANNED_MODELS: |
| m = "xgb" |
|
|
| model_dir = _resolve_expected_model_dir(prop_key, m, mode) |
|
|
| |
| allow += [ |
| f"{model_dir}/best_model.json", |
| f"{model_dir}/best_model.pt", |
| f"{model_dir}/best_model*.joblib", |
| f"{model_dir}/best_model*.json", |
| ] |
|
|
| seen = set() |
| out = [] |
| for p in allow: |
| if p not in seen: |
| out.append(p) |
| seen.add(p) |
| return out |
|
|
|
|
| def download_assets( |
| repo_id: str, |
| manifest_path: Path, |
| out_dir: Path, |
| ) -> Path: |
| out_dir = out_dir.resolve() |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| allow_patterns = build_allow_patterns_from_manifest(manifest_path) |
|
|
| snapshot_download( |
| repo_id=repo_id, |
| local_dir=str(out_dir), |
| local_dir_use_symlinks=False, |
| allow_patterns=allow_patterns, |
| ) |
| return out_dir |
|
|
|
|
| |
| |
| |
| def main(): |
| import argparse |
|
|
| ap = argparse.ArgumentParser(description="Lightweight PeptiVerse inference with on-demand model download.") |
| ap.add_argument("--repo", default=MODEL_REPO, help="HF repo id containing weights/assets.") |
| ap.add_argument("--manifest", default=str(DEFAULT_MANIFEST), help="Path to best_models.txt") |
| ap.add_argument("--assets", default=str(DEFAULT_ASSETS_DIR), help="Where to store downloaded assets") |
| ap.add_argument("--device", default=None, help="cuda / cpu / cuda:0, etc") |
|
|
| ap.add_argument("--property", default="hemolysis", help="Property key (e.g. hemolysis, solubility, ...)") |
| ap.add_argument("--mode", default="wt", choices=["wt", "smiles"], help="Input type: wt=AA sequence, smiles=SMILES") |
| ap.add_argument("--input", default="GIGAVLKVLTTGLPALISWIKRKRQQ", help="Sequence or SMILES string") |
| ap.add_argument("--target_seq", default=None, help="Target WT sequence for binding_affinity") |
| ap.add_argument("--binder", default=None, help="Binder string (AA or SMILES) for binding_affinity") |
| args = ap.parse_args() |
|
|
| manifest_path = Path(args.manifest) |
| if not manifest_path.exists(): |
| raise FileNotFoundError(f"Manifest not found: {manifest_path}") |
|
|
| assets_dir = download_assets(args.repo, manifest_path=manifest_path, out_dir=Path(args.assets)) |
|
|
| """ OPTIONAL TEST CODE |
| predictor = PeptiVersePredictor( |
| manifest_path="basic_models.txt", # use the downloaded copy to be consistent |
| classifier_weight_root=str(assets_dir), |
| device=args.device, |
| ) |
| |
| if args.property == "binding_affinity": |
| if not args.target_seq or not args.binder: |
| raise ValueError("For binding_affinity, provide --target_seq and --binder.") |
| out = predictor.predict_binding_affinity(args.mode, target_seq=args.target_seq, binder_str=args.binder) |
| else: |
| out = predictor.predict_property(args.property, args.mode, args.input) |
| |
| print(out) |
| """ |
|
|
| if __name__ == "__main__": |
| main() |
|
|