Yusuf commited on
Commit
ed657fc
·
1 Parent(s): 84cfdfc

per class accuracy

Browse files
dataPrep/helpers/transforms_loaders.py CHANGED
@@ -103,13 +103,15 @@ def make_dataset_loaders(dataset, seed, batch_size, test_size, aug_config, worke
103
  pin_memory=True,
104
  num_workers=workers
105
  )
 
106
 
107
  print(f"\nWorkers used in DataLoaders: {workers}\n")
108
 
109
  dataset_loaders = {
110
  "train": train_loader,
111
  "val": val_loader,
112
- "test": test_loader
 
113
  }
114
 
115
  return dataset_loaders
 
103
  pin_memory=True,
104
  num_workers=workers
105
  )
106
+ class_names = dataset.features['label'].names
107
 
108
  print(f"\nWorkers used in DataLoaders: {workers}\n")
109
 
110
  dataset_loaders = {
111
  "train": train_loader,
112
  "val": val_loader,
113
+ "test": test_loader,
114
+ "classNames": class_names
115
  }
116
 
117
  return dataset_loaders
testingModel/helpers/evaluation.py CHANGED
@@ -1,43 +1,88 @@
1
- import torch
2
- from torch.nn import CrossEntropyLoss
3
-
4
-
5
- """
6
- Evaluates a trained model on a dataloader that returns batches like:
7
- batch["image"] -> Tensor [B, 3, 256, 256]
8
- batch["label"] -> Tensor [B]
9
-
10
- Returns dict:
11
- { "accuracy": float, "loss": float }
12
- """
13
- def make_predictions(model, dataloader, device):
14
-
15
- model.eval()
16
- criterion = CrossEntropyLoss()
17
-
18
- total_loss = 0
19
- total_correct = 0
20
- total_samples = 0
21
-
22
- with torch.no_grad():
23
- for batch in dataloader:
24
-
25
- # Move tensors to device
26
- images = batch["image"].to(device)
27
- labels = batch["label"].to(device).long()
28
-
29
- # Forward pass
30
- outputs = model(images)
31
- loss = criterion(outputs, labels)
32
-
33
- total_loss += loss.item() * images.size(0)
34
- total_correct += (outputs.argmax(dim=1) == labels).sum().item()
35
- total_samples += labels.size(0)
36
-
37
- accuracy = total_correct / total_samples
38
- avg_loss = total_loss / total_samples
39
-
40
- return {
41
- "accuracy": accuracy,
42
- "loss": avg_loss,
43
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import CrossEntropyLoss
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+
7
+ """
8
+ Evaluates a trained model on a dataloader that returns batches like:
9
+ batch["image"] -> Tensor [B, 3, 256, 256]
10
+ batch["label"] -> Tensor [B]
11
+ """
12
+ def make_predictions(model, dataloader, device):
13
+
14
+ model.eval()
15
+ criterion = CrossEntropyLoss()
16
+
17
+ total_loss = 0
18
+ total_correct = 0
19
+ total_samples = 0
20
+
21
+ all_preds = []
22
+ all_labels = []
23
+
24
+ with torch.no_grad():
25
+ for batch in dataloader:
26
+
27
+ # Move tensors to device
28
+ images = batch["image"].to(device)
29
+ labels = batch["label"].to(device).long()
30
+
31
+ # Forward pass
32
+ outputs = model(images)
33
+ loss = criterion(outputs, labels)
34
+ preds = outputs.argmax(dim=1)
35
+
36
+ total_loss += loss.item() * images.size(0)
37
+ total_correct += (preds == labels).sum().item()
38
+ total_samples += labels.size(0)
39
+
40
+ # Accumulate all predictions and labels
41
+ all_preds.extend(preds.tolist())
42
+ all_labels.extend(labels.tolist())
43
+
44
+ accuracy = total_correct / total_samples
45
+ avg_loss = total_loss / total_samples
46
+
47
+ return {
48
+ "accuracy": accuracy,
49
+ "loss": avg_loss,
50
+ "predictions": np.array(all_preds),
51
+ "labels": np.array(all_labels),
52
+ }
53
+
54
+
55
+ # Computes per-class accuracies
56
+ def class_accuracies(labels, preds, num_classes):
57
+ correct = np.zeros(num_classes, dtype=int)
58
+ counts = np.zeros(num_classes, dtype=int)
59
+ accuracies = np.zeros(num_classes, dtype=float)
60
+
61
+ for true, pred in zip(labels, preds):
62
+ counts[true] += 1
63
+ if true == pred:
64
+ correct[true] += 1
65
+
66
+ # Calculate accuracies
67
+ for i in range(num_classes):
68
+ if counts[i] > 0:
69
+ accuracies[i] = round(correct[i] / counts[i], 4)
70
+ else:
71
+ accuracies[i] = 0.0
72
+
73
+ return accuracies
74
+
75
+
76
+ def plot_class_accuracies(accuracies, class_names):
77
+ fig, ax = plt.subplots(figsize=(12, 6))
78
+
79
+ ax.set_title("Per-Class Accuracy")
80
+ ax.set_xlabel("Class")
81
+ ax.set_ylabel("Accuracy")
82
+ ax.set_ylim(0, 1.0)
83
+ ax.bar(class_names, accuracies)
84
+
85
+ plt.xticks(rotation=90)
86
+ plt.tight_layout()
87
+
88
+ return fig
testingModel/run_testing.py CHANGED
@@ -4,7 +4,7 @@ from dataPrep.helpers.clearml_data import extract_latest_data_task
4
  import torch
5
  from models.modelOne import modelOne
6
  from models.modelTwo import BetterCNN
7
- from testingModel.helpers.evaluation import make_predictions
8
 
9
 
10
  # -------------- Load Data --------------
@@ -66,6 +66,27 @@ subset_results = make_predictions(model, test_subset, device)
66
  testing_logger.report_single_value(name="Test Subset Accuracy", value=subset_results["accuracy"])
67
  testing_logger.report_single_value(name="Test Subset Loss", value=subset_results["loss"])
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  # --------- Complete -----------------
71
  print("\n------ Testing Complete ------")
 
4
  import torch
5
  from models.modelOne import modelOne
6
  from models.modelTwo import BetterCNN
7
+ from testingModel.helpers.evaluation import make_predictions, class_accuracies, plot_class_accuracies
8
 
9
 
10
  # -------------- Load Data --------------
 
66
  testing_logger.report_single_value(name="Test Subset Accuracy", value=subset_results["accuracy"])
67
  testing_logger.report_single_value(name="Test Subset Loss", value=subset_results["loss"])
68
 
69
+ # Compute per-class accuracy
70
+ preds = subset_results["predictions"]
71
+ labels = subset_results["labels"]
72
+ class_acc = class_accuracies(
73
+ labels,
74
+ preds,
75
+ num_classes=testing_config["num_classes"]
76
+ )
77
+
78
+ # Plot with formatted class names
79
+ class_names = subset_loaders['classNames']
80
+ formatted_class_names = [" ".join(name.replace('_', ' ').split()) for name in class_names]
81
+ acc_fig = plot_class_accuracies(class_acc, formatted_class_names)
82
+
83
+ # Log accuracies plot to ClearML
84
+ testing_logger.report_matplotlib_figure(
85
+ title="Subset Per-Class Accuracy",
86
+ series="Class Accuracy",
87
+ figure=acc_fig
88
+ )
89
+
90
 
91
  # --------- Complete -----------------
92
  print("\n------ Testing Complete ------")