File size: 10,432 Bytes
142035a
e967c14
6354ea8
142035a
6354ea8
142035a
250a4a2
142035a
 
 
e967c14
 
 
 
 
 
 
 
 
6354ea8
e967c14
 
 
 
 
 
 
6354ea8
e967c14
 
 
6354ea8
e967c14
142035a
e967c14
 
6354ea8
e967c14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6354ea8
e967c14
 
 
 
6354ea8
e967c14
 
 
6354ea8
 
 
 
e967c14
 
 
 
6354ea8
142035a
 
 
e967c14
 
 
d837caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142035a
 
 
 
 
6354ea8
 
e967c14
6354ea8
e967c14
 
6354ea8
 
d837caf
6354ea8
142035a
d837caf
 
 
 
6354ea8
 
 
 
 
 
 
 
 
 
 
142035a
 
e967c14
 
142035a
6354ea8
142035a
e967c14
 
6354ea8
 
 
 
e967c14
6354ea8
e967c14
d837caf
142035a
 
 
 
e967c14
142035a
354bfc2
 
6354ea8
 
 
 
e967c14
 
6354ea8
 
 
 
 
 
 
 
142035a
 
 
 
 
 
 
 
 
 
354bfc2
 
6354ea8
142035a
e967c14
bf6f326
 
 
 
 
6354ea8
bf6f326
 
6354ea8
bf6f326
 
142035a
6354ea8
bf6f326
 
 
e967c14
6354ea8
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import logging
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from scipy.stats import pearsonr, spearmanr

logger = logging.getLogger(__name__)

SCORE_COLUMN_NAMES = {
    "confidence_score_boltz": "Boltz Confidence Score",
    "ptm_boltz": "Boltz pTM Score",
    "iptm_boltz": "Boltz ipTM Score",
    "complex_plddt_boltz": "Boltz Complex pLDDT",
    "complex_iplddt_boltz": "Boltz Complex ipLDDT",
    "complex_pde_boltz": "Boltz Complex pDE",
    "complex_ipde_boltz": "Boltz Complex ipDE",
    "interchain_pae_monomer": "AlphaFold2 GapTrick Interchain PAE",
    "interface_pae_monomer": "AlphaFold2 GapTrick Interface PAE",
    "overall_pae_monomer": "AlphaFold2 GapTrick Overall PAE",
    "interface_plddt_monomer": "AlphaFold2 GapTrick Interface pLDDT",
    "average_plddt_monomer": "AlphaFold2 GapTrick Average pLDDT",
    "ptm_monomer": "AlphaFold2 GapTrick pTM Score",
    "interface_ptm_monomer": "AlphaFold2 GapTrick Interface pTM",
    "interchain_pae_multimer": "AlphaFold2 Multimer Interchain PAE",
    "interface_pae_multimer": "AlphaFold2 Multimer Interface PAE",
    "overall_pae_multimer": "AlphaFold2 Multimer Overall PAE",
    "interface_plddt_multimer": "AlphaFold2 Multimer Interface pLDDT",
    "average_plddt_multimer": "AlphaFold2 Multimer Average pLDDT",
    "ptm_multimer": "AlphaFold2 Multimer pTM Score",
    "interface_ptm_multimer": "AlphaFold2 Multimer Interface pTM",
}

SCORE_COLUMNS = list(SCORE_COLUMN_NAMES.values())


