Spaces:
Sleeping
Sleeping
File size: 2,081 Bytes
f484830 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from torch_geometric.data import Batch
from torch_geometric.utils import from_rdmol
import torch
from src.model import GIN
from src.preprocess import create_clean_mol_objects
from src.seed import set_seed
def predict_from_smiles(smiles_list):
"""
Predict toxicity targets for a list of SMILES strings.
Args:
smiles_list (list[str]): SMILES strings
Returns:
dict: {smiles: {target_name: prediction_prob}}
"""
set_seed(42)
# tox21 targets
TARGET_NAMES = [
"NR-AR",
"NR-AR-LBD",
"NR-AhR",
"NR-Aromatase",
"NR-ER",
"NR-ER-LBD",
"NR-PPAR-gamma",
"SR-ARE",
"SR-ATAD5",
"SR-HSE",
"SR-MMP",
"SR-p53",
]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Received {len(smiles_list)} SMILES strings")
# setup model
model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean")
model_path = "./assets/best_gin_model.pt"
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
print(f"Loaded model from {model_path}")
model.to(DEVICE)
model.eval()
predictions = {}
for smiles in smiles_list:
try:
# Convert SMILES to graph
mol, _ = create_clean_mol_objects([smiles])
data = from_rdmol(mol[0]).to(DEVICE)
batch = Batch.from_data_list([data])
# Forward pass
with torch.no_grad():
logits = model(batch.x, batch.edge_index, batch.batch)
probs = torch.sigmoid(logits).cpu().numpy().flatten()
# Map predictions to targets
pred_dict = {t: float(p) for t, p in zip(TARGET_NAMES, probs)}
predictions[smiles] = pred_dict
except Exception as e:
# If SMILES fails, return zeros
pred_dict = {t: 0.0 for t in TARGET_NAMES}
predictions[smiles] = pred_dict
return predictions
|