File size: 2,588 Bytes
8f9088a
 
 
 
e8c8cc8
8f9088a
 
 
 
 
4626ab4
8f9088a
 
 
 
 
 
 
 
432e4a1
 
8f9088a
a604262
 
 
 
 
 
 
 
 
 
 
 
 
 
4626ab4
 
 
a604262
 
 
 
 
 
 
 
 
 
 
5de8b98
a604262
 
5de8b98
 
a604262
 
8f9088a
 
 
5de8b98
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import pandas as pd
import numpy as np
import torch
from PIL import Image
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
from torch import nn
import streamlit as st


st.title('Semantic Segmentation using SegFormer')
raw_image = st.file_uploader('Raw Input Image')
if raw_image is not None:
	df = pd.read_csv('class_dict_seg.csv')
	classes = df['name']
	palette = df[[' r', ' g', ' b']].values
	id2label = classes.to_dict()
	label2id = {v: k for k, v in id2label.items()}

	image = Image.open(raw_image)
	image = np.asarray(image)
	
	with st.spinner('Loading Model...'):
		feature_extractor = SegformerFeatureExtractor(align=False, reduce_zero_label=False)
		device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
		model = SegformerForSemanticSegmentation.from_pretrained("deep-learning-analytics/segformer_semantic_segmentation", 
																 ignore_mismatched_sizes=True,
		                                                         num_labels=len(id2label), id2label=id2label, label2id=label2id,
		                                                         reshape_last_stage=True)
		model = model.to(device)
		model.eval()
	
	with st.spinner('Preparing image...'):
		# prepare the image for the model (aligned resize)
		feature_extractor_inference = SegformerFeatureExtractor(do_random_crop=False, do_pad=False)
		pixel_values = feature_extractor_inference(image, return_tensors="pt").pixel_values.to(device)

	with st.spinner('Running inference...'):
		outputs = model(pixel_values=pixel_values)# logits are of shape (batch_size, num_labels, height/4, width/4)

	with st.spinner('Postprocessing...'):
		logits = outputs.logits.cpu()
		# First, rescale logits to original image size
		upsampled_logits = nn.functional.interpolate(logits,
		                size=image.shape[:-1], # (height, width)
		                mode='bilinear',
		                align_corners=False)
		# Second, apply argmax on the class dimension
		seg = upsampled_logits.argmax(dim=1)[0]
		color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3\
		all_labels = []
		for label, color in enumerate(palette):
		    color_seg[seg == label, :] = color
		    if label in seg:
		    	all_labels.append(id2label[label])
		# Convert to BGR
		color_seg = color_seg[..., ::-1]
	# Show image + mask
	img = np.array(image) * 0.5 + color_seg * 0.5
	img = img.astype(np.uint8)
	st.image(img, caption="Segmented Image")
	st.header("Predicted Labels")
	for idx, label in enumerate(all_labels):
		st.subheader(f'{idx+1}) {label}')