taskswithcode commited on
Commit
72935b7
1 Parent(s): 35af04c
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import sys
3
+ import streamlit as st
4
+ import string
5
+ import os
6
+ from io import StringIO
7
+ import pdb
8
+ import json
9
+ import torch
10
+ import requests
11
+ import socket
12
+ from streamlit_image_select import image_select
13
+
14
+
15
+
16
+
17
+
18
+ use_case = {"1":"Image background removal","2":"Masking foreground for downstream inpainting task"}
19
+ mask_types = {"blur - blurs background":"blur","map - makes the foreground white and rest black ":"map","rgba - makes background white":"rgba","green - makes the background green":"green"}
20
+
21
+
22
+
23
+
24
+ APP_NAME = "hf/salient_object_detection"
25
+ INFO_URL = "https://www.taskswithcode.com/stats/"
26
+ TMP_DIR="tmp_dir"
27
+ TMP_SEED = 1
28
+
29
+
30
+
31
+
32
+
33
+ def get_views(action):
34
+ ret_val = 0
35
+ #return "{:,}".format(ret_val)
36
+ hostname = socket.gethostname()
37
+ ip_address = socket.gethostbyname(hostname)
38
+ if ("view_count" not in st.session_state):
39
+ try:
40
+ app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address}
41
+ res = requests.post(INFO_URL, json = app_info).json()
42
+ print(res)
43
+ data = res["count"]
44
+ except:
45
+ data = 0
46
+ ret_val = data
47
+ st.session_state["view_count"] = data
48
+ else:
49
+ ret_val = st.session_state["view_count"]
50
+ if (action != "init"):
51
+ app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address}
52
+ res = requests.post(INFO_URL, json = app_info).json()
53
+ return "{:,}".format(ret_val)
54
+
55
+
56
+
57
+
58
+ def construct_model_info_for_display(model_names):
59
+ options_arr = []
60
+ #markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/></div>"
61
+ markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Model evaluated </b><br/></div>"
62
+ markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
63
+ for node in model_names:
64
+ options_arr .append(node["name"])
65
+ if (node["mark"] == "True"):
66
+ markdown_str += f"<div style=\"font-size:16px; color: #5f5f5f; text-align: left\">&nbsp;•&nbsp;Model:&nbsp;<a href=\'{node['paper_url']}\' target='_blank'>{node['name']}</a><br/>&nbsp;&nbsp;&nbsp;&nbsp;Code released by:&nbsp;<a href=\'{node['orig_author_url']}\' target='_blank'>{node['orig_author']}</a><br/>&nbsp;&nbsp;&nbsp;&nbsp;Model info:&nbsp;<a href=\'{node['sota_info']['sota_link']}\' target='_blank'>{node['sota_info']['task']}</a></div>"
67
+ if ("Note" in node):
68
+ markdown_str += f"<div style=\"font-size:16px; color: #a91212; text-align: left\">&nbsp;&nbsp;&nbsp;&nbsp;{node['Note']}<a href=\'{node['alt_url']}\' target='_blank'>link</a></div>"
69
+ markdown_str += "<div style=\"font-size:16px; color: #5f5f5f; text-align: left\"><br/></div>"
70
+
71
+ markdown_str += "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><b>Note:</b><br/>•&nbsp;Uploaded files are loaded into non-persistent memory for the duration of the computation. They are not cached</div>"
72
+ markdown_str += "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><br/><a href=\'https://github.com/taskswithcode/salient_object_detection_app.git\' target='_blank'>Github code</a> for this app</div>"
73
+ return options_arr,markdown_str
74
+
75
+
76
+ def init_page():
77
+ st.set_page_config(page_title='TWC - State-of-the-art model salient object detection (visually dominant objects in an image)', page_icon="logo.jpg", layout='centered', initial_sidebar_state='auto',
78
+ menu_items={
79
+ 'About': 'This app was created by taskswithcode. http://taskswithcode.com'
80
+
81
+ })
82
+ col,pad = st.columns([85,15])
83
+
84
+ with col:
85
+ st.image("long_form_logo_with_icon.png")
86
+
87
+
88
+ def run_test(config,input_file_name,display_area,uploaded_file,mask_type):
89
+ global TMP_SEED
90
+ display_area.text("Processing request...")
91
+ try:
92
+ if (uploaded_file is None):
93
+ file_data = open(input_file_name, "rb")
94
+ r = requests.post(config["SERVER_ADDRESS"], data={"mask":mask_type}, files={"test":file_data})
95
+ else:
96
+ file_data = uploaded_file.read()
97
+ file_name = f"{TMP_DIR}/{TMP_SEED}_{str(time.time()).replace('.','_')}_{uploaded_file.name}"
98
+ TMP_SEED += 1
99
+ with open(file_name,"wb") as fp:
100
+ fp.write(file_data)
101
+ file_data = open(file_name, "rb")
102
+ r = requests.post(config["SERVER_ADDRESS"], data={"mask":mask_type}, files={"test":file_data})
103
+ os.remove(file_name)
104
+ print("Servers response:",r.status_code,len(r.content))
105
+ if (r.status_code == 200):
106
+ size = "{:,}".format(len(r.content))
107
+ return {"response":r.content,"size":size}
108
+ else:
109
+ return {"error":f"API request failed {r.status_code}"}
110
+ except Exception as e:
111
+ st.error("Some error occurred during prediction" + str(e))
112
+ st.stop()
113
+ return {"error":f"Exception in performing salient object detection: {str(e)}"}
114
+ return {}
115
+
116
+
117
+
118
+
119
+ def display_results(results,response_info,mask):
120
+ main_sent = f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">{response_info}<br/><br/></div>"
121
+ body_sent = []
122
+ download_data = {}
123
+ main_sent = main_sent + "\n" + '\n'.join(body_sent)
124
+ st.markdown(main_sent,unsafe_allow_html=True)
125
+ st.image(results["response"], caption=f'Output of salient object detection with mask: {mask}')
126
+ st.session_state["download_ready"] = results["response"]
127
+ get_views("submit")
128
+
129
+
130
+ def init_session():
131
+ print("Init session")
132
+ init_page()
133
+ st.session_state["model_name"] = "insprynet"
134
+ st.session_state["download_ready"] = None
135
+ st.session_state["model_name"] = "ss_test"
136
+ st.session_state["file_name"] = "default"
137
+ st.session_state["mask_type"] = "blur"
138
+
139
+ def app_main(app_mode,example_files,model_name_files,config_file):
140
+ init_session()
141
+ with open(example_files) as fp:
142
+ example_file_names = json.load(fp)
143
+ with open(model_name_files) as fp:
144
+ model_names = json.load(fp)
145
+ with open(config_file) as fp:
146
+ config = json.load(fp)
147
+ curr_use_case = use_case[app_mode].split(".")[0]
148
+ curr_use_case = use_case[app_mode].split(".")[0]
149
+ st.markdown("<h5 style='text-align: center;'>State-of-the-art model for salient object detection</h5>", unsafe_allow_html=True)
150
+ st.markdown(f"<div style='color: #4f4f4f; text-align: left'>Use cases for salient object detection<br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;{use_case['1']}<br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;{use_case['2']}</div>", unsafe_allow_html=True)
151
+ st.markdown(f"<div style='color: #9f9f9f; text-align: right'>views:&nbsp;{get_views('init')}</div>", unsafe_allow_html=True)
152
+
153
+
154
+ try:
155
+
156
+
157
+ with st.form('twc_form'):
158
+
159
+ step1_line = "Upload an image or choose an example image below"
160
+ uploaded_file = st.file_uploader(step1_line, type=["png","jpg","jpeg"])
161
+
162
+ selected_file_name = image_select("Select image", ["twc_samples/sample1.jpg", "twc_samples/sample2.jpg", "twc_samples/sample3.jpg", "twc_samples/sample4.jpg"])
163
+
164
+
165
+ st.write("")
166
+ mask_type = st.selectbox(label=f'Select type of masking',
167
+ options = list(dict.keys(mask_types)), index=0, key = "twc_mask_types")
168
+ mask_type = mask_types[mask_type]
169
+ st.write("")
170
+ submit_button = st.form_submit_button('Run')
171
+ options_arr,markdown_str = construct_model_info_for_display(model_names)
172
+
173
+
174
+ input_status_area = st.empty()
175
+ display_area = st.empty()
176
+ if submit_button:
177
+ start = time.time()
178
+ if uploaded_file is not None:
179
+ st.session_state["file_name"] = uploaded_file.name
180
+ else:
181
+ st.session_state["file_name"] = selected_file_name
182
+ st.session_state["mask_type"] = mask_type
183
+ display_area.empty()
184
+ results = run_test(config,st.session_state["file_name"],display_area,uploaded_file,mask_type)
185
+ with display_area.container():
186
+ if ("error" in results):
187
+ st.error(results["error"])
188
+ else:
189
+ device = 'GPU' if torch.cuda.is_available() else 'CPU'
190
+ response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for image size: {results['size']} bytes"
191
+ display_results(results,response_info,mask_type)
192
+ #st.json(results)
193
+ st.download_button(
194
+ label="Download results as png",
195
+ data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
196
+ disabled = False if st.session_state["download_ready"] != None else True,
197
+ file_name= (st.session_state["model_name"] + "_" + st.session_state["mask_type"] + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".png").replace("/","_"),
198
+ mime='image/png',
199
+ key ="download"
200
+ )
201
+
202
+
203
+
204
+ except Exception as e:
205
+ st.error("Some error occurred during loading" + str(e))
206
+ st.stop()
207
+
208
+ st.markdown(markdown_str, unsafe_allow_html=True)
209
+
210
+
211
+
212
+ if __name__ == "__main__":
213
+ app_main("1","sod_app_examples.json","sod_app_models.json","config.json")
214
+
config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"SERVER_ADDRESS":"https://www.taskswithcode.com/salient_object_detection_api/"}
long_form_logo_with_icon.png ADDED
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ streamlit
2
+ streamlit-image-select
run.sh ADDED
@@ -0,0 +1,2 @@
 
 
1
+ streamlit run app.py --server.port 80 "1" "sod_app_examples.json" "sod_app_models.json"
2
+
sod_app_examples.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ {
2
+ "Machine learning terms (phrases test)": {"name":"tests/small_test.txt"},
3
+ "Customer feedback mixed with noise":{"name":"tests/larger_test.txt"},
4
+ "Movie reviews": {"name":"tests/imdb_sent.txt"}
5
+ }
sod_app_models.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+
3
+ { "name":"InSPyReNet" ,
4
+ "model":"InSPyReNet",
5
+ "fork_url":"https://github.com/taskswithcode/InSPyReNet.git",
6
+ "orig_author_url":"https://github.com/plemeri/inspyrenet",
7
+ "orig_author":"Taehun Kim",
8
+ "sota_info": {
9
+ "task":"This model was #1 SOTA on 8+ datasets",
10
+ "sota_link":"https://paperswithcode.com/paper/revisiting-image-pyramid-structure-for-high"
11
+ },
12
+ "paper_url":"https://arxiv.org/abs/2209.09475v1",
13
+ "mark":"True",
14
+ "Note":"This model can perform SOD on video too. For SOD on videos use the Colab link ",
15
+ "alt_url":"https://github.com/taskswithcode/InSPyReNet/blob/main/TWCSOD.ipynb",
16
+ "class":"SODModel","sota_link":"https://arxiv.org/abs/2209.09475v1"}
17
+
18
+
19
+ ]
twc_samples/sample1.jpg ADDED
twc_samples/sample2.jpg ADDED
twc_samples/sample3.jpg ADDED
twc_samples/sample4.jpg ADDED