Joblib
PeptiVerse / training_classifiers /refit_regression_seed.py
ynuozhang
major update
04c2975
import os
import json
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from datasets import load_from_disk, DatasetDict
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy.stats import spearmanr
from lightning.pytorch import seed_everything
from typing import Dict, Optional
scaler_amp = GradScaler(enabled=torch.cuda.is_available())
def load_split(dataset_path):
ds = load_from_disk(dataset_path)
if isinstance(ds, DatasetDict):
return ds["train"], ds["val"]
raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
def infer_in_dim(ds) -> int:
return int(len(ds[0]["embedding"][0]))
def collate_unpooled_reg(batch):
lengths = [int(x["length"]) for x in batch]
Lmax = max(lengths)
H = len(batch[0]["embedding"][0])
X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
for i, x in enumerate(batch):
emb = torch.tensor(x["embedding"], dtype=torch.float32)
L = emb.shape[0]
X[i, :L] = emb
if "attention_mask" in x:
m = torch.tensor(x["attention_mask"], dtype=torch.bool)
M[i, :L] = m[:L]
else:
M[i, :L] = True
return X, M, y
# ======================== Models =========================================
class MaskedMeanPool(nn.Module):
def forward(self, X, M):
Mf = M.unsqueeze(-1).float()
return (X * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
class MLPRegressor(nn.Module):
def __init__(self, in_dim, hidden=512, dropout=0.1):
super().__init__()
self.pool = MaskedMeanPool()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1)
)
def forward(self, X, M):
return self.net(self.pool(X, M)).squeeze(-1)
class CNNRegressor(nn.Module):
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
super().__init__()
blocks, ch = [], in_ch
for _ in range(layers):
blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)]
ch = c
self.conv = nn.Sequential(*blocks)
self.head = nn.Linear(c, 1)
def forward(self, X, M):
Y = self.conv(X.transpose(1, 2)).transpose(1, 2)
Mf = M.unsqueeze(-1).float()
return self.head((Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)).squeeze(-1)
class TransformerRegressor(nn.Module):
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
super().__init__()
self.proj = nn.Linear(in_dim, d_model)
self.enc = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=ff,
dropout=dropout, batch_first=True, activation="gelu"),
num_layers=layers
)
self.head = nn.Linear(d_model, 1)
def forward(self, X, M):
Z = self.enc(self.proj(X), src_key_padding_mask=~M)
Mf = M.unsqueeze(-1).float()
return self.head((Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)).squeeze(-1)
# ======================== utils =========================================
def safe_spearmanr(y_true, y_pred):
rho = spearmanr(y_true, y_pred).correlation
return 0.0 if (rho is None or np.isnan(rho)) else float(rho)
def eval_regression(y_true, y_pred) -> Dict[str, float]:
try:
from sklearn.metrics import root_mean_squared_error
rmse = float(root_mean_squared_error(y_true, y_pred))
except Exception:
rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
return {
"spearman_rho": safe_spearmanr(y_true, y_pred),
"rmse": rmse,
"mae": float(mean_absolute_error(y_true, y_pred)),
"r2": float(r2_score(y_true, y_pred)),
}
def score_from_metrics(metrics, objective):
return {"spearman": metrics["spearman_rho"],
"neg_rmse": -metrics["rmse"],
"r2": metrics["r2"]}[objective]
@torch.no_grad()
def eval_preds(model, loader, device):
model.eval()
ys, ps = [], []
for X, M, y in loader:
X, M = X.to(device), M.to(device)
ps.append(model(X, M).cpu().numpy())
ys.append(y.numpy())
return np.concatenate(ys), np.concatenate(ps)
def train_one_epoch(model, loader, optim, criterion, device):
model.train()
for X, M, y in loader:
X, M, y = X.to(device), M.to(device), y.to(device)
optim.zero_grad(set_to_none=True)
with autocast(enabled=torch.cuda.is_available()):
loss = criterion(model(X, M), y)
scaler_amp.scale(loss).backward()
scaler_amp.step(optim)
scaler_amp.update()
def build_model(model_name, in_dim, params):
dropout = float(params.get("dropout", 0.1))
if model_name == "mlp":
return MLPRegressor(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout)
elif model_name == "cnn":
return CNNRegressor(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
layers=int(params["layers"]), dropout=dropout)
elif model_name == "transformer":
return TransformerRegressor(in_dim=in_dim, d_model=int(params["d_model"]),
nhead=int(params["nhead"]), layers=int(params["layers"]),
ff=int(params["ff"]), dropout=dropout)
raise ValueError(model_name)
# ======================== Refit Loop =========================================
def refit_with_seed(dataset_path, base_out_dir, model_name, seed,
objective="spearman", device="cuda:0"):
model_path = os.path.join(base_out_dir, "best_model.pt")
if not os.path.exists(model_path):
raise FileNotFoundError(f"No best_model.pt at {model_path}. Run Optuna first.")
checkpoint = torch.load(model_path, map_location="cpu")
best_params = checkpoint["best_params"]
print(f"Loaded best_params from {model_path}")
print(json.dumps(best_params, indent=2))
seed_everything(seed)
out_dir = os.path.join(base_out_dir, f"seed_{seed}")
os.makedirs(out_dir, exist_ok=True)
train_ds, val_ds = load_split(dataset_path)
print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
batch_size = int(best_params.get("batch_size", 32))
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
in_dim = infer_in_dim(train_ds)
model = build_model(model_name, in_dim, best_params).to(device)
# Loss
loss_name = best_params.get("loss", "mse")
if loss_name == "mse":
criterion = nn.MSELoss()
else:
criterion = nn.HuberLoss(delta=float(best_params.get("huber_delta", 1.0)))
optim = torch.optim.AdamW(model.parameters(),
lr=float(best_params["lr"]),
weight_decay=float(best_params["weight_decay"]))
best_score, bad, patience = -1e18, 0, 15
best_state, best_metrics = None, {}
for epoch in range(1, 201):
train_one_epoch(model, train_loader, optim, criterion, device)
y_true, y_pred = eval_preds(model, val_loader, device)
metrics = eval_regression(y_true, y_pred)
score = score_from_metrics(metrics, objective)
if score > best_score + 1e-6:
best_score = score
best_metrics = metrics
bad = 0
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
else:
bad += 1
if bad >= patience:
print(f"Early stopping at epoch {epoch}")
break
if best_state:
model.load_state_dict(best_state)
y_true_val, y_pred_val = eval_preds(model, val_loader, device)
final_metrics = eval_regression(y_true_val, y_pred_val)
df_val = pd.DataFrame({
"y_true": y_true_val.astype(float),
"y_pred": y_pred_val.astype(float),
"residual": (y_true_val - y_pred_val).astype(float),
"abs_error": np.abs(y_true_val - y_pred_val).astype(float),
})
if "sequence" in val_ds.column_names:
df_val.insert(0, "sequence", np.asarray(val_ds["sequence"]))
df_val.to_csv(os.path.join(out_dir, "val_predictions.csv"), index=False)
torch.save({"state_dict": model.state_dict(), "best_params": best_params, "seed": seed},
os.path.join(out_dir, "model.pt"))
summary = {"model": model_name, "seed": seed, **{k: round(v, 6) for k, v in final_metrics.items()}}
with open(os.path.join(out_dir, "metrics.json"), "w") as f:
json.dump(summary, f, indent=2)
print(f"\n[Seed {seed}] rho={final_metrics['spearman_rho']:.4f} "
f"RMSE={final_metrics['rmse']:.4f} R2={final_metrics['r2']:.4f}")
return summary
# ======================== CI aggregation =========================================
def aggregate_seed_results(base_out_dir, seeds):
"""
Aggregates across seed runs using:
- t-distribution 95% CI for Spearman rho, RMSE, R2, MAE
For rho > 0.9, use Fisher z-transform CI instead.
"""
from scipy import stats
records = []
for seed in seeds:
p = os.path.join(base_out_dir, f"seed_{seed}", "metrics.json")
if os.path.exists(p):
records.append(json.load(open(p)))
else:
print(f"Warning: missing seed {seed}")
if not records:
raise ValueError("No seed results found.")
df = pd.DataFrame(records)
print("\nPer-seed results:")
print(df.to_string(index=False))
summary_rows = []
for metric in ["spearman_rho", "rmse", "mae", "r2"]:
vals = df[metric].values
n = len(vals)
mean = vals.mean()
std = vals.std(ddof=1)
se = std / np.sqrt(n)
t_crit = stats.t.ppf(0.975, df=n - 1)
ci = t_crit * se
row = {
"metric": metric,
"mean": round(mean, 4),
"std": round(std, 4),
"ci_95": round(ci, 4),
"report": f"{mean:.4f} ± {ci:.4f}",
"n_seeds": n,
}
# Flag if rho is high enough that the t-CI boundary might exceed 1.0
if metric == "spearman_rho" and (mean + ci > 0.95 or mean - ci < -0.95):
row["note"] = "rho near boundary — consider Fisher z-transform CI"
summary_rows.append(row)
summary_df = pd.DataFrame(summary_rows)
out_path = os.path.join(base_out_dir, "seed_aggregated_metrics.csv")
summary_df.to_csv(out_path, index=False)
print("\n=== Aggregated Metrics (95% CI, t-distribution) ===")
for _, row in summary_df.iterrows():
note = f" ← {row['note']}" if "note" in row and pd.notna(row.get("note")) else ""
print(f" {row['metric']:15s}: {row['report']}{note}")
print(f"\nSaved to {out_path}")
return summary_df
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, required=True)
parser.add_argument("--base_out_dir", type=str, required=True)
parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True)
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--objective", type=str, default="spearman",
choices=["spearman", "neg_rmse", "r2"])
parser.add_argument("--aggregate", action="store_true")
parser.add_argument("--all_seeds", type=int, nargs="+", default=[1986, 42, 0, 123, 12345])
args = parser.parse_args()
if args.aggregate:
aggregate_seed_results(args.base_out_dir, args.all_seeds)
else:
refit_with_seed(
dataset_path=args.dataset_path,
base_out_dir=args.base_out_dir,
model_name=args.model,
seed=args.seed,
objective=args.objective,
)