CircleStar commited on
Commit
a5f8c02
·
verified ·
1 Parent(s): 586661b

Create train_utils.py

Browse files
Files changed (1) hide show
  1. train_utils.py +216 -0
train_utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ from datetime import datetime
5
+ from typing import List, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+
11
+ from config import MODEL_DIR, META_DIR
12
+ from model import SimpleCNN
13
+ from data_utils import make_loaders
14
+
15
+
16
+ def model_weight_path(model_name: str) -> str:
17
+ return os.path.join(MODEL_DIR, f"{model_name}.pt")
18
+
19
+
20
+ def model_meta_path(model_name: str) -> str:
21
+ return os.path.join(META_DIR, f"{model_name}.json")
22
+
23
+
24
+ def list_saved_models() -> List[str]:
25
+ names = []
26
+ for fn in os.listdir(META_DIR):
27
+ if fn.endswith(".json"):
28
+ names.append(fn[:-5])
29
+ return sorted(names, reverse=True)
30
+
31
+
32
+ def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
33
+ cpu_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
34
+ torch.save(cpu_state_dict, model_weight_path(model_name))
35
+
36
+ payload = {
37
+ "model_name": model_name,
38
+ "config": config,
39
+ "training_summary": training_summary,
40
+ "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
41
+ }
42
+
43
+ with open(model_meta_path(model_name), "w", encoding="utf-8") as f:
44
+ json.dump(payload, f, indent=2, ensure_ascii=False)
45
+
46
+
47
+ def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
48
+ meta_file = model_meta_path(model_name)
49
+ weight_file = model_weight_path(model_name)
50
+
51
+ if not os.path.exists(meta_file):
52
+ raise FileNotFoundError(f"Metadata not found for model: {model_name}")
53
+ if not os.path.exists(weight_file):
54
+ raise FileNotFoundError(f"Weights not found for model: {model_name}")
55
+
56
+ with open(meta_file, "r", encoding="utf-8") as f:
57
+ meta = json.load(f)
58
+
59
+ cfg = meta["config"]
60
+
61
+ model = SimpleCNN(
62
+ num_classes=cfg["num_classes"],
63
+ conv1_channels=cfg["conv1_channels"],
64
+ conv2_channels=cfg["conv2_channels"],
65
+ kernel_size=cfg["kernel_size"],
66
+ dropout=cfg["dropout"],
67
+ fc_dim=cfg["fc_dim"],
68
+ )
69
+
70
+ state_dict = torch.load(weight_file, map_location="cpu")
71
+ model.load_state_dict(state_dict)
72
+ model.to(device)
73
+ model.eval()
74
+
75
+ return model, meta
76
+
77
+
78
+ def get_runtime_device() -> torch.device:
79
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
+
81
+
82
+ def evaluate(model, loader, criterion, device):
83
+ model.eval()
84
+ total_loss = 0.0
85
+ total = 0
86
+ correct = 0
87
+
88
+ with torch.no_grad():
89
+ for images, labels in loader:
90
+ images, labels = images.to(device), labels.to(device)
91
+
92
+ outputs = model(images)
93
+ loss = criterion(outputs, labels)
94
+
95
+ total_loss += loss.item() * images.size(0)
96
+ preds = outputs.argmax(dim=1)
97
+ correct += (preds == labels).sum().item()
98
+ total += labels.size(0)
99
+
100
+ return total_loss / total if total else 0.0, correct / total if total else 0.0
101
+
102
+
103
+ def train_model(
104
+ conv1_channels: int,
105
+ conv2_channels: int,
106
+ kernel_size: int,
107
+ dropout: float,
108
+ fc_dim: int,
109
+ learning_rate: float,
110
+ batch_size: int,
111
+ epochs: int,
112
+ model_tag: str,
113
+ ):
114
+ device = get_runtime_device()
115
+
116
+ train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
117
+ num_classes = len(class_names)
118
+
119
+ model = SimpleCNN(
120
+ num_classes=num_classes,
121
+ conv1_channels=conv1_channels,
122
+ conv2_channels=conv2_channels,
123
+ kernel_size=kernel_size,
124
+ dropout=dropout,
125
+ fc_dim=fc_dim,
126
+ ).to(device)
127
+
128
+ criterion = nn.CrossEntropyLoss()
129
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
130
+
131
+ history = []
132
+ logs = []
133
+ start_time = time.time()
134
+
135
+ for epoch in range(1, epochs + 1):
136
+ model.train()
137
+ running_loss = 0.0
138
+ total = 0
139
+ correct = 0
140
+
141
+ for images, labels in train_loader:
142
+ images, labels = images.to(device), labels.to(device)
143
+
144
+ optimizer.zero_grad()
145
+ outputs = model(images)
146
+ loss = criterion(outputs, labels)
147
+ loss.backward()
148
+ optimizer.step()
149
+
150
+ running_loss += loss.item() * images.size(0)
151
+ preds = outputs.argmax(dim=1)
152
+ correct += (preds == labels).sum().item()
153
+ total += labels.size(0)
154
+
155
+ train_loss = running_loss / total if total else 0.0
156
+ train_acc = correct / total if total else 0.0
157
+ val_loss, val_acc = evaluate(model, val_loader, criterion, device)
158
+
159
+ row = {
160
+ "epoch": epoch,
161
+ "train_loss": round(train_loss, 4),
162
+ "train_acc": round(train_acc, 4),
163
+ "val_loss": round(val_loss, 4),
164
+ "val_acc": round(val_acc, 4),
165
+ }
166
+ history.append(row)
167
+
168
+ logs.append(
169
+ f"Époque {epoch}/{epochs} | "
170
+ f"perte entraînement={train_loss:.4f}, précision entraînement={train_acc:.4f}, "
171
+ f"perte validation={val_loss:.4f}, précision validation={val_acc:.4f}"
172
+ )
173
+
174
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device)
175
+ elapsed = time.time() - start_time
176
+
177
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
178
+ safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "charcoal"
179
+ model_name = f"{safe_tag}_{timestamp}"
180
+
181
+ config = {
182
+ "dataset_name": "Charbons de bois microscopiques",
183
+ "num_classes": num_classes,
184
+ "class_names": class_names,
185
+ "conv1_channels": conv1_channels,
186
+ "conv2_channels": conv2_channels,
187
+ "kernel_size": kernel_size,
188
+ "dropout": dropout,
189
+ "fc_dim": fc_dim,
190
+ "learning_rate": learning_rate,
191
+ "batch_size": batch_size,
192
+ "epochs": epochs,
193
+ }
194
+
195
+ training_summary = {
196
+ "final_train_loss": history[-1]["train_loss"] if history else None,
197
+ "final_train_acc": history[-1]["train_acc"] if history else None,
198
+ "final_val_loss": history[-1]["val_loss"] if history else None,
199
+ "final_val_acc": history[-1]["val_acc"] if history else None,
200
+ "test_loss": round(test_loss, 4),
201
+ "test_acc": round(test_acc, 4),
202
+ "elapsed_seconds": round(elapsed, 2),
203
+ "device": str(device),
204
+ }
205
+
206
+ save_model(model, model_name, config, training_summary)
207
+
208
+ logs.append("")
209
+ logs.append("Entraînement terminé.")
210
+ logs.append(f"Modèle sauvegardé : {model_name}")
211
+ logs.append(f"Appareil utilisé : {device}")
212
+ logs.append(f"Perte test : {test_loss:.4f}")
213
+ logs.append(f"Précision test : {test_acc:.4f}")
214
+ logs.append(f"Temps écoulé : {elapsed:.1f}s")
215
+
216
+ return "\n".join(logs), history, training_summary, model_name