File size: 2,588 Bytes
8f9088a e8c8cc8 8f9088a 4626ab4 8f9088a 432e4a1 8f9088a a604262 4626ab4 a604262 5de8b98 a604262 5de8b98 a604262 8f9088a 5de8b98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import os
import pandas as pd
import numpy as np
import torch
from PIL import Image
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
from torch import nn
import streamlit as st
st.title('Semantic Segmentation using SegFormer')
raw_image = st.file_uploader('Raw Input Image')
if raw_image is not None:
df = pd.read_csv('class_dict_seg.csv')
classes = df['name']
palette = df[[' r', ' g', ' b']].values
id2label = classes.to_dict()
label2id = {v: k for k, v in id2label.items()}
image = Image.open(raw_image)
image = np.asarray(image)
with st.spinner('Loading Model...'):
feature_extractor = SegformerFeatureExtractor(align=False, reduce_zero_label=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SegformerForSemanticSegmentation.from_pretrained("deep-learning-analytics/segformer_semantic_segmentation",
ignore_mismatched_sizes=True,
num_labels=len(id2label), id2label=id2label, label2id=label2id,
reshape_last_stage=True)
model = model.to(device)
model.eval()
with st.spinner('Preparing image...'):
# prepare the image for the model (aligned resize)
feature_extractor_inference = SegformerFeatureExtractor(do_random_crop=False, do_pad=False)
pixel_values = feature_extractor_inference(image, return_tensors="pt").pixel_values.to(device)
with st.spinner('Running inference...'):
outputs = model(pixel_values=pixel_values)# logits are of shape (batch_size, num_labels, height/4, width/4)
with st.spinner('Postprocessing...'):
logits = outputs.logits.cpu()
# First, rescale logits to original image size
upsampled_logits = nn.functional.interpolate(logits,
size=image.shape[:-1], # (height, width)
mode='bilinear',
align_corners=False)
# Second, apply argmax on the class dimension
seg = upsampled_logits.argmax(dim=1)[0]
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3\
all_labels = []
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
if label in seg:
all_labels.append(id2label[label])
# Convert to BGR
color_seg = color_seg[..., ::-1]
# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)
st.image(img, caption="Segmented Image")
st.header("Predicted Labels")
for idx, label in enumerate(all_labels):
st.subheader(f'{idx+1}) {label}') |