|
from sklearn.metrics import roc_curve, precision_recall_curve |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
|
|
def eval_binary_classification(pred: np.array, true: np.array): |
|
plt.figure(figsize=(12, 6)) |
|
eval_roc_curve(pred, true) |
|
eval_pr_curve(pred, true) |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
def eval_pr_curve(pred: np.array, true: np.array): |
|
precision, recall, _ = precision_recall_curve(true, pred) |
|
plt.subplot(1, 2, 1) |
|
plt.plot(recall, precision, label="Precision-Recall Curve", color="red") |
|
plt.ylim(0) |
|
plt.xlabel("Recall") |
|
plt.ylabel("Precision") |
|
plt.title("Precision-Recall Curve") |
|
plt.legend(loc="lower right") |
|
|
|
|
|
def eval_roc_curve(pred: np.array, true: np.array) -> None: |
|
false_pos_rate, true_pos_rate, _ = roc_curve(true, pred) |
|
plt.subplot(1, 2, 2) |
|
plt.plot(false_pos_rate, true_pos_rate, label="ROC Curve") |
|
plt.plot([0, 1], [0, 1], linestyle="--", label="Random Guessing Model") |
|
plt.title("ROC Curve vs. Random") |
|
plt.xlabel("False Positive Rate") |
|
plt.ylabel("True Positive Rate") |
|
plt.legend(loc="lower right") |
|
|