SegFormer-Model / app.py
Aurel-test's picture
Change theme
4436966
import gradio as gr
import torch
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from PIL import Image
import plotly.graph_objects as go
import numpy as np
import os
import torch.nn as nn
from sklearn.metrics import jaccard_score, accuracy_score
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn.functional as F
import seaborn as sns
from functools import partial
from pytorch_grad_cam.utils.image import (
show_cam_on_image,
preprocess_image as grad_preprocess,
)
from pytorch_grad_cam import GradCAM
import cv2
import transformers
from torchvision import transforms
import albumentations as A
device = "cuda" if torch.cuda.is_available() else "cpu"
data_folder = "data_sample"
id2label = {
0: "void",
1: "flat",
2: "construction",
3: "object",
4: "nature",
5: "sky",
6: "human",
7: "vehicle",
}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)
checkpoint = "nvidia/segformer-b3-finetuned-cityscapes-1024-1024"
image_processor = SegformerImageProcessor(do_resize=False)
state_dict_path = f"runs/{checkpoint}/best_model.pt"
model = SegformerForSemanticSegmentation.from_pretrained(
checkpoint,
num_labels=num_labels,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
loaded_state_dict = torch.load(
state_dict_path, map_location=torch.device("cpu"), weights_only=True
)
model.load_state_dict(loaded_state_dict)
model = model.to(device)
model.eval()
# ---- Partie Segmentation
def load_and_prepare_images(image_name, segformer=False):
"""
Charge et prépare les images, les masques et les prédictions associées pour une image donnée.
Args:
image_name (str): Le nom du fichier de l'image à charger.
segformer (bool, optional): Si True, prédit également le masque avec SegFormer. Par défaut False.
Returns:
tuple: Contient l'image originale redimensionnée, le masque réel, la prédiction FPN,
et la prédiction SegFormer si `segformer` est True.
"""
image_path = os.path.join(data_folder, "images", image_name)
mask_name = image_name.replace("_leftImg8bit.png", "_gtFine_labelIds.png")
mask_path = os.path.join(data_folder, "masks", mask_name)
fpn_pred_path = os.path.join(data_folder, "resnet101_mask", image_name)
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found: {image_path}")
if not os.path.exists(mask_path):
raise FileNotFoundError(f"Mask not found: {mask_path}")
if not os.path.exists(fpn_pred_path):
raise FileNotFoundError(f"FPN prediction not found: {fpn_pred_path}")
original_image = Image.open(image_path).convert("RGB")
original = original_image.resize((1024, 512))
true_mask = np.array(Image.open(mask_path))
fpn_pred = np.array(Image.open(fpn_pred_path))
if segformer:
segformer_pred = predict_segmentation(original)
return original, true_mask, fpn_pred, segformer_pred
return original, true_mask, fpn_pred
def predict_segmentation(image):
"""
Prédit la segmentation d'une image donnée à l'aide d'un modèle pré-entraîné.
Args:
image (PIL.Image.Image): L'image à segmenter.
Returns:
numpy.ndarray: La carte de segmentation prédite.
"""
inputs = image_processor(images=image, return_tensors="pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
pixel_values = inputs.pixel_values.to(device)
with torch.no_grad():
outputs = model(pixel_values=pixel_values)
logits = outputs.logits
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1], # (height, width)
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
return pred_seg
def process_image(image_name):
"""
Traite une image en chargeant l'image originale, le masque réel, et les prédictions de masques.
Envoie la liste de tuple à l'interface "Predictions" de Gradio
Args:
image_name (str): Le nom de l'image à traiter.
Returns:
list: Une liste de tuples contenant l'image et son titre associé.
"""
original, true_mask, fpn_pred, segformer_pred = load_and_prepare_images(
image_name, segformer=True
)
true_mask_colored = colorize_mask(true_mask)
true_mask_colored = Image.fromarray(true_mask_colored.astype("uint8"))
true_mask_colored = true_mask_colored.resize((1024, 512))
# fpn_pred_colored = colorize_mask(fpn_pred)
segformer_pred_colored = colorize_mask(segformer_pred)
segformer_pred_colored = Image.fromarray(segformer_pred_colored.astype("uint8"))
segformer_pred_colored = segformer_pred_colored.resize((1024, 512))
return [
(original, "Image originale"),
(true_mask_colored, "Masque réel"),
(fpn_pred, "Prédiction FPN"),
(segformer_pred_colored, "Prédiction SegFormer"),
]
def create_cityscapes_label_colormap():
"""
Crée une colormap pour les labels Cityscapes.
Returns:
numpy.ndarray: Un tableau 2D où chaque ligne représente la couleur RGB d'un label.
"""
colormap = np.zeros((256, 3), dtype=np.uint8)
colormap[0] = [78, 82, 110]
colormap[1] = [128, 64, 128]
colormap[2] = [154, 156, 153]
colormap[3] = [168, 167, 18]
colormap[4] = [80, 108, 28]
colormap[5] = [112, 164, 196]
colormap[6] = [168, 28, 52]
colormap[7] = [16, 18, 112]
return colormap
# Créer la colormap une fois
cityscapes_colormap = create_cityscapes_label_colormap()
def colorize_mask(mask):
return cityscapes_colormap[mask]
# ---- Fin Partie Segmentation
# ---- Partie EDA
def analyse_mask(real_mask, num_labels):
"""
Analyse la distribution des classes dans un masque réel.
Args:
real_mask (numpy.ndarray): Le masque de labels réels.
num_labels (int): Le nombre total de classes.
Returns:
dict: Un dictionnaire contenant les proportions des classes dans le masque.
"""
counts = np.bincount(real_mask.ravel(), minlength=num_labels)
total_pixels = real_mask.size
class_proportions = counts / total_pixels
return dict(enumerate(class_proportions))
def show_eda(image_name):
"""
Affiche une analyse exploratoire de la distribution des classes pour une image et son masque associé.
Args:
image_name (str): Le nom de l'image à analyser.
Returns:
tuple: Contient l'image originale, le masque réel coloré et une figure Plotly représentant
la distribution des classes.
"""
original_image, true_mask, _ = load_and_prepare_images(image_name)
class_proportions = analyse_mask(true_mask, num_labels)
cityscapes_colormap = create_cityscapes_label_colormap()
true_mask_colored = colorize_mask(true_mask)
true_mask_colored = Image.fromarray(true_mask_colored.astype("uint8"))
true_mask_colored = true_mask_colored.resize((1024, 512))
# Trier les classes par proportion croissante
sorted_classes = sorted(
class_proportions.keys(), key=lambda x: class_proportions[x]
)
# Préparer les données pour le barplot
categories = [id2label[i] for i in sorted_classes]
values = [class_proportions[i] for i in sorted_classes]
color_list = [
f"rgb({cityscapes_colormap[i][0]}, {cityscapes_colormap[i][1]}, {cityscapes_colormap[i][2]})"
for i in sorted_classes
]
# Distribution des classes avec la colormap personnalisée
fig = go.Figure()
fig.add_trace(
go.Bar(
x=categories,
y=values,
marker_color=color_list,
text=[f"{v:.2f}" for v in values],
textposition="outside",
)
)
# Ajouter un titre et des labels, modifier la rotation et la taille de la police
fig.update_layout(
title={"text": "Distribution des classes", "font": {"size": 24}},
xaxis_title={"text": "Catégories", "font": {"size": 18}},
yaxis_title={"text": "Proportion", "font": {"size": 18}},
xaxis_tickangle=0, # Rotation modifiée à -45 degrés
uniformtext_minsize=12,
uniformtext_mode="hide",
font=dict(size=14),
autosize=True,
bargap=0.2,
height=600,
margin=dict(l=20, r=20, t=50, b=20),
)
return original_image, true_mask_colored, fig
# ----Fin Partie EDA
# ----Partie Explication GradCam
class SegformerWrapper(nn.Module):
"""
Un wrapper pour le modèle SegFormer qui renvoie uniquement les logits en sortie.
Args:
model (torch.nn.Module): Le modèle SegFormer pré-entraîné.
"""
def __init__(self, model):
"""
Initialise le SegformerWrapper.
Args:
model (torch.nn.Module): Le modèle SegFormer pré-entraîné.
"""
super().__init__()
self.model = model
def forward(self, x):
"""
Renvoie les logits du modèle au lieu de renvoyer un dictionnaire.
Args:
x (torch.Tensor): Les entrées du modèle.
Returns:
torch.Tensor: Les logits du modèle.
"""
output = self.model(x)
return output.logits
class SemanticSegmentationTarget:
"""
Représente une classe cible pour la segmentation sémantique utilisée dans GradCAM.
Args:
category (int): L'index de la catégorie cible.
mask (numpy.ndarray): Le masque binaire indiquant les pixels d'intérêt.
"""
def __init__(self, category, mask):
"""
Initialise la cible de segmentation sémantique.
Args:
category (int): L'index de la catégorie cible.
mask (numpy.ndarray): Le masque binaire indiquant les pixels d'intérêt.
"""
self.category = category
self.mask = torch.from_numpy(mask)
if torch.cuda.is_available():
self.mask = self.mask.cuda()
def __call__(self, model_output):
if isinstance(
model_output, (dict, transformers.modeling_outputs.SemanticSegmenterOutput)
):
logits = (
model_output["logits"]
if isinstance(model_output, dict)
else model_output.logits
)
elif isinstance(model_output, torch.Tensor):
logits = model_output
else:
raise ValueError(f"Unexpected model_output type: {type(model_output)}")
if logits.dim() == 4: # [batch, classes, height, width]
return (logits[0, self.category, :, :] * self.mask).sum()
elif logits.dim() == 3: # [classes, height, width]
return (logits[self.category, :, :] * self.mask).sum()
else:
raise ValueError(f"Unexpected logits shape: {logits.shape}")
def segformer_reshape_transform_huggingface(tensor, width, height):
"""
Réorganise les dimensions du tenseur pour qu'elles correspondent au format attendu par GradCAM.
Args:
tensor (torch.Tensor): Le tenseur à réorganiser.
width (int): La nouvelle largeur.
height (int): La nouvelle hauteur.
Returns:
torch.Tensor: Le tenseur réorganisé.
"""
result = tensor.reshape(tensor.size(0), height, width, tensor.size(2))
result = result.transpose(2, 3).transpose(1, 2)
return result
def explain_model(image_name, category_name):
"""
Explique les prédictions du modèle SegFormer en utilisant GradCAM pour une image et une catégorie données.
Args:
image_name (str): Le nom de l'image à expliquer.
category_name (str): Le nom de la catégorie cible.
Returns:
matplotlib.figure.Figure: Une figure matplotlib contenant la carte de chaleur GradCAM superposée sur l'image originale.
"""
original_image, _, _ = load_and_prepare_images(image_name)
rgb_img = np.float32(original_image) / 255
img_tensor = transforms.ToTensor()(rgb_img)
input_tensor = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)(img_tensor)
input_tensor = input_tensor.unsqueeze(0).to(device)
wrapped_model = SegformerWrapper(model).to(device)
with torch.no_grad():
output = wrapped_model(input_tensor)
upsampled_logits = nn.functional.interpolate(
output, size=input_tensor.shape[-2:], mode="bilinear", align_corners=False
)
normalized_masks = torch.nn.functional.softmax(upsampled_logits, dim=1).cpu()
category = label2id[category_name]
mask = normalized_masks[0].argmax(dim=0).numpy()
mask_float = np.float32(mask == category)
reshape_transform = partial(
segformer_reshape_transform_huggingface, # réorganise les dimensions du tenseur pour qu'elles correspondent au format attendu par GradCAM.
width=img_tensor.shape[2] // 32,
height=img_tensor.shape[1] // 32,
)
target_layers = [wrapped_model.model.segformer.encoder.layer_norm[-1]]
mask_float_resized = cv2.resize(mask_float, (output.shape[3], output.shape[2]))
targets = [SemanticSegmentationTarget(category, mask_float_resized)]
cam = GradCAM(
model=wrapped_model,
target_layers=target_layers,
reshape_transform=reshape_transform,
)
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
threshold = 0.01 # Seuil de 1% de sureté
thresholded_cam = grayscale_cam.copy()
thresholded_cam[grayscale_cam < threshold] = 0
if np.max(thresholded_cam) > 0:
thresholded_cam = thresholded_cam / np.max(thresholded_cam)
else:
thresholded_cam = grayscale_cam[0]
resized_cam = cv2.resize(
thresholded_cam[0], (input_tensor.shape[3], input_tensor.shape[2])
)
masked_cam = resized_cam * mask_float
if np.max(masked_cam) > 0:
cam_image = show_cam_on_image(rgb_img, masked_cam, use_rgb=True)
else:
cam_image = original_image
fig, ax = plt.subplots(figsize=(15, 10))
ax.imshow(cam_image)
ax.axis("off")
ax.set_title(f"Masque de chaleur GradCam pour {category_name}", color="white")
margin = 0.02 # Adjust this value to change the size of the margin
margin_color = "#0a0f1e"
fig.subplots_adjust(left=margin, right=1 - margin, top=1 - margin, bottom=margin)
fig.patch.set_facecolor(margin_color)
plt.close()
return fig
# ----Fin Partie Explication GradCam
# ----Partie Data augmentation
import random
def change_image():
"""
Sélectionne et charge aléatoirement une image depuis un dossier spécifié.
Returns:
PIL.Image.Image: L'image sélectionnée.
"""
image_dir = (
"data_sample/images" # Remplacez par le chemin de votre dossier d'images
)
image_list = [f for f in os.listdir(image_dir) if f.endswith(".png")]
random_image = random.choice(image_list)
return Image.open(os.path.join(image_dir, random_image))
def apply_augmentation(image, augmentation_names):
"""
Applique une ou plusieurs augmentations à une image.
Args:
image (PIL.Image.Image): L'image à augmenter.
augmentation_names (list of str): Les noms des augmentations à appliquer.
Returns:
PIL.Image.Image: L'image augmentée.
"""
augmentations = {
"Horizontal Flip": A.HorizontalFlip(p=1),
"Shift Scale Rotate": A.ShiftScaleRotate(p=1),
"Random Brightness Contrast": A.RandomBrightnessContrast(p=1),
"RGB Shift": A.RGBShift(p=1),
"Blur": A.Blur(blur_limit=(5, 7), p=1),
"Gaussian Noise": A.GaussNoise(p=1),
"Grid Distortion": A.GridDistortion(p=1),
"Random Sun": A.RandomSunFlare(p=1),
}
image_array = np.array(image)
if augmentation_names is not None:
selected_augs = [
augmentations[name] for name in augmentation_names if name in augmentations
]
compose = A.Compose(selected_augs)
# Appliquer la composition d'augmentations
augmented = compose(image=image_array)
return Image.fromarray(augmented["image"])
else:
return image
# ---- Fin Partie Data augmentation
image_list = [
f for f in os.listdir(os.path.join(data_folder, "images")) if f.endswith(".png")
]
category_list = list(id2label.values())
image_name = "dusseldorf_000012_000019_leftImg8bit.png"
default_image = os.path.join(data_folder, "images", image_name)
my_theme = gr.Theme.from_hub("gstaff/whiteboard")
with gr.Blocks(title="Preuve de concept", theme=my_theme) as demo:
gr.Markdown("# Projet 10 - Développer une preuve de concept")
with gr.Tab("Distribution"):
gr.Markdown("## Distribution des classes Cityscapes")
gr.Markdown(
"### Visualisation de la distribution de chaque classe selon l'image choisie."
)
eda_image_input = gr.Dropdown(
choices=image_list,
label="Sélectionnez une image",
)
with gr.Row():
original_image_output = gr.Image(type="pil", label="Image originale")
original_mask_output = gr.Image(type="pil", label="Masque original")
class_distribution_plot = gr.Plot(label="Distribution des classes")
eda_image_input.change(
fn=show_eda,
inputs=eda_image_input,
outputs=[
original_image_output,
original_mask_output,
class_distribution_plot,
],
)
with gr.Tab("Data Augmentation"):
gr.Markdown("## Visualisation de l'augmentation des données")
gr.Markdown(
"### Sélectionnez une ou plusieurs augmentations pour l'appliquer à l'image."
)
gr.Markdown("### Vous pouvez également changer d'image.")
with gr.Row():
image_display = gr.Image(
value=default_image,
label="Image",
show_download_button=False,
interactive=False,
)
augmented_image = gr.Image(label="Image Augmentée")
with gr.Row():
change_image_button = gr.Button("Changer image")
augmentation_dropdown = gr.Dropdown(
choices=[
"Horizontal Flip",
"Shift Scale Rotate",
"Random Brightness Contrast",
"RGB Shift",
"Blur",
"Gaussian Noise",
"Grid Distortion",
"Random Sun",
],
label="Sélectionnez une augmentation",
multiselect=True,
)
apply_button = gr.Button("Appliquer l'augmentation")
change_image_button.click(fn=change_image, outputs=image_display)
apply_button.click(
fn=apply_augmentation,
inputs=[image_display, augmentation_dropdown],
outputs=augmented_image,
)
with gr.Tab("Prédictions"):
gr.Markdown("## Comparaison de segmentations d'images Cityscapes")
gr.Markdown(
"### Sélectionnez une image pour voir la comparaison entre le masque réel, la prédiction FPN (pré-enregistré) et la prédiction du modèle SegFormer."
)
image_input = gr.Dropdown(choices=image_list, label="Sélectionnez une image")
gallery_output = gr.Gallery(
label="Résultats de segmentation",
show_label=True,
elem_id="gallery",
columns=[2],
rows=[2],
object_fit="contain",
height="512px",
min_width="1024px",
)
image_input.change(fn=process_image, inputs=image_input, outputs=gallery_output)
with gr.Tab("Explication SegFormer"):
gr.Markdown("## Explication du modèle SegFormer")
gr.Markdown(
"### La méthode Grad-CAM est une technique populaire de visualisation qui est utile pour comprendre comment un réseau neuronal convolutif a été conduit à prendre une décision de classification. Elle est spécifique à chaque classe, ce qui signifie qu’elle peut produire une visualisation distincte pour chaque classe présente dans l’image."
)
gr.Markdown(
"### NB: Si l'image s'affiche sans masque, c'est que le modèle ne trouve pas de zones significatives pour une catégorie donnée."
)
with gr.Row():
explain_image_input = gr.Dropdown(
choices=image_list, label="Sélectionnez une image"
)
explain_category_input = gr.Dropdown(
choices=category_list, label="Sélectionnez une catégorie"
)
explain_button = gr.Button("Expliquer")
explain_output = gr.Plot(label="Explication SegFormer", min_width=200)
explain_button.click(
fn=explain_model,
inputs=[explain_image_input, explain_category_input],
outputs=explain_output,
)
# Lancer l'application
demo.launch(favicon_path="favicon.ico")