Sonja Topf
renamed model file
0d7dfdb
import torch
import csv
import subprocess
from src.preprocess import create_clean_smiles
def predict(smiles_list):
"""
Predict toxicity targets for a list of SMILES strings.
Args:
smiles_list (list[str]): SMILES strings
Returns:
dict: {smiles: {target_name: prediction_prob}}
"""
# clean smiles
clean_smiles, valid_mask = create_clean_smiles(smiles_list)
# Mapping from cleaned to original for valid ones
originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok]
# sanity check (optional but nice to have)
if len(originals_valid) != len(clean_smiles):
raise ValueError(
f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES"
)
# map cleaned → original
cleaned_to_original = dict(zip(clean_smiles, originals_valid))
# tox21 targets
TARGET_NAMES = [
"NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"
]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Received {len(smiles_list)} SMILES strings")
# put smiles into csv
with open("./data/smiles.csv", "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["smiles"]) # header
for smi in clean_smiles:
writer.writerow([smi])
# predict
command = [
"chemprop", "predict",
"--test-path", "data/smiles.csv",
"--model-path", "checkpoints/model.pt",
"--smiles-columns", "smiles",
"--preds-path", "data/preds.csv"
]
# Run the command
subprocess.run(command, check=True)
# create results dictionary from predictions
csv_path = "./data/preds.csv"
predictions = {}
with open(csv_path, "r", newline="") as f:
reader = csv.DictReader(f)
rows = list(reader)
target_names = [col for col in reader.fieldnames if col != "smiles"]
for row in rows:
clean_smi = row["smiles"]
original_smi = cleaned_to_original.get(clean_smi, clean_smi)
pred_dict = {t: float(row[t]) for t in target_names}
predictions[original_smi] = pred_dict
# Add placeholder predictions for invalid SMILES
for smi, is_valid in zip(smiles_list, valid_mask):
if not is_valid:
predictions[smi] = {t: 0.5 for t in TARGET_NAMES}
return predictions