File size: 10,432 Bytes
142035a e967c14 6354ea8 142035a 6354ea8 142035a 250a4a2 142035a e967c14 6354ea8 e967c14 6354ea8 e967c14 6354ea8 e967c14 142035a e967c14 6354ea8 e967c14 6354ea8 e967c14 6354ea8 e967c14 6354ea8 e967c14 6354ea8 142035a e967c14 d837caf 142035a 6354ea8 e967c14 6354ea8 e967c14 6354ea8 d837caf 6354ea8 142035a d837caf 6354ea8 142035a e967c14 142035a 6354ea8 142035a e967c14 6354ea8 e967c14 6354ea8 e967c14 d837caf 142035a e967c14 142035a 354bfc2 6354ea8 e967c14 6354ea8 142035a 354bfc2 6354ea8 142035a e967c14 bf6f326 6354ea8 bf6f326 6354ea8 bf6f326 142035a 6354ea8 bf6f326 e967c14 6354ea8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from scipy.stats import pearsonr, spearmanr
logger = logging.getLogger(__name__)
SCORE_COLUMN_NAMES = {
"confidence_score_boltz": "Boltz Confidence Score",
"ptm_boltz": "Boltz pTM Score",
"iptm_boltz": "Boltz ipTM Score",
"complex_plddt_boltz": "Boltz Complex pLDDT",
"complex_iplddt_boltz": "Boltz Complex ipLDDT",
"complex_pde_boltz": "Boltz Complex pDE",
"complex_ipde_boltz": "Boltz Complex ipDE",
"interchain_pae_monomer": "AlphaFold2 GapTrick Interchain PAE",
"interface_pae_monomer": "AlphaFold2 GapTrick Interface PAE",
"overall_pae_monomer": "AlphaFold2 GapTrick Overall PAE",
"interface_plddt_monomer": "AlphaFold2 GapTrick Interface pLDDT",
"average_plddt_monomer": "AlphaFold2 GapTrick Average pLDDT",
"ptm_monomer": "AlphaFold2 GapTrick pTM Score",
"interface_ptm_monomer": "AlphaFold2 GapTrick Interface pTM",
"interchain_pae_multimer": "AlphaFold2 Multimer Interchain PAE",
"interface_pae_multimer": "AlphaFold2 Multimer Interface PAE",
"overall_pae_multimer": "AlphaFold2 Multimer Overall PAE",
"interface_plddt_multimer": "AlphaFold2 Multimer Interface pLDDT",
"average_plddt_multimer": "AlphaFold2 Multimer Average pLDDT",
"ptm_multimer": "AlphaFold2 Multimer pTM Score",
"interface_ptm_multimer": "AlphaFold2 Multimer Interface pTM",
}
SCORE_COLUMNS = list(SCORE_COLUMN_NAMES.values())
def get_score_description(score: str) -> str:
descriptions = {
"Boltz Confidence Score": "The Boltz model confidence score provides an overall assessment of prediction quality (0-1, higher is better).",
"Boltz pTM Score": "The Boltz model predicted TM-score (pTM) assesses the overall fold accuracy of the predicted structure (0-1, higher is better).",
"Boltz ipTM Score": "The Boltz model interface pTM score (ipTM) specifically evaluates the accuracy of interface regions (0-1, higher is better).",
"Boltz Complex pLDDT": "The Boltz model Complex pLDDT measures confidence in local structure predictions across the entire complex (0-100, higher is better).",
"Boltz Complex ipLDDT": "The Boltz model Complex interface pLDDT (ipLDDT) focuses on confidence in interface region predictions (0-100, higher is better).",
"Boltz Complex pDE": "The Boltz model Complex predicted distance error (pDE) estimates the confidence in predicted distances between residues (0-1, higher is better).",
"Boltz Complex ipDE": "The Boltz model Complex interface pDE (ipDE) estimates confidence in predicted distances specifically at interfaces (0-1, higher is better).",
"AlphaFold2 GapTrick Interchain PAE": "The AlphaFold2 GapTrick model interchain predicted aligned error (PAE) estimates position errors between chains in monomeric predictions (lower is better).",
"AlphaFold2 GapTrick Interface PAE": "The AlphaFold2 GapTrick model interface PAE estimates position errors specifically at interfaces in monomeric predictions (lower is better).",
"AlphaFold2 GapTrick Overall PAE": "The AlphaFold2 GapTrick model overall PAE estimates position errors across the entire structure in monomeric predictions (lower is better).",
"AlphaFold2 GapTrick Interface pLDDT": "The AlphaFold2 GapTrick model interface pLDDT measures confidence in interface region predictions for monomeric models (0-100, higher is better).",
"AlphaFold2 GapTrick Average pLDDT": "The AlphaFold2 GapTrick model average pLDDT provides the mean confidence across all residues in monomeric predictions (0-100, higher is better).",
"AlphaFold2 GapTrick pTM Score": "The AlphaFold2 GapTrick model pTM score assesses overall fold accuracy in monomeric predictions (0-1, higher is better).",
"AlphaFold2 GapTrick Interface pTM": "The AlphaFold2 GapTrick model interface pTM specifically evaluates accuracy of interface regions in monomeric predictions (0-1, higher is better).",
"AlphaFold2 Multimer Interface PAE": "The AlphaFold2 Multimer model interface PAE estimates position errors specifically at interfaces in multimeric predictions (lower is better).",
"AlphaFold2 Multimer Overall PAE": "The AlphaFold2 Multimer model overall PAE estimates position errors across the entire structure in multimeric predictions (lower is better).",
"AlphaFold2 Multimer Interface pLDDT": "The AlphaFold2 Multimer model interface pLDDT measures confidence in interface region predictions for multimeric models (0-100, higher is better).",
"AlphaFold2 Multimer Average pLDDT": "The AlphaFold2 Multimer model average pLDDT provides the mean confidence across all residues in multimeric predictions (0-100, higher is better).",
"AlphaFold2 Multimer pTM Score": "The AlphaFold2 Multimer model pTM score assesses overall fold accuracy in multimeric predictions (0-1, higher is better).",
"AlphaFold2 Multimer Interface pTM": "The AlphaFold2 Multimer model interface pTM specifically evaluates accuracy of interface regions in multimeric predictions (0-1, higher is better).",
}
return descriptions.get(score, "No description available for this score.")
def compute_correlation_data(
spr_data_with_scores: pd.DataFrame, score_cols: list[str]
) -> pd.DataFrame:
corr_data_file = Path("corr_data.csv")
if corr_data_file.exists():
logger.info(f"Loading correlation data from {corr_data_file}")
return pd.read_csv(corr_data_file)
corr_data = []
spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
kd_col = "KD (nM)"
corr_funcs = {}
corr_funcs["Spearman"] = spearmanr
corr_funcs["Pearson"] = pearsonr
for kd_col in ["KD (nM)", "log_kd"]:
for correlation_type, corr_func in corr_funcs.items():
for score_col in score_cols:
logger.info(
f"Computing {correlation_type} correlation between {score_col} and {kd_col}"
)
res = corr_func(
spr_data_with_scores[kd_col], spr_data_with_scores[score_col]
)
logger.info(f"Correlation function: {corr_func}")
correlation_value = res.statistic
corr_data.append(
{
"correlation_type": correlation_type,
"kd_col": kd_col,
"score": score_col,
"correlation": correlation_value,
"p-value": res.pvalue,
}
)
corr_data = pd.DataFrame(corr_data)
# Find the lines in corr_data with NaN values and remove them
corr_data = corr_data[corr_data["correlation"].notna()]
# Sort correlation data by correlation value
corr_data = corr_data.sort_values("correlation", ascending=True)
corr_data.to_csv("corr_data.csv", index=False)
return corr_data
def plot_correlation_ranking(
corr_data: pd.DataFrame, correlation_type: str, kd_col: str
) -> go.Figure:
# Create bar plot of correlations
data = corr_data[
(corr_data["correlation_type"] == correlation_type)
& (corr_data["kd_col"] == kd_col)
]
corr_ranking_plot = go.Figure(
data=[
go.Bar(
x=data["correlation"],
y=data["score"],
name=correlation_type,
orientation="h",
hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>",
)
]
)
corr_ranking_plot.update_layout(
title="Correlation with Binding Affinity",
yaxis_title="Score",
xaxis_title=correlation_type,
template="simple_white",
showlegend=False,
)
return corr_ranking_plot
def fake_predict_and_correlate(
spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]
) -> tuple[pd.DataFrame, go.Figure]:
"""Fake predict structures of all complexes and correlate the results."""
corr_data = compute_correlation_data(spr_data_with_scores, score_cols)
corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman", kd_col="KD (nM)")
cols_to_show = main_cols[:]
cols_to_show.extend(score_cols)
corr_plot = make_regression_plot(spr_data_with_scores, score_cols[0], use_log=False)
return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot
def make_regression_plot(
spr_data_with_scores: pd.DataFrame, score: str, use_log: bool
) -> go.Figure:
"""Select the regression plot to display."""
# corr_plot is a scatter plot of the regression between the binding affinity and each of the scores
scatter = go.Scatter(
x=spr_data_with_scores["KD (nM)"],
y=spr_data_with_scores[score],
name=f"Samples",
mode="markers", # Only show markers/dots, no lines
hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>",
marker=dict(color="#1f77b4"), # Set color to match default first color
)
corr_plot = go.Figure(data=scatter)
corr_plot.update_layout(
xaxis_title="KD (nM)",
yaxis_title=score,
template="simple_white",
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
),
xaxis_type="log" if use_log else "linear", # Set x-axis to logarithmic scale
)
# compute the regression line
if use_log:
# Take log of KD values for fitting
x_vals = np.log10(spr_data_with_scores["KD (nM)"])
else:
x_vals = spr_data_with_scores["KD (nM)"]
# Fit line to data
corr_line = np.polyfit(x_vals, spr_data_with_scores[score], 1)
# Generate x points for line
corr_line_x = np.linspace(min(x_vals), max(x_vals), 100)
corr_line_y = corr_line[0] * corr_line_x + corr_line[1]
# Convert back from log space if needed
if use_log:
corr_line_x = 10**corr_line_x
# add the regression line to the plot
corr_plot.add_trace(
go.Scatter(
x=corr_line_x,
y=corr_line_y,
mode="lines",
name=f"Regression line",
line=dict(color="#1f77b4"), # Set same color as scatter points
)
)
return corr_plot
|