jeffliulab's picture
Initial deploy
6cf5f32 verified
"""
Semantic Segmentation — Pixel-level classification with DeepLabV3
Courses: 100 ch3, 360 ch4
"""
import numpy as np
import torch
import torchvision.models.segmentation as seg_models
import torchvision.transforms as T
import gradio as gr
from PIL import Image
device = torch.device("cpu")
# Load DeepLabV3 with MobileNetV3 backbone (lightweight)
model = seg_models.deeplabv3_mobilenet_v3_large(
weights=seg_models.DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT
).eval().to(device)
preprocess = T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# PASCAL VOC class names (21 classes)
CLASS_NAMES = [
"background", "aeroplane", "bicycle", "bird", "boat",
"bottle", "bus", "car", "cat", "chair",
"cow", "dining table", "dog", "horse", "motorbike",
"person", "potted plant", "sheep", "sofa", "train",
"tv/monitor",
]
# Color palette for each class
PALETTE = np.array([
[0, 0, 0], # background
[128, 0, 0], # aeroplane
[0, 128, 0], # bicycle
[128, 128, 0], # bird
[0, 0, 128], # boat
[128, 0, 128], # bottle
[0, 128, 128], # bus
[128, 128, 128], # car
[64, 0, 0], # cat
[192, 0, 0], # chair
[64, 128, 0], # cow
[192, 128, 0], # dining table
[64, 0, 128], # dog
[192, 0, 128], # horse
[64, 128, 128], # motorbike
[192, 128, 128], # person
[0, 64, 0], # potted plant
[128, 64, 0], # sheep
[0, 192, 0], # sofa
[128, 192, 0], # train
[0, 64, 128], # tv/monitor
], dtype=np.uint8)
def segment(image: Image.Image, display_mode: str):
if image is None:
return None, None, ""
img = image.convert("RGB")
w, h = img.size
# Inference
inp = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(inp)["out"]
pred = output.argmax(1).squeeze().cpu().numpy()
# Resize prediction to original size
pred_resized = np.array(
Image.fromarray(pred.astype(np.uint8)).resize((w, h), Image.NEAREST)
)
# Color segmentation map
seg_color = PALETTE[pred_resized]
# Overlay
img_np = np.array(img)
overlay = (img_np * 0.5 + seg_color * 0.5).astype(np.uint8)
# Detected classes
unique_classes = np.unique(pred_resized)
detected = [CLASS_NAMES[c] for c in unique_classes if c != 0]
legend = "**Detected classes:**\n\n"
for c in unique_classes:
if c == 0:
continue
color = PALETTE[c]
pixel_pct = np.sum(pred_resized == c) / pred_resized.size * 100
color_hex = f"#{color[0]:02x}{color[1]:02x}{color[2]:02x}"
legend += f"- <span style='color:{color_hex};font-weight:bold;'>██</span> {CLASS_NAMES[c]} ({pixel_pct:.1f}%)\n"
if not detected:
legend += "- No objects detected (background only)"
if display_mode == "Overlay":
return overlay, seg_color, legend
elif display_mode == "Segmentation Only":
return seg_color, seg_color, legend
else: # Side by Side
return overlay, seg_color, legend
with gr.Blocks(title="Semantic Segmentation") as demo:
gr.Markdown(
"# Semantic Segmentation\n"
"Upload an image to see pixel-level classification (21 PASCAL VOC classes).\n"
"*Courses: 100 Deep Learning ch3, 360 Autonomous Driving ch4*"
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image")
mode = gr.Radio(
["Overlay", "Segmentation Only", "Side by Side"],
value="Overlay",
label="Display Mode",
)
btn = gr.Button("Segment", variant="primary")
with gr.Column(scale=2):
with gr.Row():
overlay_out = gr.Image(label="Result")
seg_out = gr.Image(label="Segmentation Map")
legend_md = gr.Markdown()
btn.click(segment, [input_image, mode], [overlay_out, seg_out, legend_md])
gr.Examples(
examples=[
["examples/street.jpg", "Overlay"],
["examples/room.jpg", "Side by Side"],
],
inputs=[input_image, mode],
)
if __name__ == "__main__":
demo.launch()