Sonja Topf commited on
Commit
f0bc9a8
Β·
1 Parent(s): e4f4cf0

big refactoring

Browse files
.example.env ADDED
@@ -0,0 +1 @@
 
 
1
+ TOKEN=example_token
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ tox21_test.csv
2
+ results.csv
3
+ predict copy.py
4
+ hp_search/logs/*
5
+ hp_search/models/*
6
+ __pycache__
7
+ .env
8
+ notes.txt
{assets β†’ checkpoints}/best_gin_model.pt RENAMED
File without changes
checkpoints/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39696815c68d20cae615f3e271f5d406d75b866a930851a43f6506f3a593282c
3
+ size 628746
config/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "lr": 0.0001,
3
+ "dropout": 0.1,
4
+ "hidden_dim": 128,
5
+ "num_layers": 5,
6
+ "patience": 100,
7
+ "max_epochs": 200,
8
+ "batch_size": 64,
9
+ "seed": 0,
10
+ "add_or_mean": "mean",
11
+ "window_size": 15
12
+ }
predict.py CHANGED
@@ -2,9 +2,9 @@ from torch_geometric.data import Batch
2
  from torch_geometric.utils import from_rdmol
3
  import torch
4
 
5
- from src.model import GIN
6
- from src.preprocess import create_clean_mol_objects
7
- from src.seed import set_seed
8
 
9
  def predict(smiles_list):
10
  """
@@ -26,7 +26,7 @@ def predict(smiles_list):
26
 
27
  # setup model
28
  model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean")
29
- model_path = "./assets/best_gin_model.pt"
30
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
31
  print(f"Loaded model from {model_path}")
32
  model.to(DEVICE)
@@ -54,4 +54,4 @@ def predict(smiles_list):
54
  pred_dict = {t: 0.5 for t in TARGET_NAMES}
55
  predictions[smiles] = pred_dict
56
 
57
- return predictions
 
2
  from torch_geometric.utils import from_rdmol
3
  import torch
4
 
5
+ from utils.model import GIN
6
+ from utils.preprocess import create_clean_mol_objects
7
+ from utils.seed import set_seed
8
 
9
  def predict(smiles_list):
10
  """
 
26
 
27
  # setup model
28
  model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean")
29
+ model_path = "./checkpoints/model.pt"
30
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
31
  print(f"Loaded model from {model_path}")
32
  model.to(DEVICE)
 
54
  pred_dict = {t: 0.5 for t in TARGET_NAMES}
55
  predictions[smiles] = pred_dict
56
 
57
+ return predictions
train.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_geometric.loader import DataLoader
3
+ import torch_geometric
4
+ import numpy as np
5
+ import json
6
+ import os
7
+ from dotenv import load_dotenv
8
+
9
+ from utils.model import GIN
10
+ from utils.preprocess import get_graph_datasets
11
+ from utils.train_evaluate import train_model, evaluate, compute_roc_auc_avg_and_per_class
12
+ from utils.seed import set_seed
13
+
14
+
15
+ def train(config):
16
+ SEED=config["seed"]
17
+ set_seed(SEED)
18
+ best_model_path = "./checkpoints/model.pt"
19
+
20
+ # get dataloaders
21
+ print("Loading Datasets...")
22
+ torch_geometric.seed_everything(SEED)
23
+ token = os.getenv("TOKEN")
24
+ train_dataset, val_dataset = get_graph_datasets(token)
25
+ train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
26
+ val_loader = DataLoader(val_dataset, batch_size=config["batch_size"])
27
+
28
+ # initialize
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ model = GIN(num_features=9, num_classes=12, dropout=config["dropout"], hidden_dim=config["hidden_dim"], num_layers=config["num_layers"], add_or_mean=config["add_or_mean"]).to(device)
31
+ optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
32
+
33
+ # training loop
34
+ best_mean_auc = -float("inf")
35
+ best_mean_epoch = 0
36
+ aucs = []
37
+ window_size = config["window_size"]
38
+ epoch_checkpoints = {}
39
+ print("Starting Training...")
40
+
41
+ for epoch in range(0, config["max_epochs"]):
42
+ train_loss = train_model(model, train_loader, optimizer, device)
43
+ val_loss = evaluate(model, val_loader, device)
44
+ val_auc_per_class, val_auc_avg = compute_roc_auc_avg_and_per_class(model, val_loader, device)
45
+ aucs.append(val_auc_avg)
46
+
47
+ # log
48
+ if epoch % 10 == 0:
49
+ print(f"Epoch {epoch:03d} | "
50
+ f"Train Loss: {train_loss:.4f} | "
51
+ f"Val Loss: {val_loss:.4f} | "
52
+ f"Val ROC-AUC: {val_auc_avg:.4f}")
53
+
54
+ # store model parameters for this epoch in cache (on CPU to save GPU memory)
55
+ epoch_checkpoints[epoch] = {k: v.cpu() for k, v in model.state_dict().items()}
56
+
57
+ # keep cache size limited
58
+ if len(epoch_checkpoints) > window_size + 2:
59
+ oldest = min(epoch_checkpoints.keys())
60
+ del epoch_checkpoints[oldest]
61
+
62
+ # once we have enough epochs, compute rolling mean
63
+ if len(aucs) >= window_size:
64
+ current_window = aucs[-window_size:]
65
+ current_mean_auc = np.mean(current_window)
66
+ middle_epoch = epoch - window_size // 2
67
+
68
+ # check if current mean beats the best so far
69
+ if current_mean_auc > best_mean_auc:
70
+ best_mean_auc = current_mean_auc
71
+ best_mean_epoch = middle_epoch
72
+
73
+ # save only the best middle model
74
+ if middle_epoch in epoch_checkpoints:
75
+ torch.save(epoch_checkpoints[middle_epoch], best_model_path)
76
+ print(f"🟒 New best mean AUC = {best_mean_auc:.4f} "
77
+ f"(center epoch {best_mean_epoch}) β€” model saved!")
78
+
79
+ # early stopping based on best mean epoch
80
+ if epoch - best_mean_epoch >= config["patience"]:
81
+ print(f"β›” Early stopping at epoch {epoch}. "
82
+ f"Best mean AUC = {best_mean_auc:.4f} (center epoch {best_mean_epoch})")
83
+ break
84
+
85
+ print("best_smoothed_val_auc" + str(best_mean_auc) + ", best_middle_epoch" + str(best_mean_epoch))
86
+
87
+ if __name__ == "__main__":
88
+ with open("./config/config.json", "r") as f:
89
+ config = json.load(f)
90
+ load_dotenv()
91
+ train(config)
{src β†’ utils}/model.py RENAMED
@@ -7,7 +7,7 @@ import numpy as np
7
 
8
 
9
  class GIN(torch.nn.Module):
10
- def __init__(self, num_features, num_classes, dropout, hidden_dim=64, num_layers=5, add_or_mean="add"):
11
  super().__init__()
12
  self.num_layers = num_layers
13
  self.hidden_dim = hidden_dim
 
7
 
8
 
9
  class GIN(torch.nn.Module):
10
+ def __init__(self, num_features, num_classes, dropout, hidden_dim=128, num_layers=5, add_or_mean="add"):
11
  super().__init__()
12
  self.num_layers = num_layers
13
  self.hidden_dim = hidden_dim
{src β†’ utils}/preprocess.py RENAMED
@@ -7,6 +7,33 @@ from rdkit.Chem.MolStandardize import rdMolStandardize
7
  from rdkit import Chem
8
  from torch_geometric.data import InMemoryDataset
9
  from torch_geometric.utils import from_rdmol
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def create_clean_mol_objects(smiles: list[str]) -> tuple[list[Chem.Mol], np.ndarray]:
12
  """Create cleaned RDKit Mol objects from SMILES.
@@ -87,7 +114,7 @@ class Tox21Dataset(InMemoryDataset):
87
  self.data, self.slices = self.collate(data_list)
88
 
89
 
90
- def get_graph_dataset(filepath:str):
91
  """returns an InMemoryDataset that can be used in dataloaders
92
 
93
  Args:
@@ -96,6 +123,8 @@ def get_graph_dataset(filepath:str):
96
  Returns:
97
  Tox21Dataset: dataset for dataloaders
98
  """
99
- df = pd.read_csv(filepath)
100
- dataset = Tox21Dataset(df)
101
- return dataset
 
 
 
7
  from rdkit import Chem
8
  from torch_geometric.data import InMemoryDataset
9
  from torch_geometric.utils import from_rdmol
10
+ from datasets import load_dataset
11
+
12
+ def get_tox21_split(token, cvfold=None):
13
+ ds = load_dataset("tschouis/tox21", token=token)
14
+
15
+ train_df = ds["train"].to_pandas()
16
+ val_df = ds["validation"].to_pandas()
17
+
18
+ if cvfold is None:
19
+ return {
20
+ "train": train_df,
21
+ "validation": val_df
22
+ }
23
+
24
+ combined_df = pd.concat([train_df, val_df], ignore_index=True)
25
+ cvfold = float(cvfold)
26
+
27
+ # create new splits
28
+ cvfold = float(cvfold)
29
+ train_df = combined_df[combined_df.CVfold != cvfold]
30
+ val_df = combined_df[combined_df.CVfold == cvfold]
31
+
32
+ # exclude train mols that occur in the validation split
33
+ val_inchikeys = set(val_df["inchikey"])
34
+ train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)]
35
+
36
+ return {"train": train_df.reset_index(drop=True), "validation": val_df.reset_index(drop=True)}
37
 
38
  def create_clean_mol_objects(smiles: list[str]) -> tuple[list[Chem.Mol], np.ndarray]:
39
  """Create cleaned RDKit Mol objects from SMILES.
 
114
  self.data, self.slices = self.collate(data_list)
115
 
116
 
117
+ def get_graph_datasets(token):
118
  """returns an InMemoryDataset that can be used in dataloaders
119
 
120
  Args:
 
123
  Returns:
124
  Tox21Dataset: dataset for dataloaders
125
  """
126
+ datasets = get_tox21_split(token, cvfold=4)
127
+ train_df, val_df = datasets["train"], datasets["validation"]
128
+ train_dataset = Tox21Dataset(train_df)
129
+ val_dataset = Tox21Dataset(val_df)
130
+ return train_dataset, val_dataset
{src β†’ utils}/seed.py RENAMED
File without changes
utils/train_evaluate.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from sklearn.metrics import roc_auc_score
5
+
6
+ def masked_bce_loss(logits, labels, mask):
7
+ """
8
+ logits: [batch_size, num_classes] (raw outputs)
9
+ labels: [batch_size, num_classes] (0/1 with filler)
10
+ mask: [batch_size, num_classes] (True if label is valid)
11
+ """
12
+ criterion = nn.BCEWithLogitsLoss(reduction="none")
13
+ loss_raw = criterion(logits, labels)
14
+ loss = (loss_raw * mask.float()).sum() / mask.float().sum()
15
+ return loss
16
+
17
+ def train_model(model, loader, optimizer, device):
18
+ model.train()
19
+ total_loss = 0
20
+ for batch in loader:
21
+ batch = batch.to(device)
22
+
23
+ optimizer.zero_grad()
24
+ out = model(batch.x, batch.edge_index, batch.batch) # [num_graphs, num_classes]
25
+
26
+ loss = masked_bce_loss(out, batch.y, batch.mask)
27
+ loss.backward()
28
+ optimizer.step()
29
+
30
+ total_loss += loss.item() * batch.num_graphs
31
+ return total_loss / len(loader.dataset)
32
+
33
+
34
+ @torch.no_grad()
35
+ def evaluate(model, loader, device):
36
+ model.eval()
37
+ total_loss = 0
38
+ for batch in loader:
39
+ batch = batch.to(device)
40
+ out = model(batch.x, batch.edge_index, batch.batch)
41
+ loss = masked_bce_loss(out, batch.y, batch.mask)
42
+ total_loss += loss.item() * batch.num_graphs
43
+ return total_loss / len(loader.dataset)
44
+
45
+
46
+ @torch.no_grad()
47
+ def compute_roc_auc(model, loader, device):
48
+ model.eval()
49
+ y_true, y_pred, y_mask = [], [], []
50
+
51
+ for batch in loader:
52
+ batch = batch.to(device)
53
+ out = model(batch.x, batch.edge_index, batch.batch)
54
+
55
+ # Store predictions (sigmoid β†’ probabilities)
56
+ y_pred.append(torch.sigmoid(out).cpu())
57
+ y_true.append(batch.y.cpu())
58
+ y_mask.append(batch.mask.cpu())
59
+
60
+ # Concatenate across all batches
61
+ y_true = torch.cat(y_true, dim=0).numpy()
62
+ y_pred = torch.cat(y_pred, dim=0).numpy()
63
+ y_mask = torch.cat(y_mask, dim=0).numpy()
64
+
65
+ auc_list = []
66
+ for i in range(y_true.shape[1]): # per label
67
+ mask_i = y_mask[:, i].astype(bool)
68
+ if mask_i.sum() > 0: # at least one valid label
69
+ try:
70
+ auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
71
+ auc_list.append(auc)
72
+ except ValueError:
73
+ # happens if only one class present (all 0 or all 1)
74
+ pass
75
+
76
+ return np.mean(auc_list) if len(auc_list) > 0 else float("nan")
77
+
78
+ @torch.no_grad()
79
+ def compute_roc_auc_avg_and_per_class(model, loader, device):
80
+ model.eval()
81
+ y_true, y_pred, y_mask = [], [], []
82
+
83
+ with torch.no_grad():
84
+ for batch in loader:
85
+ batch = batch.to(device)
86
+ out = model(batch.x, batch.edge_index, batch.batch)
87
+
88
+ # Store predictions (sigmoid β†’ probabilities)
89
+ y_pred.append(torch.sigmoid(out).cpu())
90
+ y_true.append(batch.y.cpu())
91
+ y_mask.append(batch.mask.cpu())
92
+
93
+ # Concatenate across all batches
94
+ y_true = torch.cat(y_true, dim=0).numpy()
95
+ y_pred = torch.cat(y_pred, dim=0).numpy()
96
+ y_mask = torch.cat(y_mask, dim=0).numpy()
97
+
98
+ # Compute AUC per class
99
+ auc_list = []
100
+ for i in range(y_true.shape[1]):
101
+ mask_i = y_mask[:, i].astype(bool)
102
+ if mask_i.sum() > 0:
103
+ try:
104
+ auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
105
+ except ValueError:
106
+ auc = np.nan # in case only one class present
107
+ else:
108
+ auc = np.nan
109
+ auc_list.append(auc)
110
+
111
+ # Convert to numpy array for easier manipulation
112
+ auc_array = np.array(auc_list, dtype=np.float32)
113
+ mean_auc = np.nanmean(auc_array) # overall mean ignoring NaNs
114
+
115
+ # Return both per-class and mean
116
+ return auc_array, mean_auc