Spaces:
Running
Running
update two-stage interface
Browse files- Home.py +2 -0
- pages/Gallery.py +13 -6
- pages/Ranking.py +19 -12
Home.py
CHANGED
@@ -30,6 +30,7 @@ def save_user_id(user_id):
|
|
30 |
if not user_id:
|
31 |
user_id = 'anonymous' + str(random.randint(0, 100000))
|
32 |
st.session_state.user_id = [user_id, datetime.now().strftime("%Y-%m-%d %H:%M:%S")]
|
|
|
33 |
|
34 |
|
35 |
def logout():
|
@@ -38,6 +39,7 @@ def logout():
|
|
38 |
st.session_state.pop('score_weights', None)
|
39 |
st.session_state.pop('gallery_state', None)
|
40 |
st.session_state.pop('progress', None)
|
|
|
41 |
|
42 |
|
43 |
def info():
|
|
|
30 |
if not user_id:
|
31 |
user_id = 'anonymous' + str(random.randint(0, 100000))
|
32 |
st.session_state.user_id = [user_id, datetime.now().strftime("%Y-%m-%d %H:%M:%S")]
|
33 |
+
st.session_state.assigned_rank_mode = random.choice(['sort', 'battle'])
|
34 |
|
35 |
|
36 |
def logout():
|
|
|
39 |
st.session_state.pop('score_weights', None)
|
40 |
st.session_state.pop('gallery_state', None)
|
41 |
st.session_state.pop('progress', None)
|
42 |
+
st.session_state.pop('gallery_focus', None)
|
43 |
|
44 |
|
45 |
def info():
|
pages/Gallery.py
CHANGED
@@ -32,6 +32,9 @@ class GalleryApp:
|
|
32 |
if 'selected_dict' not in st.session_state:
|
33 |
st.session_state['selected_dict'] = {}
|
34 |
|
|
|
|
|
|
|
35 |
def gallery_standard(self, items, col_num, info):
|
36 |
rows = len(items) // col_num + 1
|
37 |
containers = [st.container() for _ in range(rows)]
|
@@ -310,7 +313,7 @@ class GalleryApp:
|
|
310 |
# st.markdown(':orange[Please select a prompt above๐]')
|
311 |
st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
|
312 |
|
313 |
-
with subset_selector[1]:
|
314 |
st.write(':orange[๐ **Please select a prompt**]')
|
315 |
|
316 |
else:
|
@@ -322,6 +325,10 @@ class GalleryApp:
|
|
322 |
if prompt_id not in st.session_state.gallery_state:
|
323 |
st.session_state.gallery_state[prompt_id] = 'graph'
|
324 |
|
|
|
|
|
|
|
|
|
325 |
# add safety check for some prompts
|
326 |
safety_check = True
|
327 |
unsafe_prompts = {}
|
@@ -348,7 +355,7 @@ class GalleryApp:
|
|
348 |
# # st.warning('No selected images found')
|
349 |
# else:
|
350 |
self.graph_mode(prompt_id, items)
|
351 |
-
with subset_selector[1]:
|
352 |
# if st.session_state.gallery_state[prompt_id] == 'graph':
|
353 |
# subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{tag}')
|
354 |
has_selection = False
|
@@ -359,7 +366,7 @@ class GalleryApp:
|
|
359 |
pass
|
360 |
|
361 |
if has_selection:
|
362 |
-
checkout = st.button('Check out selections', use_container_width=True, type='primary')
|
363 |
if checkout:
|
364 |
print('checkout')
|
365 |
|
@@ -367,17 +374,17 @@ class GalleryApp:
|
|
367 |
print(st.session_state.gallery_state[prompt_id])
|
368 |
st.experimental_rerun()
|
369 |
else:
|
370 |
-
st.write('Select images you like below
|
371 |
|
372 |
elif st.session_state.gallery_state[prompt_id] == 'gallery':
|
373 |
items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
|
374 |
drop=True)
|
375 |
self.gallery_mode(prompt_id, items)
|
376 |
|
377 |
-
with subset_selector[1]:
|
378 |
state_operations = st.columns([1, 1])
|
379 |
with state_operations[0]:
|
380 |
-
back = st.button('Back', use_container_width=True)
|
381 |
if back:
|
382 |
st.session_state.gallery_state[prompt_id] = 'graph'
|
383 |
st.experimental_rerun()
|
|
|
32 |
if 'selected_dict' not in st.session_state:
|
33 |
st.session_state['selected_dict'] = {}
|
34 |
|
35 |
+
if 'gallery_focus' not in st.session_state:
|
36 |
+
st.session_state.gallery_focus = {'tag': None, 'prompt': None}
|
37 |
+
|
38 |
def gallery_standard(self, items, col_num, info):
|
39 |
rows = len(items) // col_num + 1
|
40 |
containers = [st.container() for _ in range(rows)]
|
|
|
313 |
# st.markdown(':orange[Please select a prompt above๐]')
|
314 |
st.write('**Feel free to navigate among tags and pages! Your selection will be saved within one log-in session.**')
|
315 |
|
316 |
+
with subset_selector[-1]:
|
317 |
st.write(':orange[๐ **Please select a prompt**]')
|
318 |
|
319 |
else:
|
|
|
325 |
if prompt_id not in st.session_state.gallery_state:
|
326 |
st.session_state.gallery_state[prompt_id] = 'graph'
|
327 |
|
328 |
+
# add focus to session state
|
329 |
+
st.session_state.gallery_focus['tag'] = tag
|
330 |
+
st.session_state.gallery_focus['prompt'] = selected_prompt
|
331 |
+
|
332 |
# add safety check for some prompts
|
333 |
safety_check = True
|
334 |
unsafe_prompts = {}
|
|
|
355 |
# # st.warning('No selected images found')
|
356 |
# else:
|
357 |
self.graph_mode(prompt_id, items)
|
358 |
+
with subset_selector[-1]:
|
359 |
# if st.session_state.gallery_state[prompt_id] == 'graph':
|
360 |
# subset = st.selectbox('Select a subset', ['All', 'Selected Only'], index=0, key=f'subset_{tag}')
|
361 |
has_selection = False
|
|
|
366 |
pass
|
367 |
|
368 |
if has_selection:
|
369 |
+
checkout = st.button('๐ Check out selections', use_container_width=True, type='primary')
|
370 |
if checkout:
|
371 |
print('checkout')
|
372 |
|
|
|
374 |
print(st.session_state.gallery_state[prompt_id])
|
375 |
st.experimental_rerun()
|
376 |
else:
|
377 |
+
st.write(':orange[๐ **Select images you like below**]')
|
378 |
|
379 |
elif st.session_state.gallery_state[prompt_id] == 'gallery':
|
380 |
items = items[items['modelVersion_id'].isin(st.session_state.selected_dict[prompt_id])].reset_index(
|
381 |
drop=True)
|
382 |
self.gallery_mode(prompt_id, items)
|
383 |
|
384 |
+
with subset_selector[-1]:
|
385 |
state_operations = st.columns([1, 1])
|
386 |
with state_operations[0]:
|
387 |
+
back = st.button('Back to ๐ผ๏ธ', use_container_width=True)
|
388 |
if back:
|
389 |
st.session_state.gallery_state[prompt_id] = 'graph'
|
390 |
st.experimental_rerun()
|
pages/Ranking.py
CHANGED
@@ -27,13 +27,20 @@ class RankingApp:
|
|
27 |
def sidebar(self):
|
28 |
with st.sidebar:
|
29 |
prompt_tags = self.promptBook['tag'].unique()
|
30 |
-
prompt_tags = np.sort(prompt_tags)
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
33 |
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
|
34 |
-
prompts = np.sort(items['prompt'].unique())[::-1]
|
|
|
|
|
|
|
35 |
|
36 |
-
selected_prompt = st.selectbox('Select a prompt', prompts)
|
37 |
|
38 |
mode = st.radio('Select a mode', ['Drag and Sort', 'Battle'], index=1)
|
39 |
|
@@ -204,24 +211,24 @@ class RankingApp:
|
|
204 |
with left:
|
205 |
image_id = items['image_id'][st.session_state.pointer[prompt_id]['left']]
|
206 |
img_url = self.images_endpoint + str(image_id) + '.png'
|
207 |
-
st.image(img_url, use_column_width=True)
|
208 |
|
209 |
-
# write the total score of this image
|
210 |
-
total_score = items['total_score'][st.session_state.pointer[prompt_id]['left']]
|
211 |
-
st.write(f'Total Score: {total_score}')
|
212 |
|
213 |
btn_left = st.button('Left is better', key='left', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'left', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
|
|
214 |
|
215 |
with right:
|
216 |
image_id = items['image_id'][st.session_state.pointer[prompt_id]['right']]
|
217 |
img_url = self.images_endpoint + str(image_id) + '.png'
|
218 |
-
st.image(img_url, use_column_width=True)
|
219 |
|
220 |
-
# write the total score of this image
|
221 |
-
total_score = items['total_score'][st.session_state.pointer[prompt_id]['right']]
|
222 |
-
st.write(f'Total Score: {total_score}')
|
223 |
|
224 |
btn_right = st.button('Right is better', key='right', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'right', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
|
|
225 |
|
226 |
def next_battle(self, prompt_id, image_ids, winner, curr_position, total_num):
|
227 |
loser = 'left' if winner == 'right' else 'right'
|
|
|
27 |
def sidebar(self):
|
28 |
with st.sidebar:
|
29 |
prompt_tags = self.promptBook['tag'].unique()
|
30 |
+
prompt_tags = np.sort(prompt_tags).tolist()
|
31 |
|
32 |
+
print(st.session_state.gallery_focus)
|
33 |
+
tag_idx = prompt_tags.index(st.session_state.gallery_focus['tag']) if st.session_state.gallery_focus['tag'] in prompt_tags else 0
|
34 |
+
print(tag_idx)
|
35 |
+
|
36 |
+
tag = st.selectbox('Select a prompt tag', prompt_tags, index=tag_idx)
|
37 |
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
|
38 |
+
prompts = np.sort(items['prompt'].unique())[::-1].tolist()
|
39 |
+
|
40 |
+
prompt_idx = prompts.index(st.session_state.gallery_focus['prompt']) if st.session_state.gallery_focus['prompt'] in prompts else 0
|
41 |
+
print(prompt_idx)
|
42 |
|
43 |
+
selected_prompt = st.selectbox('Select a prompt', prompts, index=prompt_idx)
|
44 |
|
45 |
mode = st.radio('Select a mode', ['Drag and Sort', 'Battle'], index=1)
|
46 |
|
|
|
211 |
with left:
|
212 |
image_id = items['image_id'][st.session_state.pointer[prompt_id]['left']]
|
213 |
img_url = self.images_endpoint + str(image_id) + '.png'
|
|
|
214 |
|
215 |
+
# # write the total score of this image
|
216 |
+
# total_score = items['total_score'][st.session_state.pointer[prompt_id]['left']]
|
217 |
+
# st.write(f'Total Score: {total_score}')
|
218 |
|
219 |
btn_left = st.button('Left is better', key='left', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'left', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
220 |
+
st.image(img_url, use_column_width=True)
|
221 |
|
222 |
with right:
|
223 |
image_id = items['image_id'][st.session_state.pointer[prompt_id]['right']]
|
224 |
img_url = self.images_endpoint + str(image_id) + '.png'
|
|
|
225 |
|
226 |
+
# # write the total score of this image
|
227 |
+
# total_score = items['total_score'][st.session_state.pointer[prompt_id]['right']]
|
228 |
+
# st.write(f'Total Score: {total_score}')
|
229 |
|
230 |
btn_right = st.button('Right is better', key='right', on_click=self.next_battle, kwargs={'prompt_id': prompt_id, 'image_ids': items['image_id'], 'winner': 'right', 'curr_position': curr_position, 'total_num': len(items)}, use_container_width=True)
|
231 |
+
st.image(img_url, use_column_width=True)
|
232 |
|
233 |
def next_battle(self, prompt_id, image_ids, winner, curr_position, total_num):
|
234 |
loser = 'left' if winner == 'right' else 'right'
|