Ricercar commited on
Commit
4eb5601
β€’
1 Parent(s): 57256d1

allow for manual reorder in summary page

Browse files
Files changed (4) hide show
  1. Home.py +1 -13
  2. pages/Gallery.py +1 -1
  3. pages/Ranking.py +6 -0
  4. pages/Summary.py +44 -61
Home.py CHANGED
@@ -48,18 +48,12 @@ def logout():
48
  st.session_state.pop('gallery_focus', None)
49
  st.session_state.pop('assigned_rank_mode', None)
50
  st.session_state.pop('show_NSFW', None)
 
51
 
52
 
53
  def info():
54
  with st.sidebar:
55
  st.write('## About')
56
- # st.write(
57
- # "This is an web application to collect personal preference to images synthesised by generative models fine-tuned on stable diffusion. \
58
- # **You might consider it as a tool for quickly digging out the most suitable text-to-image generation model for you from [civitai](https://civitai.com/).**"
59
- # )
60
- # st.write(
61
- # "After you picking images from gallery page, and ranking them in the ranking page, you will be able to see a dashboard showing your preferred models in the summary page, **with download links of the models ready to use in [Automatic1111 webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)!**"
62
- # )
63
  st.write(
64
  "**This is a web application for individual users to quickly dig out the most suitable text-to-image generation model from [civitai](https://civitai.com).** Our research aims to understand personal preference to images synthesized by generative models fine-tuned on stable diffusion and you can contribute by playing with this tool and giving us your feedback! "
65
  )
@@ -75,12 +69,6 @@ if __name__ == '__main__':
75
  info()
76
  st.write('A Research by [MAPS Lab](https://whongyi.github.io/MAPS-research), [NYU Shanghai](https://shanghai.nyu.edu)')
77
  st.title("πŸ™Œ Welcome to GEMRec Gallery!")
78
- # st.info("Getting obsessed with tons of different text-to-image generation models available online? \n \
79
- # Want to find the most suitable one for your taste? \n \
80
- # **GEMRec** is here to help you!"
81
- # )
82
- # st.write('### Getting obsessed with tons of different text-to-image generation models available online? Want to find the most suitable one for your taste?')
83
- # st.write('**GEMRec** is here to help you! Try it out now πŸ‘‡!')
84
 
85
  if 'user_id' not in st.session_state:
86
  login()
 
48
  st.session_state.pop('gallery_focus', None)
49
  st.session_state.pop('assigned_rank_mode', None)
50
  st.session_state.pop('show_NSFW', None)
51
+ st.session_state.pop('modelVersion_standings', None)
52
 
53
 
54
  def info():
55
  with st.sidebar:
56
  st.write('## About')
 
 
 
 
 
 
 
57
  st.write(
58
  "**This is a web application for individual users to quickly dig out the most suitable text-to-image generation model from [civitai](https://civitai.com).** Our research aims to understand personal preference to images synthesized by generative models fine-tuned on stable diffusion and you can contribute by playing with this tool and giving us your feedback! "
59
  )
 
69
  info()
70
  st.write('A Research by [MAPS Lab](https://whongyi.github.io/MAPS-research), [NYU Shanghai](https://shanghai.nyu.edu)')
71
  st.title("πŸ™Œ Welcome to GEMRec Gallery!")
 
 
 
 
 
 
72
 
73
  if 'user_id' not in st.session_state:
74
  login()
pages/Gallery.py CHANGED
@@ -383,7 +383,7 @@ class GalleryApp:
383
  default_expand = True
384
  else:
385
  default_expand = False
386
-
387
  with st.expander(f'**{prompt}**', expanded=default_expand):
388
  # st.caption('select info to show')
389
  checkout_panel = st.columns([5, 3])
 
383
  default_expand = True
384
  else:
385
  default_expand = False
386
+
387
  with st.expander(f'**{prompt}**', expanded=default_expand):
388
  # st.caption('select info to show')
389
  checkout_panel = st.columns([5, 3])
pages/Ranking.py CHANGED
@@ -190,6 +190,8 @@ class RankingApp:
190
 
191
  if progress == 'finished':
192
  st.session_state.progress[prompt_id] = 'finished'
 
 
193
  else:
194
  st.session_state.counter[prompt_id] += 1
195
 
@@ -256,6 +258,10 @@ class RankingApp:
256
 
257
  if curr_position == total_num - 1:
258
  st.session_state.progress[prompt_id] = 'finished'
 
 
 
 
259
  # st.experimental_rerun()
260
  else:
261
  st.session_state.pointer[prompt_id][loser] = curr_position + 1
 
190
 
191
  if progress == 'finished':
192
  st.session_state.progress[prompt_id] = 'finished'
193
+ # drop 'modelVersion_standings' from session state if exists
194
+ st.session_state.pop('modelVersion_standings', None)
195
  else:
196
  st.session_state.counter[prompt_id] += 1
197
 
 
258
 
259
  if curr_position == total_num - 1:
