Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from PIL import Image | |
from transformers import pipeline | |
import matplotlib.pyplot as plt | |
import io | |
model_pipeline = pipeline("image-segmentation", model="sergiopaniego/segformer-b0-segments-sidewalk-finetuned") | |
id2label = {0: 'unlabeled', 1: 'flat-road', 2: 'flat-sidewalk', 3: 'flat-crosswalk', 4: 'flat-cyclinglane', 5: 'flat-parkingdriveway', 6: 'flat-railtrack', 7: 'flat-curb', 8: 'human-person', 9: 'human-rider', 10: 'vehicle-car', 11: 'vehicle-truck', 12: 'vehicle-bus', 13: 'vehicle-tramtrain', 14: 'vehicle-motorcycle', 15: 'vehicle-bicycle', 16: 'vehicle-caravan', 17: 'vehicle-cartrailer', 18: 'construction-building', 19: 'construction-door', 20: 'construction-wall', 21: 'construction-fenceguardrail', 22: 'construction-bridge', 23: 'construction-tunnel', 24: 'construction-stairs', 25: 'object-pole', 26: 'object-trafficsign', 27: 'object-trafficlight', 28: 'nature-vegetation', 29: 'nature-terrain', 30: 'sky', 31: 'void-ground', 32: 'void-dynamic', 33: 'void-static', 34: 'void-unclear'} | |
sidewalk_palette = [ | |
[0, 0, 0], # unlabeled | |
[216, 82, 24], # flat-road | |
[255, 255, 0], # flat-sidewalk | |
[125, 46, 141], # flat-crosswalk | |
[118, 171, 47], # flat-cyclinglane | |
[161, 19, 46], # flat-parkingdriveway | |
[255, 0, 0], # flat-railtrack | |
[0, 128, 128], # flat-curb | |
[190, 190, 0], # human-person | |
[0, 255, 0], # human-rider | |
[0, 0, 255], # vehicle-car | |
[170, 0, 255], # vehicle-truck | |
[84, 84, 0], # vehicle-bus | |
[84, 170, 0], # vehicle-tramtrain | |
[84, 255, 0], # vehicle-motorcycle | |
[170, 84, 0], # vehicle-bicycle | |
[170, 170, 0], # vehicle-caravan | |
[170, 255, 0], # vehicle-cartrailer | |
[255, 84, 0], # construction-building | |
[255, 170, 0], # construction-door | |
[255, 255, 0], # construction-wall | |
[33, 138, 200], # construction-fenceguardrail | |
[0, 170, 127], # construction-bridge | |
[0, 255, 127], # construction-tunnel | |
[84, 0, 127], # construction-stairs | |
[84, 84, 127], # object-pole | |
[84, 170, 127], # object-trafficsign | |
[84, 255, 127], # object-trafficlight | |
[170, 0, 127], # nature-vegetation | |
[170, 84, 127], # nature-terrain | |
[170, 170, 127], # sky | |
[170, 255, 127], # void-ground | |
[255, 0, 127], # void-dynamic | |
[255, 84, 127], # void-static | |
[255, 170, 127], # void-unclear | |
] | |
def get_output_figure(pil_img, results): | |
plt.figure(figsize=(16, 10)) | |
plt.imshow(pil_img) | |
image_array = np.array(pil_img) | |
segmentation_map = np.zeros_like(image_array) | |
for result in results: | |
mask = np.array(result['mask']) | |
label = result['label'] | |
label_index = list(id2label.values()).index(label) | |
color = sidewalk_palette[label_index] | |
for c in range(3): | |
segmentation_map[:, :, c] = np.where(mask, color[c], segmentation_map[:, :, c]) | |
plt.imshow(segmentation_map, alpha=0.5) | |
plt.axis('off') | |
return plt.gcf() | |
def detect(image): | |
results = model_pipeline(image) | |
print(results) | |
output_figure = get_output_figure(image, results) | |
buf = io.BytesIO() | |
output_figure.savefig(buf, bbox_inches='tight') | |
buf.seek(0) | |
output_pil_img = Image.open(buf) | |
return output_pil_img | |
with gr.Blocks() as demo: | |
gr.Markdown("# Semantic segmentation with SegFormer fine tuned on segments/sidewalk") | |
gr.Markdown( | |
""" | |
This application uses a fine tuned SegFormer for sematic segmenation over an input image. | |
This version was trained using segments/sidewalk dataset. | |
You can load an image and see the predicted segmentation. | |
""" | |
) | |
gr.Interface( | |
fn=detect, | |
inputs=gr.Image(label="Input image", type="pil"), | |
outputs=[ | |
gr.Image(label="Output prediction", type="pil") | |
] | |
) | |
demo.launch(show_error=True) |