Spaces:
Running
Running
Sonja Topf
commited on
Commit
Β·
f0bc9a8
1
Parent(s):
e4f4cf0
big refactoring
Browse files- .example.env +1 -0
- .gitignore +8 -0
- {assets β checkpoints}/best_gin_model.pt +0 -0
- checkpoints/model.pt +3 -0
- config/config.json +12 -0
- predict.py +5 -5
- train.py +91 -0
- {src β utils}/model.py +1 -1
- {src β utils}/preprocess.py +33 -4
- {src β utils}/seed.py +0 -0
- utils/train_evaluate.py +116 -0
.example.env
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
TOKEN=example_token
|
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tox21_test.csv
|
| 2 |
+
results.csv
|
| 3 |
+
predict copy.py
|
| 4 |
+
hp_search/logs/*
|
| 5 |
+
hp_search/models/*
|
| 6 |
+
__pycache__
|
| 7 |
+
.env
|
| 8 |
+
notes.txt
|
{assets β checkpoints}/best_gin_model.pt
RENAMED
|
File without changes
|
checkpoints/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:39696815c68d20cae615f3e271f5d406d75b866a930851a43f6506f3a593282c
|
| 3 |
+
size 628746
|
config/config.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"lr": 0.0001,
|
| 3 |
+
"dropout": 0.1,
|
| 4 |
+
"hidden_dim": 128,
|
| 5 |
+
"num_layers": 5,
|
| 6 |
+
"patience": 100,
|
| 7 |
+
"max_epochs": 200,
|
| 8 |
+
"batch_size": 64,
|
| 9 |
+
"seed": 0,
|
| 10 |
+
"add_or_mean": "mean",
|
| 11 |
+
"window_size": 15
|
| 12 |
+
}
|
predict.py
CHANGED
|
@@ -2,9 +2,9 @@ from torch_geometric.data import Batch
|
|
| 2 |
from torch_geometric.utils import from_rdmol
|
| 3 |
import torch
|
| 4 |
|
| 5 |
-
from
|
| 6 |
-
from
|
| 7 |
-
from
|
| 8 |
|
| 9 |
def predict(smiles_list):
|
| 10 |
"""
|
|
@@ -26,7 +26,7 @@ def predict(smiles_list):
|
|
| 26 |
|
| 27 |
# setup model
|
| 28 |
model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean")
|
| 29 |
-
model_path = "./
|
| 30 |
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 31 |
print(f"Loaded model from {model_path}")
|
| 32 |
model.to(DEVICE)
|
|
@@ -54,4 +54,4 @@ def predict(smiles_list):
|
|
| 54 |
pred_dict = {t: 0.5 for t in TARGET_NAMES}
|
| 55 |
predictions[smiles] = pred_dict
|
| 56 |
|
| 57 |
-
return predictions
|
|
|
|
| 2 |
from torch_geometric.utils import from_rdmol
|
| 3 |
import torch
|
| 4 |
|
| 5 |
+
from utils.model import GIN
|
| 6 |
+
from utils.preprocess import create_clean_mol_objects
|
| 7 |
+
from utils.seed import set_seed
|
| 8 |
|
| 9 |
def predict(smiles_list):
|
| 10 |
"""
|
|
|
|
| 26 |
|
| 27 |
# setup model
|
| 28 |
model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean")
|
| 29 |
+
model_path = "./checkpoints/model.pt"
|
| 30 |
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 31 |
print(f"Loaded model from {model_path}")
|
| 32 |
model.to(DEVICE)
|
|
|
|
| 54 |
pred_dict = {t: 0.5 for t in TARGET_NAMES}
|
| 55 |
predictions[smiles] = pred_dict
|
| 56 |
|
| 57 |
+
return predictions
|
train.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch_geometric.loader import DataLoader
|
| 3 |
+
import torch_geometric
|
| 4 |
+
import numpy as np
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
from utils.model import GIN
|
| 10 |
+
from utils.preprocess import get_graph_datasets
|
| 11 |
+
from utils.train_evaluate import train_model, evaluate, compute_roc_auc_avg_and_per_class
|
| 12 |
+
from utils.seed import set_seed
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def train(config):
|
| 16 |
+
SEED=config["seed"]
|
| 17 |
+
set_seed(SEED)
|
| 18 |
+
best_model_path = "./checkpoints/model.pt"
|
| 19 |
+
|
| 20 |
+
# get dataloaders
|
| 21 |
+
print("Loading Datasets...")
|
| 22 |
+
torch_geometric.seed_everything(SEED)
|
| 23 |
+
token = os.getenv("TOKEN")
|
| 24 |
+
train_dataset, val_dataset = get_graph_datasets(token)
|
| 25 |
+
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
|
| 26 |
+
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"])
|
| 27 |
+
|
| 28 |
+
# initialize
|
| 29 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 30 |
+
model = GIN(num_features=9, num_classes=12, dropout=config["dropout"], hidden_dim=config["hidden_dim"], num_layers=config["num_layers"], add_or_mean=config["add_or_mean"]).to(device)
|
| 31 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
|
| 32 |
+
|
| 33 |
+
# training loop
|
| 34 |
+
best_mean_auc = -float("inf")
|
| 35 |
+
best_mean_epoch = 0
|
| 36 |
+
aucs = []
|
| 37 |
+
window_size = config["window_size"]
|
| 38 |
+
epoch_checkpoints = {}
|
| 39 |
+
print("Starting Training...")
|
| 40 |
+
|
| 41 |
+
for epoch in range(0, config["max_epochs"]):
|
| 42 |
+
train_loss = train_model(model, train_loader, optimizer, device)
|
| 43 |
+
val_loss = evaluate(model, val_loader, device)
|
| 44 |
+
val_auc_per_class, val_auc_avg = compute_roc_auc_avg_and_per_class(model, val_loader, device)
|
| 45 |
+
aucs.append(val_auc_avg)
|
| 46 |
+
|
| 47 |
+
# log
|
| 48 |
+
if epoch % 10 == 0:
|
| 49 |
+
print(f"Epoch {epoch:03d} | "
|
| 50 |
+
f"Train Loss: {train_loss:.4f} | "
|
| 51 |
+
f"Val Loss: {val_loss:.4f} | "
|
| 52 |
+
f"Val ROC-AUC: {val_auc_avg:.4f}")
|
| 53 |
+
|
| 54 |
+
# store model parameters for this epoch in cache (on CPU to save GPU memory)
|
| 55 |
+
epoch_checkpoints[epoch] = {k: v.cpu() for k, v in model.state_dict().items()}
|
| 56 |
+
|
| 57 |
+
# keep cache size limited
|
| 58 |
+
if len(epoch_checkpoints) > window_size + 2:
|
| 59 |
+
oldest = min(epoch_checkpoints.keys())
|
| 60 |
+
del epoch_checkpoints[oldest]
|
| 61 |
+
|
| 62 |
+
# once we have enough epochs, compute rolling mean
|
| 63 |
+
if len(aucs) >= window_size:
|
| 64 |
+
current_window = aucs[-window_size:]
|
| 65 |
+
current_mean_auc = np.mean(current_window)
|
| 66 |
+
middle_epoch = epoch - window_size // 2
|
| 67 |
+
|
| 68 |
+
# check if current mean beats the best so far
|
| 69 |
+
if current_mean_auc > best_mean_auc:
|
| 70 |
+
best_mean_auc = current_mean_auc
|
| 71 |
+
best_mean_epoch = middle_epoch
|
| 72 |
+
|
| 73 |
+
# save only the best middle model
|
| 74 |
+
if middle_epoch in epoch_checkpoints:
|
| 75 |
+
torch.save(epoch_checkpoints[middle_epoch], best_model_path)
|
| 76 |
+
print(f"π’ New best mean AUC = {best_mean_auc:.4f} "
|
| 77 |
+
f"(center epoch {best_mean_epoch}) β model saved!")
|
| 78 |
+
|
| 79 |
+
# early stopping based on best mean epoch
|
| 80 |
+
if epoch - best_mean_epoch >= config["patience"]:
|
| 81 |
+
print(f"β Early stopping at epoch {epoch}. "
|
| 82 |
+
f"Best mean AUC = {best_mean_auc:.4f} (center epoch {best_mean_epoch})")
|
| 83 |
+
break
|
| 84 |
+
|
| 85 |
+
print("best_smoothed_val_auc" + str(best_mean_auc) + ", best_middle_epoch" + str(best_mean_epoch))
|
| 86 |
+
|
| 87 |
+
if __name__ == "__main__":
|
| 88 |
+
with open("./config/config.json", "r") as f:
|
| 89 |
+
config = json.load(f)
|
| 90 |
+
load_dotenv()
|
| 91 |
+
train(config)
|
{src β utils}/model.py
RENAMED
|
@@ -7,7 +7,7 @@ import numpy as np
|
|
| 7 |
|
| 8 |
|
| 9 |
class GIN(torch.nn.Module):
|
| 10 |
-
def __init__(self, num_features, num_classes, dropout, hidden_dim=
|
| 11 |
super().__init__()
|
| 12 |
self.num_layers = num_layers
|
| 13 |
self.hidden_dim = hidden_dim
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class GIN(torch.nn.Module):
|
| 10 |
+
def __init__(self, num_features, num_classes, dropout, hidden_dim=128, num_layers=5, add_or_mean="add"):
|
| 11 |
super().__init__()
|
| 12 |
self.num_layers = num_layers
|
| 13 |
self.hidden_dim = hidden_dim
|
{src β utils}/preprocess.py
RENAMED
|
@@ -7,6 +7,33 @@ from rdkit.Chem.MolStandardize import rdMolStandardize
|
|
| 7 |
from rdkit import Chem
|
| 8 |
from torch_geometric.data import InMemoryDataset
|
| 9 |
from torch_geometric.utils import from_rdmol
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def create_clean_mol_objects(smiles: list[str]) -> tuple[list[Chem.Mol], np.ndarray]:
|
| 12 |
"""Create cleaned RDKit Mol objects from SMILES.
|
|
@@ -87,7 +114,7 @@ class Tox21Dataset(InMemoryDataset):
|
|
| 87 |
self.data, self.slices = self.collate(data_list)
|
| 88 |
|
| 89 |
|
| 90 |
-
def
|
| 91 |
"""returns an InMemoryDataset that can be used in dataloaders
|
| 92 |
|
| 93 |
Args:
|
|
@@ -96,6 +123,8 @@ def get_graph_dataset(filepath:str):
|
|
| 96 |
Returns:
|
| 97 |
Tox21Dataset: dataset for dataloaders
|
| 98 |
"""
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
| 7 |
from rdkit import Chem
|
| 8 |
from torch_geometric.data import InMemoryDataset
|
| 9 |
from torch_geometric.utils import from_rdmol
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
|
| 12 |
+
def get_tox21_split(token, cvfold=None):
|
| 13 |
+
ds = load_dataset("tschouis/tox21", token=token)
|
| 14 |
+
|
| 15 |
+
train_df = ds["train"].to_pandas()
|
| 16 |
+
val_df = ds["validation"].to_pandas()
|
| 17 |
+
|
| 18 |
+
if cvfold is None:
|
| 19 |
+
return {
|
| 20 |
+
"train": train_df,
|
| 21 |
+
"validation": val_df
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
combined_df = pd.concat([train_df, val_df], ignore_index=True)
|
| 25 |
+
cvfold = float(cvfold)
|
| 26 |
+
|
| 27 |
+
# create new splits
|
| 28 |
+
cvfold = float(cvfold)
|
| 29 |
+
train_df = combined_df[combined_df.CVfold != cvfold]
|
| 30 |
+
val_df = combined_df[combined_df.CVfold == cvfold]
|
| 31 |
+
|
| 32 |
+
# exclude train mols that occur in the validation split
|
| 33 |
+
val_inchikeys = set(val_df["inchikey"])
|
| 34 |
+
train_df = train_df[~train_df["inchikey"].isin(val_inchikeys)]
|
| 35 |
+
|
| 36 |
+
return {"train": train_df.reset_index(drop=True), "validation": val_df.reset_index(drop=True)}
|
| 37 |
|
| 38 |
def create_clean_mol_objects(smiles: list[str]) -> tuple[list[Chem.Mol], np.ndarray]:
|
| 39 |
"""Create cleaned RDKit Mol objects from SMILES.
|
|
|
|
| 114 |
self.data, self.slices = self.collate(data_list)
|
| 115 |
|
| 116 |
|
| 117 |
+
def get_graph_datasets(token):
|
| 118 |
"""returns an InMemoryDataset that can be used in dataloaders
|
| 119 |
|
| 120 |
Args:
|
|
|
|
| 123 |
Returns:
|
| 124 |
Tox21Dataset: dataset for dataloaders
|
| 125 |
"""
|
| 126 |
+
datasets = get_tox21_split(token, cvfold=4)
|
| 127 |
+
train_df, val_df = datasets["train"], datasets["validation"]
|
| 128 |
+
train_dataset = Tox21Dataset(train_df)
|
| 129 |
+
val_dataset = Tox21Dataset(val_df)
|
| 130 |
+
return train_dataset, val_dataset
|
{src β utils}/seed.py
RENAMED
|
File without changes
|
utils/train_evaluate.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from sklearn.metrics import roc_auc_score
|
| 5 |
+
|
| 6 |
+
def masked_bce_loss(logits, labels, mask):
|
| 7 |
+
"""
|
| 8 |
+
logits: [batch_size, num_classes] (raw outputs)
|
| 9 |
+
labels: [batch_size, num_classes] (0/1 with filler)
|
| 10 |
+
mask: [batch_size, num_classes] (True if label is valid)
|
| 11 |
+
"""
|
| 12 |
+
criterion = nn.BCEWithLogitsLoss(reduction="none")
|
| 13 |
+
loss_raw = criterion(logits, labels)
|
| 14 |
+
loss = (loss_raw * mask.float()).sum() / mask.float().sum()
|
| 15 |
+
return loss
|
| 16 |
+
|
| 17 |
+
def train_model(model, loader, optimizer, device):
|
| 18 |
+
model.train()
|
| 19 |
+
total_loss = 0
|
| 20 |
+
for batch in loader:
|
| 21 |
+
batch = batch.to(device)
|
| 22 |
+
|
| 23 |
+
optimizer.zero_grad()
|
| 24 |
+
out = model(batch.x, batch.edge_index, batch.batch) # [num_graphs, num_classes]
|
| 25 |
+
|
| 26 |
+
loss = masked_bce_loss(out, batch.y, batch.mask)
|
| 27 |
+
loss.backward()
|
| 28 |
+
optimizer.step()
|
| 29 |
+
|
| 30 |
+
total_loss += loss.item() * batch.num_graphs
|
| 31 |
+
return total_loss / len(loader.dataset)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@torch.no_grad()
|
| 35 |
+
def evaluate(model, loader, device):
|
| 36 |
+
model.eval()
|
| 37 |
+
total_loss = 0
|
| 38 |
+
for batch in loader:
|
| 39 |
+
batch = batch.to(device)
|
| 40 |
+
out = model(batch.x, batch.edge_index, batch.batch)
|
| 41 |
+
loss = masked_bce_loss(out, batch.y, batch.mask)
|
| 42 |
+
total_loss += loss.item() * batch.num_graphs
|
| 43 |
+
return total_loss / len(loader.dataset)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@torch.no_grad()
|
| 47 |
+
def compute_roc_auc(model, loader, device):
|
| 48 |
+
model.eval()
|
| 49 |
+
y_true, y_pred, y_mask = [], [], []
|
| 50 |
+
|
| 51 |
+
for batch in loader:
|
| 52 |
+
batch = batch.to(device)
|
| 53 |
+
out = model(batch.x, batch.edge_index, batch.batch)
|
| 54 |
+
|
| 55 |
+
# Store predictions (sigmoid β probabilities)
|
| 56 |
+
y_pred.append(torch.sigmoid(out).cpu())
|
| 57 |
+
y_true.append(batch.y.cpu())
|
| 58 |
+
y_mask.append(batch.mask.cpu())
|
| 59 |
+
|
| 60 |
+
# Concatenate across all batches
|
| 61 |
+
y_true = torch.cat(y_true, dim=0).numpy()
|
| 62 |
+
y_pred = torch.cat(y_pred, dim=0).numpy()
|
| 63 |
+
y_mask = torch.cat(y_mask, dim=0).numpy()
|
| 64 |
+
|
| 65 |
+
auc_list = []
|
| 66 |
+
for i in range(y_true.shape[1]): # per label
|
| 67 |
+
mask_i = y_mask[:, i].astype(bool)
|
| 68 |
+
if mask_i.sum() > 0: # at least one valid label
|
| 69 |
+
try:
|
| 70 |
+
auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
|
| 71 |
+
auc_list.append(auc)
|
| 72 |
+
except ValueError:
|
| 73 |
+
# happens if only one class present (all 0 or all 1)
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
return np.mean(auc_list) if len(auc_list) > 0 else float("nan")
|
| 77 |
+
|
| 78 |
+
@torch.no_grad()
|
| 79 |
+
def compute_roc_auc_avg_and_per_class(model, loader, device):
|
| 80 |
+
model.eval()
|
| 81 |
+
y_true, y_pred, y_mask = [], [], []
|
| 82 |
+
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
for batch in loader:
|
| 85 |
+
batch = batch.to(device)
|
| 86 |
+
out = model(batch.x, batch.edge_index, batch.batch)
|
| 87 |
+
|
| 88 |
+
# Store predictions (sigmoid β probabilities)
|
| 89 |
+
y_pred.append(torch.sigmoid(out).cpu())
|
| 90 |
+
y_true.append(batch.y.cpu())
|
| 91 |
+
y_mask.append(batch.mask.cpu())
|
| 92 |
+
|
| 93 |
+
# Concatenate across all batches
|
| 94 |
+
y_true = torch.cat(y_true, dim=0).numpy()
|
| 95 |
+
y_pred = torch.cat(y_pred, dim=0).numpy()
|
| 96 |
+
y_mask = torch.cat(y_mask, dim=0).numpy()
|
| 97 |
+
|
| 98 |
+
# Compute AUC per class
|
| 99 |
+
auc_list = []
|
| 100 |
+
for i in range(y_true.shape[1]):
|
| 101 |
+
mask_i = y_mask[:, i].astype(bool)
|
| 102 |
+
if mask_i.sum() > 0:
|
| 103 |
+
try:
|
| 104 |
+
auc = roc_auc_score(y_true[mask_i, i], y_pred[mask_i, i])
|
| 105 |
+
except ValueError:
|
| 106 |
+
auc = np.nan # in case only one class present
|
| 107 |
+
else:
|
| 108 |
+
auc = np.nan
|
| 109 |
+
auc_list.append(auc)
|
| 110 |
+
|
| 111 |
+
# Convert to numpy array for easier manipulation
|
| 112 |
+
auc_array = np.array(auc_list, dtype=np.float32)
|
| 113 |
+
mean_auc = np.nanmean(auc_array) # overall mean ignoring NaNs
|
| 114 |
+
|
| 115 |
+
# Return both per-class and mean
|
| 116 |
+
return auc_array, mean_auc
|