d4niel92's picture
Update app.py
960819e
raw
history blame
1.23 kB
import gradio as gr
import numpy as np
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
extractor = AutoFeatureExtractor.from_pretrained("d4niel92/my-segmentation-model")
model = SegformerForSemanticSegmentation.from_pretrained("d4niel92/my-segmentation-model")
class_labels = ['unlabeled', 'paved-area', 'dirt', 'grass', 'gravel', 'water', 'rocks', 'pool', 'vegetation', 'roof', 'wall', 'window', 'door', 'fence', 'fence-pole', 'person', 'dog', 'car', 'bicycle', 'tree', 'bald-tree', 'ar-marker', 'obstacle', 'conflicting']
def classify(im):
inputs = extractor(images=im, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
classes = logits[0].detach().cpu().numpy().argmax(axis=0)
annotations = []
for c, class_name in enumerate(class_labels):
mask = np.array(classes == c, dtype=int)
mask = np.repeat(np.repeat(mask, 5, axis=0), 5, axis=1) # scaling up the masks
annotations.append((mask, class_name))
im = np.repeat(np.repeat(im, 5, axis=0), 5, axis=1) # scaling up the images
return im, annotations
interface = gr.Interface(classify, gr.Image(type="pil", shape=(128, 128)), gr.AnnotatedImage())
interface.launch()