def get_score_description(score: str) -> str:
    descriptions = {
        "Boltz Confidence Score": "The Boltz model confidence score provides an overall assessment of prediction quality (0-1, higher is better).",
        "Boltz pTM Score": "The Boltz model predicted TM-score (pTM) assesses the overall fold accuracy of the predicted structure (0-1, higher is better).",
        "Boltz ipTM Score": "The Boltz model interface pTM score (ipTM) specifically evaluates the accuracy of interface regions (0-1, higher is better).",
        "Boltz Complex pLDDT": "The Boltz model Complex pLDDT measures confidence in local structure predictions across the entire complex (0-100, higher is better).",
        "Boltz Complex ipLDDT": "The Boltz model Complex interface pLDDT (ipLDDT) focuses on confidence in interface region predictions (0-100, higher is better).",
        "Boltz Complex pDE": "The Boltz model Complex predicted distance error (pDE) estimates the confidence in predicted distances between residues (0-1, higher is better).",
        "Boltz Complex ipDE": "The Boltz model Complex interface pDE (ipDE) estimates confidence in predicted distances specifically at interfaces (0-1, higher is better).",
        "AlphaFold2 GapTrick Interchain PAE": "The AlphaFold2 GapTrick model interchain predicted aligned error (PAE) estimates position errors between chains in monomeric predictions (lower is better).",
        "AlphaFold2 GapTrick Interface PAE": "The AlphaFold2 GapTrick model interface PAE estimates position errors specifically at interfaces in monomeric predictions (lower is better).",
        "AlphaFold2 GapTrick Overall PAE": "The AlphaFold2 GapTrick model overall PAE estimates position errors across the entire structure in monomeric predictions (lower is better).",
        "AlphaFold2 GapTrick Interface pLDDT": "The AlphaFold2 GapTrick model interface pLDDT measures confidence in interface region predictions for monomeric models (0-100, higher is better).",
        "AlphaFold2 GapTrick Average pLDDT": "The AlphaFold2 GapTrick model average pLDDT provides the mean confidence across all residues in monomeric predictions (0-100, higher is better).",
        "AlphaFold2 GapTrick pTM Score": "The AlphaFold2 GapTrick model pTM score assesses overall fold accuracy in monomeric predictions (0-1, higher is better).",
        "AlphaFold2 GapTrick Interface pTM": "The AlphaFold2 GapTrick model interface pTM specifically evaluates accuracy of interface regions in monomeric predictions (0-1, higher is better).",
        "AlphaFold2 Multimer Interface PAE": "The AlphaFold2 Multimer model interface PAE estimates position errors specifically at interfaces in multimeric predictions (lower is better).",
        "AlphaFold2 Multimer Overall PAE": "The AlphaFold2 Multimer model overall PAE estimates position errors across the entire structure in multimeric predictions (lower is better).",
        "AlphaFold2 Multimer Interface pLDDT": "The AlphaFold2 Multimer model interface pLDDT measures confidence in interface region predictions for multimeric models (0-100, higher is better).",
        "AlphaFold2 Multimer Average pLDDT": "The AlphaFold2 Multimer model average pLDDT provides the mean confidence across all residues in multimeric predictions (0-100, higher is better).",
        "AlphaFold2 Multimer pTM Score": "The AlphaFold2 Multimer model pTM score assesses overall fold accuracy in multimeric predictions (0-1, higher is better).",
        "AlphaFold2 Multimer Interface pTM": "The AlphaFold2 Multimer model interface pTM specifically evaluates accuracy of interface regions in multimeric predictions (0-1, higher is better).",
    }
    return descriptions.get(score, "No description available for this score.")


def compute_correlation_data(
    spr_data_with_scores: pd.DataFrame, score_cols: list[str]
) -> pd.DataFrame:
    corr_data_file = Path("corr_data.csv")
    if corr_data_file.exists():
        logger.info(f"Loading correlation data from {corr_data_file}")
        return pd.read_csv(corr_data_file)

    corr_data = []
    spr_data_with_scores["log_kd"] = np.log10(spr_data_with_scores["KD (nM)"])
    kd_col = "KD (nM)"
    corr_funcs = {}
    corr_funcs["Spearman"] = spearmanr
    corr_funcs["Pearson"] = pearsonr
    for kd_col in ["KD (nM)", "log_kd"]:
        for correlation_type, corr_func in corr_funcs.items():
            for score_col in score_cols:
                logger.info(
                    f"Computing {correlation_type} correlation between {score_col} and {kd_col}"
                )
                res = corr_func(
                    spr_data_with_scores[kd_col], spr_data_with_scores[score_col]
                )
                logger.info(f"Correlation function: {corr_func}")
                correlation_value = res.statistic
                corr_data.append(
                    {
                        "correlation_type": correlation_type,
                        "kd_col": kd_col,
                        "score": score_col,
                        "correlation": correlation_value,
                        "p-value": res.pvalue,
                    }
                )

    corr_data = pd.DataFrame(corr_data)
    # Find the lines in corr_data with NaN values and remove them
    corr_data = corr_data[corr_data["correlation"].notna()]
    # Sort correlation data by correlation value
    corr_data = corr_data.sort_values("correlation", ascending=True)

    corr_data.to_csv("corr_data.csv", index=False)

    return corr_data


