functionNormally Claude Sonnet 4.6 commited on
Commit
f14a2ff
·
1 Parent(s): 81c6237

Remplacer ResNet18 par un CNN simple configurable

Browse files

- model.py : nouvelle classe SimpleCNN (blocs Conv→BN→ReLU→MaxPool,
pooling global adaptatif, classifieur FC)
- train_utils.py : paramètres num_conv_blocks, base_filters, kernel_size,
use_batchnorm ; lr par défaut 0.001, batch 32
- app.py : interface mise à jour avec les nouveaux contrôles CNN

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (3) hide show
  1. app.py +45 -21
  2. model.py +30 -35
  3. train_utils.py +23 -11
app.py CHANGED
@@ -53,24 +53,30 @@ def refresh_gallery_callback(split_name, class_name, max_images):
53
 
54
  @spaces.GPU(duration=300)
55
  def train_callback(
 
 
 
 
56
  dropout,
57
  fc_dim,
58
  learning_rate,
59
  weight_decay,
60
  batch_size,
61
  epochs,
62
- fine_tune_mode,
63
  model_tag,
64
  ):
65
  try:
66
  result = train_model(
 
 
 
 
67
  dropout=float(dropout),
68
  fc_dim=int(fc_dim),
69
  learning_rate=float(learning_rate),
70
  weight_decay=float(weight_decay),
71
  batch_size=int(batch_size),
72
  epochs=int(epochs),
73
- fine_tune_mode=str(fine_tune_mode),
74
  model_tag=model_tag,
75
  )
76
 
@@ -199,14 +205,40 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
199
  )
200
 
201
  with gr.Tab("2. Entraîner un modèle"):
202
- gr.Markdown("## Entraînement avec ResNet18 pré-entraîné")
203
  gr.Markdown(
204
- "Paramètres par défaut recommandés : fine-tuning de la dernière couche convolutionnelle "
205
- "du ResNet18, faible taux d’apprentissage, augmentation légère des données."
206
  )
207
 
208
  with gr.Row():
209
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  dropout = gr.Slider(
211
  minimum=0.0,
212
  maximum=0.8,
@@ -218,11 +250,11 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
218
  fc_dim = gr.Dropdown(
219
  choices=[64, 128, 256, 512],
220
  value=256,
221
- label="Dimension de la couche cachée",
222
  )
223
 
224
  learning_rate = gr.Number(
225
- value=0.00001,
226
  label="Taux d’apprentissage",
227
  )
228
 
@@ -233,7 +265,7 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
233
 
234
  batch_size = gr.Dropdown(
235
  choices=[8, 16, 32, 64],
236
- value=16,
237
  label="Taille du batch",
238
  )
239
 
@@ -245,20 +277,9 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
245
  label="Nombre d’époques",
246
  )
247
 
248
- fine_tune_mode = gr.Dropdown(
249
- choices=["frozen", "layer4", "full"],
250
- value="layer4",
251
- label="Mode de fine-tuning",
252
- info=(
253
- "frozen = seul le classifieur est entraîné ; "
254
- "layer4 = dernière partie du ResNet18 + classifieur ; "
255
- "full = tout le réseau est ajusté."
256
- ),
257
- )
258
-
259
  model_tag = gr.Textbox(
260
  label="Nom court du modèle",
261
- placeholder="ex. charbon_resnet18_layer4",
262
  )
263
 
264
  train_btn = gr.Button("Lancer l’entraînement", variant="primary")
