File size: 6,722 Bytes
e611d1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import numpy as np

from sklearn.metrics import (
    mean_squared_error, mean_absolute_error, r2_score, 
    average_precision_score, roc_auc_score, f1_score, 
    precision_score, recall_score, matthews_corrcoef,
    accuracy_score, confusion_matrix, roc_curve, precision_recall_curve
)

def calculate_graph_metrics(preds, labels, threshold=0.5):
    """
    Calculate graph-level metrics for recall prediction.
    
    Args:
        preds: Predicted recall values (numpy array)
        labels: True recall values (numpy array)  
        threshold: Threshold for binary classification (default: 0.5, was 0.7)
        
    Returns:
        Dictionary of metrics
    """
    # Check for NaN values and replace with zeros
    preds = np.nan_to_num(preds, nan=0.0, posinf=1.0, neginf=0.0)
    labels = np.nan_to_num(labels, nan=0.0, posinf=1.0, neginf=0.0)
    
    # Convert predictions to binary for classification metrics
    pred_binary = (preds > threshold).astype(int)
    label_binary = (labels > threshold).astype(int)
    
    metrics = {}
    
    # Classification metrics
    if len(np.unique(label_binary)) > 1:  # Check if both classes exist
        metrics['recall'] = recall_score(label_binary, pred_binary, zero_division=0)
        metrics['precision'] = precision_score(label_binary, pred_binary, zero_division=0)
        metrics['mcc'] = matthews_corrcoef(label_binary, pred_binary)
        metrics['f1'] = f1_score(label_binary, pred_binary, zero_division=0)
        metrics['accuracy'] = accuracy_score(label_binary, pred_binary)
    else:
        metrics['recall'] = 0.0
        metrics['precision'] = 0.0
        metrics['mcc'] = 0.0
        metrics['f1'] = 0.0
        metrics['accuracy'] = 0.0
    
    # Regression metrics
    metrics['mse'] = mean_squared_error(labels, preds)
    metrics['mae'] = mean_absolute_error(labels, preds)
    metrics['r2'] = r2_score(labels, preds)
    
    return metrics

def calculate_node_metrics(preds, labels, find_threshold=False, include_curves=False):
    """
    Calculate node-level metrics for epitope prediction.
    
    Args:
        preds: Predicted probabilities (numpy array)
        labels: True binary labels (numpy array)
        find_threshold: If True, find the threshold that maximizes F1 score
        include_curves: If True, include PR and ROC curves for visualization
        
    Returns:
        Dictionary of metrics including optimal threshold if find_threshold=True
    """
    # Check for NaN values and replace with zeros
    preds = np.nan_to_num(preds, nan=0.0, posinf=1.0, neginf=0.0)
    labels = np.nan_to_num(labels, nan=0.0, posinf=1.0, neginf=0.0)
    
    metrics = {}
    
    # Check if both classes exist
    if len(np.unique(labels)) > 1:
        # AUROC and AUPRC (threshold-independent metrics)
        try:
            metrics['auroc'] = roc_auc_score(labels, preds)
            metrics['auprc'] = average_precision_score(labels, preds)
            
            # Include curves for visualization if requested
            if include_curves:
                # Calculate PR curve
                precision_curve, recall_curve, _ = precision_recall_curve(labels, preds)
                metrics['pr_curve'] = {
                    'precision': precision_curve,
                    'recall': recall_curve
                }
                
                # Calculate ROC curve
                fpr, tpr, _ = roc_curve(labels, preds)
                metrics['roc_curve'] = {
                    'fpr': fpr,
                    'tpr': tpr
                }
            else:
                metrics['pr_curve'] = None
                metrics['roc_curve'] = None
                
        except:
            metrics['auroc'] = 0.0
            metrics['auprc'] = 0.0
            metrics['pr_curve'] = None
            metrics['roc_curve'] = None
        
        # Find optimal threshold if requested
        if find_threshold:
            best_threshold, best_mcc = find_optimal_threshold(preds, labels)
            metrics['best_threshold'] = best_threshold
            threshold = best_threshold
        else:
            threshold = 0.5
            metrics['best_threshold'] = 0.5
        
        # Binary classification metrics using the determined threshold
        pred_binary = (preds > threshold).astype(int)
        metrics['f1'] = f1_score(labels, pred_binary, zero_division=0)
        metrics['mcc'] = matthews_corrcoef(labels, pred_binary)
        metrics['precision'] = precision_score(labels, pred_binary, zero_division=0)
        metrics['recall'] = recall_score(labels, pred_binary, zero_division=0)
        metrics['accuracy'] = accuracy_score(labels, pred_binary)
        
        # Confusion matrix components
        try:
            tn, fp, fn, tp = confusion_matrix(labels, pred_binary).ravel()
            metrics['true_positives'] = int(tp)
            metrics['false_positives'] = int(fp)
            metrics['true_negatives'] = int(tn)
            metrics['false_negatives'] = int(fn)
        except:
            metrics['true_positives'] = 0
            metrics['false_positives'] = 0
            metrics['true_negatives'] = 0
            metrics['false_negatives'] = 0
        
        # Store the threshold used for these metrics
        metrics['threshold_used'] = threshold
        
    else:
        # All metrics are 0 if only one class exists
        metrics['auroc'] = 0.0
        metrics['auprc'] = 0.0
        metrics['f1'] = 0.0
        metrics['mcc'] = 0.0
        metrics['precision'] = 0.0
        metrics['recall'] = 0.0
        metrics['accuracy'] = 0.0
        metrics['best_threshold'] = 0.5
        metrics['threshold_used'] = 0.5
        metrics['true_positives'] = 0
        metrics['false_positives'] = 0
        metrics['true_negatives'] = 0
        metrics['false_negatives'] = 0
        metrics['pr_curve'] = None
        metrics['roc_curve'] = None
    
    return metrics

def find_optimal_threshold(preds, labels, num_thresholds=100):
    """
    Find the threshold that maximizes F1 score.
    
    Args:
        preds: Predicted probabilities (numpy array)
        labels: True binary labels (numpy array)
        num_thresholds: Number of thresholds to test
        
    Returns:
        Tuple of (best_threshold, best_f1_score)
    """
    # Generate threshold candidates
    thresholds = np.linspace(0.01, 0.99, num_thresholds)
    
    best_mcc = 0.0
    best_threshold = 0.5
    
    for threshold in thresholds:
        pred_binary = (preds > threshold).astype(int)
        mcc = matthews_corrcoef(labels, pred_binary)
        
        if mcc > best_mcc:
            best_mcc = mcc
            best_threshold = threshold
    
    return best_threshold, best_mcc