CircleStar commited on
Commit
22ca06d
·
verified ·
1 Parent(s): b6076a1

Update metrics_utils.py

Browse files
Files changed (1) hide show
  1. metrics_utils.py +23 -43
metrics_utils.py CHANGED
@@ -1,26 +1,26 @@
1
- import os
2
- from typing import Dict, List
3
-
4
- import matplotlib.pyplot as plt
5
- import pandas as pd
6
- from sklearn.metrics import (
7
- accuracy_score,
8
- f1_score,
9
- classification_report,
10
- confusion_matrix,
11
- )
12
-
13
- from config import FIGURE_DIR
14
-
15
-
16
  def compute_classification_metrics(y_true, y_pred, class_names: List[str]) -> Dict:
 
 
17
  acc = accuracy_score(y_true, y_pred)
18
- f1_macro = f1_score(y_true, y_pred, average="macro", zero_division=0)
19
- f1_weighted = f1_score(y_true, y_pred, average="weighted", zero_division=0)
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  report_dict = classification_report(
22
  y_true,
23
  y_pred,
 
24
  target_names=class_names,
25
  zero_division=0,
26
  output_dict=True,
@@ -29,7 +29,11 @@ def compute_classification_metrics(y_true, y_pred, class_names: List[str]) -> Di
29
  report_df = pd.DataFrame(report_dict).transpose().reset_index()
30
  report_df = report_df.rename(columns={"index": "classe"})
31
 
32
- cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))
 
 
 
 
33
  cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)
34
 
35
  return {
@@ -38,28 +42,4 @@ def compute_classification_metrics(y_true, y_pred, class_names: List[str]) -> Di
38
  "f1_weighted": round(float(f1_weighted), 4),
39
  "classification_report": report_df,
40
  "confusion_matrix": cm_df,
41
- }
42
-
43
-
44
- def save_confusion_matrix_figure(cm_df: pd.DataFrame, model_name: str) -> str:
45
- fig_path = os.path.join(FIGURE_DIR, f"{model_name}_confusion_matrix.png")
46
-
47
- fig_width = max(8, min(20, 0.45 * len(cm_df.columns)))
48
- fig_height = max(6, min(20, 0.45 * len(cm_df.index)))
49
-
50
- plt.figure(figsize=(fig_width, fig_height))
51
- plt.imshow(cm_df.values, interpolation="nearest")
52
- plt.title("Matrice de confusion")
53
- plt.colorbar()
54
-
55
- tick_marks = range(len(cm_df.columns))
56
- plt.xticks(tick_marks, cm_df.columns, rotation=90, fontsize=7)
57
- plt.yticks(tick_marks, cm_df.index, fontsize=7)
58
-
59
- plt.xlabel("Classe prédite")
60
- plt.ylabel("Classe réelle")
61
- plt.tight_layout()
62
- plt.savefig(fig_path, dpi=200)
63
- plt.close()
64
-
65
- return fig_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def compute_classification_metrics(y_true, y_pred, class_names: List[str]) -> Dict:
2
+ labels = list(range(len(class_names)))
3
+
4
  acc = accuracy_score(y_true, y_pred)
5
+ f1_macro = f1_score(
6
+ y_true,
7
+ y_pred,
8
+ labels=labels,
9
+ average="macro",
10
+ zero_division=0,
11
+ )
12
+ f1_weighted = f1_score(
13
+ y_true,
14
+ y_pred,
15
+ labels=labels,
16
+ average="weighted",
17
+ zero_division=0,
18
+ )
19
 
20
  report_dict = classification_report(
21
  y_true,
22
  y_pred,
23
+ labels=labels,
24
  target_names=class_names,
25
  zero_division=0,
26
  output_dict=True,
 
29
  report_df = pd.DataFrame(report_dict).transpose().reset_index()
30
  report_df = report_df.rename(columns={"index": "classe"})
31
 
32
+ cm = confusion_matrix(
33
+ y_true,
34
+ y_pred,
35
+ labels=labels,
36
+ )
37
  cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)
38
 
39
  return {
 
42
  "f1_weighted": round(float(f1_weighted), 4),
43
  "classification_report": report_df,
44
  "confusion_matrix": cm_df,
45
+ }