Aurel-test commited on
Commit
1f5d8b4
1 Parent(s): ede1f67

Add DocString

Browse files
Files changed (1) hide show
  1. app.py +136 -58
app.py CHANGED
@@ -37,9 +37,9 @@ id2label = {
37
  }
38
  label2id = {v: k for k, v in id2label.items()}
39
  num_labels = len(id2label)
40
- checkpoint = "nvidia/segformer-b4-finetuned-cityscapes-1024-1024"
41
- image_processor = SegformerImageProcessor()
42
- state_dict_path = f"runs/{checkpoint}_v1/best_model.pt"
43
  model = SegformerForSemanticSegmentation.from_pretrained(
44
  checkpoint,
45
  num_labels=num_labels,
@@ -58,6 +58,17 @@ model.eval()
58
 
59
 
60
  def load_and_prepare_images(image_name, segformer=False):
 
 
 
 
 
 
 
 
 
 
 
61
  image_path = os.path.join(data_folder, "images", image_name)
62
  mask_name = image_name.replace("_leftImg8bit.png", "_gtFine_labelIds.png")
63
  mask_path = os.path.join(data_folder, "masks", mask_name)
@@ -82,35 +93,47 @@ def load_and_prepare_images(image_name, segformer=False):
82
 
83
 
84
  def predict_segmentation(image):
85
- # Charger et préparer l'image
86
- inputs = image_processor(images=image, return_tensors="pt")
87
 
88
- # Utiliser GPU si disponible
 
 
 
 
 
 
 
89
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
90
  model.to(device)
91
-
92
- # Déplacer les inputs sur le bon device et faire la prédiction
93
  pixel_values = inputs.pixel_values.to(device)
94
 
95
- with torch.no_grad(): # Désactiver le calcul des gradients pour l'inférence
96
  outputs = model(pixel_values=pixel_values)
97
  logits = outputs.logits
98
 
99
- # Redimensionner les logits à la taille de l'image d'origine
100
  upsampled_logits = nn.functional.interpolate(
101
  logits,
102
  size=image.size[::-1], # (height, width)
103
  mode="bilinear",
104
  align_corners=False,
105
  )
106
-
107
- # Obtenir la prédiction finale
108
  pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
109
 
110
  return pred_seg
111
 
112
 
113
  def process_image(image_name):
 
 
 
 
 
 
 
 
 
 
114
  original, true_mask, fpn_pred, segformer_pred = load_and_prepare_images(
115
  image_name, segformer=True
116
  )
@@ -131,6 +154,12 @@ def process_image(image_name):
131
 
132
 
133
  def create_cityscapes_label_colormap():
 
 
 
 
 
 
134
  colormap = np.zeros((256, 3), dtype=np.uint8)
135
  colormap[0] = [78, 82, 110]
136
  colormap[1] = [128, 64, 128]
@@ -147,68 +176,43 @@ def create_cityscapes_label_colormap():
147
  cityscapes_colormap = create_cityscapes_label_colormap()
148
 
149
 
150
- def blend_images(original_image, colored_segmentation, alpha=0.6):
151
- blended_image = Image.blend(original_image, colored_segmentation, alpha)
152
- return blended_image
153
-
154
-
155
  def colorize_mask(mask):
156
  return cityscapes_colormap[mask]
157
 
158
 
159
  # ---- Fin Partie Segmentation
160
 
161
- # def compare_masks(real_mask, fpn_mask, segformer_mask):
162
- # """
163
- # Compare les masques prédits par FPN et SegFormer avec le masque réel.
164
- # Retourne un score IoU et une précision pixel par pixel pour chaque modèle.
165
-
166
- # Args:
167
- # real_mask (np.array): Le masque réel de référence
168
- # fpn_mask (np.array): Le masque prédit par le modèle FPN
169
- # segformer_mask (np.array): Le masque prédit par le modèle SegFormer
170
-
171
- # Returns:
172
- # dict: Dictionnaire contenant les scores IoU et les précisions pour chaque modèle
173
- # """
174
-
175
- # assert real_mask.shape == fpn_mask.shape == segformer_mask.shape, "Les masques doivent avoir la même forme"
176
-
177
- # real_flat = real_mask.flatten()
178
- # fpn_flat = fpn_mask.flatten()
179
- # segformer_flat = segformer_mask.flatten()
180
-
181
- # # Calcul du score de Jaccard (IoU)
182
- # iou_fpn = jaccard_score(real_flat, fpn_flat, average='weighted')
183
- # iou_segformer = jaccard_score(real_flat, segformer_flat, average='weighted')
184
-
185
- # # Calcul de la précision pixel par pixel
186
- # accuracy_fpn = accuracy_score(real_flat, fpn_flat)
187
- # accuracy_segformer = accuracy_score(real_flat, segformer_flat)
188
-
189
- # return {
190
- # 'FPN': {'IoU': iou_fpn, 'Precision': accuracy_fpn},
191
- # 'SegFormer': {'IoU': iou_segformer, 'Precision': accuracy_segformer}
192
- # }
193
-
194
  # ---- Partie EDA
195
 
196
 
197
  def analyse_mask(real_mask, num_labels):
198
- # Compter les occurrences de chaque classe
199
- counts = np.bincount(real_mask.ravel(), minlength=num_labels)
200
 
201
- # Calculer le nombre total de pixels
202
- total_pixels = real_mask.size
 
203
 
204
- # Calculer les proportions
 
 
 
 
205
  class_proportions = counts / total_pixels
206
-
207
- # Créer un dictionnaire avec les proportions
208
  return dict(enumerate(class_proportions))
209
 
210
 
211
  def show_eda(image_name):
 
 
 
 
 
 
 
 
 
 
212
  original_image, true_mask, _ = load_and_prepare_images(image_name)
213
  class_proportions = analyse_mask(true_mask, num_labels)
214
  cityscapes_colormap = create_cityscapes_label_colormap()
@@ -266,17 +270,54 @@ def show_eda(image_name):
266
 
267
 
268
  class SegformerWrapper(nn.Module):
 
 
 
 
 
 
 
269
  def __init__(self, model):
 
 
 
 
 
 
270
  super().__init__()
271
  self.model = model
272
 
273
  def forward(self, x):
 
 
 
 
 
 
 
 
 
274
  output = self.model(x)
275
  return output.logits
276
 
277
 
278
  class SemanticSegmentationTarget:
 
 
 
 
 
 
 
 
279
  def __init__(self, category, mask):
 
 
 
 
 
 
 
280
  self.category = category
281
  self.mask = torch.from_numpy(mask)
282
  if torch.cuda.is_available():
@@ -305,12 +346,33 @@ class SemanticSegmentationTarget:
305
 
306
 
307
  def segformer_reshape_transform_huggingface(tensor, width, height):
 
 
 
 
 
 
 
 
 
 
 
308
  result = tensor.reshape(tensor.size(0), height, width, tensor.size(2))
309
  result = result.transpose(2, 3).transpose(1, 2)
310
  return result
311
 
312
 
313
  def explain_model(image_name, category_name):
 
 
 
 
 
 
 
 
 
 
314
  original_image, _, _ = load_and_prepare_images(image_name)
315
  rgb_img = np.float32(original_image) / 255
316
  img_tensor = transforms.ToTensor()(rgb_img)
@@ -379,6 +441,12 @@ import random
379
 
380
 
381
  def change_image():
 
 
 
 
 
 
382
  image_dir = (
383
  "data_sample/images" # Remplacez par le chemin de votre dossier d'images
384
  )
@@ -388,6 +456,16 @@ def change_image():
388
 
389
 
390
  def apply_augmentation(image, augmentation_names):
 
 
 
 
 
 
 
 
 
 
391
  augmentations = {
392
  "Horizontal Flip": A.HorizontalFlip(p=1),
393
  "Shift Scale Rotate": A.ShiftScaleRotate(p=1),
@@ -541,4 +619,4 @@ with gr.Blocks(title="Preuve de concept", theme=my_theme) as demo:
541
 
542
 
543
  # Lancer l'application
544
- demo.launch(favicon_path="favicon.ico", share=True)
 
37
  }
38
  label2id = {v: k for k, v in id2label.items()}
39
  num_labels = len(id2label)
40
+ checkpoint = "nvidia/segformer-b3-finetuned-cityscapes-1024-1024"
41
+ image_processor = SegformerImageProcessor(do_resize=False)
42
+ state_dict_path = f"runs/{checkpoint}/best_model.pt"
43
  model = SegformerForSemanticSegmentation.from_pretrained(
44
  checkpoint,
45
  num_labels=num_labels,
 
58
 
59
 
60
  def load_and_prepare_images(image_name, segformer=False):
61
+ """
62
+ Charge et prépare les images, les masques et les prédictions associées pour une image donnée.
63
+
64
+ Args:
65
+ image_name (str): Le nom du fichier de l'image à charger.
66
+ segformer (bool, optional): Si True, prédit également le masque avec SegFormer. Par défaut False.
67
+
68
+ Returns:
69
+ tuple: Contient l'image originale redimensionnée, le masque réel, la prédiction FPN,
70
+ et la prédiction SegFormer si `segformer` est True.
71
+ """
72
  image_path = os.path.join(data_folder, "images", image_name)
73
  mask_name = image_name.replace("_leftImg8bit.png", "_gtFine_labelIds.png")
74
  mask_path = os.path.join(data_folder, "masks", mask_name)
 
93
 
94
 
95
  def predict_segmentation(image):
96
+ """
97
+ Prédit la segmentation d'une image donnée à l'aide d'un modèle pré-entraîné.
98
 
99
+ Args:
100
+ image (PIL.Image.Image): L'image à segmenter.
101
+
102
+ Returns:
103
+ numpy.ndarray: La carte de segmentation prédite.
104
+ """
105
+
106
+ inputs = image_processor(images=image, return_tensors="pt")
107
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
  model.to(device)
 
 
109
  pixel_values = inputs.pixel_values.to(device)
110
 
111
+ with torch.no_grad():
112
  outputs = model(pixel_values=pixel_values)
113
  logits = outputs.logits
114
 
 
115
  upsampled_logits = nn.functional.interpolate(
116
  logits,
117
  size=image.size[::-1], # (height, width)
118
  mode="bilinear",
119
  align_corners=False,
120
  )
 
 
121
  pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
122
 
123
  return pred_seg
124
 
125
 
126
  def process_image(image_name):
127
+ """
128
+ Traite une image en chargeant l'image originale, le masque réel, et les prédictions de masques.
129
+ Envoie la liste de tuple à l'interface "Predictions" de Gradio
130
+
131
+ Args:
132
+ image_name (str): Le nom de l'image à traiter.
133
+
134
+ Returns:
135
+ list: Une liste de tuples contenant l'image et son titre associé.
136
+ """
137
  original, true_mask, fpn_pred, segformer_pred = load_and_prepare_images(
138
  image_name, segformer=True
139
  )
 
154
 
155
 
156
  def create_cityscapes_label_colormap():
157
+ """
158
+ Crée une colormap pour les labels Cityscapes.
159
+
160
+ Returns:
161
+ numpy.ndarray: Un tableau 2D où chaque ligne représente la couleur RGB d'un label.
162
+ """
163
  colormap = np.zeros((256, 3), dtype=np.uint8)
164
  colormap[0] = [78, 82, 110]
165
  colormap[1] = [128, 64, 128]
 
176
  cityscapes_colormap = create_cityscapes_label_colormap()
177
 
178
 
 
 
 
 
 
179
  def colorize_mask(mask):
180
  return cityscapes_colormap[mask]
181
 
182
 
183
  # ---- Fin Partie Segmentation
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  # ---- Partie EDA
186
 
187
 
188
  def analyse_mask(real_mask, num_labels):
189
+ """
190
+ Analyse la distribution des classes dans un masque réel.
191
 
192
+ Args:
193
+ real_mask (numpy.ndarray): Le masque de labels réels.
194
+ num_labels (int): Le nombre total de classes.
195
 
196
+ Returns:
197
+ dict: Un dictionnaire contenant les proportions des classes dans le masque.
198
+ """
199
+ counts = np.bincount(real_mask.ravel(), minlength=num_labels)
200
+ total_pixels = real_mask.size
201
  class_proportions = counts / total_pixels
 
 
202
  return dict(enumerate(class_proportions))
203
 
204
 
205
  def show_eda(image_name):
206
+ """
207
+ Affiche une analyse exploratoire de la distribution des classes pour une image et son masque associé.
208
+
209
+ Args:
210
+ image_name (str): Le nom de l'image à analyser.
211
+
212
+ Returns:
213
+ tuple: Contient l'image originale, le masque réel coloré et une figure Plotly représentant
214
+ la distribution des classes.
215
+ """
216
  original_image, true_mask, _ = load_and_prepare_images(image_name)
217
  class_proportions = analyse_mask(true_mask, num_labels)
218
  cityscapes_colormap = create_cityscapes_label_colormap()
 
270
 
271
 
272
  class SegformerWrapper(nn.Module):
273
+ """
274
+ Un wrapper pour le modèle SegFormer qui renvoie uniquement les logits en sortie.
275
+
276
+ Args:
277
+ model (torch.nn.Module): Le modèle SegFormer pré-entraîné.
278
+ """
279
+
280
  def __init__(self, model):
281
+ """
282
+ Initialise le SegformerWrapper.
283
+
284
+ Args:
285
+ model (torch.nn.Module): Le modèle SegFormer pré-entraîné.
286
+ """
287
  super().__init__()
288
  self.model = model
289
 
290
  def forward(self, x):
291
+ """
292
+ Renvoie les logits du modèle au lieu de renvoyer un dictionnaire.
293
+
294
+ Args:
295
+ x (torch.Tensor): Les entrées du modèle.
296
+
297
+ Returns:
298
+ torch.Tensor: Les logits du modèle.
299
+ """
300
  output = self.model(x)
301
  return output.logits
302
 
303
 
304
  class SemanticSegmentationTarget:
305
+ """
306
+ Représente une classe cible pour la segmentation sémantique utilisée dans GradCAM.
307
+
308
+ Args:
309
+ category (int): L'index de la catégorie cible.
310
+ mask (numpy.ndarray): Le masque binaire indiquant les pixels d'intérêt.
311
+ """
312
+
313
  def __init__(self, category, mask):
314
+ """
315
+ Initialise la cible de segmentation sémantique.
316
+
317
+ Args:
318
+ category (int): L'index de la catégorie cible.
319
+ mask (numpy.ndarray): Le masque binaire indiquant les pixels d'intérêt.
320
+ """
321
  self.category = category
322
  self.mask = torch.from_numpy(mask)
323
  if torch.cuda.is_available():
 
346
 
347
 
348
  def segformer_reshape_transform_huggingface(tensor, width, height):
349
+ """
350
+ Réorganise les dimensions du tenseur pour qu'elles correspondent au format attendu par GradCAM.
351
+
352
+ Args:
353
+ tensor (torch.Tensor): Le tenseur à réorganiser.
354
+ width (int): La nouvelle largeur.
355
+ height (int): La nouvelle hauteur.
356
+
357
+ Returns:
358
+ torch.Tensor: Le tenseur réorganisé.
359
+ """
360
  result = tensor.reshape(tensor.size(0), height, width, tensor.size(2))
361
  result = result.transpose(2, 3).transpose(1, 2)
362
  return result
363
 
364
 
365
  def explain_model(image_name, category_name):
366
+ """
367
+ Explique les prédictions du modèle SegFormer en utilisant GradCAM pour une image et une catégorie données.
368
+
369
+ Args:
370
+ image_name (str): Le nom de l'image à expliquer.
371
+ category_name (str): Le nom de la catégorie cible.
372
+
373
+ Returns:
374
+ matplotlib.figure.Figure: Une figure matplotlib contenant la carte de chaleur GradCAM superposée sur l'image originale.
375
+ """
376
  original_image, _, _ = load_and_prepare_images(image_name)
377
  rgb_img = np.float32(original_image) / 255
378
  img_tensor = transforms.ToTensor()(rgb_img)
 
441
 
442
 
443
  def change_image():
444
+ """
445
+ Sélectionne et charge aléatoirement une image depuis un dossier spécifié.
446
+
447
+ Returns:
448
+ PIL.Image.Image: L'image sélectionnée.
449
+ """
450
  image_dir = (
451
  "data_sample/images" # Remplacez par le chemin de votre dossier d'images
452
  )
 
456
 
457
 
458
  def apply_augmentation(image, augmentation_names):
459
+ """
460
+ Applique une ou plusieurs augmentations à une image.
461
+
462
+ Args:
463
+ image (PIL.Image.Image): L'image à augmenter.
464
+ augmentation_names (list of str): Les noms des augmentations à appliquer.
465
+
466
+ Returns:
467
+ PIL.Image.Image: L'image augmentée.
468
+ """
469
  augmentations = {
470
  "Horizontal Flip": A.HorizontalFlip(p=1),
471
  "Shift Scale Rotate": A.ShiftScaleRotate(p=1),
 
619
 
620
 
621
  # Lancer l'application
622
+ demo.launch(favicon_path="favicon.ico")