Spaces:
Sleeping
Sleeping
Nunzio commited on
Commit Β·
05e5639
1
Parent(s): 9fe5dbe
added more weights
Browse files
app.py
CHANGED
|
@@ -13,8 +13,10 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
| 13 |
|
| 14 |
|
| 15 |
MODELS = {
|
| 16 |
-
"BISENET": loadBiSeNet(device),
|
| 17 |
-
"
|
|
|
|
|
|
|
| 18 |
}
|
| 19 |
|
| 20 |
image_list = loadPreloadedImages(gta_image_dir, city_image_dir, turin_image_dir)
|
|
@@ -68,8 +70,8 @@ with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
|
|
| 68 |
with gr.Column():
|
| 69 |
image_input = gr.Image(type="pil", label="Upload image")
|
| 70 |
model_selector = gr.Radio(
|
| 71 |
-
choices=["BiSeNet", "BiSeNetV2"],
|
| 72 |
-
value="BiSeNet",
|
| 73 |
label="Select the real time segmentation model"
|
| 74 |
)
|
| 75 |
submit_btn = gr.Button("Run prediction")
|
|
@@ -82,10 +84,14 @@ with gr.Blocks(title="Semantic Segmentation Predictors") as demo:
|
|
| 82 |
for i in range(0, len(legend), 2):
|
| 83 |
with gr.Row():
|
| 84 |
with gr.Column(scale=1):
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
with gr.Column(scale=1):
|
| 87 |
if i + 1 < len(legend):
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
else:
|
| 90 |
gr.Markdown("") # Keeps spacing consistent if list is odd
|
| 91 |
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
MODELS = {
|
| 16 |
+
"BISENET-BASE": loadBiSeNet(device, 'weight_Base.pth'),
|
| 17 |
+
"BISENET-BEST": loadBiSeNet(device, 'weight_Best.pth'),
|
| 18 |
+
"BISENETV2-BASE": loadBiSeNetV2(device, 'weight_Base.pth'),
|
| 19 |
+
"BISENETV2-BEST": loadBiSeNetV2(device, 'weight_Best.pth')
|
| 20 |
}
|
| 21 |
|
| 22 |
image_list = loadPreloadedImages(gta_image_dir, city_image_dir, turin_image_dir)
|
|
|
|
| 70 |
with gr.Column():
|
| 71 |
image_input = gr.Image(type="pil", label="Upload image")
|
| 72 |
model_selector = gr.Radio(
|
| 73 |
+
choices=["BiSeNet-base", "BiSeNet-Best", "BiSeNetV2-base", "BiSeNetV2-Best"],
|
| 74 |
+
value="BiSeNet-base",
|
| 75 |
label="Select the real time segmentation model"
|
| 76 |
)
|
| 77 |
submit_btn = gr.Button("Run prediction")
|
|
|
|
| 84 |
for i in range(0, len(legend), 2):
|
| 85 |
with gr.Row():
|
| 86 |
with gr.Column(scale=1):
|
| 87 |
+
color_box0 = f"""<span style='display:inline-block; width:15px; height:15px;
|
| 88 |
+
background-color:rgb({legend[i][1][0]},{legend[i][1][1]},{legend[i][1][2]}); margin-left:6px; border:1px solid #000;'></span>"""
|
| 89 |
+
gr.Markdown(f"**{legend[i][1]}** β {legend[i][2]} {color_box0}")
|
| 90 |
with gr.Column(scale=1):
|
| 91 |
if i + 1 < len(legend):
|
| 92 |
+
color_box1 = f"""<span style='display:inline-block; width:15px; height:15px;
|
| 93 |
+
background-color:rgb({legend[i+1][1][0]},{legend[i+1][1][1]},{legend[i+1][1][2]}); margin-left:6px; border:1px solid #000;'></span>"""
|
| 94 |
+
gr.Markdown(f"**{legend[i+1][1]}** β {legend[i+1][2]} {color_box1}")
|
| 95 |
else:
|
| 96 |
gr.Markdown("") # Keeps spacing consistent if list is odd
|
| 97 |
|
model/modelLoading.py
CHANGED
|
@@ -5,18 +5,19 @@ from model.BiSeNetV2.model import BiSeNetV2
|
|
| 5 |
|
| 6 |
|
| 7 |
# BiSeNet model loading function
|
| 8 |
-
def loadBiSeNet(device: str = 'cpu') -> BiSeNet:
|
| 9 |
"""
|
| 10 |
Load the BiSeNet model and move it to the specified device.
|
| 11 |
|
| 12 |
Args:
|
| 13 |
device (str): Device to load the model onto ('cpu' or 'cuda').
|
|
|
|
| 14 |
|
| 15 |
Returns:
|
| 16 |
model (BiSeNet): The loaded BiSeNet model.
|
| 17 |
"""
|
| 18 |
model = BiSeNet(num_classes=19, context_path='resnet18').to(device)
|
| 19 |
-
model.load_state_dict(torch.load('./weights/BiSeNet/
|
| 20 |
model.eval()
|
| 21 |
|
| 22 |
return model
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
# BiSeNet model loading function
|
| 8 |
+
def loadBiSeNet(device: str = 'cpu', weights:str='weightADV.pth') -> BiSeNet:
|
| 9 |
"""
|
| 10 |
Load the BiSeNet model and move it to the specified device.
|
| 11 |
|
| 12 |
Args:
|
| 13 |
device (str): Device to load the model onto ('cpu' or 'cuda').
|
| 14 |
+
weights (str): weights file to be loaded
|
| 15 |
|
| 16 |
Returns:
|
| 17 |
model (BiSeNet): The loaded BiSeNet model.
|
| 18 |
"""
|
| 19 |
model = BiSeNet(num_classes=19, context_path='resnet18').to(device)
|
| 20 |
+
model.load_state_dict(torch.load(f'./weights/BiSeNet/{weights}', map_location=device)['model_state_dict'])
|
| 21 |
model.eval()
|
| 22 |
|
| 23 |
return model
|
utils/imageHandling.py
CHANGED
|
@@ -81,10 +81,11 @@ def legendHandling()->list[int, str, str]:
|
|
| 81 |
- Color description (str)
|
| 82 |
The list is sorted by class ID.
|
| 83 |
"""
|
| 84 |
-
return sorted([[0, "road", "dark purple"], [1, "sidewalk", "light purple / pink"], [2, "building", "dark gray"
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
| 88 |
], key=lambda x: x[0])
|
| 89 |
|
| 90 |
|
|
|
|
| 81 |
- Color description (str)
|
| 82 |
The list is sorted by class ID.
|
| 83 |
"""
|
| 84 |
+
return sorted([[0, "road", "dark purple", (128, 64, 128)], [1, "sidewalk", "light purple / pink", (244, 35, 232)], [2, "building", "dark gray", (70, 70, 70)],
|
| 85 |
+
[3, "wall", "blue + grey", (102, 102, 156)], [4, "fence", "beige", (190, 153, 153)], [5, "pole", "grey", (153, 153, 153)], [6, "traffic light", "orange", (250, 170, 30)],
|
| 86 |
+
[7, "traffic sign", "yellow", (220, 220, 0)], [8, "vegetation", "dark green", (107, 142, 35)], [9, "terrain", "light green", (152, 251, 152)], [10, "sky", "blue", (70, 130, 180)],
|
| 87 |
+
[11, "person", "dark red", (220, 20, 60)], [12, "rider", "light red", (255, 0, 0)], [13, "car", "blue", (0, 0, 142)], [14, "truck", "dark blue", (0, 0, 70)],
|
| 88 |
+
[15, "bus", "dark blue", (0, 60, 100)], [16, "train", "blue + green", (0, 80, 100)], [17, "motorcycle", "light blue", (0, 0, 230)], [18, "bicycle", "velvet", (119, 11, 32)]
|
| 89 |
], key=lambda x: x[0])
|
| 90 |
|
| 91 |
|
weights/BiSeNet/weight_Base.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e81645b4453ec62eda383d16e9cf279607181975dd648fdd354c78b050e7c8ca
|
| 3 |
+
size 50456242
|
weights/BiSeNet/{weightADV.pth β weight_Best.pth}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 121015606
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:372b7ecbd6c9e0fc6909370a58e3445552f3379db884ca984cd94d3de0cac66c
|
| 3 |
size 121015606
|
weights/BiSeNetV2/weight_Base.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7ff01a531a19ec5d63837d48c4f8157d8b6fa0e04960035bdd26a40c39cb5b67
|
| 3 |
+
size 21146508
|
weights/BiSeNetV2/{weightADV.pth β weight_Best.pth}
RENAMED
|
File without changes
|