Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor | |
from PIL import Image | |
import plotly.graph_objects as go | |
import numpy as np | |
import os | |
import torch.nn as nn | |
from sklearn.metrics import jaccard_score, accuracy_score | |
from collections import Counter | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import torch.nn.functional as F | |
import seaborn as sns | |
from functools import partial | |
from pytorch_grad_cam.utils.image import ( | |
show_cam_on_image, | |
preprocess_image as grad_preprocess, | |
) | |
from pytorch_grad_cam import GradCAM | |
import cv2 | |
import transformers | |
from torchvision import transforms | |
import albumentations as A | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
data_folder = "data_sample" | |
id2label = { | |
0: "void", | |
1: "flat", | |
2: "construction", | |
3: "object", | |
4: "nature", | |
5: "sky", | |
6: "human", | |
7: "vehicle", | |
} | |
label2id = {v: k for k, v in id2label.items()} | |
num_labels = len(id2label) | |
checkpoint = "nvidia/segformer-b3-finetuned-cityscapes-1024-1024" | |
image_processor = SegformerImageProcessor(do_resize=False) | |
state_dict_path = f"runs/{checkpoint}/best_model.pt" | |
model = SegformerForSemanticSegmentation.from_pretrained( | |
checkpoint, | |
num_labels=num_labels, | |
id2label=id2label, | |
label2id=label2id, | |
ignore_mismatched_sizes=True, | |
) | |
loaded_state_dict = torch.load( | |
state_dict_path, map_location=torch.device("cpu"), weights_only=True | |
) | |
model.load_state_dict(loaded_state_dict) | |
model = model.to(device) | |
model.eval() | |
# ---- Partie Segmentation | |
def load_and_prepare_images(image_name, segformer=False): | |
""" | |
Charge et prépare les images, les masques et les prédictions associées pour une image donnée. | |
Args: | |
image_name (str): Le nom du fichier de l'image à charger. | |
segformer (bool, optional): Si True, prédit également le masque avec SegFormer. Par défaut False. | |
Returns: | |
tuple: Contient l'image originale redimensionnée, le masque réel, la prédiction FPN, | |
et la prédiction SegFormer si `segformer` est True. | |
""" | |
image_path = os.path.join(data_folder, "images", image_name) | |
mask_name = image_name.replace("_leftImg8bit.png", "_gtFine_labelIds.png") | |
mask_path = os.path.join(data_folder, "masks", mask_name) | |
fpn_pred_path = os.path.join(data_folder, "resnet101_mask", image_name) | |
if not os.path.exists(image_path): | |
raise FileNotFoundError(f"Image not found: {image_path}") | |
if not os.path.exists(mask_path): | |
raise FileNotFoundError(f"Mask not found: {mask_path}") | |
if not os.path.exists(fpn_pred_path): | |
raise FileNotFoundError(f"FPN prediction not found: {fpn_pred_path}") | |
original_image = Image.open(image_path).convert("RGB") | |
original = original_image.resize((1024, 512)) | |
true_mask = np.array(Image.open(mask_path)) | |
fpn_pred = np.array(Image.open(fpn_pred_path)) | |
if segformer: | |
segformer_pred = predict_segmentation(original) | |
return original, true_mask, fpn_pred, segformer_pred | |
return original, true_mask, fpn_pred | |
def predict_segmentation(image): | |
""" | |
Prédit la segmentation d'une image donnée à l'aide d'un modèle pré-entraîné. | |
Args: | |
image (PIL.Image.Image): L'image à segmenter. | |
Returns: | |
numpy.ndarray: La carte de segmentation prédite. | |
""" | |
inputs = image_processor(images=image, return_tensors="pt") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
pixel_values = inputs.pixel_values.to(device) | |
with torch.no_grad(): | |
outputs = model(pixel_values=pixel_values) | |
logits = outputs.logits | |
upsampled_logits = nn.functional.interpolate( | |
logits, | |
size=image.size[::-1], # (height, width) | |
mode="bilinear", | |
align_corners=False, | |
) | |
pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy() | |
return pred_seg | |
def process_image(image_name): | |
""" | |
Traite une image en chargeant l'image originale, le masque réel, et les prédictions de masques. | |
Envoie la liste de tuple à l'interface "Predictions" de Gradio | |
Args: | |
image_name (str): Le nom de l'image à traiter. | |
Returns: | |
list: Une liste de tuples contenant l'image et son titre associé. | |
""" | |
original, true_mask, fpn_pred, segformer_pred = load_and_prepare_images( | |
image_name, segformer=True | |
) | |
true_mask_colored = colorize_mask(true_mask) | |
true_mask_colored = Image.fromarray(true_mask_colored.astype("uint8")) | |
true_mask_colored = true_mask_colored.resize((1024, 512)) | |
# fpn_pred_colored = colorize_mask(fpn_pred) | |
segformer_pred_colored = colorize_mask(segformer_pred) | |
segformer_pred_colored = Image.fromarray(segformer_pred_colored.astype("uint8")) | |
segformer_pred_colored = segformer_pred_colored.resize((1024, 512)) | |
return [ | |
(original, "Image originale"), | |
(true_mask_colored, "Masque réel"), | |
(fpn_pred, "Prédiction FPN"), | |
(segformer_pred_colored, "Prédiction SegFormer"), | |
] | |
def create_cityscapes_label_colormap(): | |
""" | |
Crée une colormap pour les labels Cityscapes. | |
Returns: | |
numpy.ndarray: Un tableau 2D où chaque ligne représente la couleur RGB d'un label. | |
""" | |
colormap = np.zeros((256, 3), dtype=np.uint8) | |
colormap[0] = [78, 82, 110] | |
colormap[1] = [128, 64, 128] | |
colormap[2] = [154, 156, 153] | |
colormap[3] = [168, 167, 18] | |
colormap[4] = [80, 108, 28] | |
colormap[5] = [112, 164, 196] | |
colormap[6] = [168, 28, 52] | |
colormap[7] = [16, 18, 112] | |
return colormap | |
# Créer la colormap une fois | |
cityscapes_colormap = create_cityscapes_label_colormap() | |
def colorize_mask(mask): | |
return cityscapes_colormap[mask] | |
# ---- Fin Partie Segmentation | |
# ---- Partie EDA | |
def analyse_mask(real_mask, num_labels): | |
""" | |
Analyse la distribution des classes dans un masque réel. | |
Args: | |
real_mask (numpy.ndarray): Le masque de labels réels. | |
num_labels (int): Le nombre total de classes. | |
Returns: | |
dict: Un dictionnaire contenant les proportions des classes dans le masque. | |
""" | |
counts = np.bincount(real_mask.ravel(), minlength=num_labels) | |
total_pixels = real_mask.size | |
class_proportions = counts / total_pixels | |
return dict(enumerate(class_proportions)) | |
def show_eda(image_name): | |
""" | |
Affiche une analyse exploratoire de la distribution des classes pour une image et son masque associé. | |
Args: | |
image_name (str): Le nom de l'image à analyser. | |
Returns: | |
tuple: Contient l'image originale, le masque réel coloré et une figure Plotly représentant | |
la distribution des classes. | |
""" | |
original_image, true_mask, _ = load_and_prepare_images(image_name) | |
class_proportions = analyse_mask(true_mask, num_labels) | |
cityscapes_colormap = create_cityscapes_label_colormap() | |
true_mask_colored = colorize_mask(true_mask) | |
true_mask_colored = Image.fromarray(true_mask_colored.astype("uint8")) | |
true_mask_colored = true_mask_colored.resize((1024, 512)) | |
# Trier les classes par proportion croissante | |
sorted_classes = sorted( | |
class_proportions.keys(), key=lambda x: class_proportions[x] | |
) | |
# Préparer les données pour le barplot | |
categories = [id2label[i] for i in sorted_classes] | |
values = [class_proportions[i] for i in sorted_classes] | |
color_list = [ | |
f"rgb({cityscapes_colormap[i][0]}, {cityscapes_colormap[i][1]}, {cityscapes_colormap[i][2]})" | |
for i in sorted_classes | |
] | |
# Distribution des classes avec la colormap personnalisée | |
fig = go.Figure() | |
fig.add_trace( | |
go.Bar( | |
x=categories, | |
y=values, | |
marker_color=color_list, | |
text=[f"{v:.2f}" for v in values], | |
textposition="outside", | |
) | |
) | |
# Ajouter un titre et des labels, modifier la rotation et la taille de la police | |
fig.update_layout( | |
title={"text": "Distribution des classes", "font": {"size": 24}}, | |
xaxis_title={"text": "Catégories", "font": {"size": 18}}, | |
yaxis_title={"text": "Proportion", "font": {"size": 18}}, | |
xaxis_tickangle=0, # Rotation modifiée à -45 degrés | |
uniformtext_minsize=12, | |
uniformtext_mode="hide", | |
font=dict(size=14), | |
autosize=True, | |
bargap=0.2, | |
height=600, | |
margin=dict(l=20, r=20, t=50, b=20), | |
) | |
return original_image, true_mask_colored, fig | |
# ----Fin Partie EDA | |
# ----Partie Explication GradCam | |
class SegformerWrapper(nn.Module): | |
""" | |
Un wrapper pour le modèle SegFormer qui renvoie uniquement les logits en sortie. | |
Args: | |
model (torch.nn.Module): Le modèle SegFormer pré-entraîné. | |
""" | |
def __init__(self, model): | |
""" | |
Initialise le SegformerWrapper. | |
Args: | |
model (torch.nn.Module): Le modèle SegFormer pré-entraîné. | |
""" | |
super().__init__() | |
self.model = model | |
def forward(self, x): | |
""" | |
Renvoie les logits du modèle au lieu de renvoyer un dictionnaire. | |
Args: | |
x (torch.Tensor): Les entrées du modèle. | |
Returns: | |
torch.Tensor: Les logits du modèle. | |
""" | |
output = self.model(x) | |
return output.logits | |
class SemanticSegmentationTarget: | |
""" | |
Représente une classe cible pour la segmentation sémantique utilisée dans GradCAM. | |
Args: | |
category (int): L'index de la catégorie cible. | |
mask (numpy.ndarray): Le masque binaire indiquant les pixels d'intérêt. | |
""" | |
def __init__(self, category, mask): | |
""" | |
Initialise la cible de segmentation sémantique. | |
Args: | |
category (int): L'index de la catégorie cible. | |
mask (numpy.ndarray): Le masque binaire indiquant les pixels d'intérêt. | |
""" | |
self.category = category | |
self.mask = torch.from_numpy(mask) | |
if torch.cuda.is_available(): | |
self.mask = self.mask.cuda() | |
def __call__(self, model_output): | |
if isinstance( | |
model_output, (dict, transformers.modeling_outputs.SemanticSegmenterOutput) | |
): | |
logits = ( | |
model_output["logits"] | |
if isinstance(model_output, dict) | |
else model_output.logits | |
) | |
elif isinstance(model_output, torch.Tensor): | |
logits = model_output | |
else: | |
raise ValueError(f"Unexpected model_output type: {type(model_output)}") | |
if logits.dim() == 4: # [batch, classes, height, width] | |
return (logits[0, self.category, :, :] * self.mask).sum() | |
elif logits.dim() == 3: # [classes, height, width] | |
return (logits[self.category, :, :] * self.mask).sum() | |
else: | |
raise ValueError(f"Unexpected logits shape: {logits.shape}") | |
def segformer_reshape_transform_huggingface(tensor, width, height): | |
""" | |
Réorganise les dimensions du tenseur pour qu'elles correspondent au format attendu par GradCAM. | |
Args: | |
tensor (torch.Tensor): Le tenseur à réorganiser. | |
width (int): La nouvelle largeur. | |
height (int): La nouvelle hauteur. | |
Returns: | |
torch.Tensor: Le tenseur réorganisé. | |
""" | |
result = tensor.reshape(tensor.size(0), height, width, tensor.size(2)) | |
result = result.transpose(2, 3).transpose(1, 2) | |
return result | |
def explain_model(image_name, category_name): | |
""" | |
Explique les prédictions du modèle SegFormer en utilisant GradCAM pour une image et une catégorie données. | |
Args: | |
image_name (str): Le nom de l'image à expliquer. | |
category_name (str): Le nom de la catégorie cible. | |
Returns: | |
matplotlib.figure.Figure: Une figure matplotlib contenant la carte de chaleur GradCAM superposée sur l'image originale. | |
""" | |
original_image, _, _ = load_and_prepare_images(image_name) | |
rgb_img = np.float32(original_image) / 255 | |
img_tensor = transforms.ToTensor()(rgb_img) | |
input_tensor = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
)(img_tensor) | |
input_tensor = input_tensor.unsqueeze(0).to(device) | |
wrapped_model = SegformerWrapper(model).to(device) | |
with torch.no_grad(): | |
output = wrapped_model(input_tensor) | |
upsampled_logits = nn.functional.interpolate( | |
output, size=input_tensor.shape[-2:], mode="bilinear", align_corners=False | |
) | |
normalized_masks = torch.nn.functional.softmax(upsampled_logits, dim=1).cpu() | |
category = label2id[category_name] | |
mask = normalized_masks[0].argmax(dim=0).numpy() | |
mask_float = np.float32(mask == category) | |
reshape_transform = partial( | |
segformer_reshape_transform_huggingface, # réorganise les dimensions du tenseur pour qu'elles correspondent au format attendu par GradCAM. | |
width=img_tensor.shape[2] // 32, | |
height=img_tensor.shape[1] // 32, | |
) | |
target_layers = [wrapped_model.model.segformer.encoder.layer_norm[-1]] | |
mask_float_resized = cv2.resize(mask_float, (output.shape[3], output.shape[2])) | |
targets = [SemanticSegmentationTarget(category, mask_float_resized)] | |
cam = GradCAM( | |
model=wrapped_model, | |
target_layers=target_layers, | |
reshape_transform=reshape_transform, | |
) | |
grayscale_cam = cam(input_tensor=input_tensor, targets=targets) | |
threshold = 0.01 # Seuil de 1% de sureté | |
thresholded_cam = grayscale_cam.copy() | |
thresholded_cam[grayscale_cam < threshold] = 0 | |
if np.max(thresholded_cam) > 0: | |
thresholded_cam = thresholded_cam / np.max(thresholded_cam) | |
else: | |
thresholded_cam = grayscale_cam[0] | |
resized_cam = cv2.resize( | |
thresholded_cam[0], (input_tensor.shape[3], input_tensor.shape[2]) | |
) | |
masked_cam = resized_cam * mask_float | |
if np.max(masked_cam) > 0: | |
cam_image = show_cam_on_image(rgb_img, masked_cam, use_rgb=True) | |
else: | |
cam_image = original_image | |
fig, ax = plt.subplots(figsize=(15, 10)) | |
ax.imshow(cam_image) | |
ax.axis("off") | |
ax.set_title(f"Masque de chaleur GradCam pour {category_name}", color="white") | |
margin = 0.02 # Adjust this value to change the size of the margin | |
margin_color = "#0a0f1e" | |
fig.subplots_adjust(left=margin, right=1 - margin, top=1 - margin, bottom=margin) | |
fig.patch.set_facecolor(margin_color) | |
plt.close() | |
return fig | |
# ----Fin Partie Explication GradCam | |
# ----Partie Data augmentation | |
import random | |
def change_image(): | |
""" | |
Sélectionne et charge aléatoirement une image depuis un dossier spécifié. | |
Returns: | |
PIL.Image.Image: L'image sélectionnée. | |
""" | |
image_dir = ( | |
"data_sample/images" # Remplacez par le chemin de votre dossier d'images | |
) | |
image_list = [f for f in os.listdir(image_dir) if f.endswith(".png")] | |
random_image = random.choice(image_list) | |
return Image.open(os.path.join(image_dir, random_image)) | |
def apply_augmentation(image, augmentation_names): | |
""" | |
Applique une ou plusieurs augmentations à une image. | |
Args: | |
image (PIL.Image.Image): L'image à augmenter. | |
augmentation_names (list of str): Les noms des augmentations à appliquer. | |
Returns: | |
PIL.Image.Image: L'image augmentée. | |
""" | |
augmentations = { | |
"Horizontal Flip": A.HorizontalFlip(p=1), | |
"Shift Scale Rotate": A.ShiftScaleRotate(p=1), | |
"Random Brightness Contrast": A.RandomBrightnessContrast(p=1), | |
"RGB Shift": A.RGBShift(p=1), | |
"Blur": A.Blur(blur_limit=(5, 7), p=1), | |
"Gaussian Noise": A.GaussNoise(p=1), | |
"Grid Distortion": A.GridDistortion(p=1), | |
"Random Sun": A.RandomSunFlare(p=1), | |
} | |
image_array = np.array(image) | |
if augmentation_names is not None: | |
selected_augs = [ | |
augmentations[name] for name in augmentation_names if name in augmentations | |
] | |
compose = A.Compose(selected_augs) | |
# Appliquer la composition d'augmentations | |
augmented = compose(image=image_array) | |
return Image.fromarray(augmented["image"]) | |
else: | |
return image | |
# ---- Fin Partie Data augmentation | |
image_list = [ | |
f for f in os.listdir(os.path.join(data_folder, "images")) if f.endswith(".png") | |
] | |
category_list = list(id2label.values()) | |
image_name = "dusseldorf_000012_000019_leftImg8bit.png" | |
default_image = os.path.join(data_folder, "images", image_name) | |
my_theme = gr.Theme.from_hub("gstaff/whiteboard") | |
with gr.Blocks(title="Preuve de concept", theme=my_theme) as demo: | |
gr.Markdown("# Projet 10 - Développer une preuve de concept") | |
with gr.Tab("Distribution"): | |
gr.Markdown("## Distribution des classes Cityscapes") | |
gr.Markdown( | |
"### Visualisation de la distribution de chaque classe selon l'image choisie." | |
) | |
eda_image_input = gr.Dropdown( | |
choices=image_list, | |
label="Sélectionnez une image", | |
) | |
with gr.Row(): | |
original_image_output = gr.Image(type="pil", label="Image originale") | |
original_mask_output = gr.Image(type="pil", label="Masque original") | |
class_distribution_plot = gr.Plot(label="Distribution des classes") | |
eda_image_input.change( | |
fn=show_eda, | |
inputs=eda_image_input, | |
outputs=[ | |
original_image_output, | |
original_mask_output, | |
class_distribution_plot, | |
], | |
) | |
with gr.Tab("Data Augmentation"): | |
gr.Markdown("## Visualisation de l'augmentation des données") | |
gr.Markdown( | |
"### Sélectionnez une ou plusieurs augmentations pour l'appliquer à l'image." | |
) | |
gr.Markdown("### Vous pouvez également changer d'image.") | |
with gr.Row(): | |
image_display = gr.Image( | |
value=default_image, | |
label="Image", | |
show_download_button=False, | |
interactive=False, | |
) | |
augmented_image = gr.Image(label="Image Augmentée") | |
with gr.Row(): | |
change_image_button = gr.Button("Changer image") | |
augmentation_dropdown = gr.Dropdown( | |
choices=[ | |
"Horizontal Flip", | |
"Shift Scale Rotate", | |
"Random Brightness Contrast", | |
"RGB Shift", | |
"Blur", | |
"Gaussian Noise", | |
"Grid Distortion", | |
"Random Sun", | |
], | |
label="Sélectionnez une augmentation", | |
multiselect=True, | |
) | |
apply_button = gr.Button("Appliquer l'augmentation") | |
change_image_button.click(fn=change_image, outputs=image_display) | |
apply_button.click( | |
fn=apply_augmentation, | |
inputs=[image_display, augmentation_dropdown], | |
outputs=augmented_image, | |
) | |
with gr.Tab("Prédictions"): | |
gr.Markdown("## Comparaison de segmentations d'images Cityscapes") | |
gr.Markdown( | |
"### Sélectionnez une image pour voir la comparaison entre le masque réel, la prédiction FPN (pré-enregistré) et la prédiction du modèle SegFormer." | |
) | |
image_input = gr.Dropdown(choices=image_list, label="Sélectionnez une image") | |
gallery_output = gr.Gallery( | |
label="Résultats de segmentation", | |
show_label=True, | |
elem_id="gallery", | |
columns=[2], | |
rows=[2], | |
object_fit="contain", | |
height="512px", | |
min_width="1024px", | |
) | |
image_input.change(fn=process_image, inputs=image_input, outputs=gallery_output) | |
with gr.Tab("Explication SegFormer"): | |
gr.Markdown("## Explication du modèle SegFormer") | |
gr.Markdown( | |
"### La méthode Grad-CAM est une technique populaire de visualisation qui est utile pour comprendre comment un réseau neuronal convolutif a été conduit à prendre une décision de classification. Elle est spécifique à chaque classe, ce qui signifie qu’elle peut produire une visualisation distincte pour chaque classe présente dans l’image." | |
) | |
gr.Markdown( | |
"### NB: Si l'image s'affiche sans masque, c'est que le modèle ne trouve pas de zones significatives pour une catégorie donnée." | |
) | |
with gr.Row(): | |
explain_image_input = gr.Dropdown( | |
choices=image_list, label="Sélectionnez une image" | |
) | |
explain_category_input = gr.Dropdown( | |
choices=category_list, label="Sélectionnez une catégorie" | |
) | |
explain_button = gr.Button("Expliquer") | |
explain_output = gr.Plot(label="Explication SegFormer", min_width=200) | |
explain_button.click( | |
fn=explain_model, | |
inputs=[explain_image_input, explain_category_input], | |
outputs=explain_output, | |
) | |
# Lancer l'application | |
demo.launch(favicon_path="favicon.ico") | |