File size: 1,104 Bytes
ef83bf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from sklearn.metrics import roc_curve, precision_recall_curve
import matplotlib.pyplot as plt
import numpy as np


def eval_binary_classification(pred: np.array, true: np.array):
    plt.figure(figsize=(12, 6))
    eval_roc_curve(pred, true)
    eval_pr_curve(pred, true)
    plt.tight_layout()
    plt.show()


def eval_pr_curve(pred: np.array, true: np.array):
    precision, recall, _ = precision_recall_curve(true, pred)
    plt.subplot(1, 2, 1)
    plt.plot(recall, precision, label="Precision-Recall Curve", color="red")
    plt.ylim(0)
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Precision-Recall Curve")
    plt.legend(loc="lower right")


def eval_roc_curve(pred: np.array, true: np.array) -> None:
    false_pos_rate, true_pos_rate, _ = roc_curve(true, pred)
    plt.subplot(1, 2, 2)
    plt.plot(false_pos_rate, true_pos_rate, label="ROC Curve")
    plt.plot([0, 1], [0, 1], linestyle="--", label="Random Guessing Model")
    plt.title("ROC Curve vs. Random")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend(loc="lower right")