260
  st.session_state.progress[prompt_id] = 'finished'
261
+
262
+ # drop 'modelVersion_standings' from session state if exists
263
+ st.session_state.pop('modelVersion_standings', None)
264
+
265
  # st.experimental_rerun()
266
  else:
267
  st.session_state.pointer[prompt_id][loser] = curr_position + 1
pages/Summary.py CHANGED
@@ -23,6 +23,10 @@ class DashboardApp:
23
  self.promptBook = promptBook
24
  self.session_finished = session_finished
25
 
 
 
 
 
26
  def sidebar(self, tags, mode):
27
  with st.sidebar:
28
  tag = st.selectbox('Select a tag', tags, key='tag')
@@ -54,86 +58,61 @@ class DashboardApp:
54
  results = curser.fetchall()
55
  curser.close()
56
 
57
- modelVersion_standings = self.score_calculator(results, db_table)
 
58
 
59
- # sort the modelVersion_standings by value into a list of tuples in descending order
60
- modelVersion_standings = sorted(modelVersion_standings.items(), key=lambda x: x[1], reverse=True)
61
 
62
  tab1, tab2 = st.tabs(['Top Picks', 'Detailed Info'])
63
 
64
  with tab1:
65
  # self.podium(modelVersion_standings)
66
- self.podium_expander(modelVersion_standings)
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  with tab2:
69
  st.write('## Detailed information of all selected models')
