CircleStar commited on
Commit
958eb86
·
verified ·
1 Parent(s): e74b30d

Update metrics_utils.py

Browse files
Files changed (1) hide show
  1. metrics_utils.py +30 -9
metrics_utils.py CHANGED
@@ -1,9 +1,8 @@
1
- from typing import List, Dict
2
-
3
  import os
 
 
4
  import matplotlib.pyplot as plt
5
  import pandas as pd
6
-
7
  from sklearn.metrics import (
8
  accuracy_score,
9
  f1_score,
@@ -12,6 +11,8 @@ from sklearn.metrics import (
12
  )
13
 
14
  from config import FIGURE_DIR
 
 
15
  def compute_classification_metrics(y_true, y_pred, class_names: List[str]) -> Dict:
16
  labels = list(range(len(class_names)))
17
 
@@ -43,11 +44,7 @@ def compute_classification_metrics(y_true, y_pred, class_names: List[str]) -> Di
43
  report_df = pd.DataFrame(report_dict).transpose().reset_index()
44
  report_df = report_df.rename(columns={"index": "classe"})
45
 
46
- cm = confusion_matrix(
47
- y_true,
48
- y_pred,
49
- labels=labels,
50
- )
51
  cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)
52
 
53
  return {
@@ -56,4 +53,28 @@ def compute_classification_metrics(y_true, y_pred, class_names: List[str]) -> Di
56
  "f1_weighted": round(float(f1_weighted), 4),
57
  "classification_report": report_df,
58
  "confusion_matrix": cm_df,
59
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Dict, List
3
+
4
  import matplotlib.pyplot as plt
5
  import pandas as pd
 
6
  from sklearn.metrics import (
7
  accuracy_score,
8
  f1_score,
 
11
  )
12
 
13
  from config import FIGURE_DIR
14
+
15
+
16
  def compute_classification_metrics(y_true, y_pred, class_names: List[str]) -> Dict:
17
  labels = list(range(len(class_names)))
18
 
 
44
  report_df = pd.DataFrame(report_dict).transpose().reset_index()
45
  report_df = report_df.rename(columns={"index": "classe"})
46
 
47
+ cm = confusion_matrix(y_true, y_pred, labels=labels)
 
 
 
 
48
  cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)
49
 
50
  return {
 
53
  "f1_weighted": round(float(f1_weighted), 4),
54
  "classification_report": report_df,
55
  "confusion_matrix": cm_df,
56
+ }
57
+
58
+
59
+ def save_confusion_matrix_figure(cm_df: pd.DataFrame, model_name: str) -> str:
60
+ fig_path = os.path.join(FIGURE_DIR, f"{model_name}_confusion_matrix.png")
61
+
62
+ fig_width = max(8, min(24, 0.45 * len(cm_df.columns)))
63
+ fig_height = max(6, min(24, 0.45 * len(cm_df.index)))
64
+
65
+ plt.figure(figsize=(fig_width, fig_height))
66
+ plt.imshow(cm_df.values, interpolation="nearest")
67
+ plt.title("Matrice de confusion")
68
+ plt.colorbar()
69
+
70
+ tick_marks = range(len(cm_df.columns))
71
+ plt.xticks(tick_marks, cm_df.columns, rotation=90, fontsize=7)
72
+ plt.yticks(tick_marks, cm_df.index, fontsize=7)
73
+
74
+ plt.xlabel("Classe prédite")
75
+ plt.ylabel("Classe réelle")
76
+ plt.tight_layout()
77
+ plt.savefig(fig_path, dpi=200)
78
+ plt.close()
79
+
80
+ return fig_path