Nunzio commited on
Commit
05e5639
Β·
1 Parent(s): 9fe5dbe

added more weights

Browse files
app.py CHANGED
@@ -13,8 +13,10 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
 
14
 
15
  MODELS = {
16
- "BISENET": loadBiSeNet(device),
17
- "BISENETV2": loadBiSeNetV2(device)
 
 
18
  }
19
 
20
  image_list = loadPreloadedImages(gta_image_dir, city_image_dir, turin_image_dir)
@@ -68,8 +70,8 @@ with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
68
  with gr.Column():
69
  image_input = gr.Image(type="pil", label="Upload image")
70
  model_selector = gr.Radio(
71
- choices=["BiSeNet", "BiSeNetV2"],
72
- value="BiSeNet",
73
  label="Select the real time segmentation model"
74
  )
75
  submit_btn = gr.Button("Run prediction")
@@ -82,10 +84,14 @@ with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
82
  for i in range(0, len(legend), 2):
83
  with gr.Row():
84
  with gr.Column(scale=1):
85
- gr.Markdown(f"**{legend[i][1]}** β†’ {legend[i][2]}")
 
 
86
  with gr.Column(scale=1):
87
  if i + 1 < len(legend):
88
- gr.Markdown(f"**{legend[i+1][1]}** β†’ {legend[i+1][2]}")
 
 
89
  else:
90
  gr.Markdown("") # Keeps spacing consistent if list is odd
91
 
 
13
 
14
 
15
  MODELS = {
16
+ "BISENET-BASE": loadBiSeNet(device, 'weight_Base.pth'),
17
+ "BISENET-BEST": loadBiSeNet(device, 'weight_Best.pth'),
18
+ "BISENETV2-BASE": loadBiSeNetV2(device, 'weight_Base.pth'),
19
+ "BISENETV2-BEST": loadBiSeNetV2(device, 'weight_Best.pth')
20
  }
21
 
22
  image_list = loadPreloadedImages(gta_image_dir, city_image_dir, turin_image_dir)
 
70
  with gr.Column():
71
  image_input = gr.Image(type="pil", label="Upload image")
72
  model_selector = gr.Radio(
73
+ choices=["BiSeNet-base", "BiSeNet-Best", "BiSeNetV2-base", "BiSeNetV2-Best"],
74
+ value="BiSeNet-base",
75
  label="Select the real time segmentation model"
76
  )
77
  submit_btn = gr.Button("Run prediction")
 
84
  for i in range(0, len(legend), 2):
85
  with gr.Row():
86
  with gr.Column(scale=1):
87
+ color_box0 = f"""<span style='display:inline-block; width:15px; height:15px;
88
+ background-color:rgb({legend[i][1][0]},{legend[i][1][1]},{legend[i][1][2]}); margin-left:6px; border:1px solid #000;'></span>"""
89
+ gr.Markdown(f"**{legend[i][1]}** β†’ {legend[i][2]} {color_box0}")
90
  with gr.Column(scale=1):
91
  if i + 1 < len(legend):
92
+ color_box1 = f"""<span style='display:inline-block; width:15px; height:15px;
93
+ background-color:rgb({legend[i+1][1][0]},{legend[i+1][1][1]},{legend[i+1][1][2]}); margin-left:6px; border:1px solid #000;'></span>"""
94
+ gr.Markdown(f"**{legend[i+1][1]}** β†’ {legend[i+1][2]} {color_box1}")
95
  else:
96
  gr.Markdown("") # Keeps spacing consistent if list is odd
97
 
model/modelLoading.py CHANGED
@@ -5,18 +5,19 @@ from model.BiSeNetV2.model import BiSeNetV2
5
 
6
 
7
  # BiSeNet model loading function
8
- def loadBiSeNet(device: str = 'cpu') -> BiSeNet:
9
  """
10
  Load the BiSeNet model and move it to the specified device.
11
 
12
  Args:
13
  device (str): Device to load the model onto ('cpu' or 'cuda').
 
14
 
15
  Returns:
16
  model (BiSeNet): The loaded BiSeNet model.
17
  """
