Spaces:
Running
Running
gallery 2.0
Browse filesgraph view added
- pages/Gallery.py +165 -43
pages/Gallery.py
CHANGED
@@ -9,6 +9,7 @@ import streamlit as st
|
|
9 |
from bs4 import BeautifulSoup
|
10 |
from datasets import load_dataset, Dataset, load_from_disk
|
11 |
from huggingface_hub import login
|
|
|
12 |
from streamlit_extras.switch_page_button import switch_page
|
13 |
from sklearn.svm import LinearSVC
|
14 |
|
@@ -50,6 +51,55 @@ class GalleryApp:
|
|
50 |
for key in info:
|
51 |
st.write(f"**{key}**: {items.iloc[idx + j][key]}")
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def selection_panel(self, items):
|
54 |
# temperal function
|
55 |
|
@@ -170,6 +220,8 @@ class GalleryApp:
|
|
170 |
|
171 |
selected_prompt = st.selectbox('Select prompt', prompts)
|
172 |
|
|
|
|
|
173 |
items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
|
174 |
prompt_id = items['prompt_id'].unique()[0]
|
175 |
note = items['note'].unique()[0]
|
@@ -206,14 +258,14 @@ class GalleryApp:
|
|
206 |
except:
|
207 |
pass
|
208 |
|
209 |
-
return prompt_tags, tag, prompt_id, items
|
210 |
|
211 |
def app(self):
|
212 |
st.title('Model Visualization and Retrieval')
|
213 |
st.write('This is a gallery of images generated by the models')
|
214 |
|
215 |
-
prompt_tags, tag, prompt_id, items = self.sidebar()
|
216 |
-
items, info, col_num = self.selection_panel(items)
|
217 |
|
218 |
# add safety check for some prompts
|
219 |
safety_check = True
|
@@ -230,57 +282,115 @@ class GalleryApp:
|
|
230 |
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
|
231 |
|
232 |
if safety_check:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
-
if 'selected_dict' in st.session_state:
|
235 |
-
# st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
|
236 |
-
dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
|
237 |
-
dynamic_weight_panel = st.columns(len(dynamic_weight_options))
|
238 |
-
|
239 |
-
if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
|
240 |
-
btn_disable = False
|
241 |
else:
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
249 |
-
|
250 |
-
|
251 |
-
|
|
|
252 |
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
gallery_space = st.empty()
|
257 |
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
-
with gallery_space.container():
|
277 |
-
with st.spinner('Loading images...'):
|
278 |
-
self.gallery_standard(items, col_num, info)
|
279 |
|
280 |
-
st.info("Don't forget to scroll back to top and click the 'Confirm Selection' button to save your selection!!!")
|
281 |
-
# prompt = st.chat_input(f"checked: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
|
282 |
-
# if prompt:
|
283 |
-
# switch_page("ranking")
|
284 |
|
285 |
def submit_actions(self, status, prompt_id):
|
286 |
# remove counter from session state
|
@@ -429,6 +539,18 @@ def load_hf_dataset():
|
|
429 |
|
430 |
return roster, promptBook, images_ds
|
431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
|
433 |
if __name__ == "__main__":
|
434 |
st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
|
|
|
9 |
from bs4 import BeautifulSoup
|
10 |
from datasets import load_dataset, Dataset, load_from_disk
|
11 |
from huggingface_hub import login
|
12 |
+
from streamlit_agraph import agraph, Node, Edge, Config
|
13 |
from streamlit_extras.switch_page_button import switch_page
|
14 |
from sklearn.svm import LinearSVC
|
15 |
|
|
|
51 |
for key in info:
|
52 |
st.write(f"**{key}**: {items.iloc[idx + j][key]}")
|
53 |
|
54 |
+
def gallery_graph(self, items):
|
55 |
+
items = load_tsne_coordinates(items)
|
56 |
+
|
57 |
+
scale = 50
|
58 |
+
items.loc[:, 'x'] = items['x'] * scale
|
59 |
+
items.loc[:, 'y'] = items['y'] * scale
|
60 |
+
|
61 |
+
nodes = []
|
62 |
+
edges = []
|
63 |
+
|
64 |
+
for idx in items.index:
|
65 |
+
# if items.loc[idx, 'modelVersion_id'] in st.session_state.selected_dict.get(items.loc[idx, 'prompt_id'], 0):
|
66 |
+
# opacity = 0.2
|
67 |
+
# else:
|
68 |
+
# opacity = 1.0
|
69 |
+
|
70 |
+
nodes.append(Node(id=items.loc[idx, 'image_id'],
|
71 |
+
# label=str(items.loc[idx, 'model_name']),
|
72 |
+
size=20,
|
73 |
+
shape='image',
|
74 |
+
image=f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{items.loc[idx, 'image_id']}.png",
|
75 |
+
x=items.loc[idx, 'x'].item(),
|
76 |
+
y=items.loc[idx, 'y'].item(),
|
77 |
+
fixed=True,
|
78 |
+
color={'background': '#00000', 'border': '#ffffff'},
|
79 |
+
# opacity=opacity,
|
80 |
+
shadow={'enabled': True, 'color': 'rgba(0,0,0,0.4)', 'size': 10, 'x': 1, 'y': 1},
|
81 |
+
# borderWidth=1,
|
82 |
+
# shapeProperties={'useBorderWithImage': True},
|
83 |
+
)
|
84 |
+
)
|
85 |
+
|
86 |
+
config = Config(width='100%',
|
87 |
+
height=800,
|
88 |
+
directed=True,
|
89 |
+
physics=False,
|
90 |
+
hierarchical=False,
|
91 |
+
# **kwargs
|
92 |
+
)
|
93 |
+
|
94 |
+
return agraph(nodes=nodes,
|
95 |
+
edges=edges,
|
96 |
+
config=config
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
def selection_panel(self, items):
|
104 |
# temperal function
|
105 |
|
|
|
220 |
|
221 |
selected_prompt = st.selectbox('Select prompt', prompts)
|
222 |
|
223 |
+
mode = st.radio('Select a mode', ['Gallery', 'Graph'], horizontal=True, index=1)
|
224 |
+
|
225 |
items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
|
226 |
prompt_id = items['prompt_id'].unique()[0]
|
227 |
note = items['note'].unique()[0]
|
|
|
258 |
except:
|
259 |
pass
|
260 |
|
261 |
+
return prompt_tags, tag, prompt_id, items, mode
|
262 |
|
263 |
def app(self):
|
264 |
st.title('Model Visualization and Retrieval')
|
265 |
st.write('This is a gallery of images generated by the models')
|
266 |
|
267 |
+
prompt_tags, tag, prompt_id, items, mode = self.sidebar()
|
268 |
+
# items, info, col_num = self.selection_panel(items)
|
269 |
|
270 |
# add safety check for some prompts
|
271 |
safety_check = True
|
|
|
282 |
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
|
283 |
|
284 |
if safety_check:
|
285 |
+
if mode == 'Gallery':
|
286 |
+
self.gallery_mode(prompt_id, items)
|
287 |
+
elif mode == 'Graph':
|
288 |
+
self.graph_mode(prompt_id, items)
|
289 |
+
|
290 |
+
|
291 |
+
def graph_mode(self, prompt_id, items):
|
292 |
+
graph_cols = st.columns([3, 1])
|
293 |
+
prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}",
|
294 |
+
disabled=False, key=f'{prompt_id}')
|
295 |
+
if prompt:
|
296 |
+
switch_page("ranking")
|
297 |
+
|
298 |
+
with graph_cols[0]:
|
299 |
+
return_value = self.gallery_graph(items)
|
300 |
+
with graph_cols[1]:
|
301 |
+
if return_value:
|
302 |
+
image_url = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{return_value}.png"
|
303 |
+
|
304 |
+
st.image(image_url)
|
305 |
+
|
306 |
+
item = items[items['image_id'] == return_value].reset_index(drop=True).iloc[0]
|
307 |
+
modelVersion_id = item['modelVersion_id']
|
308 |
+
|
309 |
+
# handle selection
|
310 |
+
if 'selected_dict' in st.session_state:
|
311 |
+
if item['prompt_id'] not in st.session_state.selected_dict:
|
312 |
+
st.session_state.selected_dict[item['prompt_id']] = []
|
313 |
+
|
314 |
+
if modelVersion_id in st.session_state.selected_dict[item['prompt_id']]:
|
315 |
+
checked = True
|
316 |
+
else:
|
317 |
+
checked = False
|
318 |
+
|
319 |
+
if checked:
|
320 |
+
deselect = st.button('Deselect', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True)
|
321 |
+
if deselect:
|
322 |
+
st.session_state.selected_dict[item['prompt_id']].remove(item['modelVersion_id'])
|
323 |
+
st.experimental_rerun()
|
324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
else:
|
326 |
+
select = st.button('Select', key=f'select_{item["prompt_id"]}_{item["modelVersion_id"]}', use_container_width=True, type='primary')
|
327 |
+
if select:
|
328 |
+
st.session_state.selected_dict[item['prompt_id']].append(item['modelVersion_id'])
|
329 |
+
st.experimental_rerun()
|
330 |
+
|
331 |
+
# st.write(item)
|
332 |
+
infos = ['model_name', 'modelVersion_name', 'model_download_count', 'clip_score', 'mcos_score',
|
333 |
+
'nsfw_score']
|
334 |
+
for info in infos:
|
335 |
+
st.write(f"**{info}**:")
|
336 |
+
st.write(item[info])
|
337 |
|
338 |
+
else:
|
339 |
+
st.info('Please click on an image to show')
|
340 |
+
|
341 |
+
|
342 |
+
def gallery_mode(self, prompt_id, items):
|
343 |
+
items, info, col_num = self.selection_panel(items)
|
344 |
+
|
345 |
+
if 'selected_dict' in st.session_state:
|
346 |
+
# st.write('checked: ', str(st.session_state.selected_dict.get(prompt_id, [])))
|
347 |
+
dynamic_weight_options = ['Grid Search', 'SVM', 'Greedy']
|
348 |
+
dynamic_weight_panel = st.columns(len(dynamic_weight_options))
|
349 |
+
|
350 |
+
if len(st.session_state.selected_dict.get(prompt_id, [])) > 0:
|
351 |
+
btn_disable = False
|
352 |
+
else:
|
353 |
+
btn_disable = True
|
354 |
|
355 |
+
for i in range(len(dynamic_weight_options)):
|
356 |
+
method = dynamic_weight_options[i]
|
357 |
+
with dynamic_weight_panel[i]:
|
358 |
+
btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
|
359 |
|
360 |
+
prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
|
361 |
+
if prompt:
|
362 |
+
switch_page("ranking")
|
|
|
363 |
|
364 |
+
with st.form(key=f'{prompt_id}'):
|
365 |
+
# buttons = st.columns([1, 1, 1])
|
366 |
+
buttons_space = st.columns([1, 1, 1, 1])
|
367 |
+
gallery_space = st.empty()
|
368 |
|
369 |
+
with buttons_space[0]:
|
370 |
+
continue_btn = st.form_submit_button('Confirm Selection', use_container_width=True, type='primary')
|
371 |
+
if continue_btn:
|
372 |
+
self.submit_actions('Continue', prompt_id)
|
373 |
|
374 |
+
with buttons_space[1]:
|
375 |
+
select_btn = st.form_submit_button('Select All', use_container_width=True)
|
376 |
+
if select_btn:
|
377 |
+
self.submit_actions('Select', prompt_id)
|
378 |
|
379 |
+
with buttons_space[2]:
|
380 |
+
deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
|
381 |
+
if deselect_btn:
|
382 |
+
self.submit_actions('Deselect', prompt_id)
|
383 |
+
|
384 |
+
with buttons_space[3]:
|
385 |
+
refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
|
386 |
+
|
387 |
+
with gallery_space.container():
|
388 |
+
with st.spinner('Loading images...'):
|
389 |
+
self.gallery_standard(items, col_num, info)
|
390 |
+
|
391 |
+
st.info("Don't forget to scroll back to top and click the 'Confirm Selection' button to save your selection!!!")
|
392 |
|
|
|
|
|
|
|
393 |
|
|
|
|
|
|
|
|
|
394 |
|
395 |
def submit_actions(self, status, prompt_id):
|
396 |
# remove counter from session state
|
|
|
539 |
|
540 |
return roster, promptBook, images_ds
|
541 |
|
542 |
+
@st.cache_data
|
543 |
+
def load_tsne_coordinates(items):
|
544 |
+
# load tsne coordinates
|
545 |
+
tsne_df = pd.read_parquet('./data/feats_tsne.parquet')
|
546 |
+
|
547 |
+
# print(tsne_df['modelVersion_id'].dtype)
|
548 |
+
|
549 |
+
print('before merge:', items)
|
550 |
+
items = items.merge(tsne_df, on=['modelVersion_id', 'prompt_id'], how='left')
|
551 |
+
print('after merge:', items)
|
552 |
+
return items
|
553 |
+
|
554 |
|
555 |
if __name__ == "__main__":
|
556 |
st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
|