CircleStar commited on
Commit
14b719f
·
verified ·
1 Parent(s): 63e305e

Update train_utils.py

Browse files
Files changed (1) hide show
  1. train_utils.py +130 -35
train_utils.py CHANGED
@@ -8,9 +8,10 @@ import torch
8
  import torch.nn as nn
9
  import torch.optim as optim
10
 
11
- from config import MODEL_DIR, META_DIR
12
- from model import SimpleCNN
13
  from data_utils import make_loaders
 
 
14
 
15
 
16
  def model_weight_path(model_name: str) -> str:
@@ -29,6 +30,10 @@ def list_saved_models() -> List[str]:
29
  return sorted(names, reverse=True)
30
 
31
 
 
 
 
 
32
  def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
33
  cpu_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
34
  torch.save(cpu_state_dict, model_weight_path(model_name))
@@ -49,22 +54,20 @@ def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
49
  weight_file = model_weight_path(model_name)
50
 
51
  if not os.path.exists(meta_file):
52
- raise FileNotFoundError(f"Metadata not found for model: {model_name}")
53
  if not os.path.exists(weight_file):
54
- raise FileNotFoundError(f"Weights not found for model: {model_name}")
55
 
56
  with open(meta_file, "r", encoding="utf-8") as f:
57
  meta = json.load(f)
58
 
59
  cfg = meta["config"]
60
 
61
- model = SimpleCNN(
62
  num_classes=cfg["num_classes"],
63
- conv1_channels=cfg["conv1_channels"],
64
- conv2_channels=cfg["conv2_channels"],
65
- kernel_size=cfg["kernel_size"],
66
  dropout=cfg["dropout"],
67
  fc_dim=cfg["fc_dim"],
 
68
  )
69
 
70
  state_dict = torch.load(weight_file, map_location="cpu")
@@ -75,12 +78,9 @@ def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
75
  return model, meta
76
 
77
 
78
- def get_runtime_device() -> torch.device:
79
- return torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
-
81
-
82
- def evaluate(model, loader, criterion, device):
83
  model.eval()
 
84
  total_loss = 0.0
85
  total = 0
86
  correct = 0
@@ -94,21 +94,42 @@ def evaluate(model, loader, criterion, device):
94
 
95
  total_loss += loss.item() * images.size(0)
96
  preds = outputs.argmax(dim=1)
 
97
  correct += (preds == labels).sum().item()
98
  total += labels.size(0)
99
 
100
- return total_loss / total if total else 0.0, correct / total if total else 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
 
103
  def train_model(
104
- conv1_channels: int,
105
- conv2_channels: int,
106
- kernel_size: int,
107
  dropout: float,
108
  fc_dim: int,
109
  learning_rate: float,
 
110
  batch_size: int,
111
  epochs: int,
 
112
  model_tag: str,
113
  ):
114
  device = get_runtime_device()
