Paul Engstler
Initial commit
92f0e98
raw
history blame contribute delete
915 Bytes
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np
import numpy as np
def get_confusion_matrix_figure(y_true, y_pred, labels=None, title="Confusion matrix"):
"""
Returns a confusion matrix plot.
"""
plt.figure(dpi=600)
label_codes = np.arange(len(labels)) if labels is not None else None
cm = confusion_matrix(y_true, y_pred, labels=label_codes)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
values_format = None # Format specification for values in confusion matrix. If `None`, the format specification is 'd' or '.2g' whichever is shorter.
disp.plot(
include_values=True,
cmap=plt.cm.Blues, # 'viridis'
ax=None,
xticks_rotation='horizontal',
values_format=values_format
)
fig = disp.figure_
fig.suptitle(title)
return fig