18
  model = BiSeNet(num_classes=19, context_path='resnet18').to(device)
19
- model.load_state_dict(torch.load('./weights/BiSeNet/weightADV.pth', map_location=device)['model_state_dict'])
20
  model.eval()
21
 
22
  return model
 
5
 
6
 
7
  # BiSeNet model loading function
8
+ def loadBiSeNet(device: str = 'cpu', weights:str='weightADV.pth') -> BiSeNet:
9
  """
10
  Load the BiSeNet model and move it to the specified device.
11
 
12
  Args:
13
  device (str): Device to load the model onto ('cpu' or 'cuda').
14
+ weights (str): weights file to be loaded
15
 
16
  Returns:
17
  model (BiSeNet): The loaded BiSeNet model.
18
  """
19
  model = BiSeNet(num_classes=19, context_path='resnet18').to(device)
20
+ model.load_state_dict(torch.load(f'./weights/BiSeNet/{weights}', map_location=device)['model_state_dict'])
21
  model.eval()
22
 
23
  return model
utils/imageHandling.py CHANGED
@@ -81,10 +81,11 @@ def legendHandling()->list[int, str, str]:
81
  - Color description (str)
82
  The list is sorted by class ID.
83
  """
84
- return sorted([[0, "road", "dark purple"], [1, "sidewalk", "light purple / pink"], [2, "building", "dark gray"], [3, "wall", "blue + grey"],
85
- [4, "fence", "beige"], [5, "pole", "grey"], [6, "traffic light", "orange"], [7, "traffic sign", "yellow"], [8, "vegetation", "dark green"],
86
- [9, "terrain", "light green"], [10, "sky", "blue"], [11, "person", "dark red"], [12, "rider", "light red"], [13, "car", "blue"],
87
- [14, "truck", "dark blue"], [15, "bus", "dark blue"], [16, "train", "blue + green"], [17, "motorcycle", "light blue"], [18, "bicycle", "velvet"]
 
88
  ], key=lambda x: x[0])
89
 
90
 
 
81
  - Color description (str)
82
  The list is sorted by class ID.
83
  """
84
+ return sorted([[0, "road", "dark purple", (128, 64, 128)], [1, "sidewalk", "light purple / pink", (244, 35, 232)], [2, "building", "dark gray", (70, 70, 70)],
85
+ [3, "wall", "blue + grey", (102, 102, 156)], [4, "fence", "beige", (190, 153, 153)], [5, "pole", "grey", (153, 153, 153)], [6, "traffic light", "orange", (250, 170, 30)],
86
+ [7, "traffic sign", "yellow", (220, 220, 0)], [8, "vegetation", "dark green", (107, 142, 35)], [9, "terrain", "light green", (152, 251, 152)], [10, "sky", "blue", (70, 130, 180)],
87
+ [11, "person", "dark red", (220, 20, 60)], [12, "rider", "light red", (255, 0, 0)], [13, "car", "blue", (0, 0, 142)], [14, "truck", "dark blue", (0, 0, 70)],
88
+ [15, "bus", "dark blue", (0, 60, 100)], [16, "train", "blue + green", (0, 80, 100)], [17, "motorcycle", "light blue", (0, 0, 230)], [18, "bicycle", "velvet", (119, 11, 32)]
89
  ], key=lambda x: x[0])
90
 
91
 
weights/BiSeNet/weight_Base.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e81645b4453ec62eda383d16e9cf279607181975dd648fdd354c78b050e7c8ca
3
+ size 50456242
weights/BiSeNet/{weightADV.pth β†’ weight_Best.pth} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:880db4160f20c87aecc13845ad691b1963fbce3d713b1dda1964457b9e0d8f0a
3
  size 121015606
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:372b7ecbd6c9e0fc6909370a58e3445552f3379db884ca984cd94d3de0cac66c
3
  size 121015606
weights/BiSeNetV2/weight_Base.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ff01a531a19ec5d63837d48c4f8157d8b6fa0e04960035bdd26a40c39cb5b67
3
+ size 21146508
weights/BiSeNetV2/{weightADV.pth β†’ weight_Best.pth} RENAMED
File without changes