joshangngoching commited on
Commit
78e99f5
1 Parent(s): c97ff5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -10
app.py CHANGED
@@ -4,10 +4,37 @@ from transformers import pipeline
4
  import numpy as np
5
  import cv2
6
  import matplotlib.cm as cm
 
 
 
7
 
8
- semantic_segmentation = pipeline("image-segmentation", "nvidia/segformer-b5-finetuned-cityscapes-1024-1024")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"])
11
 
12
  def draw_masks_fromDict(image, results):
13
  masked_image = image.copy()
@@ -26,16 +53,53 @@ def draw_masks_fromDict(image, results):
26
  masked_image = masked_image.astype(np.uint8)
27
  return cv2.addWeighted(image, 0.3, masked_image, 0.7, 0)
28
 
 
 
 
 
29
 
30
  if uploaded_file is not None:
31
- image = Image.open(uploaded_file)
32
- st.image(image, caption='Uploaded Image.', use_column_width=True)
33
- st.write("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- segmentation_results = semantic_segmentation(image)
36
- st.json(segmentation_results)
37
 
38
- image_with_masks = draw_masks_fromDict(np.array(image), segmentation_results)
 
 
 
 
 
39
 
40
- image_with_masks_pil = Image.fromarray(image_with_masks, 'RGB')
41
- st.image(image_with_masks_pil, caption='Segmented Image', use_column_width=True)
 
 
4
  import numpy as np
5
  import cv2
6
  import matplotlib.cm as cm
7
+ import time
8
+ import base64
9
+ from io import BytesIO
10
 
11
+ st.set_page_config(layout="wide")
12
+
13
+ with open("styles.css") as f:
14
+ st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True)
15
+
16
+
17
+ st.markdown("<h1 class='title'>Segformer Semantic Segmentation</h1>", unsafe_allow_html=True)
18
+ st.markdown("""
19
+ <div class='text-center'>
20
+ This app uses the Segformer deep learning model to perform semantic segmentation on road images. The Transformer-based model is
21
+ trained on the CityScapes dataset which contains images of urban road scenes. Upload a
22
+ road scene and the app will return the image with semantic segmentation applied.
23
+ </div>
24
+ """, unsafe_allow_html=True)
25
+
26
+ group_members = ["Ang Ngo Ching, Josh Darren W.", "Bautista, Ryan Matthew M.", "Lacuesta, Angelo Giuseppe M.", "Reyes, Kenwin Hans", "Ting, Sidney Mitchell O."]
27
+
28
+
29
+ # model_versions = ["b1", "b2", "b3", "b4", "b5"]
30
+ # selected_model_version = st.selectbox("Select a model version:", model_versions)
31
+
32
+
33
+ semantic_segmentation = pipeline("image-segmentation", f"nvidia/segformer-b1-finetuned-cityscapes-1024-1024")
34
+
35
+ new_file_uploaded = False
36
+ uploaded_file = st.file_uploader("", type=["jpg", "png"])
37
 
 
38
 
39
  def draw_masks_fromDict(image, results):
40
  masked_image = image.copy()
 
53
  masked_image = masked_image.astype(np.uint8)
54
  return cv2.addWeighted(image, 0.3, masked_image, 0.7, 0)
55
 
56
+ col1, col2 = st.columns(2)
57
+
58
+ if "uploaded_file" not in st.session_state:
59
+ st.session_state.uploaded_file = None
60
 
61
  if uploaded_file is not None:
62
+ st.session_state.uploaded_file = uploaded_file
63
+
64
+ if st.session_state.uploaded_file is not None:
65
+ image = Image.open(st.session_state.uploaded_file)
66
+ col1, col2 = st.columns(2)
67
+
68
+ with col1:
69
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
70
+
71
+
72
+ while True:
73
+ with st.spinner('Processing...'):
74
+ segmentation_results = semantic_segmentation(image)
75
+ image_with_masks = draw_masks_fromDict(np.array(image), segmentation_results)
76
+ image_with_masks_pil = Image.fromarray(image_with_masks, 'RGB')
77
+
78
+ with col2:
79
+ st.image(image_with_masks_pil, caption='Segmented Image.', use_column_width=True)
80
+
81
+ buffered = BytesIO()
82
+ image_with_masks_pil.save(buffered, format="PNG")
83
+ img_str = base64.b64encode(buffered.getvalue()).decode()
84
+ href = f'<a href="data:file/png;base64,{img_str}" download="segmented_{st.session_state.uploaded_file.name}">Download Segmented Image</a>'
85
+ st.markdown(href, unsafe_allow_html=True)
86
+
87
+ new_file_uploaded = False
88
+
89
+ while not new_file_uploaded:
90
+ time.sleep(1)
91
+
92
+
93
 
94
+ pdf_url = "https://arxiv.org/pdf/2105.15203.pdf"
 
95
 
96
+ st.markdown("""
97
+ <h3 class='text-center'>
98
+ Read more about the paper below👇
99
+ </h5>
100
+ """, unsafe_allow_html=True)
101
+ st.markdown(f'<iframe class="pdf" src={pdf_url}></iframe>', unsafe_allow_html=True)
102
 
103
+ st.markdown("Group Members:")
104
+ for member in group_members:
105
+ st.markdown("- " + member)