File size: 2,081 Bytes
f484830
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from torch_geometric.data import Batch
from torch_geometric.utils import from_rdmol
import torch

from src.model import GIN
from src.preprocess import create_clean_mol_objects
from src.seed import set_seed

def predict_from_smiles(smiles_list):
    """
    Predict toxicity targets for a list of SMILES strings.

    Args:
        smiles_list (list[str]): SMILES strings

    Returns:
        dict: {smiles: {target_name: prediction_prob}}
    """
    set_seed(42)
    # tox21 targets
    TARGET_NAMES = [
            "NR-AR",
            "NR-AR-LBD",
            "NR-AhR",
            "NR-Aromatase",
            "NR-ER",
            "NR-ER-LBD",
            "NR-PPAR-gamma",
            "SR-ARE",
            "SR-ATAD5",
            "SR-HSE",
            "SR-MMP",
            "SR-p53",
        ]
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Received {len(smiles_list)} SMILES strings")

    # setup model
    model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean")
    model_path = "./assets/best_gin_model.pt"
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    print(f"Loaded model from {model_path}")
    model.to(DEVICE)
    model.eval()
    predictions = {}

    for smiles in smiles_list:
        try:
            # Convert SMILES to graph
            mol, _ = create_clean_mol_objects([smiles])
            data = from_rdmol(mol[0]).to(DEVICE)
            batch = Batch.from_data_list([data])

            # Forward pass
            with torch.no_grad():
                logits = model(batch.x, batch.edge_index, batch.batch)
                probs = torch.sigmoid(logits).cpu().numpy().flatten()

            # Map predictions to targets
            pred_dict = {t: float(p) for t, p in zip(TARGET_NAMES, probs)}
            predictions[smiles] = pred_dict

        except Exception as e:
            # If SMILES fails, return zeros
            pred_dict = {t: 0.0 for t in TARGET_NAMES}
            predictions[smiles] = pred_dict

    return predictions