Spaces:
Sleeping
Sleeping
| from torch_geometric.data import Batch | |
| from torch_geometric.utils import from_rdmol | |
| import torch | |
| from src.model import GIN | |
| from src.preprocess import create_clean_mol_objects | |
| from src.seed import set_seed | |
| def predict_from_smiles(smiles_list): | |
| """ | |
| Predict toxicity targets for a list of SMILES strings. | |
| Args: | |
| smiles_list (list[str]): SMILES strings | |
| Returns: | |
| dict: {smiles: {target_name: prediction_prob}} | |
| """ | |
| set_seed(42) | |
| # tox21 targets | |
| TARGET_NAMES = [ | |
| "NR-AR", | |
| "NR-AR-LBD", | |
| "NR-AhR", | |
| "NR-Aromatase", | |
| "NR-ER", | |
| "NR-ER-LBD", | |
| "NR-PPAR-gamma", | |
| "SR-ARE", | |
| "SR-ATAD5", | |
| "SR-HSE", | |
| "SR-MMP", | |
| "SR-p53", | |
| ] | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Received {len(smiles_list)} SMILES strings") | |
| # setup model | |
| model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean") | |
| model_path = "./assets/best_gin_model.pt" | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| print(f"Loaded model from {model_path}") | |
| model.to(DEVICE) | |
| model.eval() | |
| predictions = {} | |
| for smiles in smiles_list: | |
| try: | |
| # Convert SMILES to graph | |
| mol, _ = create_clean_mol_objects([smiles]) | |
| data = from_rdmol(mol[0]).to(DEVICE) | |
| batch = Batch.from_data_list([data]) | |
| # Forward pass | |
| with torch.no_grad(): | |
| logits = model(batch.x, batch.edge_index, batch.batch) | |
| probs = torch.sigmoid(logits).cpu().numpy().flatten() | |
| # Map predictions to targets | |
| pred_dict = {t: float(p) for t, p in zip(TARGET_NAMES, probs)} | |
| predictions[smiles] = pred_dict | |
| except Exception as e: | |
| # If SMILES fails, return zeros | |
| pred_dict = {t: 0.0 for t in TARGET_NAMES} | |
| predictions[smiles] = pred_dict | |
| return predictions | |