""" This files includes a predict function for the Tox21. As an input it takes a list of SMILES and it outputs a nested dictionary with SMILES and target names as keys. """ # --------------------------------------------------------------------------------------- # Dependencies from collections import defaultdict import numpy as np import json import joblib import torch from src.model import Tox21SNNClassifier, SNNConfig from src.preprocess import create_descriptors, FeaturePreprocessor from src.utils import TASKS, normalize_config # --------------------------------------------------------------------------------------- CONFIG_FILE = "./config/config.json" def predict( smiles_list: list[str], default_prediction=0.5 ) -> dict[str, dict[str, float]]: """Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for any molecule that could not be cleaned. Args: smiles_list (list[str]): list of SMILES strings Returns: dict: nested prediction dictionary, following {'': {'': }} """ print(f"Received {len(smiles_list)} SMILES strings") # preprocessing pipeline with open(CONFIG_FILE, "r") as f: config = json.load(f) config = normalize_config(config) features, is_clean = create_descriptors( smiles_list, config["descriptors"], **config["ecfp"] ) print(f"Created descriptors for {sum(is_clean)} molecules.") print(f"{len(is_clean) - sum(is_clean)} molecules removed during cleaning") # setup model preprocessor = FeaturePreprocessor( feature_selection_config=config["feature_selection"], feature_quantilization_config=config["feature_quantilization"], descriptors=config["descriptors"], max_samples=config["max_samples"], scaler=config["scaler"], ) preprocessor_ckpt = joblib.load(config["preprocessor_path"]) preprocessor.set_state(preprocessor_ckpt["preprocessor"]) print(f"Loaded preprocessor from {config['preprocessor_path']}") features = {descr: array[is_clean] for descr, array in features.items()} features = preprocessor.transform(features) dataset = torch.utils.data.TensorDataset(torch.FloatTensor(features)) loader = torch.utils.data.DataLoader( dataset, batch_size=256, shuffle=False, num_workers=0 ) # setup model cfg = SNNConfig( hidden_dim=512, n_layers=8, dropout=0.05, layer_form="rect", in_features=features.shape[1], out_features=12, ) model = Tox21SNNClassifier(cfg) model.load_model(config["ckpt_path"]) model.eval() print(f"Loaded model from {config['ckpt_path']}") predictions = defaultdict(dict) print(f"Create predictions:") preds = [] with torch.no_grad(): preds = np.concatenate([model.predict(batch[0]) for batch in loader], axis=0) for i, target in enumerate(model.tasks): target_preds = np.empty_like(is_clean, dtype=float) target_preds[~is_clean] = default_prediction target_preds[is_clean] = preds[:, i] for smiles, pred in zip(smiles_list, target_preds): predictions[smiles][target] = float(pred) return predictions