# fonctions_api.py (version allégée pour l'API FastAPI avec SegFormer) import torch import torch.nn.functional as F import numpy as np from pathlib import Path from transformers import SegformerForSemanticSegmentation # -------- Palette (optionnelle pour colorisation) -------- PALETTE = { 0: (0, 0, 0), # void 1: (50, 50, 150), # flat 2: (102, 0, 204), # construction 3: (255, 85, 0), # object 4: (255, 255, 0), # nature 5: (0, 255, 255), # sky 6: (255, 0, 255), # human 7: (255, 255, 255), # vehicle } # -------- Fonction principale pour charger SegFormer -------- def charger_segformer(num_classes=8): model = SegformerForSemanticSegmentation.from_pretrained( "nvidia/segformer-b5-finetuned-ade-640-640", num_labels=num_classes, ignore_mismatched_sizes=True ) model.config.num_labels = num_classes model.config.output_hidden_states = False return model # -------- Remapping Cityscapes labelIds vers 8 classes -------- def remap_classes(mask: np.ndarray) -> np.ndarray: labelIds_to_main_classes = { 0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 1, 8: 1, 9: 0, 10: 0, 11: 2, 12: 2, 13: 2, 14: 0, 15: 0, 16: 0, 17: 3, 18: 3, 19: 3, 20: 3, 21: 4, 22: 4, 23: 5, 24: 6, 25: 6, 26: 7, 27: 7, 28: 7, 29: 7, 30: 7, 31: 7, 32: 7, 33: 7 } remapped_mask = np.copy(mask) for original_class, new_class in labelIds_to_main_classes.items(): remapped_mask[mask == original_class] = new_class remapped_mask[mask > 33] = 0 return remapped_mask.astype(np.uint8) # -------- Convertit un masque 2D en image RGB (optionnel) -------- def decode_cityscapes_mask(mask): h, w = mask.shape mask_rgb = np.zeros((h, w, 3), dtype=np.uint8) for class_id, color in PALETTE.items(): mask_rgb[mask == class_id] = color return mask_rgb