MultiTaskTox / predict.py
mschuh's picture
Added first version
94b1553 verified
from __future__ import annotations
import json
from functools import lru_cache
from pathlib import Path
from typing import Dict, List
import joblib
import numpy as np
from src.constants import TARGET_NAMES
from src.features import FingerprintFeaturizer
from src.seed import set_seed
BASE_PREDICTION = 0.5
@lru_cache(maxsize=1)
def _load_manifest() -> Dict:
manifest_path = Path("./checkpoints/training_manifest.json")
if not manifest_path.exists():
raise FileNotFoundError("Missing checkpoints/training_manifest.json. Run train.py first.")
with manifest_path.open("r", encoding="utf-8") as f:
manifest = json.load(f)
return manifest
@lru_cache(maxsize=2)
def _load_stage_models(stage: str):
manifest = _load_manifest()
stage_info = manifest.get(stage, {})
model_dir = stage_info.get("model_dir")
if not model_dir:
return {}
model_path = Path(model_dir)
models = {}
for target in manifest.get("target_names", TARGET_NAMES):
model_file = model_path / f"{target}.pkl"
if model_file.exists():
models[target] = joblib.load(model_file)
return models
def _compute_stage1_predictions(features: np.ndarray, target_names: List[str]) -> np.ndarray:
"""Return predictions for the valid molecules from stage-1 models."""
stage1_models = _load_stage_models("stage1")
if features.shape[0] == 0:
return np.zeros((0, len(target_names)), dtype=np.float32)
predictions = np.full((features.shape[0], len(target_names)), BASE_PREDICTION, dtype=np.float32)
for idx, target in enumerate(target_names):
booster = stage1_models.get(target)
if booster is None:
continue
best_iter = getattr(booster, "best_iteration_", None)
kwargs = {"num_iteration": best_iter} if best_iter is not None else {}
preds = booster.predict_proba(features, **kwargs)[:, 1]
predictions[:, idx] = preds
return predictions
def _compute_stage2_predictions(
base_features: np.ndarray,
stage1_preds: np.ndarray,
target_names: List[str],
) -> np.ndarray:
stage2_models = _load_stage_models("stage2")
if not stage2_models:
return stage1_preds
n_samples = base_features.shape[0]
results = np.full((n_samples, len(target_names)), BASE_PREDICTION, dtype=np.float32)
for idx, target in enumerate(target_names):
model = stage2_models.get(target)
if model is None:
results[:, idx] = stage1_preds[:, idx]
continue
augmented = np.concatenate(
[
base_features,
np.delete(stage1_preds, idx, axis=1),
],
axis=1,
)
best_iter = getattr(model, "best_iteration_", None)
kwargs = {"num_iteration": best_iter} if best_iter is not None else {}
preds = model.predict_proba(augmented, **kwargs)[:, 1]
results[:, idx] = preds
return results
def predict(smiles_list: List[str]) -> Dict[str, Dict[str, float]]:
"""
Predict toxicity targets for a list of SMILES strings.
Args:
smiles_list (list[str]): SMILES strings
Returns:
dict: {smiles: {target_name: prediction_prob}}
"""
set_seed(0)
manifest = _load_manifest()
target_names = manifest.get("target_names", TARGET_NAMES)
feature_config = manifest.get("feature_config", {"type": "ecfp"})
featurizer = FingerprintFeaturizer(feature_config)
batch, features = featurizer.featurize_smiles(smiles_list)
stage1_preds = _compute_stage1_predictions(features, target_names)
stage2_preds = _compute_stage2_predictions(features, stage1_preds, target_names)
predictions: Dict[str, Dict[str, float]] = {}
valid_idx = 0
for original_smiles, is_valid in zip(smiles_list, batch.mask):
if not is_valid:
predictions[original_smiles] = {target: BASE_PREDICTION for target in target_names}
continue
row_preds = stage2_preds[valid_idx] if stage2_preds.size else np.full(len(target_names), BASE_PREDICTION)
predictions[original_smiles] = {target: float(score) for target, score in zip(target_names, row_preds)}
valid_idx += 1
return predictions