| | import json |
| | from pathlib import Path |
| |
|
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import pandas as pd |
| | from tqdm import tqdm |
| |
|
| | METRICS = ["lddt", "bb_lddt", "tm_score", "rmsd"] |
| |
|
| |
|
| | def compute_af3_metrics(preds, evals, name): |
| | metrics = {} |
| |
|
| | top_model = None |
| | top_confidence = -1000 |
| | for model_id in range(5): |
| | |
| | confidence_file = ( |
| | Path(preds) / f"seed-1_sample-{model_id}" / "summary_confidences.json" |
| | ) |
| | with confidence_file.open("r") as f: |
| | confidence_data = json.load(f) |
| | confidence = confidence_data["ranking_score"] |
| | if confidence > top_confidence: |
| | top_model = model_id |
| | top_confidence = confidence |
| |
|
| | |
| | eval_file = Path(evals) / f"{name}_model_{model_id}.json" |
| | with eval_file.open("r") as f: |
| | eval_data = json.load(f) |
| | for metric_name in METRICS: |
| | if metric_name in eval_data: |
| | metrics.setdefault(metric_name, []).append(eval_data[metric_name]) |
| |
|
| | if "dockq" in eval_data and eval_data["dockq"] is not None: |
| | metrics.setdefault("dockq_>0.23", []).append( |
| | np.mean( |
| | [float(v > 0.23) for v in eval_data["dockq"] if v is not None] |
| | ) |
| | ) |
| | metrics.setdefault("dockq_>0.49", []).append( |
| | np.mean( |
| | [float(v > 0.49) for v in eval_data["dockq"] if v is not None] |
| | ) |
| | ) |
| | metrics.setdefault("len_dockq_", []).append( |
| | len([v for v in eval_data["dockq"] if v is not None]) |
| | ) |
| |
|
| | eval_file = Path(evals) / f"{name}_model_{model_id}_ligand.json" |
| | with eval_file.open("r") as f: |
| | eval_data = json.load(f) |
| | if "lddt_pli" in eval_data: |
| | lddt_plis = [ |
| | x["score"] for x in eval_data["lddt_pli"]["assigned_scores"] |
| | ] |
| | for _ in eval_data["lddt_pli"][ |
| | "model_ligand_unassigned_reason" |
| | ].items(): |
| | lddt_plis.append(0) |
| | if not lddt_plis: |
| | continue |
| | lddt_pli = np.mean([x for x in lddt_plis]) |
| | metrics.setdefault("lddt_pli", []).append(lddt_pli) |
| | metrics.setdefault("len_lddt_pli", []).append(len(lddt_plis)) |
| |
|
| | if "rmsd" in eval_data: |
| | rmsds = [x["score"] for x in eval_data["rmsd"]["assigned_scores"]] |
| | for _ in eval_data["rmsd"]["model_ligand_unassigned_reason"].items(): |
| | rmsds.append(100) |
| | if not rmsds: |
| | continue |
| | rmsd2 = np.mean([x < 2.0 for x in rmsds]) |
| | rmsd5 = np.mean([x < 5.0 for x in rmsds]) |
| | metrics.setdefault("rmsd<2", []).append(rmsd2) |
| | metrics.setdefault("rmsd<5", []).append(rmsd5) |
| | metrics.setdefault("len_rmsd", []).append(len(rmsds)) |
| |
|
| | |
| | oracle = {k: min(v) if k == "rmsd" else max(v) for k, v in metrics.items()} |
| | avg = {k: sum(v) / len(v) for k, v in metrics.items()} |
| | top1 = {k: v[top_model] for k, v in metrics.items()} |
| |
|
| | results = {} |
| | for metric_name in metrics: |
| | if metric_name.startswith("len_"): |
| | continue |
| | if metric_name == "lddt_pli": |
| | l = metrics["len_lddt_pli"][0] |
| | elif metric_name == "rmsd<2" or metric_name == "rmsd<5": |
| | l = metrics["len_rmsd"][0] |
| | elif metric_name == "dockq_>0.23" or metric_name == "dockq_>0.49": |
| | l = metrics["len_dockq_"][0] |
| | else: |
| | l = 1 |
| | results[metric_name] = { |
| | "oracle": oracle[metric_name], |
| | "average": avg[metric_name], |
| | "top1": top1[metric_name], |
| | "len": l, |
| | } |
| |
|
| | return results |
| |
|
| |
|
| | def compute_chai_metrics(preds, evals, name): |
| | metrics = {} |
| |
|
| | top_model = None |
| | top_confidence = 0 |
| | for model_id in range(5): |
| | |
| | confidence_file = Path(preds) / f"scores.model_idx_{model_id}.npz" |
| | confidence_data = np.load(confidence_file) |
| | confidence = confidence_data["aggregate_score"].item() |
| | if confidence > top_confidence: |
| | top_model = model_id |
| | top_confidence = confidence |
| |
|
| | |
| | eval_file = Path(evals) / f"{name}_model_{model_id}.json" |
| | with eval_file.open("r") as f: |
| | eval_data = json.load(f) |
| | for metric_name in METRICS: |
| | if metric_name in eval_data: |
| | metrics.setdefault(metric_name, []).append(eval_data[metric_name]) |
| |
|
| | if "dockq" in eval_data and eval_data["dockq"] is not None: |
| | metrics.setdefault("dockq_>0.23", []).append( |
| | np.mean( |
| | [float(v > 0.23) for v in eval_data["dockq"] if v is not None] |
| | ) |
| | ) |
| | metrics.setdefault("dockq_>0.49", []).append( |
| | np.mean( |
| | [float(v > 0.49) for v in eval_data["dockq"] if v is not None] |
| | ) |
| | ) |
| | metrics.setdefault("len_dockq_", []).append( |
| | len([v for v in eval_data["dockq"] if v is not None]) |
| | ) |
| |
|
| | eval_file = Path(evals) / f"{name}_model_{model_id}_ligand.json" |
| | with eval_file.open("r") as f: |
| | eval_data = json.load(f) |
| | if "lddt_pli" in eval_data: |
| | lddt_plis = [ |
| | x["score"] for x in eval_data["lddt_pli"]["assigned_scores"] |
| | ] |
| | for _ in eval_data["lddt_pli"][ |
| | "model_ligand_unassigned_reason" |
| | ].items(): |
| | lddt_plis.append(0) |
| | if not lddt_plis: |
| | continue |
| | lddt_pli = np.mean([x for x in lddt_plis]) |
| | metrics.setdefault("lddt_pli", []).append(lddt_pli) |
| | metrics.setdefault("len_lddt_pli", []).append(len(lddt_plis)) |
| |
|
| | if "rmsd" in eval_data: |
| | rmsds = [x["score"] for x in eval_data["rmsd"]["assigned_scores"]] |
| | for _ in eval_data["rmsd"]["model_ligand_unassigned_reason"].items(): |
| | rmsds.append(100) |
| | if not rmsds: |
| | continue |
| | rmsd2 = np.mean([x < 2.0 for x in rmsds]) |
| | rmsd5 = np.mean([x < 5.0 for x in rmsds]) |
| | metrics.setdefault("rmsd<2", []).append(rmsd2) |
| | metrics.setdefault("rmsd<5", []).append(rmsd5) |
| | metrics.setdefault("len_rmsd", []).append(len(rmsds)) |
| |
|
| | |
| | oracle = {k: min(v) if k == "rmsd" else max(v) for k, v in metrics.items()} |
| | avg = {k: sum(v) / len(v) for k, v in metrics.items()} |
| | top1 = {k: v[top_model] for k, v in metrics.items()} |
| |
|
| | results = {} |
| | for metric_name in metrics: |
| | if metric_name.startswith("len_"): |
| | continue |
| | if metric_name == "lddt_pli": |
| | l = metrics["len_lddt_pli"][0] |
| | elif metric_name == "rmsd<2" or metric_name == "rmsd<5": |
| | l = metrics["len_rmsd"][0] |
| | elif metric_name == "dockq_>0.23" or metric_name == "dockq_>0.49": |
| | l = metrics["len_dockq_"][0] |
| | else: |
| | l = 1 |
| | results[metric_name] = { |
| | "oracle": oracle[metric_name], |
| | "average": avg[metric_name], |
| | "top1": top1[metric_name], |
| | "len": l, |
| | } |
| |
|
| | return results |
| |
|
| |
|
| | def compute_boltz_metrics(preds, evals, name): |
| | metrics = {} |
| |
|
| | top_model = None |
| | top_confidence = 0 |
| | for model_id in range(5): |
| | |
| | confidence_file = ( |
| | Path(preds) / f"confidence_{Path(preds).name}_model_{model_id}.json" |
| | ) |
| | with confidence_file.open("r") as f: |
| | confidence_data = json.load(f) |
| | confidence = confidence_data["confidence_score"] |
| | if confidence > top_confidence: |
| | top_model = model_id |
| | top_confidence = confidence |
| |
|
| | |
| | eval_file = Path(evals) / f"{name}_model_{model_id}.json" |
| | with eval_file.open("r") as f: |
| | eval_data = json.load(f) |
| | for metric_name in METRICS: |
| | if metric_name in eval_data: |
| | metrics.setdefault(metric_name, []).append(eval_data[metric_name]) |
| |
|
| | if "dockq" in eval_data and eval_data["dockq"] is not None: |
| | metrics.setdefault("dockq_>0.23", []).append( |
| | np.mean( |
| | [float(v > 0.23) for v in eval_data["dockq"] if v is not None] |
| | ) |
| | ) |
| | metrics.setdefault("dockq_>0.49", []).append( |
| | np.mean( |
| | [float(v > 0.49) for v in eval_data["dockq"] if v is not None] |
| | ) |
| | ) |
| | metrics.setdefault("len_dockq_", []).append( |
| | len([v for v in eval_data["dockq"] if v is not None]) |
| | ) |
| |
|
| | eval_file = Path(evals) / f"{name}_model_{model_id}_ligand.json" |
| | with eval_file.open("r") as f: |
| | eval_data = json.load(f) |
| | if "lddt_pli" in eval_data: |
| | lddt_plis = [ |
| | x["score"] for x in eval_data["lddt_pli"]["assigned_scores"] |
| | ] |
| | for _ in eval_data["lddt_pli"][ |
| | "model_ligand_unassigned_reason" |
| | ].items(): |
| | lddt_plis.append(0) |
| | if not lddt_plis: |
| | continue |
| | lddt_pli = np.mean([x for x in lddt_plis]) |
| | metrics.setdefault("lddt_pli", []).append(lddt_pli) |
| | metrics.setdefault("len_lddt_pli", []).append(len(lddt_plis)) |
| |
|
| | if "rmsd" in eval_data: |
| | rmsds = [x["score"] for x in eval_data["rmsd"]["assigned_scores"]] |
| | for _ in eval_data["rmsd"]["model_ligand_unassigned_reason"].items(): |
| | rmsds.append(100) |
| | if not rmsds: |
| | continue |
| | rmsd2 = np.mean([x < 2.0 for x in rmsds]) |
| | rmsd5 = np.mean([x < 5.0 for x in rmsds]) |
| | metrics.setdefault("rmsd<2", []).append(rmsd2) |
| | metrics.setdefault("rmsd<5", []).append(rmsd5) |
| | metrics.setdefault("len_rmsd", []).append(len(rmsds)) |
| |
|
| | |
| | oracle = {k: min(v) if k == "rmsd" else max(v) for k, v in metrics.items()} |
| | avg = {k: sum(v) / len(v) for k, v in metrics.items()} |
| | top1 = {k: v[top_model] for k, v in metrics.items()} |
| |
|
| | results = {} |
| | for metric_name in metrics: |
| | if metric_name.startswith("len_"): |
| | continue |
| | if metric_name == "lddt_pli": |
| | l = metrics["len_lddt_pli"][0] |
| | elif metric_name == "rmsd<2" or metric_name == "rmsd<5": |
| | l = metrics["len_rmsd"][0] |
| | elif metric_name == "dockq_>0.23" or metric_name == "dockq_>0.49": |
| | l = metrics["len_dockq_"][0] |
| | else: |
| | l = 1 |
| | results[metric_name] = { |
| | "oracle": oracle[metric_name], |
| | "average": avg[metric_name], |
| | "top1": top1[metric_name], |
| | "len": l, |
| | } |
| |
|
| | return results |
| |
|
| |
|
| | def eval_models( |
| | chai_preds, |
| | chai_evals, |
| | af3_preds, |
| | af3_evals, |
| | boltz_preds, |
| | boltz_evals, |
| | boltz_preds_x, |
| | boltz_evals_x, |
| | ): |
| | |
| | chai_preds_names = { |
| | x.name.lower(): x |
| | for x in Path(chai_preds).iterdir() |
| | if not x.name.lower().startswith(".") |
| | } |
| | af3_preds_names = { |
| | x.name.lower(): x |
| | for x in Path(af3_preds).iterdir() |
| | if not x.name.lower().startswith(".") |
| | } |
| | boltz_preds_names = { |
| | x.name.lower(): x |
| | for x in Path(boltz_preds).iterdir() |
| | if not x.name.lower().startswith(".") |
| | } |
| | boltz_preds_names_x = { |
| | x.name.lower(): x |
| | for x in Path(boltz_preds_x).iterdir() |
| | if not x.name.lower().startswith(".") |
| | } |
| |
|
| | print("Chai preds", len(chai_preds_names)) |
| | print("Af3 preds", len(af3_preds_names)) |
| | print("Boltz preds", len(boltz_preds_names)) |
| | print("Boltzx preds", len(boltz_preds_names_x)) |
| |
|
| | common = ( |
| | set(chai_preds_names.keys()) |
| | & set(af3_preds_names.keys()) |
| | & set(boltz_preds_names.keys()) |
| | & set(boltz_preds_names_x.keys()) |
| | ) |
| |
|
| | |
| | keys_to_remove = ["t1133", "h1134", "r1134s1", "t1134s2", "t1121", "t1123", "t1159"] |
| | for key in keys_to_remove: |
| | if key in common: |
| | common.remove(key) |
| | print("Common", len(common)) |
| |
|
| | |
| | |
| | results = [] |
| | for name in tqdm(common): |
| | try: |
| | af3_results = compute_af3_metrics( |
| | af3_preds_names[name], |
| | af3_evals, |
| | name, |
| | ) |
| |
|
| | except Exception as e: |
| | import traceback |
| |
|
| | traceback.print_exc() |
| | print(f"Error evaluating AF3 {name}: {e}") |
| | continue |
| | try: |
| | chai_results = compute_chai_metrics( |
| | chai_preds_names[name], |
| | chai_evals, |
| | name, |
| | ) |
| | except Exception as e: |
| | import traceback |
| |
|
| | traceback.print_exc() |
| | print(f"Error evaluating Chai {name}: {e}") |
| | continue |
| | try: |
| | boltz_results = compute_boltz_metrics( |
| | boltz_preds_names[name], |
| | boltz_evals, |
| | name, |
| | ) |
| | except Exception as e: |
| | import traceback |
| |
|
| | traceback.print_exc() |
| | print(f"Error evaluating Boltz {name}: {e}") |
| | continue |
| |
|
| | try: |
| | boltz_results_x = compute_boltz_metrics( |
| | boltz_preds_names_x[name], |
| | boltz_evals_x, |
| | name, |
| | ) |
| | except Exception as e: |
| | import traceback |
| |
|
| | traceback.print_exc() |
| | print(f"Error evaluating Boltzx {name}: {e}") |
| | continue |
| |
|
| | for metric_name in af3_results: |
| | if metric_name in chai_results and metric_name in boltz_results: |
| | if ( |
| | ( |
| | af3_results[metric_name]["len"] |
| | == chai_results[metric_name]["len"] |
| | ) |
| | and ( |
| | af3_results[metric_name]["len"] |
| | == boltz_results[metric_name]["len"] |
| | ) |
| | and ( |
| | af3_results[metric_name]["len"] |
| | == boltz_results_x[metric_name]["len"] |
| | ) |
| | ): |
| | results.append( |
| | { |
| | "tool": "AF3 oracle", |
| | "target": name, |
| | "metric": metric_name, |
| | "value": af3_results[metric_name]["oracle"], |
| | } |
| | ) |
| | results.append( |
| | { |
| | "tool": "AF3 top-1", |
| | "target": name, |
| | "metric": metric_name, |
| | "value": af3_results[metric_name]["top1"], |
| | } |
| | ) |
| | results.append( |
| | { |
| | "tool": "Chai-1 oracle", |
| | "target": name, |
| | "metric": metric_name, |
| | "value": chai_results[metric_name]["oracle"], |
| | } |
| | ) |
| | results.append( |
| | { |
| | "tool": "Chai-1 top-1", |
| | "target": name, |
| | "metric": metric_name, |
| | "value": chai_results[metric_name]["top1"], |
| | } |
| | ) |
| | results.append( |
| | { |
| | "tool": "Boltz-1 oracle", |
| | "target": name, |
| | "metric": metric_name, |
| | "value": boltz_results[metric_name]["oracle"], |
| | } |
| | ) |
| | results.append( |
| | { |
| | "tool": "Boltz-1 top-1", |
| | "target": name, |
| | "metric": metric_name, |
| | "value": boltz_results[metric_name]["top1"], |
| | } |
| | ) |
| | results.append( |
| | { |
| | "tool": "Boltz-1x oracle", |
| | "target": name, |
| | "metric": metric_name, |
| | "value": boltz_results_x[metric_name]["oracle"], |
| | } |
| | ) |
| | results.append( |
| | { |
| | "tool": "Boltz-1x top-1", |
| | "target": name, |
| | "metric": metric_name, |
| | "value": boltz_results_x[metric_name]["top1"], |
| | } |
| | ) |
| | else: |
| | print( |
| | "Different lengths", |
| | name, |
| | metric_name, |
| | af3_results[metric_name]["len"], |
| | chai_results[metric_name]["len"], |
| | boltz_results[metric_name]["len"], |
| | boltz_results_x[metric_name]["len"], |
| | ) |
| | else: |
| | print( |
| | "Missing metric", |
| | name, |
| | metric_name, |
| | metric_name in chai_results, |
| | metric_name in boltz_results, |
| | metric_name in boltz_results_x, |
| | ) |
| |
|
| | |
| | df = pd.DataFrame(results) |
| | return df |
| |
|
| |
|
| | def eval_validity_checks(df): |
| | |
| | name_mapping = { |
| | "af3": "AF3 top-1", |
| | "chai": "Chai-1 top-1", |
| | "boltz1": "Boltz-1 top-1", |
| | "boltz1x": "Boltz-1x top-1", |
| | } |
| | top1 = df[df["model_idx"] == 0] |
| | top1 = top1[["tool", "pdb_id", "valid"]] |
| | top1["tool"] = top1["tool"].apply(lambda x: name_mapping[x]) |
| | top1 = top1.rename(columns={"tool": "tool", "pdb_id": "target", "valid": "value"}) |
| | top1["metric"] = "physical validity" |
| | top1["target"] = top1["target"].apply(lambda x: x.lower()) |
| | top1 = top1[["tool", "target", "metric", "value"]] |
| |
|
| | name_mapping = { |
| | "af3": "AF3 oracle", |
| | "chai": "Chai-1 oracle", |
| | "boltz1": "Boltz-1 oracle", |
| | "boltz1x": "Boltz-1x oracle", |
| | } |
| | oracle = df[["tool", "model_idx", "pdb_id", "valid"]] |
| | oracle = oracle.groupby(["tool", "pdb_id"])["valid"].max().reset_index() |
| | oracle = oracle.rename( |
| | columns={"tool": "tool", "pdb_id": "target", "valid": "value"} |
| | ) |
| | oracle["tool"] = oracle["tool"].apply(lambda x: name_mapping[x]) |
| | oracle["metric"] = "physical validity" |
| | oracle = oracle[["tool", "target", "metric", "value"]] |
| | oracle["target"] = oracle["target"].apply(lambda x: x.lower()) |
| | out = pd.concat([top1, oracle]) |
| | return out |
| |
|
| |
|
| | def bootstrap_ci(series, n_boot=1000, alpha=0.05): |
| | """ |
| | Compute 95% bootstrap confidence intervals for the mean of 'series'. |
| | """ |
| | n = len(series) |
| | boot_means = [] |
| | |
| | for _ in range(n_boot): |
| | sample = series.sample(n, replace=True) |
| | boot_means.append(sample.mean()) |
| |
|
| | boot_means = np.array(boot_means) |
| | mean_val = np.mean(series) |
| | lower = np.percentile(boot_means, 100 * alpha / 2) |
| | upper = np.percentile(boot_means, 100 * (1 - alpha / 2)) |
| | return mean_val, lower, upper |
| |
|
| |
|
| | def plot_data(desired_tools, desired_metrics, df, dataset, filename): |
| | filtered_df = df[ |
| | df["tool"].isin(desired_tools) & df["metric"].isin(desired_metrics) |
| | ] |
| |
|
| | |
| | boot_stats = filtered_df.groupby(["tool", "metric"])["value"].apply(bootstrap_ci) |
| |
|
| | |
| | boot_stats = boot_stats.apply(pd.Series) |
| | boot_stats.columns = ["mean", "lower", "upper"] |
| |
|
| | |
| | plot_data = boot_stats["mean"].unstack("tool") |
| | plot_data = plot_data.reindex(desired_metrics) |
| |
|
| | lower_data = boot_stats["lower"].unstack("tool") |
| | lower_data = lower_data.reindex(desired_metrics) |
| |
|
| | upper_data = boot_stats["upper"].unstack("tool") |
| | upper_data = upper_data.reindex(desired_metrics) |
| |
|
| | |
| | tool_order = [ |
| | "AF3 oracle", |
| | "AF3 top-1", |
| | "Chai-1 oracle", |
| | "Chai-1 top-1", |
| | "Boltz-1 oracle", |
| | "Boltz-1 top-1", |
| | "Boltz-1x oracle", |
| | "Boltz-1x top-1", |
| | ] |
| | plot_data = plot_data[tool_order] |
| | lower_data = lower_data[tool_order] |
| | upper_data = upper_data[tool_order] |
| |
|
| | |
| | renaming = { |
| | "lddt_pli": "Mean LDDT-PLI", |
| | "rmsd<2": "L-RMSD < 2A", |
| | "lddt": "Mean LDDT", |
| | "dockq_>0.23": "DockQ > 0.23", |
| | "physical validity": "Physical Validity", |
| | } |
| | plot_data = plot_data.rename(index=renaming) |
| | lower_data = lower_data.rename(index=renaming) |
| | upper_data = upper_data.rename(index=renaming) |
| | mean_vals = plot_data.values |
| |
|
| | |
| | tool_colors = [ |
| | "#994C00", |
| | "#FFB55A", |
| | "#931652", |
| | "#FC8AD9", |
| | "#188F52", |
| | "#86E935", |
| | "#004D80", |
| | "#55C2FF", |
| | ] |
| |
|
| | fig, ax = plt.subplots(figsize=(10, 5)) |
| |
|
| | x = np.arange(len(plot_data.index)) |
| | bar_spacing = 0.015 |
| | total_width = 0.7 |
| | |
| | width = (total_width - (len(tool_order) - 1) * bar_spacing) / len(tool_order) |
| |
|
| | for i, tool in enumerate(tool_order): |
| | |
| | offsets = x - (total_width - width) / 2 + i * (width + bar_spacing) |
| | |
| | tool_means = plot_data[tool].values |
| | tool_yerr_lower = mean_vals[:, i] - lower_data.values[:, i] |
| | tool_yerr_upper = upper_data.values[:, i] - mean_vals[:, i] |
| | |
| | tool_yerr = np.vstack([tool_yerr_lower, tool_yerr_upper]) |
| |
|
| | ax.bar( |
| | offsets, |
| | tool_means, |
| | width=width, |
| | color=tool_colors[i], |
| | label=tool, |
| | yerr=tool_yerr, |
| | capsize=2, |
| | error_kw={"elinewidth": 0.75}, |
| | ) |
| |
|
| | ax.set_xticks(x) |
| | ax.set_xticklabels(plot_data.index, rotation=0) |
| | ax.set_ylabel("Value") |
| | ax.set_title(f"Performances on {dataset} with 95% CI (Bootstrap)") |
| |
|
| | plt.tight_layout() |
| | ax.legend(loc="lower center", bbox_to_anchor=(0.5, 0.85), ncols=4, frameon=False) |
| |
|
| | plt.savefig(filename) |
| | plt.show() |
| |
|
| |
|
| | def main(): |
| | eval_folder = "../../boltz_results_final/" |
| | output_folder = "../../boltz_results_final/" |
| |
|
| | |
| | chai_preds = eval_folder + "outputs/test/chai" |
| | chai_evals = eval_folder + "evals/test/chai" |
| |
|
| | af3_preds = eval_folder + "outputs/test/af3" |
| | af3_evals = eval_folder + "evals/test/af3" |
| |
|
| | boltz_preds = eval_folder + "outputs/test/boltz/predictions" |
| | boltz_evals = eval_folder + "evals/test/boltz" |
| |
|
| | boltz_preds_x = eval_folder + "outputs/test/boltzx/predictions" |
| | boltz_evals_x = eval_folder + "evals/test/boltzx" |
| |
|
| | validity_checks = eval_folder + "physical_checks_test.csv" |
| |
|
| | df_validity_checks = pd.read_csv(validity_checks) |
| | df_validity_checks = eval_validity_checks(df_validity_checks) |
| |
|
| | df = eval_models( |
| | chai_preds, |
| | chai_evals, |
| | af3_preds, |
| | af3_evals, |
| | boltz_preds, |
| | boltz_evals, |
| | boltz_preds_x, |
| | boltz_evals_x, |
| | ) |
| |
|
| | df = pd.concat([df, df_validity_checks]).reset_index(drop=True) |
| | df.to_csv(output_folder + "results_test.csv", index=False) |
| |
|
| | desired_tools = [ |
| | "AF3 oracle", |
| | "AF3 top-1", |
| | "Chai-1 oracle", |
| | "Chai-1 top-1", |
| | "Boltz-1 oracle", |
| | "Boltz-1 top-1", |
| | "Boltz-1x oracle", |
| | "Boltz-1x top-1", |
| | ] |
| | desired_metrics = ["lddt", "dockq_>0.23", "lddt_pli", "rmsd<2", "physical validity"] |
| | plot_data( |
| | desired_tools, desired_metrics, df, "PDB Test", output_folder + "plot_test.pdf" |
| | ) |
| |
|
| | |
| | chai_preds = eval_folder + "outputs/casp15/chai" |
| | chai_evals = eval_folder + "evals/casp15/chai" |
| |
|
| | af3_preds = eval_folder + "outputs/casp15/af3" |
| | af3_evals = eval_folder + "evals/casp15/af3" |
| |
|
| | boltz_preds = eval_folder + "outputs/casp15/boltz/predictions" |
| | boltz_evals = eval_folder + "evals/casp15/boltz" |
| |
|
| | boltz_preds_x = eval_folder + "outputs/casp15/boltzx/predictions" |
| | boltz_evals_x = eval_folder + "evals/casp15/boltzx" |
| |
|
| | validity_checks = eval_folder + "physical_checks_casp.csv" |
| |
|
| | df_validity_checks = pd.read_csv(validity_checks) |
| | df_validity_checks = eval_validity_checks(df_validity_checks) |
| |
|
| | df = eval_models( |
| | chai_preds, |
| | chai_evals, |
| | af3_preds, |
| | af3_evals, |
| | boltz_preds, |
| | boltz_evals, |
| | boltz_preds_x, |
| | boltz_evals_x, |
| | ) |
| |
|
| | df = pd.concat([df, df_validity_checks]).reset_index(drop=True) |
| | df.to_csv(output_folder + "results_casp.csv", index=False) |
| |
|
| | plot_data( |
| | desired_tools, desired_metrics, df, "CASP15", output_folder + "plot_casp.pdf" |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|