CircleStar commited on
Commit
551ca15
·
verified ·
1 Parent(s): cfe30ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +258 -54
app.py CHANGED
@@ -1,44 +1,111 @@
1
  import json
2
 
3
- import spaces
4
  import gradio as gr
 
5
 
6
- from train_utils import train_model, list_saved_models, model_meta_path
7
- from predict_utils import predict_uploaded_image, test_random_sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
9
 
10
- @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def train_callback(
12
- conv1_channels,
13
- conv2_channels,
14
- kernel_size,
15
  dropout,
16
  fc_dim,
17
  learning_rate,
 
18
  batch_size,
19
  epochs,
 
20
  model_tag,
21
  ):
22
  try:
23
- logs, history, summary, model_name = train_model(
24
- int(conv1_channels),
25
- int(conv2_channels),
26
- int(kernel_size),
27
- float(dropout),
28
- int(fc_dim),
29
- float(learning_rate),
30
- int(batch_size),
31
- int(epochs),
32
- model_tag,
33
  )
34
 
35
  models = list_saved_models()
36
- selected = model_name if model_name in models else (models[0] if models else None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- return logs, history, summary, gr.update(choices=models, value=selected)
39
 
 
 
 
 
 
40
  except Exception as e:
41
- return f"Échec de l’entraînement :\n{str(e)}", None, None, gr.update()
42
 
43
 
44
  @spaces.GPU(duration=60)
@@ -81,75 +148,182 @@ initial_models = list_saved_models()
81
  with gr.Blocks(title="Classification d’images microscopiques") as demo:
82
  gr.Markdown("# Classification d’images microscopiques de charbons de bois")
83
  gr.Markdown(
84
- "Cette application permet d’entraîner un réseau de neurones convolutif simple "
85
- "sur un jeu de données privé Hugging Face, puis de tester les modèles sauvegardés "
86
- "sur une image importée ou sur un échantillon aléatoire."
87
  )
88
 
89
  with gr.Tabs():
90
- with gr.Tab("Entraîner"):
 
 
 
91
  with gr.Row():
92
- with gr.Column():
93
- gr.Markdown("### Paramètres d’entraînement")
94
 
95
- conv1_channels = gr.Slider(
96
- 8, 64, value=16, step=8, label="Nombre de canaux - couche convolutionnelle 1"
97
- )
98
- conv2_channels = gr.Slider(
99
- 16, 128, value=32, step=16, label="Nombre de canaux - couche convolutionnelle 2"
100
- )
101
- kernel_size = gr.Dropdown(
102
- choices=[3, 5], value=3, label="Taille du noyau"
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  dropout = gr.Slider(
105
- 0.0, 0.7, value=0.2, step=0.05, label="Dropout"
 
 
 
 
106
  )
107
- fc_dim = gr.Slider(
108
- 32, 256, value=128, step=32, label="Dimension de la couche cachée fully-connected"
 
 
109
  )
110
  learning_rate = gr.Number(
111
- value=0.001, label="Taux d’apprentissage"
 
 
 
 
 
112
  )
113
  batch_size = gr.Dropdown(
114
- choices=[16, 32, 64, 128], value=32, label="Taille du batch"
 
 
115
  )
116
  epochs = gr.Slider(
117
- 1, 20, value=5, step=1, label="Nombre d’époques"
 
 
 
 
 
 
 
 
118
  )
119
  model_tag = gr.Textbox(
120
  label="Nom court du modèle",
121
- placeholder="ex. charbon_cnn_test"
122
  )
123
 
124
  train_btn = gr.Button("Lancer l’entraînement", variant="primary")
125
 
126
  with gr.Column():
127
- train_status = gr.Textbox(label="Journal d’entraînement", lines=18)
 
 
 
128
  train_history = gr.JSON(label="Historique d’entraînement")
129
- train_summary = gr.JSON(label="Résumé d’entraînement")
 
 
130
 
131
- with gr.Tab("Tester"):
132
  with gr.Row():
133
- with gr.Column():
134
- gr.Markdown("### Modèle sauvegardé")
 
 
 
 
 
 
 
 
135
 
 
 
 
 
 
 
 
 
 
 
 
136
  model_selector = gr.Dropdown(
137
  choices=initial_models,
138
  value=initial_models[0] if initial_models else None,
139
- label="Sélectionner un modèle",
140
  )
141
  refresh_btn = gr.Button("Actualiser la liste des modèles")
142
  load_info_btn = gr.Button("Afficher les informations du modèle")
143
  model_info = gr.JSON(label="Métadonnées du modèle")
144
 
145
  with gr.Column():
146
- gr.Markdown("### Prédiction sur une image importée")
 
 
 
 
 
 
 
 
 
 
 
147
 
 
 
 
 
 
 
 
 
 
 
148
  upload_image = gr.Image(type="pil", label="Importer une image")
149
  predict_btn = gr.Button("Prédire la classe", variant="primary")
 
150
  predict_text = gr.Textbox(label="Résultat de la prédiction", lines=7)
151
  predict_probs = gr.Label(label="Probabilités par classe")
152
 
 
 
153
  with gr.Row():
154
  random_test_btn = gr.Button("Tester un échantillon aléatoire")
155
 
@@ -158,20 +332,39 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
158
  random_sample_text = gr.Textbox(label="Résultat sur l’échantillon", lines=7)
159
  random_sample_probs = gr.Label(label="Probabilités par classe")
160
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  train_btn.click(
162
  fn=train_callback,
163
  inputs=[
164
- conv1_channels,
165
- conv2_channels,
166
- kernel_size,
167
  dropout,
168
  fc_dim,
169
  learning_rate,
 
170
  batch_size,
171
  epochs,
 
172
  model_tag,
173
  ],
174
- outputs=[train_status, train_history, train_summary, model_selector],
 
 
 
 
 
 
 
 
175
  )
176
 
177
  refresh_btn.click(
@@ -186,6 +379,17 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
186
  outputs=model_info,
187
  )
188
 
 
 
 
 
 
 
 
 
 
 
 
189
  predict_btn.click(
190
  fn=predict_uploaded_image_callback,
191
  inputs=[model_selector, upload_image],
@@ -200,4 +404,4 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
200
 
201
 
202
  if __name__ == "__main__":
203
- demo.launch()
 
1
  import json
2
 
 
3
  import gradio as gr
4
+ import spaces
5
 
6
+ from data_utils import (
7
+ dataset_overview,
8
+ get_class_names,
9
+ get_images_for_gallery,
10
+ )
11
+ from train_utils import (
12
+ train_model,
13
+ list_saved_models,
14
+ model_meta_path,
15
+ evaluate_saved_model,
16
+ )
17
+ from predict_utils import (
18
+ predict_uploaded_image,
19
+ test_random_sample,
20
+ )
21
+
22
+
23
+ def load_dataset_overview_callback():
24
+ try:
25
+ summary, distribution_df = dataset_overview()
26
+ class_names = ["Toutes les classes"] + get_class_names()
27
 
28
+ return (
29
+ summary,
30
+ distribution_df,
31
+ gr.update(choices=class_names, value="Toutes les classes"),
32
+ )
33
 
34
+ except Exception as e:
35
+ return (
36
+ {"Erreur": str(e)},
37
+ None,
38
+ gr.update(),
39
+ )
40
+
41
+
42
+ def refresh_gallery_callback(split_name, class_name, max_images):
43
+ try:
44
+ gallery = get_images_for_gallery(
45
+ split_name=split_name,
46
+ class_name=class_name,
47
+ max_images=int(max_images),
48
+ )
49
+ return gallery
50
+ except Exception as e:
51
+ return [(None, f"Erreur : {str(e)}")]
52
+
53
+
54
+ @spaces.GPU(duration=300)
55
  def train_callback(
 
 
 
56
  dropout,
57
  fc_dim,
58
  learning_rate,
59
+ weight_decay,
60
  batch_size,
61
  epochs,
62
+ freeze_backbone,
63
  model_tag,
64
  ):
65
  try:
66
+ result = train_model(
67
+ dropout=float(dropout),
68
+ fc_dim=int(fc_dim),
69
+ learning_rate=float(learning_rate),
70
+ weight_decay=float(weight_decay),
71
+ batch_size=int(batch_size),
72
+ epochs=int(epochs),
73
+ freeze_backbone=bool(freeze_backbone),
74
+ model_tag=model_tag,
 
75
  )
76
 
77
  models = list_saved_models()
78
+ selected = result["model_name"] if result["model_name"] in models else None
79
+
80
+ return (
81
+ result["logs"],
82
+ result["history"],
83
+ result["summary"],
84
+ result["classification_report"],
85
+ result["confusion_matrix"],
86
+ result["confusion_matrix_path"],
87
+ gr.update(choices=models, value=selected),
88
+ )
89
+
90
+ except Exception as e:
91
+ return (
92
+ f"Échec de l’entraînement :\n{str(e)}",
93
+ None,
94
+ None,
95
+ None,
96
+ None,
97
+ None,
98
+ gr.update(),
99
+ )
100
 
 
101
 
102
+ @spaces.GPU(duration=120)
103
+ def evaluate_saved_model_callback(model_name):
104
+ try:
105
+ summary, report_df, cm_df, cm_path = evaluate_saved_model(model_name)
106
+ return summary, report_df, cm_df, cm_path
107
  except Exception as e:
108
+ return {"Erreur": str(e)}, None, None, None
109
 
110
 
111
  @spaces.GPU(duration=60)
 
148
  with gr.Blocks(title="Classification d’images microscopiques") as demo:
149
  gr.Markdown("# Classification d’images microscopiques de charbons de bois")
150
  gr.Markdown(
151
+ "Application pédagogique pour explorer un jeu de données d’images microscopiques, "
152
+ "entraîner un modèle de classification et analyser ses performances."
 
153
  )
154
 
155
  with gr.Tabs():
156
+
157
+ with gr.Tab("1. Explorer le jeu de données"):
158
+ gr.Markdown("## Comprendre le jeu de données avant l’entraînement")
159
+
160
  with gr.Row():
161
+ load_dataset_btn = gr.Button("Charger les informations du dataset", variant="primary")
 
162
 
163
+ with gr.Row():
164
+ dataset_summary = gr.JSON(label="Résumé général du dataset")
165
+
166
+ with gr.Row():
167
+ class_distribution = gr.Dataframe(
168
+ label="Distribution des images par split et par classe",
169
+ interactive=False,
170
+ )
171
+
172
+ gr.Markdown("## Visualisation des images")
173
+
174
+ with gr.Row():
175
+ split_selector = gr.Dropdown(
176
+ choices=["train", "validation", "test"],
177
+ value="train",
178
+ label="Split",
179
+ )
180
+ class_selector = gr.Dropdown(
181
+ choices=["Toutes les classes"],
182
+ value="Toutes les classes",
183
+ label="Classe",
184
+ )
185
+ max_images = gr.Slider(
186
+ minimum=4,
187
+ maximum=48,
188
+ value=24,
189
+ step=4,
190
+ label="Nombre d’images à afficher",
191
+ )
192
+ refresh_gallery_btn = gr.Button("Afficher des exemples")
193
+
194
+ image_gallery = gr.Gallery(
195
+ label="Exemples d’images",
196
+ columns=4,
197
+ height=600,
198
+ )
199
+
200
+ with gr.Tab("2. Entraîner un modèle"):
201
+ gr.Markdown("## Entraînement avec ResNet18 pré-entraîné")
202
+ gr.Markdown(
203
+ "Le modèle utilise un backbone ResNet18 pré-entraîné sur ImageNet. "
204
+ "Pour limiter le surapprentissage sur un petit dataset, il est recommandé de commencer "
205
+ "avec le backbone gelé."
206
+ )
207
+
208
+ with gr.Row():
209
+ with gr.Column():
210
  dropout = gr.Slider(
211
+ 0.0,
212
+ 0.8,
213
+ value=0.5,
214
+ step=0.05,
215
+ label="Dropout",
216
  )
217
+ fc_dim = gr.Dropdown(
218
+ choices=[64, 128, 256, 512],
219
+ value=256,
220
+ label="Dimension de la couche cachée",
221
  )
222
  learning_rate = gr.Number(
223
+ value=0.0001,
224
+ label="Taux d’apprentissage",
225
+ )
226
+ weight_decay = gr.Number(
227
+ value=0.0001,
228
+ label="Weight decay",
229
  )
230
  batch_size = gr.Dropdown(
231
+ choices=[8, 16, 32, 64],
232
+ value=16,
233
+ label="Taille du batch",
234
  )
235
  epochs = gr.Slider(
236
+ 1,
237
+ 50,
238
+ value=10,
239
+ step=1,
240
+ label="Nombre d’époques",
241
+ )
242
+ freeze_backbone = gr.Checkbox(
243
+ value=True,
244
+ label="Geler le backbone ResNet18",
245
  )
246
  model_tag = gr.Textbox(
247
  label="Nom court du modèle",
248
+ placeholder="ex. charbon_resnet18_test",
249
  )
250
 
251
  train_btn = gr.Button("Lancer l’entraînement", variant="primary")
252
 
253
  with gr.Column():
254
+ train_status = gr.Textbox(
255
+ label="Journal d’entraînement",
256
+ lines=18,
257
+ )
258
  train_history = gr.JSON(label="Historique d’entraînement")
259
+ train_summary = gr.JSON(label="Résumé final")
260
+
261
+ gr.Markdown("## Résultats sur le test set")
262
 
 
263
  with gr.Row():
264
+ train_report = gr.Dataframe(
265
+ label="Rapport de classification",
266
+ interactive=False,
267
+ )
268
+
269
+ with gr.Row():
270
+ train_confusion_matrix = gr.Dataframe(
271
+ label="Matrice de confusion",
272
+ interactive=False,
273
+ )
274
 
275
+ with gr.Row():
276
+ train_confusion_matrix_image = gr.Image(
277
+ label="Matrice de confusion - figure",
278
+ type="filepath",
279
+ )
280
+
281
+ with gr.Tab("3. Tester et analyser un modèle"):
282
+ gr.Markdown("## Sélectionner un modèle sauvegardé")
283
+
284
+ with gr.Row():
285
+ with gr.Column():
286
  model_selector = gr.Dropdown(
287
  choices=initial_models,
288
  value=initial_models[0] if initial_models else None,
289
+ label="Modèle sauvegardé",
290
  )
291
  refresh_btn = gr.Button("Actualiser la liste des modèles")
292
  load_info_btn = gr.Button("Afficher les informations du modèle")
293
  model_info = gr.JSON(label="Métadonnées du modèle")
294
 
295
  with gr.Column():
296
+ evaluate_btn = gr.Button("Évaluer le modèle sur le test set", variant="primary")
297
+ eval_summary = gr.JSON(label="Résumé des métriques")
298
+ eval_report = gr.Dataframe(
299
+ label="Rapport de classification",
300
+ interactive=False,
301
+ )
302
+
303
+ with gr.Row():
304
+ eval_confusion_matrix = gr.Dataframe(
305
+ label="Matrice de confusion",
306
+ interactive=False,
307
+ )
308
 
309
+ with gr.Row():
310
+ eval_confusion_matrix_image = gr.Image(
311
+ label="Matrice de confusion - figure",
312
+ type="filepath",
313
+ )
314
+
315
+ gr.Markdown("## Prédiction sur une image importée")
316
+
317
+ with gr.Row():
318
+ with gr.Column():
319
  upload_image = gr.Image(type="pil", label="Importer une image")
320
  predict_btn = gr.Button("Prédire la classe", variant="primary")
321
+ with gr.Column():
322
  predict_text = gr.Textbox(label="Résultat de la prédiction", lines=7)
323
  predict_probs = gr.Label(label="Probabilités par classe")
324
 
325
+ gr.Markdown("## Test sur un échantillon aléatoire du test set")
326
+
327
  with gr.Row():
328
  random_test_btn = gr.Button("Tester un échantillon aléatoire")
329
 
 
332
  random_sample_text = gr.Textbox(label="Résultat sur l’échantillon", lines=7)
333
  random_sample_probs = gr.Label(label="Probabilités par classe")
334
 
335
+ load_dataset_btn.click(
336
+ fn=load_dataset_overview_callback,
337
+ inputs=None,
338
+ outputs=[dataset_summary, class_distribution, class_selector],
339
+ )
340
+
341
+ refresh_gallery_btn.click(
342
+ fn=refresh_gallery_callback,
343
+ inputs=[split_selector, class_selector, max_images],
344
+ outputs=image_gallery,
345
+ )
346
+
347
  train_btn.click(
348
  fn=train_callback,
349
  inputs=[
 
 
 
350
  dropout,
351
  fc_dim,
352
  learning_rate,
353
+ weight_decay,
354
  batch_size,
355
  epochs,
356
+ freeze_backbone,
357
  model_tag,
358
  ],
359
+ outputs=[
360
+ train_status,
361
+ train_history,
362
+ train_summary,
363
+ train_report,
364
+ train_confusion_matrix,
365
+ train_confusion_matrix_image,
366
+ model_selector,
367
+ ],
368
  )
369
 
370
  refresh_btn.click(
 
379
  outputs=model_info,
380
  )
381
 
382
+ evaluate_btn.click(
383
+ fn=evaluate_saved_model_callback,
384
+ inputs=model_selector,
385
+ outputs=[
386
+ eval_summary,
387
+ eval_report,
388
+ eval_confusion_matrix,
389
+ eval_confusion_matrix_image,
390
+ ],
391
+ )
392
+
393
  predict_btn.click(
394
  fn=predict_uploaded_image_callback,
395
  inputs=[model_selector, upload_image],
 
404
 
405
 
406
  if __name__ == "__main__":
407
+ demo.launch(ssr_mode=False)