| import os |
| import json |
| import argparse |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from torch.cuda.amp import autocast, GradScaler |
| from datasets import load_from_disk, DatasetDict |
| from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score |
| from scipy.stats import spearmanr |
| from lightning.pytorch import seed_everything |
| from typing import Dict, Optional |
|
|
| scaler_amp = GradScaler(enabled=torch.cuda.is_available()) |
|
|
| def load_split(dataset_path): |
| ds = load_from_disk(dataset_path) |
| if isinstance(ds, DatasetDict): |
| return ds["train"], ds["val"] |
| raise ValueError("Expected DatasetDict with 'train' and 'val' splits") |
|
|
| def infer_in_dim(ds) -> int: |
| return int(len(ds[0]["embedding"][0])) |
|
|
| def collate_unpooled_reg(batch): |
| lengths = [int(x["length"]) for x in batch] |
| Lmax = max(lengths) |
| H = len(batch[0]["embedding"][0]) |
|
|
| X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32) |
| M = torch.zeros(len(batch), Lmax, dtype=torch.bool) |
| y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) |
|
|
| for i, x in enumerate(batch): |
| emb = torch.tensor(x["embedding"], dtype=torch.float32) |
| L = emb.shape[0] |
| X[i, :L] = emb |
| if "attention_mask" in x: |
| m = torch.tensor(x["attention_mask"], dtype=torch.bool) |
| M[i, :L] = m[:L] |
| else: |
| M[i, :L] = True |
| return X, M, y |
|
|
| |
|
|
| class MaskedMeanPool(nn.Module): |
| def forward(self, X, M): |
| Mf = M.unsqueeze(-1).float() |
| return (X * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0) |
|
|
| class MLPRegressor(nn.Module): |
| def __init__(self, in_dim, hidden=512, dropout=0.1): |
| super().__init__() |
| self.pool = MaskedMeanPool() |
| self.net = nn.Sequential( |
| nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1) |
| ) |
| def forward(self, X, M): |
| return self.net(self.pool(X, M)).squeeze(-1) |
|
|
| class CNNRegressor(nn.Module): |
| def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1): |
| super().__init__() |
| blocks, ch = [], in_ch |
| for _ in range(layers): |
| blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)] |
| ch = c |
| self.conv = nn.Sequential(*blocks) |
| self.head = nn.Linear(c, 1) |
| def forward(self, X, M): |
| Y = self.conv(X.transpose(1, 2)).transpose(1, 2) |
| Mf = M.unsqueeze(-1).float() |
| return self.head((Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)).squeeze(-1) |
|
|
| class TransformerRegressor(nn.Module): |
| def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1): |
| super().__init__() |
| self.proj = nn.Linear(in_dim, d_model) |
| self.enc = nn.TransformerEncoder( |
| nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=ff, |
| dropout=dropout, batch_first=True, activation="gelu"), |
| num_layers=layers |
| ) |
| self.head = nn.Linear(d_model, 1) |
| def forward(self, X, M): |
| Z = self.enc(self.proj(X), src_key_padding_mask=~M) |
| Mf = M.unsqueeze(-1).float() |
| return self.head((Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)).squeeze(-1) |
|
|
| |
|
|
| def safe_spearmanr(y_true, y_pred): |
| rho = spearmanr(y_true, y_pred).correlation |
| return 0.0 if (rho is None or np.isnan(rho)) else float(rho) |
|
|
| def eval_regression(y_true, y_pred) -> Dict[str, float]: |
| try: |
| from sklearn.metrics import root_mean_squared_error |
| rmse = float(root_mean_squared_error(y_true, y_pred)) |
| except Exception: |
| rmse = float(np.sqrt(mean_squared_error(y_true, y_pred))) |
| return { |
| "spearman_rho": safe_spearmanr(y_true, y_pred), |
| "rmse": rmse, |
| "mae": float(mean_absolute_error(y_true, y_pred)), |
| "r2": float(r2_score(y_true, y_pred)), |
| } |
|
|
| def score_from_metrics(metrics, objective): |
| return {"spearman": metrics["spearman_rho"], |
| "neg_rmse": -metrics["rmse"], |
| "r2": metrics["r2"]}[objective] |
|
|
| @torch.no_grad() |
| def eval_preds(model, loader, device): |
| model.eval() |
| ys, ps = [], [] |
| for X, M, y in loader: |
| X, M = X.to(device), M.to(device) |
| ps.append(model(X, M).cpu().numpy()) |
| ys.append(y.numpy()) |
| return np.concatenate(ys), np.concatenate(ps) |
|
|
| def train_one_epoch(model, loader, optim, criterion, device): |
| model.train() |
| for X, M, y in loader: |
| X, M, y = X.to(device), M.to(device), y.to(device) |
| optim.zero_grad(set_to_none=True) |
| with autocast(enabled=torch.cuda.is_available()): |
| loss = criterion(model(X, M), y) |
| scaler_amp.scale(loss).backward() |
| scaler_amp.step(optim) |
| scaler_amp.update() |
|
|
| def build_model(model_name, in_dim, params): |
| dropout = float(params.get("dropout", 0.1)) |
| if model_name == "mlp": |
| return MLPRegressor(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout) |
| elif model_name == "cnn": |
| return CNNRegressor(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]), |
| layers=int(params["layers"]), dropout=dropout) |
| elif model_name == "transformer": |
| return TransformerRegressor(in_dim=in_dim, d_model=int(params["d_model"]), |
| nhead=int(params["nhead"]), layers=int(params["layers"]), |
| ff=int(params["ff"]), dropout=dropout) |
| raise ValueError(model_name) |
|
|
| |
|
|
| def refit_with_seed(dataset_path, base_out_dir, model_name, seed, |
| objective="spearman", device="cuda:0"): |
| model_path = os.path.join(base_out_dir, "best_model.pt") |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"No best_model.pt at {model_path}. Run Optuna first.") |
|
|
| checkpoint = torch.load(model_path, map_location="cpu") |
| best_params = checkpoint["best_params"] |
| print(f"Loaded best_params from {model_path}") |
| print(json.dumps(best_params, indent=2)) |
|
|
| seed_everything(seed) |
| out_dir = os.path.join(base_out_dir, f"seed_{seed}") |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| train_ds, val_ds = load_split(dataset_path) |
| print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}") |
|
|
| batch_size = int(best_params.get("batch_size", 32)) |
| train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, |
| collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True) |
| val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, |
| collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True) |
|
|
| in_dim = infer_in_dim(train_ds) |
| model = build_model(model_name, in_dim, best_params).to(device) |
|
|
| |
| loss_name = best_params.get("loss", "mse") |
| if loss_name == "mse": |
| criterion = nn.MSELoss() |
| else: |
| criterion = nn.HuberLoss(delta=float(best_params.get("huber_delta", 1.0))) |
|
|
| optim = torch.optim.AdamW(model.parameters(), |
| lr=float(best_params["lr"]), |
| weight_decay=float(best_params["weight_decay"])) |
|
|
| best_score, bad, patience = -1e18, 0, 15 |
| best_state, best_metrics = None, {} |
|
|
| for epoch in range(1, 201): |
| train_one_epoch(model, train_loader, optim, criterion, device) |
| y_true, y_pred = eval_preds(model, val_loader, device) |
| metrics = eval_regression(y_true, y_pred) |
| score = score_from_metrics(metrics, objective) |
|
|
| if score > best_score + 1e-6: |
| best_score = score |
| best_metrics = metrics |
| bad = 0 |
| best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} |
| else: |
| bad += 1 |
| if bad >= patience: |
| print(f"Early stopping at epoch {epoch}") |
| break |
|
|
| if best_state: |
| model.load_state_dict(best_state) |
|
|
| y_true_val, y_pred_val = eval_preds(model, val_loader, device) |
| final_metrics = eval_regression(y_true_val, y_pred_val) |
|
|
| df_val = pd.DataFrame({ |
| "y_true": y_true_val.astype(float), |
| "y_pred": y_pred_val.astype(float), |
| "residual": (y_true_val - y_pred_val).astype(float), |
| "abs_error": np.abs(y_true_val - y_pred_val).astype(float), |
| }) |
| if "sequence" in val_ds.column_names: |
| df_val.insert(0, "sequence", np.asarray(val_ds["sequence"])) |
| df_val.to_csv(os.path.join(out_dir, "val_predictions.csv"), index=False) |
|
|
| torch.save({"state_dict": model.state_dict(), "best_params": best_params, "seed": seed}, |
| os.path.join(out_dir, "model.pt")) |
|
|
| summary = {"model": model_name, "seed": seed, **{k: round(v, 6) for k, v in final_metrics.items()}} |
| with open(os.path.join(out_dir, "metrics.json"), "w") as f: |
| json.dump(summary, f, indent=2) |
|
|
| print(f"\n[Seed {seed}] rho={final_metrics['spearman_rho']:.4f} " |
| f"RMSE={final_metrics['rmse']:.4f} R2={final_metrics['r2']:.4f}") |
| return summary |
|
|
| |
|
|
| def aggregate_seed_results(base_out_dir, seeds): |
| """ |
| Aggregates across seed runs using: |
| - t-distribution 95% CI for Spearman rho, RMSE, R2, MAE |
| For rho > 0.9, use Fisher z-transform CI instead. |
| """ |
| from scipy import stats |
|
|
| records = [] |
| for seed in seeds: |
| p = os.path.join(base_out_dir, f"seed_{seed}", "metrics.json") |
| if os.path.exists(p): |
| records.append(json.load(open(p))) |
| else: |
| print(f"Warning: missing seed {seed}") |
|
|
| if not records: |
| raise ValueError("No seed results found.") |
|
|
| df = pd.DataFrame(records) |
| print("\nPer-seed results:") |
| print(df.to_string(index=False)) |
|
|
| summary_rows = [] |
| for metric in ["spearman_rho", "rmse", "mae", "r2"]: |
| vals = df[metric].values |
| n = len(vals) |
| mean = vals.mean() |
| std = vals.std(ddof=1) |
| se = std / np.sqrt(n) |
| t_crit = stats.t.ppf(0.975, df=n - 1) |
| ci = t_crit * se |
| row = { |
| "metric": metric, |
| "mean": round(mean, 4), |
| "std": round(std, 4), |
| "ci_95": round(ci, 4), |
| "report": f"{mean:.4f} ± {ci:.4f}", |
| "n_seeds": n, |
| } |
| |
| if metric == "spearman_rho" and (mean + ci > 0.95 or mean - ci < -0.95): |
| row["note"] = "rho near boundary — consider Fisher z-transform CI" |
| summary_rows.append(row) |
|
|
| summary_df = pd.DataFrame(summary_rows) |
| out_path = os.path.join(base_out_dir, "seed_aggregated_metrics.csv") |
| summary_df.to_csv(out_path, index=False) |
|
|
| print("\n=== Aggregated Metrics (95% CI, t-distribution) ===") |
| for _, row in summary_df.iterrows(): |
| note = f" ← {row['note']}" if "note" in row and pd.notna(row.get("note")) else "" |
| print(f" {row['metric']:15s}: {row['report']}{note}") |
| print(f"\nSaved to {out_path}") |
| return summary_df |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--dataset_path", type=str, required=True) |
| parser.add_argument("--base_out_dir", type=str, required=True) |
| parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True) |
| parser.add_argument("--seed", type=int, required=True) |
| parser.add_argument("--objective", type=str, default="spearman", |
| choices=["spearman", "neg_rmse", "r2"]) |
| parser.add_argument("--aggregate", action="store_true") |
| parser.add_argument("--all_seeds", type=int, nargs="+", default=[1986, 42, 0, 123, 12345]) |
| args = parser.parse_args() |
|
|
| if args.aggregate: |
| aggregate_seed_results(args.base_out_dir, args.all_seeds) |
| else: |
| refit_with_seed( |
| dataset_path=args.dataset_path, |
| base_out_dir=args.base_out_dir, |
| model_name=args.model, |
| seed=args.seed, |
| objective=args.objective, |
| ) |