sergiopaniego's picture
Updated Space
b42e4a2
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image
from transformers import pipeline
import matplotlib.pyplot as plt
import io
model_pipeline = pipeline("image-segmentation", model="sergiopaniego/segformer-b0-segments-sidewalk-finetuned")
id2label = {0: 'unlabeled', 1: 'flat-road', 2: 'flat-sidewalk', 3: 'flat-crosswalk', 4: 'flat-cyclinglane', 5: 'flat-parkingdriveway', 6: 'flat-railtrack', 7: 'flat-curb', 8: 'human-person', 9: 'human-rider', 10: 'vehicle-car', 11: 'vehicle-truck', 12: 'vehicle-bus', 13: 'vehicle-tramtrain', 14: 'vehicle-motorcycle', 15: 'vehicle-bicycle', 16: 'vehicle-caravan', 17: 'vehicle-cartrailer', 18: 'construction-building', 19: 'construction-door', 20: 'construction-wall', 21: 'construction-fenceguardrail', 22: 'construction-bridge', 23: 'construction-tunnel', 24: 'construction-stairs', 25: 'object-pole', 26: 'object-trafficsign', 27: 'object-trafficlight', 28: 'nature-vegetation', 29: 'nature-terrain', 30: 'sky', 31: 'void-ground', 32: 'void-dynamic', 33: 'void-static', 34: 'void-unclear'}
sidewalk_palette = [
[0, 0, 0], # unlabeled
[216, 82, 24], # flat-road
[255, 255, 0], # flat-sidewalk
[125, 46, 141], # flat-crosswalk
[118, 171, 47], # flat-cyclinglane
[161, 19, 46], # flat-parkingdriveway
[255, 0, 0], # flat-railtrack
[0, 128, 128], # flat-curb
[190, 190, 0], # human-person
[0, 255, 0], # human-rider
[0, 0, 255], # vehicle-car
[170, 0, 255], # vehicle-truck
[84, 84, 0], # vehicle-bus
[84, 170, 0], # vehicle-tramtrain
[84, 255, 0], # vehicle-motorcycle
[170, 84, 0], # vehicle-bicycle
[170, 170, 0], # vehicle-caravan
[170, 255, 0], # vehicle-cartrailer
[255, 84, 0], # construction-building
[255, 170, 0], # construction-door
[255, 255, 0], # construction-wall
[33, 138, 200], # construction-fenceguardrail
[0, 170, 127], # construction-bridge
[0, 255, 127], # construction-tunnel
[84, 0, 127], # construction-stairs
[84, 84, 127], # object-pole
[84, 170, 127], # object-trafficsign
[84, 255, 127], # object-trafficlight
[170, 0, 127], # nature-vegetation
[170, 84, 127], # nature-terrain
[170, 170, 127], # sky
[170, 255, 127], # void-ground
[255, 0, 127], # void-dynamic
[255, 84, 127], # void-static
[255, 170, 127], # void-unclear
]
def get_output_figure(pil_img, results):
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
image_array = np.array(pil_img)
segmentation_map = np.zeros_like(image_array)
for result in results:
mask = np.array(result['mask'])
label = result['label']
label_index = list(id2label.values()).index(label)
color = sidewalk_palette[label_index]
for c in range(3):
segmentation_map[:, :, c] = np.where(mask, color[c], segmentation_map[:, :, c])
plt.imshow(segmentation_map, alpha=0.5)
plt.axis('off')
return plt.gcf()
@spaces.GPU
def detect(image):
results = model_pipeline(image)
print(results)
output_figure = get_output_figure(image, results)
buf = io.BytesIO()
output_figure.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
return output_pil_img
with gr.Blocks() as demo:
gr.Markdown("# Semantic segmentation with SegFormer fine tuned on segments/sidewalk")
gr.Markdown(
"""
This application uses a fine tuned SegFormer for sematic segmenation over an input image.
This version was trained using segments/sidewalk dataset.
You can load an image and see the predicted segmentation.
"""
)
gr.Interface(
fn=detect,
inputs=gr.Image(label="Input image", type="pil"),
outputs=[
gr.Image(label="Output prediction", type="pil")
]
)
demo.launch(show_error=True)