Nigl / src /visual_eval.py
Jensen-holm's picture
re trained the models on only tournament games and the ChalkSeedDiff
ef83bf7
raw
history blame
No virus
1.1 kB
from sklearn.metrics import roc_curve, precision_recall_curve
import matplotlib.pyplot as plt
import numpy as np
def eval_binary_classification(pred: np.array, true: np.array):
plt.figure(figsize=(12, 6))
eval_roc_curve(pred, true)
eval_pr_curve(pred, true)
plt.tight_layout()
plt.show()
def eval_pr_curve(pred: np.array, true: np.array):
precision, recall, _ = precision_recall_curve(true, pred)
plt.subplot(1, 2, 1)
plt.plot(recall, precision, label="Precision-Recall Curve", color="red")
plt.ylim(0)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend(loc="lower right")
def eval_roc_curve(pred: np.array, true: np.array) -> None:
false_pos_rate, true_pos_rate, _ = roc_curve(true, pred)
plt.subplot(1, 2, 2)
plt.plot(false_pos_rate, true_pos_rate, label="ROC Curve")
plt.plot([0, 1], [0, 1], linestyle="--", label="Random Guessing Model")
plt.title("ROC Curve vs. Random")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend(loc="lower right")