SegFormer / app.py
karolmajek's picture
nonsense
9dcc578
raw history blame
No virus
2.95 kB
from matplotlib.pyplot import axis
import gradio as gr
import requests
import numpy as np
from torch import nn
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import requests
url1 = 'https://cdn.pixabay.com/photo/2014/09/07/21/52/city-438393_1280.jpg'
r = requests.get(url1, allow_redirects=True)
open("city1.jpg", 'wb').write(r.content)
url2 = 'https://cdn.pixabay.com/photo/2016/02/19/11/36/canal-1209808_1280.jpg'
r = requests.get(url2, allow_redirects=True)
open("city2.jpg", 'wb').write(r.content)
def cityscapes_palette():
return [[128, 64, 128],[244, 35, 232],[70, 70, 70],[102, 102, 156],[190, 153, 153],
[153, 153, 153],[250, 170, 30],[220, 220, 0],[107, 142, 35],[152, 251, 152],
[70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
[0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]
model_name = "nvidia/segformer-b5-finetuned-cityscapes-1024-1024"
feature_extractor = SegformerFeatureExtractor.from_pretrained(model_name)
model = SegformerForSemanticSegmentation.from_pretrained(model_name)
def inference(image):
image = image.resize((1024,1024))
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
# First, rescale logits to original image size
logits = nn.functional.interpolate(outputs.logits.detach().cpu(),
size=image.size[::-1], # (height, width)
mode='bilinear',
align_corners=False)
# Second, apply argmax on the class dimension
seg = logits.argmax(dim=1)[0]
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
palette = np.array(cityscapes_palette())
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)
merged = np.concatenate((np.concatenate((np.array(image), color_seg), axis=1), np.concatenate((np.zeros_like(image), img), axis=1)), axis=0)
return merged
title = "Transformers - SegFormer B5 @ 1024px"
description = "demo for SegFormer. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.\nModel: nvidia/segformer-b5-finetuned-cityscapes-1024-1024"
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2105.15203'>SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers</a> | <a href='https://huggingface.co/transformers/model_doc/segformer.html#segformerforsemanticsegmentation'>Segformer page</a></p>"
gr.Interface(
inference,
[gr.inputs.Image(type="pil", label="Input")],
gr.outputs.Image(type="numpy", label="Output"),
title=title,
description=description,
article=article,
examples=[
["city1.jpg"],
["city2.jpg"]
]).launch()