@@ -116,24 +137,33 @@ def train_model(
116
  train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
117
  num_classes = len(class_names)
118
 
119
- model = SimpleCNN(
120
  num_classes=num_classes,
121
- conv1_channels=conv1_channels,
122
- conv2_channels=conv2_channels,
123
- kernel_size=kernel_size,
124
  dropout=dropout,
125
  fc_dim=fc_dim,
 
126
  ).to(device)
127
 
 
 
 
128
  criterion = nn.CrossEntropyLoss()
129
- optimizer = optim.Adam(model.parameters(), lr=learning_rate)
 
 
 
 
130
 
131
  history = []
132
  logs = []
133
  start_time = time.time()
134
 
 
 
 
135
  for epoch in range(1, epochs + 1):
136
  model.train()
 
137
  running_loss = 0.0
138
  total = 0
139
  correct = 0
@@ -143,18 +173,28 @@ def train_model(
143
 
144
  optimizer.zero_grad()
145
  outputs = model(images)
 
146
  loss = criterion(outputs, labels)
147
  loss.backward()
148
  optimizer.step()
149
 
150
  running_loss += loss.item() * images.size(0)
 
151
  preds = outputs.argmax(dim=1)
152
  correct += (preds == labels).sum().item()
153
  total += labels.size(0)
154
 
155
  train_loss = running_loss / total if total else 0.0
156
  train_acc = correct / total if total else 0.0
157
- val_loss, val_acc = evaluate(model, val_loader, criterion, device)
 
 
 
 
 
 
 
 
158
 
159
  row = {
160
  "epoch": epoch,
@@ -163,6 +203,7 @@ def train_model(
163
  "val_loss": round(val_loss, 4),
164
  "val_acc": round(val_acc, 4),
165
  }
 
166
  history.append(row)
167
 
168
  logs.append(
@@ -171,36 +212,49 @@ def train_model(
171
  f"perte validation={val_loss:.4f}, précision validation={val_acc:.4f}"
172
  )
173
 
174
- test_loss, test_acc = evaluate(model, test_loader, criterion, device)
 
 
 
 
 
 
 
175
  elapsed = time.time() - start_time
176
 
177
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
178
- safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "charcoal"
179
  model_name = f"{safe_tag}_{timestamp}"
180
 
 
 
181
  config = {
182
- "dataset_name": "Charbons de bois microscopiques",
 
183
  "num_classes": num_classes,
184
  "class_names": class_names,
185
- "conv1_channels": conv1_channels,
186
- "conv2_channels": conv2_channels,
187
- "kernel_size": kernel_size,
188
  "dropout": dropout,
189
  "fc_dim": fc_dim,
190
  "learning_rate": learning_rate,
 
191
  "batch_size": batch_size,
192
  "epochs": epochs,
 
193
  }
194
 
195
  training_summary = {
196
  "final_train_loss": history[-1]["train_loss"] if history else None,
197
  "final_train_acc": history[-1]["train_acc"] if history else None,
198
- "final_val_loss": history[-1]["val_loss"] if history else None,
199
  "final_val_acc": history[-1]["val_acc"] if history else None,
200
- "test_loss": round(test_loss, 4),
201
- "test_acc": round(test_acc, 4),
 
 
202
  "elapsed_seconds": round(elapsed, 2),
203
  "device": str(device),
 
 
204
  }
205
 
206
  save_model(model, model_name, config, training_summary)
@@ -209,8 +263,49 @@ def train_model(
209
  logs.append("Entraînement terminé.")
210
  logs.append(f"Modèle sauvegardé : {model_name}")
211
  logs.append(f"Appareil utilisé : {device}")
212
- logs.append(f"Perte test : {test_loss:.4f}")
213
- logs.append(f"Précision test : {test_acc:.4f}")
 
 
 
 
214
  logs.append(f"Temps écoulé : {elapsed:.1f}s")
215
 
216
- return "\n".join(logs), history, training_summary, model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import torch.nn as nn
9
  import torch.optim as optim
10
 
11
+ from config import MODEL_DIR, META_DIR, DATASET_DISPLAY_NAME
 
12
  from data_utils import make_loaders
13
+ from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure
14
+ from model import ResNet18Classifier
15
 
16
 
17
  def model_weight_path(model_name: str) -> str:
 
30
  return sorted(names, reverse=True)
31
 
32
 
33
+ def get_runtime_device() -> torch.device:
34
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+
37
  def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
38
  cpu_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
39
  torch.save(cpu_state_dict, model_weight_path(model_name))
 
54
  weight_file = model_weight_path(model_name)
55
 
56
  if not os.path.exists(meta_file):
57
+ raise FileNotFoundError(f"Métadonnées introuvables pour le modèle : {model_name}")
58
  if not os.path.exists(weight_file):
59
+ raise FileNotFoundError(f"Poids introuvables pour le modèle : {model_name}")
60
 
61
  with open(meta_file, "r", encoding="utf-8") as f:
62
  meta = json.load(f)
63
 
64
  cfg = meta["config"]
65
 
66
+ model = ResNet18Classifier(
67
  num_classes=cfg["num_classes"],
 
 
 
68
  dropout=cfg["dropout"],
69
  fc_dim=cfg["fc_dim"],
70
+ freeze_backbone=cfg.get("freeze_backbone", True),
71
  )
72
 
73
  state_dict = torch.load(weight_file, map_location="cpu")
 
78
  return model, meta
79
 
80
 
81
+ def evaluate_loss_acc(model, loader, criterion, device):
 
 
 
 
82
  model.eval()
83
+
84
  total_loss = 0.0
85
  total = 0
86
  correct = 0
 
94
 
95
  total_loss += loss.item() * images.size(0)
96
  preds = outputs.argmax(dim=1)
97
+
98
  correct += (preds == labels).sum().item()
99
  total += labels.size(0)
100
 
101
+ avg_loss = total_loss / total if total else 0.0
102
+ acc = correct / total if total else 0.0
103
+
104
+ return avg_loss, acc
105
+
106
+
107
+ def collect_predictions(model, loader, device):
108
+ model.eval()
109
+
110
+ y_true = []
111
+ y_pred = []
112
+
113
+ with torch.no_grad():
114
+ for images, labels in loader:
115
+ images = images.to(device)
116
+ outputs = model(images)
117
+ preds = outputs.argmax(dim=1).detach().cpu().tolist()
118
+
119
+ y_pred.extend(preds)
120
+ y_true.extend(labels.tolist())
121
+
122
+ return y_true, y_pred
123
 
124
 
125
  def train_model(
 
 
 
126
  dropout: float,
127
  fc_dim: int,
128
  learning_rate: float,
129
+ weight_decay: float,
130
  batch_size: int,
131
  epochs: int,
132
+ freeze_backbone: bool,
133
  model_tag: str,
134
  ):
135
  device = get_runtime_device()
 
137
  train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
138
  num_classes = len(class_names)
139
 
140
+ model = ResNet18Classifier(
141
  num_classes=num_classes,
 
 
 
142
  dropout=dropout,
143
  fc_dim=fc_dim,
144
+ freeze_backbone=freeze_backbone,
145
  ).to(device)
146
 
147
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
148
+ total_params = sum(p.numel() for p in model.parameters())
149
+
150
  criterion = nn.CrossEntropyLoss()
151
+ optimizer = optim.AdamW(
152
+ filter(lambda p: p.requires_grad, model.parameters()),
153
+ lr=learning_rate,
154
+ weight_decay=weight_decay,
155
+ )
156
 
157
  history = []
158
  logs = []
159
  start_time = time.time()
160
 
161
+ best_val_loss = float("inf")
162
+ best_state_dict = None
163
+
164
  for epoch in range(1, epochs + 1):
165
  model.train()
166
+
167
  running_loss = 0.0
168
  total = 0
169
  correct = 0
 
173
 
174
  optimizer.zero_grad()
175
  outputs = model(images)
176
+
177
  loss = criterion(outputs, labels)
178
  loss.backward()
179
  optimizer.step()
180
 
181
  running_loss += loss.item() * images.size(0)
182
+
183
  preds = outputs.argmax(dim=1)
184
  correct += (preds == labels).sum().item()
185
  total += labels.size(0)
186
 
187
  train_loss = running_loss / total if total else 0.0
188
  train_acc = correct / total if total else 0.0
189
+
190
+ val_loss, val_acc = evaluate_loss_acc(model, val_loader, criterion, device)
191
+
192
+ if val_loss < best_val_loss:
193
+ best_val_loss = val_loss
194
+ best_state_dict = {
195
+ k: v.detach().cpu().clone()
196
+ for k, v in model.state_dict().items()
197
+ }
198
 
199
  row = {
200
  "epoch": epoch,
 
203
  "val_loss": round(val_loss, 4),
204
  "val_acc": round(val_acc, 4),
205
  }
206
+
207
  history.append(row)
208
 
209
  logs.append(
 
212
  f"perte validation={val_loss:.4f}, précision validation={val_acc:.4f}"
213
  )
214
 
215
+ if best_state_dict is not None:
216
+ model.load_state_dict(best_state_dict)
217
+
218
+ test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
219
+ y_true, y_pred = collect_predictions(model, test_loader, device)
220
+
221
+ metrics = compute_classification_metrics(y_true, y_pred, class_names)
222
+
223
  elapsed = time.time() - start_time
224
 
225
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
226
+ safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "charcoal_resnet18"
227
  model_name = f"{safe_tag}_{timestamp}"
228
 
229
+ cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
230
+
231
  config = {
232
+ "dataset_name": DATASET_DISPLAY_NAME,
233
+ "architecture": "ResNet18 pretrained + classifier head",
234
  "num_classes": num_classes,
235
  "class_names": class_names,
 
 
 
236
  "dropout": dropout,
237
  "fc_dim": fc_dim,
238
  "learning_rate": learning_rate,
239
+ "weight_decay": weight_decay,
240
  "batch_size": batch_size,
241
  "epochs": epochs,
242
+ "freeze_backbone": freeze_backbone,
243
  }
244
 
245
  training_summary = {
246
  "final_train_loss": history[-1]["train_loss"] if history else None,
247
  "final_train_acc": history[-1]["train_acc"] if history else None,
248
+ "best_val_loss": round(best_val_loss, 4),
249
  "final_val_acc": history[-1]["val_acc"] if history else None,
250
+ "test_cross_entropy_loss": round(test_loss, 4),
251
+ "test_accuracy": round(test_acc, 4),
252
+ "test_f1_macro": metrics["f1_macro"],
253
+ "test_f1_weighted": metrics["f1_weighted"],
254
  "elapsed_seconds": round(elapsed, 2),
255
  "device": str(device),
256
+ "total_params": total_params,
257
+ "trainable_params": trainable_params,
258
  }
259
 
260
  save_model(model, model_name, config, training_summary)
 
263
  logs.append("Entraînement terminé.")
264
  logs.append(f"Modèle sauvegardé : {model_name}")
265
  logs.append(f"Appareil utilisé : {device}")
266
+ logs.append(f"Nombre total de paramètres : {total_params}")
267
+ logs.append(f"Paramètres entraînables : {trainable_params}")
268
+ logs.append(f"Perte test cross-entropy : {test_loss:.4f}")
269
+ logs.append(f"Accuracy test : {test_acc:.4f}")
270
+ logs.append(f"F1 macro test : {metrics['f1_macro']:.4f}")
271
+ logs.append(f"F1 pondéré test : {metrics['f1_weighted']:.4f}")
272
  logs.append(f"Temps écoulé : {elapsed:.1f}s")
273
 
274
+ return {
275
+ "logs": "\n".join(logs),
276
+ "history": history,
277
+ "summary": training_summary,
278
+ "model_name": model_name,
279
+ "classification_report": metrics["classification_report"],
280
+ "confusion_matrix": metrics["confusion_matrix"],
281
+ "confusion_matrix_path": cm_path,
282
+ }
283
+
284
+
285
+ def evaluate_saved_model(model_name: str):
286
+ if not model_name:
287
+ raise ValueError("Aucun modèle sélectionné.")
288
+
289
+ device = get_runtime_device()
290
+ model, meta = load_model(model_name, device)
291
+
292
+ batch_size = int(meta["config"].get("batch_size", 32))
293
+ _, _, test_loader, class_names = make_loaders(batch_size)
294
+
295
+ criterion = nn.CrossEntropyLoss()
296
+
297
+ test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
298
+ y_true, y_pred = collect_predictions(model, test_loader, device)
299
+
300
+ metrics = compute_classification_metrics(y_true, y_pred, class_names)
301
+ cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
302
+
303
+ summary = {
304
+ "test_cross_entropy_loss": round(test_loss, 4),
305
+ "test_accuracy": round(test_acc, 4),
306
+ "test_f1_macro": metrics["f1_macro"],
307
+ "test_f1_weighted": metrics["f1_weighted"],
308
+ "device": str(device),
309
+ }
310
+
311
+ return summary, metrics["classification_report"], metrics["confusion_matrix"], cm_path