@@ -360,13 +381,16 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
360
  train_btn.click(
361
  fn=train_callback,
362
  inputs=[
 
 
 
 
363
  dropout,
364
  fc_dim,
365
  learning_rate,
366
  weight_decay,
367
  batch_size,
368
  epochs,
369
- fine_tune_mode,
370
  model_tag,
371
  ],
372
  outputs=[
 
53
 
54
  @spaces.GPU(duration=300)
55
  def train_callback(
56
+ num_conv_blocks,
57
+ base_filters,
58
+ kernel_size,
59
+ use_batchnorm,
60
  dropout,
61
  fc_dim,
62
  learning_rate,
63
  weight_decay,
64
  batch_size,
65
  epochs,
 
66
  model_tag,
67
  ):
68
  try:
69
  result = train_model(
70
+ num_conv_blocks=int(num_conv_blocks),
71
+ base_filters=int(base_filters),
72
+ kernel_size=int(kernel_size),
73
+ use_batchnorm=bool(use_batchnorm),
74
  dropout=float(dropout),
75
  fc_dim=int(fc_dim),
76
  learning_rate=float(learning_rate),
77
  weight_decay=float(weight_decay),
78
  batch_size=int(batch_size),
79
  epochs=int(epochs),
 
80
  model_tag=model_tag,
81
  )
82
 
 
205
  )
206
 
207
  with gr.Tab("2. Entraîner un modèle"):
208
+ gr.Markdown("## Entraînement d’un CNN simple (entraîné de zéro)")
209
  gr.Markdown(
210
+ "Configurez librement l’architecture du CNN : nombre de blocs convolutionnels, "
211
+ "nombre de filtres, taille du noyau, etc. Tous les paramètres sont entraînables."
212
  )
213
 
214
  with gr.Row():
215
  with gr.Column():
216
+ num_conv_blocks = gr.Slider(
217
+ minimum=2,
218
+ maximum=5,
219
+ value=3,
220
+ step=1,
221
+ label="Nombre de blocs convolutionnels",
222
+ info="Chaque bloc enchaîne Conv2d → (BN) → ReLU → MaxPool2d.",
223
+ )
224
+
225
+ base_filters = gr.Dropdown(
226
+ choices=[16, 32, 64, 128],
227
+ value=32,
228
+ label="Filtres du premier bloc (doublent à chaque bloc)",
229
+ )
230
+
231
+ kernel_size = gr.Dropdown(
232
+ choices=[3, 5],
233
+ value=3,
234
+ label="Taille du noyau de convolution",
235
+ )
236
+
237
+ use_batchnorm = gr.Checkbox(
238
+ value=True,
239
+ label="Normalisation par lots (BatchNorm)",
240
+ )
241
+
242
  dropout = gr.Slider(
243
  minimum=0.0,
244
  maximum=0.8,
 
250
  fc_dim = gr.Dropdown(
251
  choices=[64, 128, 256, 512],
252
  value=256,
253
+ label="Dimension de la couche cachée (classifieur)",
254
  )
255
 
256
  learning_rate = gr.Number(
257
+ value=0.001,
258
  label="Taux d’apprentissage",
259
  )
260
 
 
265
 
266
  batch_size = gr.Dropdown(
267
  choices=[8, 16, 32, 64],
268
+ value=32,
269
  label="Taille du batch",
270
  )
271
 
 
277
  label="Nombre d’époques",
278
  )
279
 
 
 
 
 
 
 
 
 
 
 
 
280
  model_tag = gr.Textbox(
281
  label="Nom court du modèle",
282
+ placeholder="ex. cnn_3blocs_32filtres",
283
  )
284
 
285
  train_btn = gr.Button("Lancer l’entraînement", variant="primary")
 
381
  train_btn.click(
382
  fn=train_callback,
383
  inputs=[
384
+ num_conv_blocks,
385
+ base_filters,
386
+ kernel_size,
387
+ use_batchnorm,
388
  dropout,
389
  fc_dim,
390
  learning_rate,
391
  weight_decay,
392
  batch_size,
393
  epochs,
 
394
  model_tag,
395
  ],
396
  outputs=[
model.py CHANGED
@@ -1,52 +1,47 @@
1
  import torch.nn as nn
2
- from torchvision import models
3
 
4
 
5
- class ResNet18Classifier(nn.Module):
6
  def __init__(
7
  self,
8
  num_classes: int,
 
 
 
 
9
  dropout: float = 0.4,
10
  fc_dim: int = 256,
11
- fine_tune_mode: str = "layer4",
12
  ):
13
  super().__init__()
14
 
15
- weights = models.ResNet18_Weights.DEFAULT
16
- self.backbone = models.resnet18(weights=weights)
17
-
18
- in_features = self.backbone.fc.in_features
19
-
20
- # Freeze everything first
21
- for param in self.backbone.parameters():
22
- param.requires_grad = False
23
-
24
- # Fine-tuning strategy
25
- if fine_tune_mode == "frozen":
26
- pass
27
-
28
- elif fine_tune_mode == "layer4":
29
- for param in self.backbone.layer4.parameters():
30
- param.requires_grad = True
31
-
32
- elif fine_tune_mode == "full":
33
- for param in self.backbone.parameters():
34
- param.requires_grad = True
35
-
36
- else:
37
- raise ValueError(f"Unsupported fine_tune_mode: {fine_tune_mode}")
38
-
39
- self.backbone.fc = nn.Sequential(
40
  nn.Dropout(dropout),
41
- nn.Linear(in_features, fc_dim),
42
- nn.ReLU(),
43
  nn.Dropout(dropout),
44
  nn.Linear(fc_dim, num_classes),
45
  )
46
 
47
- # Always train classifier head
48
- for param in self.backbone.fc.parameters():
49
- param.requires_grad = True
50
-
51
  def forward(self, x):
52
- return self.backbone(x)
 
 
 
 
1
  import torch.nn as nn
 
2
 
3
 
4
+ class SimpleCNN(nn.Module):
5
  def __init__(
6
  self,
7
  num_classes: int,
8
+ num_conv_blocks: int = 3,
9
+ base_filters: int = 32,
10
+ kernel_size: int = 3,
11
+ use_batchnorm: bool = True,
12
  dropout: float = 0.4,
13
  fc_dim: int = 256,
 
14
  ):
15
  super().__init__()
16
 
17
+ padding = kernel_size // 2
18
+ layers = []
19
+ in_channels = 3
20
+
21
+ for i in range(num_conv_blocks):
22
+ # Les filtres doublent à chaque bloc, plafonnés à 512
23
+ out_channels = min(base_filters * (2 ** i), 512)
24
+ layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding))
25
+ if use_batchnorm:
26
+ layers.append(nn.BatchNorm2d(out_channels))
27
+ layers.append(nn.ReLU(inplace=True))
28
+ layers.append(nn.MaxPool2d(2, 2))
29
+ in_channels = out_channels
30
+
31
+ self.features = nn.Sequential(*layers)
32
+ # Pooling global : indépendant de la taille spatiale d'entrée
33
+ self.pool = nn.AdaptiveAvgPool2d(1)
34
+
35
+ self.classifier = nn.Sequential(
 
 
 
 
 
 
36
  nn.Dropout(dropout),
37
+ nn.Linear(in_channels, fc_dim),
38
+ nn.ReLU(inplace=True),
39
  nn.Dropout(dropout),
40
  nn.Linear(fc_dim, num_classes),
41
  )
42
 
 
 
 
 
43
  def forward(self, x):
44
+ x = self.features(x)
45
+ x = self.pool(x)
46
+ x = x.flatten(1)
47
+ return self.classifier(x)
train_utils.py CHANGED
@@ -11,7 +11,7 @@ import torch.optim as optim
11
  from config import MODEL_DIR, META_DIR, DATASET_DISPLAY_NAME
12
  from data_utils import make_loaders
13
  from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure
14
- from model import ResNet18Classifier
15
 
16
 
17
  def model_weight_path(model_name: str) -> str:
@@ -64,11 +64,14 @@ def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
64
 
65
  cfg = meta["config"]
66
 
67
- model = ResNet18Classifier(
68
  num_classes=cfg["num_classes"],
 
 
 
 
69
  dropout=cfg.get("dropout", 0.4),
70
  fc_dim=cfg.get("fc_dim", 256),
71
- fine_tune_mode=cfg.get("fine_tune_mode", "layer4"),
72
  )
73
 
74
  state_dict = torch.load(weight_file, map_location="cpu")
@@ -125,13 +128,16 @@ def collect_predictions(model, loader, device):
125
 
126
 
127
  def train_model(
 
 
 
 
128
  dropout: float = 0.4,
129
  fc_dim: int = 256,
130
- learning_rate: float = 0.00005,
131
  weight_decay: float = 0.0001,
132
- batch_size: int = 16,
133
  epochs: int = 30,
134
- fine_tune_mode: str = "layer4",
135
  model_tag: str = "",
136
  ):
137
  device = get_runtime_device()
@@ -139,11 +145,14 @@ def train_model(
139
  train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
140
  num_classes = len(class_names)
141
 
142
- model = ResNet18Classifier(
143
  num_classes=num_classes,
 
 
 
 
144
  dropout=dropout,
145
  fc_dim=fc_dim,
146
- fine_tune_mode=fine_tune_mode,
147
  ).to(device)
148
 
149
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
@@ -237,16 +246,19 @@ def train_model(
237
 
238
  config = {
239
  "dataset_name": DATASET_DISPLAY_NAME,
240
- "architecture": "ResNet18 pretrained + classifier head",
241
  "num_classes": num_classes,
242
  "class_names": class_names,
 
 
 
 
243
  "dropout": dropout,
244
  "fc_dim": fc_dim,
245
  "learning_rate": learning_rate,
246
  "weight_decay": weight_decay,
247
  "batch_size": batch_size,
248
  "epochs": epochs,
249
- "fine_tune_mode": fine_tune_mode,
250
  }
251
 
252
  training_summary = {
@@ -271,7 +283,7 @@ def train_model(
271
  logs.append("Entraînement terminé.")
272
  logs.append(f"Modèle sauvegardé : {model_name}")
273
  logs.append(f"Appareil utilisé : {device}")
274
- logs.append(f"Mode de fine-tuning : {fine_tune_mode}")
275
  logs.append(f"Nombre total de paramètres : {total_params}")
276
  logs.append(f"Paramètres entraînables : {trainable_params}")
277
  logs.append(f"Perte test cross-entropy : {test_loss:.4f}")
 
11
  from config import MODEL_DIR, META_DIR, DATASET_DISPLAY_NAME
12
  from data_utils import make_loaders
13
  from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure
14
+ from model import SimpleCNN
15
 
16
 
17
  def model_weight_path(model_name: str) -> str:
 
64
 
65
  cfg = meta["config"]
66
 
67
+ model = SimpleCNN(
68
  num_classes=cfg["num_classes"],
69
+ num_conv_blocks=cfg.get("num_conv_blocks", 3),
70
+ base_filters=cfg.get("base_filters", 32),
71
+ kernel_size=cfg.get("kernel_size", 3),
72
+ use_batchnorm=cfg.get("use_batchnorm", True),
73
  dropout=cfg.get("dropout", 0.4),
74
  fc_dim=cfg.get("fc_dim", 256),
 
75
  )
76
 
77
  state_dict = torch.load(weight_file, map_location="cpu")
 
128
 
129
 
130
  def train_model(
131
+ num_conv_blocks: int = 3,
132
+ base_filters: int = 32,
133
+ kernel_size: int = 3,
134
+ use_batchnorm: bool = True,
135
  dropout: float = 0.4,
136
  fc_dim: int = 256,
137
+ learning_rate: float = 0.001,
138
  weight_decay: float = 0.0001,
139
+ batch_size: int = 32,
140
  epochs: int = 30,
 
141
  model_tag: str = "",
142
  ):
143
  device = get_runtime_device()
 
145
  train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
146
  num_classes = len(class_names)
147
 
148
+ model = SimpleCNN(
149
  num_classes=num_classes,
150
+ num_conv_blocks=num_conv_blocks,
151
+ base_filters=base_filters,
152
+ kernel_size=kernel_size,
153
+ use_batchnorm=use_batchnorm,
154
  dropout=dropout,
155
  fc_dim=fc_dim,
 
156
  ).to(device)
157
 
158
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
 
246
 
247
  config = {
248
  "dataset_name": DATASET_DISPLAY_NAME,
249
+ "architecture": "CNN simple entraîné de zéro",
250
  "num_classes": num_classes,
251
  "class_names": class_names,
252
+ "num_conv_blocks": num_conv_blocks,
253
+ "base_filters": base_filters,
254
+ "kernel_size": kernel_size,
255
+ "use_batchnorm": use_batchnorm,
256
  "dropout": dropout,
257
  "fc_dim": fc_dim,
258
  "learning_rate": learning_rate,
259
  "weight_decay": weight_decay,
260
  "batch_size": batch_size,
261
  "epochs": epochs,
 
262
  }
263
 
264
  training_summary = {
 
283
  logs.append("Entraînement terminé.")
284
  logs.append(f"Modèle sauvegardé : {model_name}")
285
  logs.append(f"Appareil utilisé : {device}")
286
+ logs.append(f"Architecture : {num_conv_blocks} blocs conv, filtres de base={base_filters}, noyau={kernel_size}x{kernel_size}, BatchNorm={use_batchnorm}")
287
  logs.append(f"Nombre total de paramètres : {total_params}")
288
  logs.append(f"Paramètres entraînables : {trainable_params}")
289
  logs.append(f"Perte test cross-entropy : {test_loss:.4f}")