| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
| from typing import Dict, Optional, Tuple |
|
|
| from src.predictor import SingleTaskEnsemblePredictor |
| from src.predictor_multitask import MultiTaskEnsemblePredictor |
|
|
|
|
| class RouterPredictor: |
| """ |
| Routes each property to either: |
| - single-task ensemble (models/single_models) |
| - multitask ensemble (models/multitask_models/{task}_*) |
| based on models/best_model_map.json |
| """ |
|
|
| def __init__( |
| self, |
| map_path: str = "models/best_model_map.json", |
| single_dir: str = "models/single_models", |
| multitask_dir: str = "models/multitask_models", |
| device: str = "cpu", |
| ): |
| self.map_path = Path(map_path) |
| self.map: Dict[str, dict] = json.load(open(self.map_path)) |
| self.single = SingleTaskEnsemblePredictor(models_dir=single_dir, device=device) |
| self.multi = MultiTaskEnsemblePredictor(models_dir=multitask_dir, device=device) |
|
|
| def predict_mean_std(self, smiles: str, prop: str) -> Tuple[Optional[float], Optional[float], dict, str]: |
| prop = prop.lower() |
| cfg = self.map.get(prop, {"family": "single"}) |
|
|
| fam = cfg.get("family", "single").lower() |
| if fam == "multitask": |
| task = str(cfg.get("task", "all")).lower() |
| mean, std, per_seed = self.multi.predict_mean_std(smiles, prop_key=prop, task=task) |
| label = f"multitask:{task}" |
| return mean, std, per_seed, label |
|
|
| |
| mean, std, per_seed = self.single.predict_mean_std(smiles, prop) |
| label = "single" |
| return mean, std, per_seed, label |
|
|