Spaces:
Sleeping
Sleeping
| import torch | |
| import csv | |
| import subprocess | |
| from src.preprocess import create_clean_smiles | |
| def predict(smiles_list): | |
| """ | |
| Predict toxicity targets for a list of SMILES strings. | |
| Args: | |
| smiles_list (list[str]): SMILES strings | |
| Returns: | |
| dict: {smiles: {target_name: prediction_prob}} | |
| """ | |
| # clean smiles | |
| clean_smiles, valid_mask = create_clean_smiles(smiles_list) | |
| # Mapping from cleaned to original for valid ones | |
| originals_valid = [orig for orig, ok in zip(smiles_list, valid_mask) if ok] | |
| # sanity check (optional but nice to have) | |
| if len(originals_valid) != len(clean_smiles): | |
| raise ValueError( | |
| f"Mismatch: {len(originals_valid)} valid originals vs {len(clean_smiles)} cleaned SMILES" | |
| ) | |
| # map cleaned → original | |
| cleaned_to_original = dict(zip(clean_smiles, originals_valid)) | |
| # tox21 targets | |
| TARGET_NAMES = [ | |
| "NR-AhR","NR-AR","NR-AR-LBD","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53" | |
| ] | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Received {len(smiles_list)} SMILES strings") | |
| # put smiles into csv | |
| with open("./data/smiles.csv", "w", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["smiles"]) # header | |
| for smi in clean_smiles: | |
| writer.writerow([smi]) | |
| # predict | |
| command = [ | |
| "chemprop", "predict", | |
| "--test-path", "data/smiles.csv", | |
| "--model-path", "checkpoints/model.pt", | |
| "--smiles-columns", "smiles", | |
| "--preds-path", "data/preds.csv" | |
| ] | |
| # Run the command | |
| subprocess.run(command, check=True) | |
| # create results dictionary from predictions | |
| csv_path = "./data/preds.csv" | |
| predictions = {} | |
| with open(csv_path, "r", newline="") as f: | |
| reader = csv.DictReader(f) | |
| rows = list(reader) | |
| target_names = [col for col in reader.fieldnames if col != "smiles"] | |
| for row in rows: | |
| clean_smi = row["smiles"] | |
| original_smi = cleaned_to_original.get(clean_smi, clean_smi) | |
| pred_dict = {t: float(row[t]) for t in target_names} | |
| predictions[original_smi] = pred_dict | |
| # Add placeholder predictions for invalid SMILES | |
| for smi, is_valid in zip(smiles_list, valid_mask): | |
| if not is_valid: | |
| predictions[smi] = {t: 0.5 for t in TARGET_NAMES} | |
| return predictions |