ReCEP / src /bce /utils /metrics.py
NielTT's picture
Upload 108 files
e611d1f verified
import numpy as np
from sklearn.metrics import (
mean_squared_error, mean_absolute_error, r2_score,
average_precision_score, roc_auc_score, f1_score,
precision_score, recall_score, matthews_corrcoef,
accuracy_score, confusion_matrix, roc_curve, precision_recall_curve
)
def calculate_graph_metrics(preds, labels, threshold=0.5):
"""
Calculate graph-level metrics for recall prediction.
Args:
preds: Predicted recall values (numpy array)
labels: True recall values (numpy array)
threshold: Threshold for binary classification (default: 0.5, was 0.7)
Returns:
Dictionary of metrics
"""
# Check for NaN values and replace with zeros
preds = np.nan_to_num(preds, nan=0.0, posinf=1.0, neginf=0.0)
labels = np.nan_to_num(labels, nan=0.0, posinf=1.0, neginf=0.0)
# Convert predictions to binary for classification metrics
pred_binary = (preds > threshold).astype(int)
label_binary = (labels > threshold).astype(int)
metrics = {}
# Classification metrics
if len(np.unique(label_binary)) > 1: # Check if both classes exist
metrics['recall'] = recall_score(label_binary, pred_binary, zero_division=0)
metrics['precision'] = precision_score(label_binary, pred_binary, zero_division=0)
metrics['mcc'] = matthews_corrcoef(label_binary, pred_binary)
metrics['f1'] = f1_score(label_binary, pred_binary, zero_division=0)
metrics['accuracy'] = accuracy_score(label_binary, pred_binary)
else:
metrics['recall'] = 0.0
metrics['precision'] = 0.0
metrics['mcc'] = 0.0
metrics['f1'] = 0.0
metrics['accuracy'] = 0.0
# Regression metrics
metrics['mse'] = mean_squared_error(labels, preds)
metrics['mae'] = mean_absolute_error(labels, preds)
metrics['r2'] = r2_score(labels, preds)
return metrics
def calculate_node_metrics(preds, labels, find_threshold=False, include_curves=False):
"""
Calculate node-level metrics for epitope prediction.
Args:
preds: Predicted probabilities (numpy array)
labels: True binary labels (numpy array)
find_threshold: If True, find the threshold that maximizes F1 score
include_curves: If True, include PR and ROC curves for visualization
Returns:
Dictionary of metrics including optimal threshold if find_threshold=True
"""
# Check for NaN values and replace with zeros
preds = np.nan_to_num(preds, nan=0.0, posinf=1.0, neginf=0.0)
labels = np.nan_to_num(labels, nan=0.0, posinf=1.0, neginf=0.0)
metrics = {}
# Check if both classes exist
if len(np.unique(labels)) > 1:
# AUROC and AUPRC (threshold-independent metrics)
try:
metrics['auroc'] = roc_auc_score(labels, preds)
metrics['auprc'] = average_precision_score(labels, preds)
# Include curves for visualization if requested
if include_curves:
# Calculate PR curve
precision_curve, recall_curve, _ = precision_recall_curve(labels, preds)
metrics['pr_curve'] = {
'precision': precision_curve,
'recall': recall_curve
}
# Calculate ROC curve
fpr, tpr, _ = roc_curve(labels, preds)
metrics['roc_curve'] = {
'fpr': fpr,
'tpr': tpr
}
else:
metrics['pr_curve'] = None
metrics['roc_curve'] = None
except:
metrics['auroc'] = 0.0
metrics['auprc'] = 0.0
metrics['pr_curve'] = None
metrics['roc_curve'] = None
# Find optimal threshold if requested
if find_threshold:
best_threshold, best_mcc = find_optimal_threshold(preds, labels)
metrics['best_threshold'] = best_threshold
threshold = best_threshold
else:
threshold = 0.5
metrics['best_threshold'] = 0.5
# Binary classification metrics using the determined threshold
pred_binary = (preds > threshold).astype(int)
metrics['f1'] = f1_score(labels, pred_binary, zero_division=0)
metrics['mcc'] = matthews_corrcoef(labels, pred_binary)
metrics['precision'] = precision_score(labels, pred_binary, zero_division=0)
metrics['recall'] = recall_score(labels, pred_binary, zero_division=0)
metrics['accuracy'] = accuracy_score(labels, pred_binary)
# Confusion matrix components
try:
tn, fp, fn, tp = confusion_matrix(labels, pred_binary).ravel()
metrics['true_positives'] = int(tp)
metrics['false_positives'] = int(fp)
metrics['true_negatives'] = int(tn)
metrics['false_negatives'] = int(fn)
except:
metrics['true_positives'] = 0
metrics['false_positives'] = 0
metrics['true_negatives'] = 0
metrics['false_negatives'] = 0
# Store the threshold used for these metrics
metrics['threshold_used'] = threshold
else:
# All metrics are 0 if only one class exists
metrics['auroc'] = 0.0
metrics['auprc'] = 0.0
metrics['f1'] = 0.0
metrics['mcc'] = 0.0
metrics['precision'] = 0.0
metrics['recall'] = 0.0
metrics['accuracy'] = 0.0
metrics['best_threshold'] = 0.5
metrics['threshold_used'] = 0.5
metrics['true_positives'] = 0
metrics['false_positives'] = 0
metrics['true_negatives'] = 0
metrics['false_negatives'] = 0
metrics['pr_curve'] = None
metrics['roc_curve'] = None
return metrics
def find_optimal_threshold(preds, labels, num_thresholds=100):
"""
Find the threshold that maximizes F1 score.
Args:
preds: Predicted probabilities (numpy array)
labels: True binary labels (numpy array)
num_thresholds: Number of thresholds to test
Returns:
Tuple of (best_threshold, best_f1_score)
"""
# Generate threshold candidates
thresholds = np.linspace(0.01, 0.99, num_thresholds)
best_mcc = 0.0
best_threshold = 0.5
for threshold in thresholds:
pred_binary = (preds > threshold).astype(int)
mcc = matthews_corrcoef(labels, pred_binary)
if mcc > best_mcc:
best_mcc = mcc
best_threshold = threshold
return best_threshold, best_mcc