70
- detailed_info = pd.merge(pd.DataFrame(modelVersion_standings, columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id')
71
- st.data_editor(detailed_info, hide_index=True, disabled=True)
72
-
73
- def podium(self, modelVersion_standings, n=3):
74
- st.write('## Top picks')
75
- metric_cols = st.columns(n)
76
- image_display = st.empty()
77
 
78
- for i in range(n):
79
- with metric_cols[i]:
80
- modelVersion_id = modelVersion_standings[i][0]
81
- winning_times = modelVersion_standings[i][1]
82
-
83
- model_id, model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_id', 'model_name', 'modelVersion_name', 'modelVersion_url']].values[0]
84
-
85
- metric_card = stylable_container(
86
- key="container_with_border",
87
- css_styles="""
88
- {
89
- border: 1.5px solid rgba(49, 51, 63, 0.2);
90
- border-left: 0.5rem solid gold;
91
- border-radius: 5px;
92
- padding: calc(1em + 5px);
93
- gap: 0.5em;
94
- box-shadow: 0 0 2rem rgba(0, 0, 0, 0.08);
95
- overflow-x: scroll;
96
- }
97
- """,
98
- )
99
-
100
- with metric_card:
101
- icon = 'πŸ₯‡'if i == 0 else 'πŸ₯ˆ' if i == 1 else 'πŸ₯‰'
102
- # st.write(model_id)
103
- st.write(f'### {icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})')
104
- st.write(f'Ranking Score: {winning_times}')
105
 
106
- show_image = st.button('Show Image', key=modelVersion_id, use_container_width=True)
107
- if show_image:
108
 
109
- images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
110
- with image_display.container():
111
- st.write('---')
112
- st.write(f'### Images generated with {icon} {model_name}, {modelVersion_name}')
113
- col_num = 4
114
- image_cols = st.columns(col_num)
115
- for i in range(len(images)):
116
- with image_cols[i % col_num]:
117
- image = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{images[i]}.png"
118
- st.image(image, use_column_width=True)
119
-
120
- def podium_expander(self, modelVersion_standings, n=3):
121
- # st.write('## Top picks')
122
- # metric_cols = st.columns(n)
123
  for i in range(n):
124
- # with metric_cols[i]:
125
- modelVersion_id = modelVersion_standings[i][0]
126
- winning_times = modelVersion_standings[i][1]
127
 
128
  model_id, model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_id', 'model_name', 'modelVersion_name', 'modelVersion_url']].values[0]
129
 
130
- icon = 'πŸ₯‡'if i == 0 else 'πŸ₯ˆ' if i == 1 else 'πŸ₯‰'
131
  podium_display = st.columns([1, 14])
132
  with podium_display[0]:
133
- st.title(f'{icon}')
 
 
 
 
134
  with podium_display[1]:
135
- st.write(f'##### {model_name}, {modelVersion_name}')
136
- st.write(f'[Civitai Page](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id}), [Model Download Link]({url}), Ranking Score: {winning_times}')
 
 
 
 
 
 
 
137
  # with st.expander(f'**{icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})**, Ranking Score: {winning_times}'):
138
  with st.expander(f'Show Images'):
139
  images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
@@ -159,6 +138,10 @@ class DashboardApp:
159
  if i != n - 1:
160
  st.write('---')
161
 
 
 
 
 
162
  def score_calculator(self, results, db_table):
163
  modelVersion_standings = {}
164
  if db_table == 'battle_results':
 
23
  self.promptBook = promptBook
24
  self.session_finished = session_finished
25
 
26
+ # init modelVersion_standings
27
+ if 'modelVersion_standings' not in st.session_state:
28
+ st.session_state.modelVersion_standings = {}
29
+
30
  def sidebar(self, tags, mode):
31
  with st.sidebar:
32
  tag = st.selectbox('Select a tag', tags, key='tag')
 
58
  results = curser.fetchall()
59
  curser.close()
60
 
61
+ if tag not in st.session_state.modelVersion_standings:
62
+ st.session_state.modelVersion_standings[tag] = self.score_calculator(results, db_table)
63
 
64
+ # sort the modelVersion_standings by value into a list of tuples in descending order
65
+ st.session_state.modelVersion_standings[tag] = sorted(st.session_state.modelVersion_standings[tag].items(), key=lambda x: x[1], reverse=True)
66
 
67
  tab1, tab2 = st.tabs(['Top Picks', 'Detailed Info'])
68
 
69
  with tab1:
70
  # self.podium(modelVersion_standings)
71
+ switch_stage = st.toggle('Manual Reorder', key='switch_stage')
72
+ if switch_stage:
73
+ self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='edit')
74
+ else:
75
+ self.podium_expander(tag, n=3, summary_mode='display')
76
+ # if st.session_state.summary_mode == 'display':
77
+ # switch_stage = st.button('Manual Reorder', key='switch_stage_edit', on_click=lambda: st.session_state.__setitem__('summary_mode', 'edit'))
78
+ # self.podium_expander(tag, n=3, summary_mode='display')
79
+ #
80
+ # elif st.session_state.summary_mode == 'edit':
81
+ # switch_stage = st.button('Done', key='switch_stage_done', type='primary', on_click=lambda: st.session_state.__setitem__('summary_mode', 'display'))
82
+ # self.podium_expander(tag, n=len(st.session_state.modelVersion_standings[tag]), summary_mode='edit')
83
 
84
  with tab2:
85
  st.write('## Detailed information of all selected models')
86
+ detailed_info = pd.merge(pd.DataFrame(st.session_state.modelVersion_standings[tag], columns=['modelVersion_id', 'ranking_score']), self.roster, on='modelVersion_id')
 
 
 
 
 
 
87
 
88
+ st.data_editor(detailed_info, hide_index=False, disabled=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ def podium_expander(self, tag, n=3, summary_mode: ['display', 'edit'] = 'display'):
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  for i in range(n):
93
+ modelVersion_id = st.session_state.modelVersion_standings[tag][i][0]
94
+ winning_times = st.session_state.modelVersion_standings[tag][i][1]
 
95
 
96
  model_id, model_name, modelVersion_name, url = self.roster[self.roster['modelVersion_id'] == modelVersion_id][['model_id', 'model_name', 'modelVersion_name', 'modelVersion_url']].values[0]
97
 
98
+ icon = 'πŸ₯‡'if i == 0 else 'πŸ₯ˆ' if i == 1 else 'πŸ₯‰' if i == 2 else '🎈'
99
  podium_display = st.columns([1, 14])
100
  with podium_display[0]:
101
+ if summary_mode == 'display':
102
+ st.title(f'{icon}')
103
+ elif summary_mode == 'edit':
104
+ moveup = st.button('⬆', key=f'moveup_{modelVersion_id}', help='Move this model up', disabled=i == 0, on_click=self.switch_order, args=(tag, i, i - 1))
105
+ movedown = st.button('⬇', key=f'movedown_{modelVersion_id}', help='Move this model down', disabled=i == n - 1, on_click=self.switch_order, args=(tag, i, i + 1))
106
  with podium_display[1]:
107
+ title_display = st.columns([4, 1, 1])
108
+ with title_display[0]:
109
+ st.write(f'##### {model_name}, {modelVersion_name}')
110
+ st.write(f'Ranking Score: {winning_times}')
111
+ with title_display[1]:
112
+ st.link_button('Download Model', url, use_container_width=True)
113
+ with title_display[2]:
114
+ st.link_button('Civitai Page', f'https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id}', use_container_width=True, type='primary')
115
+ # st.write(f'[Civitai Page](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id}), [Model Download Link]({url}), Ranking Score: {winning_times}')
116
  # with st.expander(f'**{icon} {model_name}, [{modelVersion_name}](https://civitai.com/models/{model_id}?modelVersionId={modelVersion_id})**, Ranking Score: {winning_times}'):
117
  with st.expander(f'Show Images'):
118
  images = self.promptBook[self.promptBook['modelVersion_id'] == modelVersion_id]['image_id'].values
 
138
  if i != n - 1:
139
  st.write('---')
140
 
141
+ def switch_order(self, tag, current, target):
142
+ st.session_state.modelVersion_standings[tag][current], st.session_state.modelVersion_standings[tag][target] = st.session_state.modelVersion_standings[tag][target], st.session_state.modelVersion_standings[tag][current]
143
+
144
+
145
  def score_calculator(self, results, db_table):
146
  modelVersion_standings = {}
147
  if db_table == 'battle_results':