Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| import pandas as pd | |
| from rdkit import Chem | |
| from rdkit.Chem.MolStandardize import rdMolStandardize | |
| from rdkit import Chem | |
| from torch_geometric.data import InMemoryDataset | |
| from torch_geometric.utils import from_rdmol | |
| from datasets import load_dataset | |
| def get_tox21_split(token, cvfold=None): | |
| ds = load_dataset("ml-jku/tox21", token=token) | |
| train_df = ds["train"].to_pandas() | |
| val_df = ds["validation"].to_pandas() | |
| if cvfold is None: | |
| return { | |
| "train": train_df, | |
| "validation": val_df | |
| } | |
| combined_df = pd.concat([train_df, val_df], ignore_index=True) | |
| cvfold = float(cvfold) | |
| # create new splits | |
| cvfold = float(cvfold) | |
| train_df = combined_df[combined_df.CVfold != cvfold] | |
| val_df = combined_df[combined_df.CVfold == cvfold] | |
| # exclude train mols that occur in the validation split | |
| val_inchikeys = set(val_df["inchikey"]) | |
| train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)] | |
| return {"train": train_df.reset_index(drop=True), "validation": val_df.reset_index(drop=True)} | |
| def create_clean_mol_objects(smiles: list[str]) -> tuple[list[Chem.Mol], np.ndarray]: | |
| """Create cleaned RDKit Mol objects from SMILES. | |
| Returns (list of mols, mask of valid mols). | |
| """ | |
| clean_mol_mask = [] | |
| mols = [] | |
| # Standardizer components | |
| cleaner = rdMolStandardize.CleanupParameters() | |
| tautomer_enumerator = rdMolStandardize.TautomerEnumerator() | |
| for smi in smiles: | |
| try: | |
| mol = Chem.MolFromSmiles(smi) | |
| if mol is None: | |
| clean_mol_mask.append(False) | |
| continue | |
| # Cleanup and canonicalize | |
| mol = rdMolStandardize.Cleanup(mol, cleaner) | |
| mol = tautomer_enumerator.Canonicalize(mol) | |
| # Recompute canonical SMILES & reload | |
| can_smi = Chem.MolToSmiles(mol) | |
| mol = Chem.MolFromSmiles(can_smi) | |
| if mol is not None: | |
| mols.append(mol) | |
| clean_mol_mask.append(True) | |
| else: | |
| clean_mol_mask.append(False) | |
| except Exception as e: | |
| print(f"Failed to standardize {smi}: {e}") | |
| clean_mol_mask.append(False) | |
| return mols, np.array(clean_mol_mask, dtype=bool) | |
| class Tox21Dataset(InMemoryDataset): | |
| def __init__(self, dataframe): | |
| super().__init__() | |
| data_list = [] | |
| # Clean molecules & filter dataframe | |
| mols, clean_mask = create_clean_mol_objects(dataframe["smiles"].tolist()) | |
| self.clean_mask = torch.tensor(clean_mask, dtype=torch.bool) | |
| drop_cols = ["ID","smiles","inchikey","sdftitle","order","set","CVfold"] | |
| labels_df = dataframe.drop(columns=drop_cols) | |
| numeric_labels = labels_df.apply(pd.to_numeric, errors="coerce").fillna(0.0) | |
| self.all_labels = torch.tensor(numeric_labels.values, dtype=torch.float) | |
| self.all_label_masks = torch.tensor(~labels_df.isna().values, dtype=torch.bool) | |
| dataframe = dataframe[clean_mask].reset_index(drop=True) | |
| # Now mols and dataframe are aligned, so we can zip | |
| for mol, (_, row) in zip(mols, dataframe.iterrows()): | |
| try: | |
| data = from_rdmol(mol) | |
| # Extract labels as a pandas Series | |
| labels = row.drop(drop_cols) | |
| # Mask for valid labels | |
| mask = ~labels.isna() | |
| # Explicit numeric conversion, replaces NaN with 0.0 safely | |
| labels = pd.to_numeric(labels, errors="coerce").fillna(0.0).astype(float).values | |
| # Convert to tensors | |
| y = torch.tensor(labels, dtype=torch.float).unsqueeze(0) | |
| m = torch.tensor(mask.values, dtype=torch.bool).unsqueeze(0) | |
| data.y = y | |
| data.mask = m | |
| data_list.append(data) | |
| except Exception as e: | |
| print(f"Skipping molecule {row['smiles']} due to error: {e}") | |
| # Collate into dataset | |
| self.data, self.slices = self.collate(data_list) | |
| def get_graph_datasets(token): | |
| """returns an InMemoryDataset that can be used in dataloaders | |
| Args: | |
| filepath (str): the filepath of the data csv | |
| Returns: | |
| Tox21Dataset: dataset for dataloaders | |
| """ | |
| datasets = get_tox21_split(token, cvfold=4) | |
| train_df, val_df = datasets["train"], datasets["validation"] | |
| train_dataset = Tox21Dataset(train_df) | |
| val_dataset = Tox21Dataset(val_df) | |
| return train_dataset, val_dataset |