def plot_correlation_ranking(
    corr_data: pd.DataFrame, correlation_type: str, kd_col: str
) -> go.Figure:
    # Create bar plot of correlations
    data = corr_data[
        (corr_data["correlation_type"] == correlation_type)
        & (corr_data["kd_col"] == kd_col)
    ]
    corr_ranking_plot = go.Figure(
        data=[
            go.Bar(
                x=data["correlation"],
                y=data["score"],
                name=correlation_type,
                orientation="h",
                hovertemplate="<i>Score:</i> %{y}<br><i>Correlation:</i> %{x:.3f}<br>",
            )
        ]
    )
    corr_ranking_plot.update_layout(
        title="Correlation with Binding Affinity",
        yaxis_title="Score",
        xaxis_title=correlation_type,
        template="simple_white",
        showlegend=False,
    )
    return corr_ranking_plot


def fake_predict_and_correlate(
    spr_data_with_scores: pd.DataFrame, score_cols: list[str], main_cols: list[str]
) -> tuple[pd.DataFrame, go.Figure]:
    """Fake predict structures of all complexes and correlate the results."""

    corr_data = compute_correlation_data(spr_data_with_scores, score_cols)
    corr_ranking_plot = plot_correlation_ranking(corr_data, "Spearman", kd_col="KD (nM)")

    cols_to_show = main_cols[:]
    cols_to_show.extend(score_cols)

    corr_plot = make_regression_plot(spr_data_with_scores, score_cols[0], use_log=False)

    return spr_data_with_scores[cols_to_show].round(2), corr_ranking_plot, corr_plot


def make_regression_plot(
    spr_data_with_scores: pd.DataFrame, score: str, use_log: bool
) -> go.Figure:
    """Select the regression plot to display."""
    # corr_plot is a scatter plot of the regression between the binding affinity and each of the scores
    scatter = go.Scatter(
        x=spr_data_with_scores["KD (nM)"],
        y=spr_data_with_scores[score],
        name=f"Samples",
        mode="markers",  # Only show markers/dots, no lines
        hovertemplate="<i>Score:</i> %{y}<br><i>KD:</i> %{x:.2f}<br>",
        marker=dict(color="#1f77b4"),  # Set color to match default first color
    )
    corr_plot = go.Figure(data=scatter)
    corr_plot.update_layout(
        xaxis_title="KD (nM)",
        yaxis_title=score,
        template="simple_white",
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1,
        ),
        xaxis_type="log" if use_log else "linear",  # Set x-axis to logarithmic scale
    )
    # compute the regression line
    if use_log:
        # Take log of KD values for fitting
        x_vals = np.log10(spr_data_with_scores["KD (nM)"])
    else:
        x_vals = spr_data_with_scores["KD (nM)"]

    # Fit line to data
    corr_line = np.polyfit(x_vals, spr_data_with_scores[score], 1)

    # Generate x points for line
    corr_line_x = np.linspace(min(x_vals), max(x_vals), 100)
    corr_line_y = corr_line[0] * corr_line_x + corr_line[1]

    # Convert back from log space if needed
    if use_log:
        corr_line_x = 10**corr_line_x
    # add the regression line to the plot
    corr_plot.add_trace(
        go.Scatter(
            x=corr_line_x,
            y=corr_line_y,
            mode="lines",
            name=f"Regression line",
            line=dict(color="#1f77b4"),  # Set same color as scatter points
        )
    )
    return corr_plot