hshetty's picture
Update app.py
7623705
raw
history blame
No virus
1.53 kB
import matplotlib.pyplot as plt
import gradio as gr
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
import torch
import numpy as np
extractor = AutoFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("hshetty/my-segmentation-model")
def classify(im):
inputs = extractor(images=im, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
classes = logits[0].detach().cpu().numpy().argmax(axis=0)
colors = np.array([[128,0,0], [128,128,0], [0, 0, 128], [128,0,128], [0, 0, 0]])
return colors[classes]
interface = gr.Interface(fn=classify,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.inputs.Image(type="pil"),
title="Self Driving Car App- Semantic Segmentation",
description="This is a self driving car app using Semantic Semendation as part of week 2 end to end vision application project on CoRise.",
examples=["https://datasets-server.huggingface.co/assets/segments/sidewalk-semantic/--/segments--sidewalk-semantic-2/train/3/pixel_values/image.jpg",
"https://datasets-server.huggingface.co/assets/segments/sidewalk-semantic/--/segments--sidewalk-semantic-2/train/5/pixel_values/image.jpg",
"https://datasets-server.huggingface.co/assets/segments/sidewalk-semantic/--/segments--sidewalk-semantic-2/train/20/pixel_values/image.jpg"])
# FILL HERE
interface.launch()