Spaces:
Running
Running
beta testing for new gallery page
Browse files- Home.py +11 -3
- pages/Gallery.py +44 -32
Home.py
CHANGED
@@ -42,12 +42,20 @@ def info():
|
|
42 |
with st.sidebar:
|
43 |
st.write('## About')
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
st.write(
|
46 |
-
"This is
|
47 |
-
**You might consider it as a tool for quickly digging out the most suitable text-to-image generation model for you from [civitai](https://civitai.com/).**"
|
48 |
)
|
|
|
49 |
st.write(
|
50 |
-
"After
|
51 |
)
|
52 |
|
53 |
|
|
|
42 |
with st.sidebar:
|
43 |
st.write('## About')
|
44 |
|
45 |
+
# st.write(
|
46 |
+
# "This is an web application to collect personal preference to images synthesised by generative models fine-tuned on stable diffusion. \
|
47 |
+
# **You might consider it as a tool for quickly digging out the most suitable text-to-image generation model for you from [civitai](https://civitai.com/).**"
|
48 |
+
# )
|
49 |
+
# st.write(
|
50 |
+
# "After you picking images from gallery page, and ranking them in the ranking page, you will be able to see a dashboard showing your preferred models in the summary page, **with download links of the models ready to use in [Automatic1111 webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)!**"
|
51 |
+
# )
|
52 |
+
|
53 |
st.write(
|
54 |
+
"This is a web application for individual users to quickly dig out the most suitable text-to-image generation model from civitai. Our research aims to understand personal preference to images synthesized by generative models fine-tuned on stable diffusion and you can contribute by playing with this tool and giving us your feedback! "
|
|
|
55 |
)
|
56 |
+
|
57 |
st.write(
|
58 |
+
"After picking images you liked from Gallery and a battle-mode Ranking Contest, a summary dashboard will be presented indicating your preferred models with download links ready to be deployed in Webui !"
|
59 |
)
|
60 |
|
61 |
|
pages/Gallery.py
CHANGED
@@ -12,6 +12,7 @@ from datasets import load_dataset, Dataset, load_from_disk
|
|
12 |
from huggingface_hub import login
|
13 |
from streamlit_agraph import agraph, Node, Edge, Config
|
14 |
from streamlit_extras.switch_page_button import switch_page
|
|
|
15 |
from sklearn.svm import LinearSVC
|
16 |
|
17 |
SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
|
@@ -226,6 +227,7 @@ class GalleryApp:
|
|
226 |
|
227 |
# items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
|
228 |
|
|
|
229 |
|
230 |
# show source
|
231 |
if isinstance(note, str):
|
@@ -282,36 +284,46 @@ class GalleryApp:
|
|
282 |
|
283 |
subset_selector = st.columns([3, 1])
|
284 |
with subset_selector[0]:
|
285 |
-
selected_prompt = st.selectbox('Select prompt', prompts, index=3)
|
|
|
286 |
with subset_selector[1]:
|
287 |
-
subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
|
316 |
def graph_mode(self, prompt_id, items):
|
317 |
graph_cols = st.columns([3, 1])
|
@@ -397,9 +409,9 @@ class GalleryApp:
|
|
397 |
# with dynamic_weight_panel[i]:
|
398 |
# btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
|
399 |
|
400 |
-
prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
|
401 |
-
if prompt:
|
402 |
-
|
403 |
|
404 |
with st.form(key=f'{prompt_id}'):
|
405 |
# buttons = st.columns([1, 1, 1])
|
|
|
12 |
from huggingface_hub import login
|
13 |
from streamlit_agraph import agraph, Node, Edge, Config
|
14 |
from streamlit_extras.switch_page_button import switch_page
|
15 |
+
from streamlit_extras.no_default_selectbox import selectbox
|
16 |
from sklearn.svm import LinearSVC
|
17 |
|
18 |
SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'msq_score', 'pop': 'model_download_count'}
|
|
|
227 |
|
228 |
# items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
|
229 |
|
230 |
+
st.title('Model Visualization and Retrieval')
|
231 |
|
232 |
# show source
|
233 |
if isinstance(note, str):
|
|
|
284 |
|
285 |
subset_selector = st.columns([3, 1])
|
286 |
with subset_selector[0]:
|
287 |
+
# selected_prompt = st.selectbox('Select prompt', prompts, index=3)
|
288 |
+
selected_prompt = selectbox('Select prompt', prompts, key=f'prompt_{tag}', no_selection_label='---')
|
289 |
with subset_selector[1]:
|
290 |
+
subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{tag}')
|
291 |
+
|
292 |
+
if selected_prompt is None:
|
293 |
+
st.markdown(':orange[Please select a prompt above👆]')
|
294 |
+
else:
|
295 |
+
items = items[items['prompt'] == selected_prompt].reset_index(drop=True)
|
296 |
+
prompt_id = items['prompt_id'].unique()[0]
|
297 |
+
note = items['note'].unique()[0]
|
298 |
+
|
299 |
+
# add safety check for some prompts
|
300 |
+
safety_check = True
|
301 |
+
unsafe_prompts = {}
|
302 |
+
# initialize unsafe prompts
|
303 |
+
for prompt_tag in prompt_tags:
|
304 |
+
unsafe_prompts[prompt_tag] = []
|
305 |
+
# manually add unsafe prompts
|
306 |
+
unsafe_prompts['world knowledge'] = [83]
|
307 |
+
unsafe_prompts['abstract'] = [1, 3]
|
308 |
+
|
309 |
+
if int(prompt_id.item()) in unsafe_prompts[tag]:
|
310 |
+
st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
|
311 |
+
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'safety_{prompt_id}')
|
312 |
+
|
313 |
+
if safety_check:
|
314 |
+
|
315 |
+
if subset == 'Selected Only' and 'selected_dict' in st.session_state:
|
316 |
+
# try:
|
317 |
+
items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(drop=True)
|
318 |
+
self.gallery_mode(prompt_id, items)
|
319 |
+
# except:
|
320 |
+
# st.warning('No selected images found')
|
321 |
+
else:
|
322 |
+
self.graph_mode(prompt_id, items)
|
323 |
+
try:
|
324 |
+
self.sidebar(items, prompt_id, note)
|
325 |
+
except:
|
326 |
+
pass
|
327 |
|
328 |
def graph_mode(self, prompt_id, items):
|
329 |
graph_cols = st.columns([3, 1])
|
|
|
409 |
# with dynamic_weight_panel[i]:
|
410 |
# btn = st.button(method, use_container_width=True, disabled=btn_disable, on_click=self.dynamic_weight, args=(prompt_id, items, method))
|
411 |
|
412 |
+
# prompt = st.chat_input(f"Selected model version ids: {str(st.session_state.selected_dict.get(prompt_id, []))}", disabled=False, key=f'{prompt_id}')
|
413 |
+
# if prompt:
|
414 |
+
# switch_page("ranking")
|
415 |
|
416 |
with st.form(key=f'{prompt_id}'):
|
417 |
# buttons = st.columns([1, 1, 1])
|