| """ |
| Evaluation script for the trained Q_theta scorer. |
| |
| Computes: |
| 1. Selectivity metrics (gap, ranking accuracy, AUC) |
| 2. DockQ correlation (Spearman/Pearson) |
| 3. Score distributions (violin plots) |
| 4. Best-of-K analysis (as function of K) |
| 5. Per-target breakdown |
| |
| Usage: |
| python code/scripts/evaluate.py \ |
| --target cam \ |
| --checkpoint checkpoints/Q_theta_phase2.pt \ |
| --data_dir data/processed \ |
| --gpu 7 |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import logging |
| import json |
| import numpy as np |
| import torch |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| from scipy.stats import spearmanr, pearsonr |
| from sklearn.metrics import roc_auc_score, roc_curve |
|
|
| _CODE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) |
| if _CODE_DIR not in sys.path: |
| sys.path.insert(0, _CODE_DIR) |
|
|
| from models.scorer import build_model |
| from data.dataset import TwoStateComplexDataset, collate_fn |
| from torch.utils.data import DataLoader |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def compute_best_of_k(pos_scores, K_values=None, threshold=0.7): |
| """ |
| Simulate best-of-K selection: what fraction of draws contain at least one good binder? |
| Assumes pos_scores are from a distribution of candidate binders for goal state X+. |
| """ |
| if K_values is None: |
| K_values = [1, 2, 5, 10, 20, 50, 100] |
| results = {} |
| n = len(pos_scores) |
| n_trials = 1000 |
|
|
| for K in K_values: |
| successes = 0 |
| for _ in range(n_trials): |
| idxs = np.random.choice(n, size=min(K, n), replace=False) |
| best_score = pos_scores[idxs].max() |
| if best_score >= threshold: |
| successes += 1 |
| results[K] = successes / n_trials |
|
|
| return results |
|
|
|
|
| def compute_selectivity_margin(pos_scores, neg_scores): |
| """Compute per-sample selectivity margin S_theta.""" |
| eps = 1e-6 |
| pos_logit = np.log(pos_scores.clip(eps, 1-eps) / (1-pos_scores).clip(eps)) |
| neg_logit = np.log(neg_scores.clip(eps, 1-eps) / (1-neg_scores).clip(eps)) |
| selectivity = pos_logit - np.log(np.exp(neg_logit) + 1e-8) |
| return selectivity |
|
|
|
|
| def plot_score_distributions(pos_scores, neg_scores, decoy_scores=None, |
| title='Score Distributions', outpath=None): |
| """Violin plot of score distributions for different complex types.""" |
| fig, ax = plt.subplots(figsize=(8, 6)) |
|
|
| data = [pos_scores, neg_scores] |
| labels = ['Positive\n(X+, Y)', 'Negative\n(X0, Y)'] |
| colors = ['#2196F3', '#F44336'] |
|
|
| if decoy_scores is not None and len(decoy_scores) > 0: |
| data.append(decoy_scores) |
| labels.append('Decoys\n(X+, Y~)') |
| colors.append('#FF9800') |
|
|
| parts = ax.violinplot(data, positions=range(len(data)), showmedians=True) |
| for i, (pc, c) in enumerate(zip(parts['bodies'], colors)): |
| pc.set_facecolor(c) |
| pc.set_alpha(0.7) |
|
|
| ax.set_xticks(range(len(data))) |
| ax.set_xticklabels(labels) |
| ax.set_ylabel('Q_theta Score', fontsize=12) |
| ax.set_title(title, fontsize=14) |
| ax.set_ylim(0, 1) |
| ax.axhline(0.5, color='gray', linestyle='--', alpha=0.5, label='Decision boundary') |
| ax.legend() |
|
|
| |
| for i, (d, c) in enumerate(zip(data, colors)): |
| ax.text(i, 0.02, f'μ={d.mean():.2f}\nσ={d.std():.2f}', |
| ha='center', fontsize=9, color=c) |
|
|
| plt.tight_layout() |
| if outpath: |
| plt.savefig(outpath, dpi=150, bbox_inches='tight') |
| logger.info(f"Saved plot to {outpath}") |
| plt.close() |
|
|
|
|
| def plot_roc_curve(labels, scores, title='ROC Curve', outpath=None): |
| """Plot ROC curve for positive vs negative classification.""" |
| fpr, tpr, _ = roc_curve(labels, scores) |
| auc = roc_auc_score(labels, scores) |
|
|
| fig, ax = plt.subplots(figsize=(6, 6)) |
| ax.plot(fpr, tpr, 'b-', lw=2, label=f'AUC = {auc:.3f}') |
| ax.plot([0, 1], [0, 1], 'k--', lw=1) |
| ax.set_xlabel('False Positive Rate') |
| ax.set_ylabel('True Positive Rate') |
| ax.set_title(title) |
| ax.legend() |
| plt.tight_layout() |
| if outpath: |
| plt.savefig(outpath, dpi=150, bbox_inches='tight') |
| plt.close() |
| return auc |
|
|
|
|
| def plot_best_of_k(results, outpath=None): |
| """Plot best-of-K success rate as a function of K.""" |
| Ks = sorted(results.keys()) |
| success_rates = [results[K] for K in Ks] |
|
|
| fig, ax = plt.subplots(figsize=(8, 5)) |
| ax.semilogx(Ks, success_rates, 'b-o', lw=2, markersize=8) |
| ax.set_xlabel('K (number of candidates)', fontsize=12) |
| ax.set_ylabel('Success rate (best score > 0.7)', fontsize=12) |
| ax.set_title('Best-of-K Analysis', fontsize=14) |
| ax.set_ylim(0, 1.05) |
| ax.grid(True, alpha=0.3) |
| ax.axhline(0.8, color='red', linestyle='--', alpha=0.5, label='80% success') |
| ax.legend() |
| plt.tight_layout() |
| if outpath: |
| plt.savefig(outpath, dpi=150, bbox_inches='tight') |
| plt.close() |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device): |
| """Run model on a dataset and collect all predictions.""" |
| model.eval() |
| all_scores, all_labels, all_types, all_pdbs = [], [], [], [] |
|
|
| for batch in loader: |
| esm_feats = batch['esm_feats'].to(device) if 'esm_feats' in batch else None |
| scores = model( |
| batch['node_feats'].to(device), |
| batch['edge_feats'].to(device), |
| batch['node_mask'].to(device), |
| esm_feats=esm_feats, |
| ) |
| all_scores.extend(scores.cpu().numpy().tolist()) |
| all_labels.extend(batch['label'].numpy().tolist()) |
| all_types.extend(batch['type']) |
| all_pdbs.extend(batch['pdb']) |
|
|
| return (np.array(all_scores), np.array(all_labels), |
| np.array(all_types), np.array(all_pdbs)) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Evaluate Allo-Designer Q_theta scorer') |
| parser.add_argument('--target', default='cam', |
| help='Target name (cam, abl, era, or any custom target with data in data/processed/)') |
| parser.add_argument('--all_targets', action='store_true', |
| help='Evaluate on all available targets and produce aggregated results') |
| parser.add_argument('--checkpoint', required=True, help='Path to model checkpoint') |
| parser.add_argument('--data_dir', default='data/processed') |
| parser.add_argument('--split', choices=['val', 'test'], default='test') |
| parser.add_argument('--batch_size', type=int, default=32) |
| parser.add_argument('--gpu', type=int, default=7) |
| parser.add_argument('--outdir', default='results') |
| parser.add_argument('--bok_threshold', type=float, default=0.7, |
| help='Score threshold for best-of-K (default 0.7; use per-target value for calibrated results)') |
| parser.add_argument('--esm_dir', default=None, |
| help='Path to ESM-2 embedding cache (auto-detected at <data_dir>/esm2_embeddings if omitted)') |
| parser.add_argument('--no_wandb', action='store_true', help='(ignored; here for CLI compatibility)') |
| args = parser.parse_args() |
|
|
| |
| if args.esm_dir is None: |
| cand = os.path.join(args.data_dir, 'esm2_embeddings') |
| if os.path.isdir(cand): |
| args.esm_dir = cand |
|
|
| device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') |
| os.makedirs(args.outdir, exist_ok=True) |
| os.makedirs(f'{args.outdir}/figures', exist_ok=True) |
| os.makedirs(f'{args.outdir}/tables', exist_ok=True) |
|
|
| |
| state = torch.load(args.checkpoint, map_location=device) |
| config = state.get('config', {}) |
| model = build_model(config).to(device) |
| model.load_state_dict(state['model_state']) |
| logger.info(f"Loaded model from {args.checkpoint}") |
|
|
| |
| data_path = os.path.join(args.data_dir, args.target, f'{args.split}.pkl') |
| if not os.path.exists(data_path): |
| logger.error(f"Data not found: {data_path}") |
| sys.exit(1) |
|
|
| dataset = TwoStateComplexDataset(data_path, max_nodes=128, |
| esm_dir=args.esm_dir, target_name=args.target) |
| loader = DataLoader( |
| dataset, batch_size=args.batch_size, shuffle=False, |
| num_workers=2, collate_fn=collate_fn |
| ) |
|
|
| |
| logger.info(f"Evaluating on {len(dataset)} samples...") |
| scores, labels, types, pdbs = evaluate(model, loader, device) |
|
|
| |
| pos_mask = (types == 'positive') |
| neg_apo_mask = (types == 'negative_apo') |
| decoy_mask = np.array(['decoy' in t for t in types]) |
|
|
| pos_scores = scores[pos_mask] |
| neg_scores = scores[neg_apo_mask] |
| decoy_scores = scores[decoy_mask] |
|
|
| logger.info(f"\n{'='*50}") |
| logger.info(f"Results for {args.target} ({args.split})") |
| logger.info(f"{'='*50}") |
| logger.info(f"Positive samples: {pos_mask.sum()}") |
| logger.info(f"Negative (apo) samples: {neg_apo_mask.sum()}") |
| logger.info(f"Decoy samples: {decoy_mask.sum()}") |
|
|
| |
| metrics = {} |
|
|
| |
| sp, p_val = spearmanr(scores, labels) |
| metrics['spearman_all'] = float(sp) |
| metrics['spearman_pval'] = float(p_val) |
| logger.info(f"\nSpearman(Q_theta, DockQ): {sp:.3f} (p={p_val:.3e})") |
|
|
| |
| if pos_mask.sum() > 0 and neg_apo_mask.sum() > 0: |
| gap = float(pos_scores.mean() - neg_scores.mean()) |
| ranking_acc = float((pos_scores.mean() > neg_scores).mean() if len(neg_scores) > 0 else 0.5) |
| metrics['selectivity_gap'] = gap |
| metrics['pos_score_mean'] = float(pos_scores.mean()) |
| metrics['neg_score_mean'] = float(neg_scores.mean()) |
| metrics['pos_score_std'] = float(pos_scores.std()) |
| metrics['neg_score_std'] = float(neg_scores.std()) |
| logger.info(f"Selectivity gap (pos - neg): {gap:.3f}") |
| logger.info(f" Pos: {pos_scores.mean():.3f} ± {pos_scores.std():.3f}") |
| logger.info(f" Neg: {neg_scores.mean():.3f} ± {neg_scores.std():.3f}") |
|
|
| |
| if pos_mask.sum() > 0 and neg_apo_mask.sum() > 0: |
| pn_scores = np.concatenate([pos_scores, neg_scores]) |
| pn_labels = np.concatenate([np.ones(len(pos_scores)), np.zeros(len(neg_scores))]) |
| auc = roc_auc_score(pn_labels, pn_scores) |
| metrics['auc_pos_vs_neg'] = float(auc) |
| logger.info(f"AUC (pos vs neg_apo): {auc:.3f}") |
|
|
| |
| plot_roc_curve( |
| pn_labels, pn_scores, |
| title=f'ROC: Positive vs Negative Apo ({args.target.upper()})', |
| outpath=f'{args.outdir}/figures/roc_{args.target}_{args.split}.png' |
| ) |
|
|
| |
| binary = (labels > 0.5).astype(int) |
| if binary.sum() > 0 and binary.sum() < len(binary): |
| auc_quality = roc_auc_score(binary, scores) |
| metrics['auc_quality'] = float(auc_quality) |
| logger.info(f"AUC (quality>0.5): {auc_quality:.3f}") |
|
|
| |
| if len(pos_scores) > 0: |
| bok_results = compute_best_of_k(pos_scores, K_values=[1, 2, 5, 10, 20, 50], |
| threshold=args.bok_threshold) |
| metrics['best_of_k'] = {str(K): float(v) for K, v in bok_results.items()} |
| logger.info(f"\nBest-of-K success rates:") |
| for K, rate in bok_results.items(): |
| logger.info(f" K={K:3d}: {rate:.3f}") |
| plot_best_of_k( |
| bok_results, |
| outpath=f'{args.outdir}/figures/best_of_k_{args.target}_{args.split}.png' |
| ) |
|
|
| |
| plot_score_distributions( |
| pos_scores if len(pos_scores) > 0 else np.array([]), |
| neg_scores if len(neg_scores) > 0 else np.array([]), |
| decoy_scores if len(decoy_scores) > 0 else None, |
| title=f'Q_theta Score Distributions ({args.target.upper()})', |
| outpath=f'{args.outdir}/figures/score_dist_{args.target}_{args.split}.png' |
| ) |
|
|
| |
| out_json = f'{args.outdir}/tables/eval_{args.target}_{args.split}.json' |
| with open(out_json, 'w') as f: |
| json.dump(metrics, f, indent=2) |
| logger.info(f"\nSaved metrics to {out_json}") |
|
|
| |
| logger.info(f"\n{'='*50}") |
| logger.info("SUMMARY TABLE") |
| logger.info(f"{'='*50}") |
| logger.info(f"{'Metric':<30} {'Value':>10}") |
| logger.info(f"{'-'*42}") |
| for k, v in metrics.items(): |
| if isinstance(v, float): |
| logger.info(f"{k:<30} {v:>10.4f}") |
| logger.info(f"{'='*50}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|