zenes commited on
Commit
035e155
1 Parent(s): fa8a637

Add streamlit application

Browse files

Signed-off-by: airh4ck <dudnikoff98@gmail.com>

Files changed (3) hide show
  1. app.py +5 -0
  2. segmentation.py +44 -0
  3. streamlit_config.py +66 -0
app.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import streamlit_config as st
2
+
3
+ if __name__ == "__main__":
4
+ st.init()
5
+ st.run()
segmentation.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
2
+ from transformers.modeling_outputs import SemanticSegmenterOutput
3
+ from transformers.feature_extraction_utils import BatchFeature
4
+ from PIL import Image
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import seaborn as sns
9
+ import itertools
10
+
11
+
12
+ def create_model():
13
+ return SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
14
+
15
+
16
+ def create_feature_extractor():
17
+ return SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
18
+
19
+
20
+ def postprocess(masks, height, width):
21
+ masks = F.interpolate(masks, (height, width))
22
+
23
+ label_per_pixel = torch.argmax(
24
+ masks.squeeze(), dim=0).detach().numpy()
25
+ color_mask = np.zeros(label_per_pixel.shape + (3,))
26
+ palette = itertools.cycle(sns.color_palette())
27
+
28
+ for lbl in np.unique(label_per_pixel):
29
+ color_mask[label_per_pixel == lbl, :] = np.asarray(next(palette)) * 255
30
+
31
+ return color_mask
32
+
33
+
34
+ def segment(image: Image, model, feature_extractor) -> torch.Tensor:
35
+ inputs = feature_extractor(
36
+ images=image, return_tensors="pt")
37
+ outputs = model(**inputs)
38
+ masks = outputs.logits
39
+
40
+ color_mask = postprocess(masks, image.height, image.width)
41
+ pred_img = np.array(image.convert('RGB')) * 0.25 + color_mask * 0.75
42
+ pred_img = pred_img.astype(np.uint8)
43
+
44
+ return pred_img
streamlit_config.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import re
4
+ from io import BytesIO
5
+
6
+ import segmentation
7
+
8
+
9
+ def init():
10
+ st.set_page_config(page_title="Semantic image segmentation")
11
+ st.session_state["model"] = segmentation.create_model()
12
+ st.session_state["feature_extractor"] = segmentation.create_feature_extractor()
13
+
14
+
15
+ @st.experimental_memo(show_spinner=False)
16
+ def process_file(file):
17
+ return segmentation.segment(
18
+ Image.open(file),
19
+ st.session_state["model"],
20
+ st.session_state["feature_extractor"]
21
+ )
22
+
23
+
24
+ def get_uploaded_file():
25
+ return st.file_uploader(
26
+ label="Choose a file",
27
+ type=["png", "jpg", "jpeg"],
28
+ )
29
+
30
+
31
+ def download_button(file, name, format):
32
+ st.download_button(
33
+ label="Download processed image",
34
+ data=file,
35
+ file_name=name,
36
+ mime="image/" + format
37
+ )
38
+
39
+
40
+ def run():
41
+ st.title("Semantic image segmentation")
42
+ st.subheader("Upload your image and get an image with segmentation")
43
+
44
+ file = get_uploaded_file()
45
+ if not file:
46
+ return
47
+
48
+ placeholder = st.empty()
49
+ placeholder.info(
50
+ "Processing..."
51
+ )
52
+
53
+ image = process_file(file)
54
+ placeholder.empty()
55
+ placeholder.image(image)
56
+
57
+ filename = file.name
58
+ format = re.findall("\..*$", filename)[0][1:]
59
+
60
+ image = Image.fromarray(image)
61
+
62
+ buf = BytesIO()
63
+ image.save(buf, format="JPEG")
64
+ byte_image = buf.getvalue()
65
+
66
+ download_button(byte_image, filename, format)