CircleStar commited on
Commit
63e305e
·
verified ·
1 Parent(s): 2e34d29

Create metrics_utils.py

Browse files
Files changed (1) hide show
  1. metrics_utils.py +65 -0
metrics_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List
3
+
4
+ import matplotlib.pyplot as plt
5
+ import pandas as pd
6
+ from sklearn.metrics import (
7
+ accuracy_score,
8
+ f1_score,
9
+ classification_report,
10
+ confusion_matrix,
11
+ )
12
+
13
+ from config import FIGURE_DIR
14
+
15
+
16
+ def compute_classification_metrics(y_true, y_pred, class_names: List[str]) -> Dict:
17
+ acc = accuracy_score(y_true, y_pred)
18
+ f1_macro = f1_score(y_true, y_pred, average="macro", zero_division=0)
19
+ f1_weighted = f1_score(y_true, y_pred, average="weighted", zero_division=0)
20
+
21
+ report_dict = classification_report(
22
+ y_true,
23
+ y_pred,
24
+ target_names=class_names,
25
+ zero_division=0,
26
+ output_dict=True,
27
+ )
28
+
29
+ report_df = pd.DataFrame(report_dict).transpose().reset_index()
30
+ report_df = report_df.rename(columns={"index": "classe"})
31
+
32
+ cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))
33
+ cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)
34
+
35
+ return {
36
+ "accuracy": round(float(acc), 4),
37
+ "f1_macro": round(float(f1_macro), 4),
38
+ "f1_weighted": round(float(f1_weighted), 4),
39
+ "classification_report": report_df,
40
+ "confusion_matrix": cm_df,
41
+ }
42
+
43
+
44
+ def save_confusion_matrix_figure(cm_df: pd.DataFrame, model_name: str) -> str:
45
+ fig_path = os.path.join(FIGURE_DIR, f"{model_name}_confusion_matrix.png")
46
+
47
+ fig_width = max(8, min(20, 0.45 * len(cm_df.columns)))
48
+ fig_height = max(6, min(20, 0.45 * len(cm_df.index)))
49
+
50
+ plt.figure(figsize=(fig_width, fig_height))
51
+ plt.imshow(cm_df.values, interpolation="nearest")
52
+ plt.title("Matrice de confusion")
53
+ plt.colorbar()
54
+
55
+ tick_marks = range(len(cm_df.columns))
56
+ plt.xticks(tick_marks, cm_df.columns, rotation=90, fontsize=7)
57
+ plt.yticks(tick_marks, cm_df.index, fontsize=7)
58
+
59
+ plt.xlabel("Classe prédite")
60
+ plt.ylabel("Classe réelle")
61
+ plt.tight_layout()
62
+ plt.savefig(fig_path, dpi=200)
63
+ plt.close()
64
+
65
+ return fig_path