joshangngoching commited on
Commit
0664e3a
1 Parent(s): ad646b5

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +113 -0
main.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from transformers import pipeline
4
+ import numpy as np
5
+ import cv2
6
+ import matplotlib.cm as cm
7
+ import time
8
+ import base64
9
+ from io import BytesIO
10
+
11
+ st.set_page_config(layout="wide")
12
+
13
+ with open("styles.css") as f:
14
+ st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True)
15
+
16
+
17
+ st.markdown("<h1 class='title'>Segformer Semantic Segmentation</h1>", unsafe_allow_html=True)
18
+ st.markdown("""
19
+ <div class='text-center'>
20
+ This app uses the Segformer deep learning model to perform semantic segmentation on <b style='color: red; font-weight: 40px;'>road images</b>. The Transformer-based model is
21
+ trained on the CityScapes dataset which contains images of urban road scenes. Upload a
22
+ road scene and the app will return the image with semantic segmentation applied.
23
+ </div>
24
+ """, unsafe_allow_html=True)
25
+
26
+ group_members = ["Ang Ngo Ching, Josh Darren W.", "Bautista, Ryan Matthew M.", "Lacuesta, Angelo Giuseppe M.", "Reyes, Kenwin Hans", "Ting, Sidney Mitchell O."]
27
+
28
+
29
+ # model_versions = ["b1", "b2", "b3", "b4", "b5"]
30
+ # selected_model_version = st.selectbox("Select a model version:", model_versions)
31
+
32
+ st.markdown("""
33
+ <h3 class='text-center' style='margin-top: 0.5rem;'>
34
+ ℹ️ You can get sample images of road scenes in this <a href='https://drive.google.com/drive/folders/1202EMeXAHnN18NuhJKWWme34vg0V-svY?fbclid=IwAR3kyjGS895nOBKi9aGT_P4gLX9jvSNrV5b5y3GH49t2Pvg2sZSRA58LLxs' target='_blank'>link</a>.
35
+ </h3>""", unsafe_allow_html=True)
36
+ semantic_segmentation = pipeline("image-segmentation", f"nvidia/segformer-b1-finetuned-cityscapes-1024-1024")
37
+
38
+ new_file_uploaded = False
39
+ uploaded_file = st.file_uploader("", type=["jpg", "png"])
40
+ label_colors = {}
41
+
42
+ def draw_masks_fromDict(image, results):
43
+ masked_image = image.copy()
44
+
45
+ colormap = cm.get_cmap('nipy_spectral')
46
+
47
+ for i, result in enumerate(results):
48
+ mask = np.array(result['mask'])
49
+ mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
50
+
51
+ color = colormap(i / len(results))[:3]
52
+ color = tuple(int(c * 255) for c in color)
53
+
54
+ masked_image = np.where(mask, color, masked_image)
55
+
56
+ label_colors[color] = result['label']
57
+
58
+ masked_image = masked_image.astype(np.uint8)
59
+ return cv2.addWeighted(image, 0.3, masked_image, 0.7, 0)
60
+
61
+ col1, col2 = st.columns(2)
62
+
63
+ if "uploaded_file" not in st.session_state:
64
+ st.session_state.uploaded_file = None
65
+
66
+ if uploaded_file is not None:
67
+ st.session_state.uploaded_file = uploaded_file
68
+
69
+ if st.session_state.uploaded_file is not None:
70
+ image = Image.open(st.session_state.uploaded_file)
71
+ col1, col2 = st.columns(2)
72
+
73
+ with col1:
74
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
75
+
76
+
77
+ while True:
78
+ with st.spinner('Processing...'):
79
+ segmentation_results = semantic_segmentation(image)
80
+ image_with_masks = draw_masks_fromDict(np.array(image)[:, :, :3], segmentation_results)
81
+ image_with_masks_pil = Image.fromarray(image_with_masks, 'RGB')
82
+
83
+ with col2:
84
+ st.image(image_with_masks_pil, caption='Segmented Image.', use_column_width=True)
85
+
86
+ st.markdown("**Labels:**")
87
+ for color, label in label_colors.items():
88
+ st.markdown(f"<div style='display: flex; align-items: center; margin-bottom: 0.5rem;'><span style='display: inline-block; width: 20px; height: 20px; background-color: rgb{color}; margin-right: 1rem; border-radius: 10px;'></span><p style='margin: 0;'>{label}</p></div>", unsafe_allow_html=True)
89
+
90
+ buffered = BytesIO()
91
+ image_with_masks_pil.save(buffered, format="PNG")
92
+ img_str = base64.b64encode(buffered.getvalue()).decode()
93
+ href = f'<a href="data:file/png;base64,{img_str}" download="segmented_{st.session_state.uploaded_file.name}">Download Segmented Image</a>'
94
+ st.markdown(href, unsafe_allow_html=True)
95
+
96
+ new_file_uploaded = False
97
+
98
+ while not new_file_uploaded:
99
+ time.sleep(1)
100
+
101
+
102
+ pdf_url = "https://arxiv.org/pdf/2105.15203.pdf"
103
+
104
+ st.markdown("""
105
+ <h3 style='text-align: center; margin-top: 2rem;'>
106
+ Read more about the paper below👇
107
+ </h5>
108
+ """, unsafe_allow_html=True)
109
+ st.markdown(f'<iframe class="pdf" src={pdf_url}></iframe>', unsafe_allow_html=True)
110
+
111
+ st.markdown("Group Members:")
112
+ for member in group_members:
113
+ st.markdown("- " + member)