Spaces:
Running
Running
add new custom weighting mode
Browse files
app.py
CHANGED
@@ -10,6 +10,8 @@ from datasets import load_dataset, Dataset, load_from_disk
|
|
10 |
from huggingface_hub import login
|
11 |
import os
|
12 |
import requests
|
|
|
|
|
13 |
|
14 |
SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'avg_rank', 'pop': 'model_download_count'}
|
15 |
|
@@ -57,7 +59,7 @@ class GalleryApp:
|
|
57 |
with cols[j]:
|
58 |
# show image
|
59 |
image = st.session_state.images[items.iloc[idx+j]['row_idx'].item()]['image']
|
60 |
-
|
61 |
st.image(image,
|
62 |
use_column_width=True,
|
63 |
)
|
@@ -75,75 +77,22 @@ class GalleryApp:
|
|
75 |
# with containers[row_idx+1]:
|
76 |
# st.image(image, use_column_width=True)
|
77 |
|
78 |
-
def
|
79 |
-
st.title('Model Coffer Gallery')
|
80 |
-
st.write('This is a gallery of images generated by the models in the Model Coffer')
|
81 |
-
|
82 |
-
with st.sidebar:
|
83 |
-
prompt_tags = self.promptBook['tag'].unique()
|
84 |
-
# sort tags by alphabetical order
|
85 |
-
prompt_tags = np.sort(prompt_tags)[::-1]
|
86 |
-
|
87 |
-
tag = st.selectbox('Select a tag', prompt_tags)
|
88 |
-
|
89 |
-
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
|
90 |
-
|
91 |
-
original_prompts = np.sort(items['prompt'].unique())[::-1]
|
92 |
-
|
93 |
-
# remove the first four items in the prompt, which are mostly the same
|
94 |
-
if tag != 'abstract':
|
95 |
-
prompts = [', '.join(x.split(', ')[4:]) for x in original_prompts]
|
96 |
-
prompt = st.selectbox('Select prompt', prompts)
|
97 |
-
|
98 |
-
idx = prompts.index(prompt)
|
99 |
-
prompt_full = ', '.join(original_prompts[idx].split(', ')[:4]) + ', ' + prompt
|
100 |
-
else:
|
101 |
-
prompt_full = st.selectbox('Select prompt', original_prompts)
|
102 |
-
|
103 |
-
prompt_id = items[items['prompt'] == prompt_full]['prompt_id'].unique()[0]
|
104 |
-
items = items[items['prompt_id'] == prompt_id].reset_index(drop=True)
|
105 |
-
|
106 |
-
st.write('**Prompt ID**')
|
107 |
-
st.caption(f"{prompt_id}")
|
108 |
-
st.write('**Prompt**')
|
109 |
-
st.caption(f"{items['prompt'][0]}")
|
110 |
-
st.write('**Negative Prompt**')
|
111 |
-
st.caption(f"{items['negativePrompt'][0]}")
|
112 |
-
st.write('**Sampler**')
|
113 |
-
st.caption(f"{items['sampler'][0]}")
|
114 |
-
st.write('**cfgScale**')
|
115 |
-
st.caption(f"{items['cfgScale'][0]}")
|
116 |
-
st.write('**Size**')
|
117 |
-
st.caption(f"width: {items['size'][0].split('x')[0]}, height: {items['size'][0].split('x')[1]}")
|
118 |
-
st.write('**Seed**')
|
119 |
-
st.caption(f"{items['seed'][0]}")
|
120 |
-
|
121 |
-
# # for tag as civitai, add civitai reference
|
122 |
-
# if tag == 'civitai':
|
123 |
-
# st.write('**Reference**')
|
124 |
-
#
|
125 |
-
# res = requests.get(f'https://civitai.com/images', params={'post_id': prompt_id})
|
126 |
-
# st.write(res)
|
127 |
-
# image_url = res.json()['items'][0]['url']
|
128 |
-
# st.image(image_url, use_column_width=True)
|
129 |
-
|
130 |
-
# with images:
|
131 |
-
# selecters = st.columns([2, 1, 2, 0.5])
|
132 |
selecters = st.columns([4, 1, 1])
|
133 |
|
134 |
with selecters[0]:
|
135 |
-
# # sort_by = st.selectbox('Sort by', items.columns[11: -1])
|
136 |
-
# sort_by = st.selectbox('Sort by', ['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
|
137 |
-
# 'modelVersion_name', 'modelVersion_id'])
|
138 |
-
print(items.columns)
|
139 |
types = st.columns([1, 3])
|
140 |
with types[0]:
|
141 |
sort_type = st.selectbox('Sort by', ['IDs and Names', 'Scores'])
|
142 |
with types[1]:
|
143 |
if sort_type == 'IDs and Names':
|
144 |
-
sort_by = st.selectbox('Sort by',
|
|
|
|
|
145 |
elif sort_type == 'Scores':
|
146 |
-
sort_by = st.multiselect('Sort by', ['clip_score', 'avg_rank', 'popularity'],
|
|
|
|
|
147 |
# process sort_by to map to the column name
|
148 |
|
149 |
if len(sort_by) == 3:
|
@@ -172,18 +121,24 @@ class GalleryApp:
|
|
172 |
items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
|
173 |
|
174 |
with selecters[2]:
|
175 |
-
filter = st.selectbox('Filter', ['
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
info = st.multiselect('Show Info',
|
182 |
['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
|
183 |
-
'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
|
|
|
184 |
default=sort_by)
|
185 |
|
186 |
-
print('info', info)
|
187 |
# add one annotation
|
188 |
mentioned_scores = []
|
189 |
for i in info:
|
@@ -193,20 +148,173 @@ class GalleryApp:
|
|
193 |
if SCORE_NAME_MAPPING[m] not in mentioned_scores:
|
194 |
mentioned_scores.append(SCORE_NAME_MAPPING[m])
|
195 |
if len(mentioned_scores) > 0:
|
196 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
def reset_current_prompt(self, prompt_id):
|
212 |
# reset current prompt
|
@@ -223,10 +331,6 @@ class GalleryApp:
|
|
223 |
dataset = load_dataset('NYUSHPRP/ModelCofferMetadata', split='train')
|
224 |
# get checked images
|
225 |
checked_info = self.promptBook['checked']
|
226 |
-
# print('checked_info: ', checked_info)
|
227 |
-
# for d in checked_info:
|
228 |
-
# if d is True:
|
229 |
-
# print('checked')
|
230 |
|
231 |
if 'checked' in dataset.column_names:
|
232 |
dataset = dataset.remove_columns('checked')
|
@@ -254,6 +358,10 @@ if __name__ == '__main__':
|
|
254 |
if 'checked' not in st.session_state.promptBook.columns:
|
255 |
st.session_state.promptBook.loc[:, 'checked'] = False
|
256 |
|
|
|
|
|
|
|
|
|
257 |
st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
|
258 |
# st.session_state.images = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train', streaming=True)
|
259 |
print(st.session_state.images)
|
|
|
10 |
from huggingface_hub import login
|
11 |
import os
|
12 |
import requests
|
13 |
+
from bs4 import BeautifulSoup
|
14 |
+
import re
|
15 |
|
16 |
SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'avg_rank', 'pop': 'model_download_count'}
|
17 |
|
|
|
59 |
with cols[j]:
|
60 |
# show image
|
61 |
image = st.session_state.images[items.iloc[idx+j]['row_idx'].item()]['image']
|
62 |
+
|
63 |
st.image(image,
|
64 |
use_column_width=True,
|
65 |
)
|
|
|
77 |
# with containers[row_idx+1]:
|
78 |
# st.image(image, use_column_width=True)
|
79 |
|
80 |
+
def selection_panel(self, items):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
selecters = st.columns([4, 1, 1])
|
82 |
|
83 |
with selecters[0]:
|
|
|
|
|
|
|
|
|
84 |
types = st.columns([1, 3])
|
85 |
with types[0]:
|
86 |
sort_type = st.selectbox('Sort by', ['IDs and Names', 'Scores'])
|
87 |
with types[1]:
|
88 |
if sort_type == 'IDs and Names':
|
89 |
+
sort_by = st.selectbox('Sort by',
|
90 |
+
['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id'],
|
91 |
+
label_visibility='hidden')
|
92 |
elif sort_type == 'Scores':
|
93 |
+
sort_by = st.multiselect('Sort by', ['clip_score', 'avg_rank', 'popularity'],
|
94 |
+
label_visibility='hidden',
|
95 |
+
default=['clip_score', 'avg_rank', 'popularity'])
|
96 |
# process sort_by to map to the column name
|
97 |
|
98 |
if len(sort_by) == 3:
|
|
|
121 |
items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
|
122 |
|
123 |
with selecters[2]:
|
124 |
+
filter = st.selectbox('Filter', ['Safe', 'All', 'Unsafe'])
|
125 |
+
print('filter', filter)
|
126 |
+
# initialize unsafe_modelVersion_ids
|
127 |
+
if filter == 'Safe':
|
128 |
+
# return checked items
|
129 |
+
items = items[items['checked'] == False].reset_index(drop=True)
|
130 |
+
|
131 |
+
elif filter == 'Unsafe':
|
132 |
+
# return unchecked items
|
133 |
+
items = items[items['checked'] == True].reset_index(drop=True)
|
134 |
+
print(items)
|
135 |
|
136 |
info = st.multiselect('Show Info',
|
137 |
['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
|
138 |
+
'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
|
139 |
+
'clip+rank+pop'],
|
140 |
default=sort_by)
|
141 |
|
|
|
142 |
# add one annotation
|
143 |
mentioned_scores = []
|
144 |
for i in info:
|
|
|
148 |
if SCORE_NAME_MAPPING[m] not in mentioned_scores:
|
149 |
mentioned_scores.append(SCORE_NAME_MAPPING[m])
|
150 |
if len(mentioned_scores) > 0:
|
151 |
+
st.info(
|
152 |
+
f"**Note:** The scores {mentioned_scores} are normalized to [0, 1] for each score type, and then added together. The higher the score, the better the model.")
|
153 |
+
|
154 |
+
col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
|
155 |
+
|
156 |
+
return items, info, col_num
|
157 |
+
|
158 |
+
|
159 |
+
def selection_panel_2(self, items):
|
160 |
+
selecters = st.columns([1, 5])
|
161 |
+
|
162 |
+
with selecters[0]:
|
163 |
+
sort_type = st.selectbox('Sort by', ['IDs and Names', 'Scores'])
|
164 |
+
if sort_type == 'Scores':
|
165 |
+
sort_by = 'weighted_score_sum'
|
166 |
+
|
167 |
+
with selecters[1]:
|
168 |
+
if sort_type == 'IDs and Names':
|
169 |
+
sub_selecters = st.columns([3, 1, 1])
|
170 |
+
with sub_selecters[0]:
|
171 |
+
sort_by = st.selectbox('Sort by',
|
172 |
+
['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id'],
|
173 |
+
label_visibility='hidden')
|
174 |
+
|
175 |
+
continue_idx = 1
|
176 |
+
|
177 |
+
else:
|
178 |
+
sub_selecters = st.columns([1, 1, 1, 1, 1])
|
179 |
+
|
180 |
+
with sub_selecters[0]:
|
181 |
+
clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1)
|
182 |
+
with sub_selecters[1]:
|
183 |
+
rank_weight = st.number_input('Rank Score Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1)
|
184 |
+
with sub_selecters[2]:
|
185 |
+
pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1)
|
186 |
+
|
187 |
+
items.loc[:, 'weighted_score_sum'] = items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
|
188 |
+
'norm_pop'] * pop_weight
|
189 |
+
|
190 |
+
continue_idx = 3
|
191 |
+
|
192 |
+
with sub_selecters[continue_idx]:
|
193 |
+
order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
|
194 |
+
if order == 'Ascending':
|
195 |
+
order = True
|
196 |
+
else:
|
197 |
+
order = False
|
198 |
+
|
199 |
+
items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
|
200 |
+
|
201 |
+
with sub_selecters[continue_idx+1]:
|
202 |
+
filter = st.selectbox('Filter', ['Safe', 'All', 'Unsafe'])
|
203 |
+
print('filter', filter)
|
204 |
+
# initialize unsafe_modelVersion_ids
|
205 |
+
if filter == 'Safe':
|
206 |
+
# return checked items
|
207 |
+
items = items[items['checked'] == False].reset_index(drop=True)
|
208 |
+
|
209 |
+
elif filter == 'Unsafe':
|
210 |
+
# return unchecked items
|
211 |
+
items = items[items['checked'] == True].reset_index(drop=True)
|
212 |
+
print(items)
|
213 |
+
|
214 |
+
info = st.multiselect('Show Info',
|
215 |
+
['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
|
216 |
+
'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
|
217 |
+
'clip+rank+pop', 'weighted_score_sum'],
|
218 |
+
default=sort_by)
|
219 |
|
220 |
+
# add one annotation
|
221 |
+
mentioned_scores = []
|
222 |
+
for i in info:
|
223 |
+
if '+' in i:
|
224 |
+
mentioned = i.split('+')
|
225 |
+
for m in mentioned:
|
226 |
+
if SCORE_NAME_MAPPING[m] not in mentioned_scores:
|
227 |
+
mentioned_scores.append(SCORE_NAME_MAPPING[m])
|
228 |
+
if len(mentioned_scores) > 0:
|
229 |
+
st.info(
|
230 |
+
f"**Note:** The scores {mentioned_scores} are normalized to [0, 1] for each score type, and then added together. The higher the score, the better the model.")
|
231 |
|
232 |
col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
|
233 |
|
234 |
+
return items, info, col_num
|
235 |
+
|
236 |
+
def app(self):
|
237 |
+
st.title('Model Coffer Gallery')
|
238 |
+
st.write('This is a gallery of images generated by the models in the Model Coffer')
|
239 |
+
|
240 |
+
with st.sidebar:
|
241 |
+
prompt_tags = self.promptBook['tag'].unique()
|
242 |
+
# sort tags by alphabetical order
|
243 |
+
prompt_tags = np.sort(prompt_tags)[::-1]
|
244 |
+
|
245 |
+
tag = st.selectbox('Select a tag', prompt_tags)
|
246 |
+
|
247 |
+
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
|
248 |
+
|
249 |
+
original_prompts = np.sort(items['prompt'].unique())[::-1]
|
250 |
+
|
251 |
+
# remove the first four items in the prompt, which are mostly the same
|
252 |
+
if tag != 'abstract':
|
253 |
+
prompts = [', '.join(x.split(', ')[4:]) for x in original_prompts]
|
254 |
+
prompt = st.selectbox('Select prompt', prompts)
|
255 |
+
|
256 |
+
idx = prompts.index(prompt)
|
257 |
+
prompt_full = ', '.join(original_prompts[idx].split(', ')[:4]) + ', ' + prompt
|
258 |
+
else:
|
259 |
+
prompt_full = st.selectbox('Select prompt', original_prompts)
|
260 |
+
|
261 |
+
prompt_id = items[items['prompt'] == prompt_full]['prompt_id'].unique()[0]
|
262 |
+
items = items[items['prompt_id'] == prompt_id].reset_index(drop=True)
|
263 |
+
|
264 |
+
# show image metadata
|
265 |
+
image_metadatas = ['prompt_id', 'prompt', 'negativePrompt', 'sampler', 'cfgScale', 'size', 'seed']
|
266 |
+
for key in image_metadatas:
|
267 |
+
label = ' '.join(key.split('_')).capitalize()
|
268 |
+
st.write(f"**{label}**")
|
269 |
+
if items[key][0] == ' ':
|
270 |
+
st.write('`None`')
|
271 |
+
else:
|
272 |
+
st.caption(f"{items[key][0]}")
|
273 |
+
|
274 |
+
# for tag as civitai, add civitai reference
|
275 |
+
if tag == 'civitai':
|
276 |
+
try:
|
277 |
+
st.write('**Civitai Reference**')
|
278 |
+
res = requests.get(f'https://civitai.com/images/{prompt_id.item()}')
|
279 |
+
# st.write(res.text)
|
280 |
+
soup = BeautifulSoup(res.text, 'html.parser')
|
281 |
+
image_section = soup.find('div', {'class': 'mantine-12rlksp'})
|
282 |
+
image_url = image_section.find('img')['src']
|
283 |
+
st.image(image_url, use_column_width=True)
|
284 |
+
except:
|
285 |
+
pass
|
286 |
+
|
287 |
+
|
288 |
+
# add safety check for some prompts
|
289 |
+
safety_check = True
|
290 |
+
unsafe_prompts = {}
|
291 |
+
# initialize unsafe prompts
|
292 |
+
for prompt_tag in prompt_tags:
|
293 |
+
unsafe_prompts[prompt_tag] = []
|
294 |
+
# manually add unsafe prompts
|
295 |
+
unsafe_prompts['civitai'] = [375790, 366222, 295008, 256477]
|
296 |
+
unsafe_prompts['people'] = [53]
|
297 |
+
unsafe_prompts['art'] = [23]
|
298 |
+
unsafe_prompts['abstract'] = [10, 12]
|
299 |
+
|
300 |
+
if int(prompt_id.item()) in unsafe_prompts[tag]:
|
301 |
+
st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
|
302 |
+
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.')
|
303 |
+
|
304 |
+
if safety_check:
|
305 |
+
items, info, col_num = self.selection_panel_2(items)
|
306 |
+
# self.gallery_standard(items, col_num, info)
|
307 |
+
|
308 |
+
with st.form(key=f'{prompt_id}', clear_on_submit=False):
|
309 |
+
buttons = st.columns([1, 1, 1])
|
310 |
+
with buttons[0]:
|
311 |
+
submit = st.form_submit_button('Save selections', on_click=self.save_checked, use_container_width=True, type='primary')
|
312 |
+
with buttons[1]:
|
313 |
+
submit = st.form_submit_button('Reset current prompt', on_click=self.reset_current_prompt, kwargs={'prompt_id': prompt_id} , use_container_width=True)
|
314 |
+
with buttons[2]:
|
315 |
+
submit = st.form_submit_button('Reset all selections', on_click=self.reset_all, use_container_width=True)
|
316 |
+
|
317 |
+
self.gallery_standard(items, col_num, info)
|
318 |
|
319 |
def reset_current_prompt(self, prompt_id):
|
320 |
# reset current prompt
|
|
|
331 |
dataset = load_dataset('NYUSHPRP/ModelCofferMetadata', split='train')
|
332 |
# get checked images
|
333 |
checked_info = self.promptBook['checked']
|
|
|
|
|
|
|
|
|
334 |
|
335 |
if 'checked' in dataset.column_names:
|
336 |
dataset = dataset.remove_columns('checked')
|
|
|
358 |
if 'checked' not in st.session_state.promptBook.columns:
|
359 |
st.session_state.promptBook.loc[:, 'checked'] = False
|
360 |
|
361 |
+
# add 'custom_score_weights' column to promptBook if not exist
|
362 |
+
if 'weighted_score_sum' not in st.session_state.promptBook.columns:
|
363 |
+
st.session_state.promptBook.loc[:, 'weighted_score_sum'] = 0
|
364 |
+
|
365 |
st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
|
366 |
# st.session_state.images = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train', streaming=True)
|
367 |
print(st.session_state.images)
|