Nunzio commited on
Commit
e749740
·
1 Parent(s): 53e85ab

added legend

Browse files
Files changed (2) hide show
  1. app.py +10 -1
  2. utils/imageHandling.py +18 -0
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  import gradio as gr
3
 
4
- from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing, loadPreloadedImages
5
  from model.modelLoading import loadBiSeNet, loadBiSeNetV2
6
 
7
 
@@ -76,6 +76,15 @@ with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
76
  with gr.Column():
77
  result_display = gr.Image(label="Model prediction", visible=True)
78
  error_text = gr.Markdown("", visible=False)
 
 
 
 
 
 
 
 
 
79
 
80
  with gr.Row():
81
  gr.Markdown("## Preloaded images to be used for testing the model")
 
1
  import torch
2
  import gradio as gr
3
 
4
+ from utils.imageHandling import hfImageToTensor, preprocessing, postprocessing, loadPreloadedImages, legendHandling
5
  from model.modelLoading import loadBiSeNet, loadBiSeNetV2
6
 
7
 
 
76
  with gr.Column():
77
  result_display = gr.Image(label="Model prediction", visible=True)
78
  error_text = gr.Markdown("", visible=False)
79
+ with gr.Row():
80
+ gr.Markdown("The legend of the classes is the following (format **ID**: **name** - Color: **color**)")
81
+ string = ""
82
+ for i, (id, name, color) in enumerate(legendHandling()):
83
+ if i and not i % 3:
84
+ gr.Markdown(string)
85
+ string = ""
86
+
87
+ string += f"**{id}**: {name} - Color: {color}"
88
 
89
  with gr.Row():
90
  gr.Markdown("## Preloaded images to be used for testing the model")
utils/imageHandling.py CHANGED
@@ -70,6 +70,24 @@ def print_mask(mask:torch.Tensor, numClasses:int=19)->None:
70
  new_mask[mask == i] = torch.tensor(colors[i][:3], dtype=torch.uint8)
71
  return new_mask.permute(2,0,1)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # %% postprocessing
74
  def postprocessing(pred: torch.Tensor) -> torch.Tensor:
75
  """
 
70
  new_mask[mask == i] = torch.tensor(colors[i][:3], dtype=torch.uint8)
71
  return new_mask.permute(2,0,1)
72
 
73
+
74
+ def legendHandling()->list[int, str, str]:
75
+ """
76
+ Returns a sorted list of tuples containing class IDs, names, and colors for semantic segmentation.
77
+
78
+ Each tuple contains:
79
+ - Class ID (int)
80
+ - Class name (str)
81
+ - Color description (str)
82
+ The list is sorted by class ID.
83
+ """
84
+ return sorted([[0, "road", "dark purple"], [1, "sidewalk", "light purple / pink"], [2, "building", "dark gray"], [3, "wall", "blue + grey"],
85
+ [4, "fence", "beige"], [5, "pole", "grey"], [6, "traffic light", "orange"], [7, "traffic sign", "yellow"], [8, "vegetation", "dark green"],
86
+ [9, "terrain", "light green"], [10, "sky", "blue"], [11, "person", "dark red"], [12, "rider", "light red"], [13, "car", "blue"],
87
+ [14, "truck", "dark blue"], [15, "bus", "dark blue"], [16, "train", "blue + green"], [17, "motorcycle", "light blue"], [18, "bicycle", "velvet"]
88
+ ], key=lambda x: x[0])
89
+
90
+
91
  # %% postprocessing
92
  def postprocessing(pred: torch.Tensor) -> torch.Tensor:
93
  """