import gradio as gr import torch from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor from PIL import Image import plotly.graph_objects as go import numpy as np import os import torch.nn as nn from sklearn.metrics import jaccard_score, accuracy_score from collections import Counter import matplotlib.pyplot as plt import seaborn as sns import torch.nn.functional as F import seaborn as sns from functools import partial from pytorch_grad_cam.utils.image import ( show_cam_on_image, preprocess_image as grad_preprocess, ) from pytorch_grad_cam import GradCAM import cv2 import transformers from torchvision import transforms import albumentations as A device = "cuda" if torch.cuda.is_available() else "cpu" data_folder = "data_sample" id2label = { 0: "void", 1: "flat", 2: "construction", 3: "object", 4: "nature", 5: "sky", 6: "human", 7: "vehicle", } label2id = {v: k for k, v in id2label.items()} num_labels = len(id2label) checkpoint = "nvidia/segformer-b3-finetuned-cityscapes-1024-1024" image_processor = SegformerImageProcessor(do_resize=False) state_dict_path = f"runs/{checkpoint}/best_model.pt" model = SegformerForSemanticSegmentation.from_pretrained( checkpoint, num_labels=num_labels, id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True, ) loaded_state_dict = torch.load( state_dict_path, map_location=torch.device("cpu"), weights_only=True ) model.load_state_dict(loaded_state_dict) model = model.to(device) model.eval() # ---- Partie Segmentation def load_and_prepare_images(image_name, segformer=False): """ Charge et prépare les images, les masques et les prédictions associées pour une image donnée. Args: image_name (str): Le nom du fichier de l'image à charger. segformer (bool, optional): Si True, prédit également le masque avec SegFormer. Par défaut False. Returns: tuple: Contient l'image originale redimensionnée, le masque réel, la prédiction FPN, et la prédiction SegFormer si `segformer` est True. """ image_path = os.path.join(data_folder, "images", image_name) mask_name = image_name.replace("_leftImg8bit.png", "_gtFine_labelIds.png") mask_path = os.path.join(data_folder, "masks", mask_name) fpn_pred_path = os.path.join(data_folder, "resnet101_mask", image_name) if not os.path.exists(image_path): raise FileNotFoundError(f"Image not found: {image_path}") if not os.path.exists(mask_path): raise FileNotFoundError(f"Mask not found: {mask_path}") if not os.path.exists(fpn_pred_path): raise FileNotFoundError(f"FPN prediction not found: {fpn_pred_path}") original_image = Image.open(image_path).convert("RGB") original = original_image.resize((1024, 512)) true_mask = np.array(Image.open(mask_path)) fpn_pred = np.array(Image.open(fpn_pred_path)) if segformer: segformer_pred = predict_segmentation(original) return original, true_mask, fpn_pred, segformer_pred return original, true_mask, fpn_pred def predict_segmentation(image): """ Prédit la segmentation d'une image donnée à l'aide d'un modèle pré-entraîné. Args: image (PIL.Image.Image): L'image à segmenter. Returns: numpy.ndarray: La carte de segmentation prédite. """ inputs = image_processor(images=image, return_tensors="pt") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) pixel_values = inputs.pixel_values.to(device) with torch.no_grad(): outputs = model(pixel_values=pixel_values) logits = outputs.logits upsampled_logits = nn.functional.interpolate( logits, size=image.size[::-1], # (height, width) mode="bilinear", align_corners=False, ) pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy() return pred_seg def process_image(image_name): """ Traite une image en chargeant l'image originale, le masque réel, et les prédictions de masques. Envoie la liste de tuple à l'interface "Predictions" de Gradio Args: image_name (str): Le nom de l'image à traiter. Returns: list: Une liste de tuples contenant l'image et son titre associé. """ original, true_mask, fpn_pred, segformer_pred = load_and_prepare_images( image_name, segformer=True ) true_mask_colored = colorize_mask(true_mask) true_mask_colored = Image.fromarray(true_mask_colored.astype("uint8")) true_mask_colored = true_mask_colored.resize((1024, 512)) # fpn_pred_colored = colorize_mask(fpn_pred) segformer_pred_colored = colorize_mask(segformer_pred) segformer_pred_colored = Image.fromarray(segformer_pred_colored.astype("uint8")) segformer_pred_colored = segformer_pred_colored.resize((1024, 512)) return [ (original, "Image originale"), (true_mask_colored, "Masque réel"), (fpn_pred, "Prédiction FPN"), (segformer_pred_colored, "Prédiction SegFormer"), ] def create_cityscapes_label_colormap(): """ Crée une colormap pour les labels Cityscapes. Returns: numpy.ndarray: Un tableau 2D où chaque ligne représente la couleur RGB d'un label. """ colormap = np.zeros((256, 3), dtype=np.uint8) colormap[0] = [78, 82, 110] colormap[1] = [128, 64, 128] colormap[2] = [154, 156, 153] colormap[3] = [168, 167, 18] colormap[4] = [80, 108, 28] colormap[5] = [112, 164, 196] colormap[6] = [168, 28, 52] colormap[7] = [16, 18, 112] return colormap # Créer la colormap une fois cityscapes_colormap = create_cityscapes_label_colormap() def colorize_mask(mask): return cityscapes_colormap[mask] # ---- Fin Partie Segmentation # ---- Partie EDA def analyse_mask(real_mask, num_labels): """ Analyse la distribution des classes dans un masque réel. Args: real_mask (numpy.ndarray): Le masque de labels réels. num_labels (int): Le nombre total de classes. Returns: dict: Un dictionnaire contenant les proportions des classes dans le masque. """ counts = np.bincount(real_mask.ravel(), minlength=num_labels) total_pixels = real_mask.size class_proportions = counts / total_pixels return dict(enumerate(class_proportions)) def show_eda(image_name): """ Affiche une analyse exploratoire de la distribution des classes pour une image et son masque associé. Args: image_name (str): Le nom de l'image à analyser. Returns: tuple: Contient l'image originale, le masque réel coloré et une figure Plotly représentant la distribution des classes. """ original_image, true_mask, _ = load_and_prepare_images(image_name) class_proportions = analyse_mask(true_mask, num_labels) cityscapes_colormap = create_cityscapes_label_colormap() true_mask_colored = colorize_mask(true_mask) true_mask_colored = Image.fromarray(true_mask_colored.astype("uint8")) true_mask_colored = true_mask_colored.resize((1024, 512)) # Trier les classes par proportion croissante sorted_classes = sorted( class_proportions.keys(), key=lambda x: class_proportions[x] ) # Préparer les données pour le barplot categories = [id2label[i] for i in sorted_classes] values = [class_proportions[i] for i in sorted_classes] color_list = [ f"rgb({cityscapes_colormap[i][0]}, {cityscapes_colormap[i][1]}, {cityscapes_colormap[i][2]})" for i in sorted_classes ] # Distribution des classes avec la colormap personnalisée fig = go.Figure() fig.add_trace( go.Bar( x=categories, y=values, marker_color=color_list, text=[f"{v:.2f}" for v in values], textposition="outside", ) ) # Ajouter un titre et des labels, modifier la rotation et la taille de la police fig.update_layout( title={"text": "Distribution des classes", "font": {"size": 24}}, xaxis_title={"text": "Catégories", "font": {"size": 18}}, yaxis_title={"text": "Proportion", "font": {"size": 18}}, xaxis_tickangle=0, # Rotation modifiée à -45 degrés uniformtext_minsize=12, uniformtext_mode="hide", font=dict(size=14), autosize=True, bargap=0.2, height=600, margin=dict(l=20, r=20, t=50, b=20), ) return original_image, true_mask_colored, fig # ----Fin Partie EDA # ----Partie Explication GradCam class SegformerWrapper(nn.Module): """ Un wrapper pour le modèle SegFormer qui renvoie uniquement les logits en sortie. Args: model (torch.nn.Module): Le modèle SegFormer pré-entraîné. """ def __init__(self, model): """ Initialise le SegformerWrapper. Args: model (torch.nn.Module): Le modèle SegFormer pré-entraîné. """ super().__init__() self.model = model def forward(self, x): """ Renvoie les logits du modèle au lieu de renvoyer un dictionnaire. Args: x (torch.Tensor): Les entrées du modèle. Returns: torch.Tensor: Les logits du modèle. """ output = self.model(x) return output.logits class SemanticSegmentationTarget: """ Représente une classe cible pour la segmentation sémantique utilisée dans GradCAM. Args: category (int): L'index de la catégorie cible. mask (numpy.ndarray): Le masque binaire indiquant les pixels d'intérêt. """ def __init__(self, category, mask): """ Initialise la cible de segmentation sémantique. Args: category (int): L'index de la catégorie cible. mask (numpy.ndarray): Le masque binaire indiquant les pixels d'intérêt. """ self.category = category self.mask = torch.from_numpy(mask) if torch.cuda.is_available(): self.mask = self.mask.cuda() def __call__(self, model_output): if isinstance( model_output, (dict, transformers.modeling_outputs.SemanticSegmenterOutput) ): logits = ( model_output["logits"] if isinstance(model_output, dict) else model_output.logits ) elif isinstance(model_output, torch.Tensor): logits = model_output else: raise ValueError(f"Unexpected model_output type: {type(model_output)}") if logits.dim() == 4: # [batch, classes, height, width] return (logits[0, self.category, :, :] * self.mask).sum() elif logits.dim() == 3: # [classes, height, width] return (logits[self.category, :, :] * self.mask).sum() else: raise ValueError(f"Unexpected logits shape: {logits.shape}") def segformer_reshape_transform_huggingface(tensor, width, height): """ Réorganise les dimensions du tenseur pour qu'elles correspondent au format attendu par GradCAM. Args: tensor (torch.Tensor): Le tenseur à réorganiser. width (int): La nouvelle largeur. height (int): La nouvelle hauteur. Returns: torch.Tensor: Le tenseur réorganisé. """ result = tensor.reshape(tensor.size(0), height, width, tensor.size(2)) result = result.transpose(2, 3).transpose(1, 2) return result def explain_model(image_name, category_name): """ Explique les prédictions du modèle SegFormer en utilisant GradCAM pour une image et une catégorie données. Args: image_name (str): Le nom de l'image à expliquer. category_name (str): Le nom de la catégorie cible. Returns: matplotlib.figure.Figure: Une figure matplotlib contenant la carte de chaleur GradCAM superposée sur l'image originale. """ original_image, _, _ = load_and_prepare_images(image_name) rgb_img = np.float32(original_image) / 255 img_tensor = transforms.ToTensor()(rgb_img) input_tensor = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )(img_tensor) input_tensor = input_tensor.unsqueeze(0).to(device) wrapped_model = SegformerWrapper(model).to(device) with torch.no_grad(): output = wrapped_model(input_tensor) upsampled_logits = nn.functional.interpolate( output, size=input_tensor.shape[-2:], mode="bilinear", align_corners=False ) normalized_masks = torch.nn.functional.softmax(upsampled_logits, dim=1).cpu() category = label2id[category_name] mask = normalized_masks[0].argmax(dim=0).numpy() mask_float = np.float32(mask == category) reshape_transform = partial( segformer_reshape_transform_huggingface, # réorganise les dimensions du tenseur pour qu'elles correspondent au format attendu par GradCAM. width=img_tensor.shape[2] // 32, height=img_tensor.shape[1] // 32, ) target_layers = [wrapped_model.model.segformer.encoder.layer_norm[-1]] mask_float_resized = cv2.resize(mask_float, (output.shape[3], output.shape[2])) targets = [SemanticSegmentationTarget(category, mask_float_resized)] cam = GradCAM( model=wrapped_model, target_layers=target_layers, reshape_transform=reshape_transform, ) grayscale_cam = cam(input_tensor=input_tensor, targets=targets) threshold = 0.01 # Seuil de 1% de sureté thresholded_cam = grayscale_cam.copy() thresholded_cam[grayscale_cam < threshold] = 0 if np.max(thresholded_cam) > 0: thresholded_cam = thresholded_cam / np.max(thresholded_cam) else: thresholded_cam = grayscale_cam[0] resized_cam = cv2.resize( thresholded_cam[0], (input_tensor.shape[3], input_tensor.shape[2]) ) masked_cam = resized_cam * mask_float if np.max(masked_cam) > 0: cam_image = show_cam_on_image(rgb_img, masked_cam, use_rgb=True) else: cam_image = original_image fig, ax = plt.subplots(figsize=(15, 10)) ax.imshow(cam_image) ax.axis("off") ax.set_title(f"Masque de chaleur GradCam pour {category_name}", color="white") margin = 0.02 # Adjust this value to change the size of the margin margin_color = "#0a0f1e" fig.subplots_adjust(left=margin, right=1 - margin, top=1 - margin, bottom=margin) fig.patch.set_facecolor(margin_color) plt.close() return fig # ----Fin Partie Explication GradCam # ----Partie Data augmentation import random def change_image(): """ Sélectionne et charge aléatoirement une image depuis un dossier spécifié. Returns: PIL.Image.Image: L'image sélectionnée. """ image_dir = ( "data_sample/images" # Remplacez par le chemin de votre dossier d'images ) image_list = [f for f in os.listdir(image_dir) if f.endswith(".png")] random_image = random.choice(image_list) return Image.open(os.path.join(image_dir, random_image)) def apply_augmentation(image, augmentation_names): """ Applique une ou plusieurs augmentations à une image. Args: image (PIL.Image.Image): L'image à augmenter. augmentation_names (list of str): Les noms des augmentations à appliquer. Returns: PIL.Image.Image: L'image augmentée. """ augmentations = { "Horizontal Flip": A.HorizontalFlip(p=1), "Shift Scale Rotate": A.ShiftScaleRotate(p=1), "Random Brightness Contrast": A.RandomBrightnessContrast(p=1), "RGB Shift": A.RGBShift(p=1), "Blur": A.Blur(blur_limit=(5, 7), p=1), "Gaussian Noise": A.GaussNoise(p=1), "Grid Distortion": A.GridDistortion(p=1), "Random Sun": A.RandomSunFlare(p=1), } image_array = np.array(image) if augmentation_names is not None: selected_augs = [ augmentations[name] for name in augmentation_names if name in augmentations ] compose = A.Compose(selected_augs) # Appliquer la composition d'augmentations augmented = compose(image=image_array) return Image.fromarray(augmented["image"]) else: return image # ---- Fin Partie Data augmentation image_list = [ f for f in os.listdir(os.path.join(data_folder, "images")) if f.endswith(".png") ] category_list = list(id2label.values()) image_name = "dusseldorf_000012_000019_leftImg8bit.png" default_image = os.path.join(data_folder, "images", image_name) my_theme = gr.Theme.from_hub("gstaff/whiteboard") with gr.Blocks(title="Preuve de concept", theme=my_theme) as demo: gr.Markdown("# Projet 10 - Développer une preuve de concept") with gr.Tab("Distribution"): gr.Markdown("## Distribution des classes Cityscapes") gr.Markdown( "### Visualisation de la distribution de chaque classe selon l'image choisie." ) eda_image_input = gr.Dropdown( choices=image_list, label="Sélectionnez une image", ) with gr.Row(): original_image_output = gr.Image(type="pil", label="Image originale") original_mask_output = gr.Image(type="pil", label="Masque original") class_distribution_plot = gr.Plot(label="Distribution des classes") eda_image_input.change( fn=show_eda, inputs=eda_image_input, outputs=[ original_image_output, original_mask_output, class_distribution_plot, ], ) with gr.Tab("Data Augmentation"): gr.Markdown("## Visualisation de l'augmentation des données") gr.Markdown( "### Sélectionnez une ou plusieurs augmentations pour l'appliquer à l'image." ) gr.Markdown("### Vous pouvez également changer d'image.") with gr.Row(): image_display = gr.Image( value=default_image, label="Image", show_download_button=False, interactive=False, ) augmented_image = gr.Image(label="Image Augmentée") with gr.Row(): change_image_button = gr.Button("Changer image") augmentation_dropdown = gr.Dropdown( choices=[ "Horizontal Flip", "Shift Scale Rotate", "Random Brightness Contrast", "RGB Shift", "Blur", "Gaussian Noise", "Grid Distortion", "Random Sun", ], label="Sélectionnez une augmentation", multiselect=True, ) apply_button = gr.Button("Appliquer l'augmentation") change_image_button.click(fn=change_image, outputs=image_display) apply_button.click( fn=apply_augmentation, inputs=[image_display, augmentation_dropdown], outputs=augmented_image, ) with gr.Tab("Prédictions"): gr.Markdown("## Comparaison de segmentations d'images Cityscapes") gr.Markdown( "### Sélectionnez une image pour voir la comparaison entre le masque réel, la prédiction FPN (pré-enregistré) et la prédiction du modèle SegFormer." ) image_input = gr.Dropdown(choices=image_list, label="Sélectionnez une image") gallery_output = gr.Gallery( label="Résultats de segmentation", show_label=True, elem_id="gallery", columns=[2], rows=[2], object_fit="contain", height="512px", min_width="1024px", ) image_input.change(fn=process_image, inputs=image_input, outputs=gallery_output) with gr.Tab("Explication SegFormer"): gr.Markdown("## Explication du modèle SegFormer") gr.Markdown( "### La méthode Grad-CAM est une technique populaire de visualisation qui est utile pour comprendre comment un réseau neuronal convolutif a été conduit à prendre une décision de classification. Elle est spécifique à chaque classe, ce qui signifie qu’elle peut produire une visualisation distincte pour chaque classe présente dans l’image." ) gr.Markdown( "### NB: Si l'image s'affiche sans masque, c'est que le modèle ne trouve pas de zones significatives pour une catégorie donnée." ) with gr.Row(): explain_image_input = gr.Dropdown( choices=image_list, label="Sélectionnez une image" ) explain_category_input = gr.Dropdown( choices=category_list, label="Sélectionnez une catégorie" ) explain_button = gr.Button("Expliquer") explain_output = gr.Plot(label="Explication SegFormer", min_width=200) explain_button.click( fn=explain_model, inputs=[explain_image_input, explain_category_input], outputs=explain_output, ) # Lancer l'application demo.launch(favicon_path="favicon.ico")