Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Dict, List | |
| import joblib | |
| import numpy as np | |
| from src.constants import TARGET_NAMES | |
| from src.features import FingerprintFeaturizer | |
| from src.seed import set_seed | |
| BASE_PREDICTION = 0.5 | |
| def _load_manifest() -> Dict: | |
| manifest_path = Path("./checkpoints/training_manifest.json") | |
| if not manifest_path.exists(): | |
| raise FileNotFoundError("Missing checkpoints/training_manifest.json. Run train.py first.") | |
| with manifest_path.open("r", encoding="utf-8") as f: | |
| manifest = json.load(f) | |
| return manifest | |
| def _load_stage_models(stage: str): | |
| manifest = _load_manifest() | |
| stage_info = manifest.get(stage, {}) | |
| model_dir = stage_info.get("model_dir") | |
| if not model_dir: | |
| return {} | |
| model_path = Path(model_dir) | |
| models = {} | |
| for target in manifest.get("target_names", TARGET_NAMES): | |
| model_file = model_path / f"{target}.pkl" | |
| if model_file.exists(): | |
| models[target] = joblib.load(model_file) | |
| return models | |
| def _compute_stage1_predictions(features: np.ndarray, target_names: List[str]) -> np.ndarray: | |
| """Return predictions for the valid molecules from stage-1 models.""" | |
| stage1_models = _load_stage_models("stage1") | |
| if features.shape[0] == 0: | |
| return np.zeros((0, len(target_names)), dtype=np.float32) | |
| predictions = np.full((features.shape[0], len(target_names)), BASE_PREDICTION, dtype=np.float32) | |
| for idx, target in enumerate(target_names): | |
| booster = stage1_models.get(target) | |
| if booster is None: | |
| continue | |
| best_iter = getattr(booster, "best_iteration_", None) | |
| kwargs = {"num_iteration": best_iter} if best_iter is not None else {} | |
| preds = booster.predict_proba(features, **kwargs)[:, 1] | |
| predictions[:, idx] = preds | |
| return predictions | |
| def _compute_stage2_predictions( | |
| base_features: np.ndarray, | |
| stage1_preds: np.ndarray, | |
| target_names: List[str], | |
| ) -> np.ndarray: | |
| stage2_models = _load_stage_models("stage2") | |
| if not stage2_models: | |
| return stage1_preds | |
| n_samples = base_features.shape[0] | |
| results = np.full((n_samples, len(target_names)), BASE_PREDICTION, dtype=np.float32) | |
| for idx, target in enumerate(target_names): | |
| model = stage2_models.get(target) | |
| if model is None: | |
| results[:, idx] = stage1_preds[:, idx] | |
| continue | |
| augmented = np.concatenate( | |
| [ | |
| base_features, | |
| np.delete(stage1_preds, idx, axis=1), | |
| ], | |
| axis=1, | |
| ) | |
| best_iter = getattr(model, "best_iteration_", None) | |
| kwargs = {"num_iteration": best_iter} if best_iter is not None else {} | |
| preds = model.predict_proba(augmented, **kwargs)[:, 1] | |
| results[:, idx] = preds | |
| return results | |
| def predict(smiles_list: List[str]) -> Dict[str, Dict[str, float]]: | |
| """ | |
| Predict toxicity targets for a list of SMILES strings. | |
| Args: | |
| smiles_list (list[str]): SMILES strings | |
| Returns: | |
| dict: {smiles: {target_name: prediction_prob}} | |
| """ | |
| set_seed(0) | |
| manifest = _load_manifest() | |
| target_names = manifest.get("target_names", TARGET_NAMES) | |
| feature_config = manifest.get("feature_config", {"type": "ecfp"}) | |
| featurizer = FingerprintFeaturizer(feature_config) | |
| batch, features = featurizer.featurize_smiles(smiles_list) | |
| stage1_preds = _compute_stage1_predictions(features, target_names) | |
| stage2_preds = _compute_stage2_predictions(features, stage1_preds, target_names) | |
| predictions: Dict[str, Dict[str, float]] = {} | |
| valid_idx = 0 | |
| for original_smiles, is_valid in zip(smiles_list, batch.mask): | |
| if not is_valid: | |
| predictions[original_smiles] = {target: BASE_PREDICTION for target in target_names} | |
| continue | |
| row_preds = stage2_preds[valid_idx] if stage2_preds.size else np.full(len(target_names), BASE_PREDICTION) | |
| predictions[original_smiles] = {target: float(score) for target, score in zip(target_names, row_preds)} | |
| valid_idx += 1 | |
| return predictions | |