Ricercar commited on
Commit
b21aab5
·
1 Parent(s): 6bdebc7

update theshold

Browse files
app.py CHANGED
@@ -26,7 +26,7 @@ def altair_histogram(hist_data, sort_by, mini, maxi):
26
  chart = (
27
  alt.Chart(hist_data)
28
  .mark_bar(opacity=0.7, cornerRadius=2)
29
- .encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=20)), y="count()")
30
  # .add_selection(brushed)
31
  # .properties(width=800, height=300)
32
  )
@@ -84,28 +84,31 @@ class GalleryApp:
84
 
85
  def gallery_standard(self, items, col_num, info):
86
  rows = len(items) // col_num + 1
87
- containers = [st.container() for _ in range(rows*2)]
 
88
  for idx in range(0, len(items), col_num):
89
  # assign one container for each row
90
- row_idx = (idx // col_num) * 2
 
91
  with containers[row_idx]:
92
  cols = st.columns(col_num)
93
  for j in range(col_num):
94
  if idx + j < len(items):
95
  with cols[j]:
96
  # show image
97
- image = self.images_ds[items.iloc[idx+j]['row_idx'].item()]['image']
98
 
99
- st.image(image,
100
- use_column_width=True,
101
- )
102
 
103
- # show checkbox
104
- self.promptBook.loc[items.iloc[idx+j]['row_idx'].item(), 'checked'] = st.checkbox('Select', value=self.promptBook.loc[items.iloc[idx+j]['row_idx'].item(), 'checked'], key=f'select_{idx+j}')
 
 
105
 
 
106
  # show selected info
107
  for key in info:
108
- st.write(f"**{key}**: {items.iloc[idx+j][key]}")
109
 
110
  # st.write(row_idx/2, idx+j, rows)
111
  # extra_info = st.checkbox('Extra Info', key=f'extra_info_{idx+j}')
@@ -192,16 +195,19 @@ class GalleryApp:
192
  return items, info, col_num
193
 
194
  def selection_panel_2(self, items):
195
- selecters = st.columns([1, 5])
196
 
 
197
  with selecters[0]:
198
- sort_type = st.selectbox('Sort by', ['IDs and Names', 'Scores'])
199
  if sort_type == 'Scores':
200
  sort_by = 'weighted_score_sum'
201
 
 
202
  with selecters[1]:
203
  if sort_type == 'IDs and Names':
204
- sub_selecters = st.columns([3, 1, 1])
 
205
  with sub_selecters[0]:
206
  sort_by = st.selectbox('Sort by',
207
  ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id'],
@@ -210,81 +216,89 @@ class GalleryApp:
210
  continue_idx = 1
211
 
212
  else:
213
- sub_selecters = st.columns([1, 1, 1, 1, 1])
 
 
 
 
214
 
215
  with sub_selecters[0]:
216
- clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1)
217
  with sub_selecters[1]:
218
- rank_weight = st.number_input('Rank Score Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1)
219
  with sub_selecters[2]:
220
- pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1)
 
 
221
 
222
  items.loc[:, 'weighted_score_sum'] = round(items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
223
  'norm_pop'] * pop_weight, 4)
224
 
225
  continue_idx = 3
226
 
227
-
228
  with sub_selecters[continue_idx]:
229
- order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
230
- if order == 'Ascending':
231
- order = True
232
- else:
233
- order = False
234
-
235
- items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
236
-
237
- with sub_selecters[continue_idx+1]:
238
- filter = st.selectbox('Filter', ['Safe', 'All', 'Unsafe'])
239
- print('filter', filter)
240
- # initialize unsafe_modelVersion_ids
241
- if filter == 'Safe':
242
- # return checked items
243
- items = items[items['checked'] == False].reset_index(drop=True)
244
-
245
- elif filter == 'Unsafe':
246
- # return unchecked items
247
- items = items[items['checked'] == True].reset_index(drop=True)
248
- print(items)
249
 
250
  # draw a distribution histogram
251
  if sort_type == 'Scores':
252
- with st.expander('Show score distribution histogram and select score range'):
253
- st.write('**Score distribution histogram**')
254
- chart_space = st.container()
255
- # st.write('Select the range of scores to show')
256
- hist_data = pd.DataFrame(items[sort_by])
257
- mini = hist_data[sort_by].min().item()
258
- maxi = hist_data[sort_by].max().item()
259
- st.write('**Select the range of scores to show**')
260
- r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), label_visibility='collapsed')
261
- with chart_space:
262
- st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
263
- # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
264
- # r = event_dict.get(sort_by)
265
- if r:
266
- items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
267
- # st.write(r)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
 
269
 
270
- info = st.multiselect('Show Info',
271
- ['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
272
- 'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
273
- 'clip+rank+pop', 'weighted_score_sum'],
274
- default=sort_by)
 
275
 
276
- # add one annotation
277
- mentioned_scores = []
278
- for i in info:
279
- if '+' in i:
280
- mentioned = i.split('+')
281
- for m in mentioned:
282
- if SCORE_NAME_MAPPING[m] not in mentioned_scores:
283
- mentioned_scores.append(SCORE_NAME_MAPPING[m])
284
- if len(mentioned_scores) > 0:
285
- st.info(
286
- f"**Note:** The scores {mentioned_scores} are normalized to [0, 1] for each score type, and then added together. The higher the score, the better the model.")
287
 
 
288
  col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
289
 
290
  return items, info, col_num
@@ -351,6 +365,7 @@ class GalleryApp:
351
  unsafe_prompts['people'] = [53]
352
  unsafe_prompts['art'] = [23]
353
  unsafe_prompts['abstract'] = [10, 12]
 
354
 
355
  if int(prompt_id.item()) in unsafe_prompts[tag]:
356
  st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
@@ -358,19 +373,18 @@ class GalleryApp:
358
 
359
  if safety_check:
360
  items, info, col_num = self.selection_panel_2(items)
361
-
362
- # self.gallery_standard(items, col_num, info)
363
-
364
- with st.form(key=f'{prompt_id}', clear_on_submit=False):
365
- buttons = st.columns([1, 1, 1])
366
- with buttons[0]:
367
- submit = st.form_submit_button('Save selections', on_click=self.save_checked, use_container_width=True, type='primary')
368
- with buttons[1]:
369
- submit = st.form_submit_button('Reset current prompt', on_click=self.reset_current_prompt, kwargs={'prompt_id': prompt_id} , use_container_width=True)
370
- with buttons[2]:
371
- submit = st.form_submit_button('Reset all selections', on_click=self.reset_all, use_container_width=True)
372
-
373
- self.gallery_standard(items, col_num, info)
374
 
375
  def reset_current_prompt(self, prompt_id):
376
  # reset current prompt
@@ -393,11 +407,15 @@ class GalleryApp:
393
  dataset = dataset.add_column('checked', checked_info)
394
 
395
  # print('metadata dataset: ', dataset)
 
396
  dataset.push_to_hub('NYUSHPRP/ModelCofferMetadata', split='train')
397
 
398
 
399
  @st.cache_data
400
  def load_hf_dataset():
 
 
 
401
  # load from huggingface
402
  roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
403
  promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
@@ -426,7 +444,6 @@ def load_hf_dataset():
426
 
427
 
428
  if __name__ == '__main__':
429
- login(token=os.environ.get("HF_TOKEN"))
430
  st.set_page_config(layout="wide")
431
 
432
  roster, promptBook, images_ds = load_hf_dataset()
 
26
  chart = (
27
  alt.Chart(hist_data)
28
  .mark_bar(opacity=0.7, cornerRadius=2)
29
+ .encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=25)), y="count()")
30
  # .add_selection(brushed)
31
  # .properties(width=800, height=300)
32
  )
 
84
 
85
  def gallery_standard(self, items, col_num, info):
86
  rows = len(items) // col_num + 1
87
+ # containers = [st.container() for _ in range(rows * 2)]
88
+ containers = [st.container() for _ in range(rows)]
89
  for idx in range(0, len(items), col_num):
90
  # assign one container for each row
91
+ # row_idx = (idx // col_num) * 2
92
+ row_idx = idx // col_num
93
  with containers[row_idx]:
94
  cols = st.columns(col_num)
95
  for j in range(col_num):
96
  if idx + j < len(items):
97
  with cols[j]:
98
  # show image
99
+ image = self.images_ds[items.iloc[idx + j]['row_idx'].item()]['image']
100
 
101
+ st.image(image, use_column_width=True)
 
 
102
 
103
+ # # show checkbox
104
+ # self.promptBook.loc[items.iloc[idx + j]['row_idx'].item(), 'checked'] = st.checkbox(
105
+ # 'Select', value=self.promptBook.loc[items.iloc[idx + j]['row_idx'].item(), 'checked'],
106
+ # key=f'select_{idx + j}')
107
 
108
+ st.write(idx+j)
109
  # show selected info
110
  for key in info:
111
+ st.write(f"**{key}**: {items.iloc[idx + j][key]}")
112
 
113
  # st.write(row_idx/2, idx+j, rows)
114
  # extra_info = st.checkbox('Extra Info', key=f'extra_info_{idx+j}')
 
195
  return items, info, col_num
196
 
197
  def selection_panel_2(self, items):
198
+ selecters = st.columns([1, 4])
199
 
200
+ # select sort type
201
  with selecters[0]:
202
+ sort_type = st.selectbox('Sort by', ['Scores', 'IDs and Names'])
203
  if sort_type == 'Scores':
204
  sort_by = 'weighted_score_sum'
205
 
206
+ # select other options
207
  with selecters[1]:
208
  if sort_type == 'IDs and Names':
209
+ sub_selecters = st.columns([3, 1])
210
+ # select sort by
211
  with sub_selecters[0]:
212
  sort_by = st.selectbox('Sort by',
213
  ['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id'],
 
216
  continue_idx = 1
217
 
218
  else:
219
+ # add custom weights
220
+ sub_selecters = st.columns([1, 1, 1, 1])
221
+
222
+ if 'default_weights' not in st.session_state:
223
+ st.session_state.default_weights = [1.0, 1.0, 1.0]
224
 
225
  with sub_selecters[0]:
226
+ clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.default_weights[0], step=0.1, help='the weight for normalized clip score')
227
  with sub_selecters[1]:
228
+ rank_weight = st.number_input('Distinctiveness Weight', min_value=-100.0, max_value=100.0, value=st.session_state.default_weights[1], step=0.1, help='the weight for average rank')
229
  with sub_selecters[2]:
230
+ pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.default_weights[2], step=0.1, help='the weight for normalized popularity score')
231
+
232
+ st.session_state.default_weights = [clip_weight, rank_weight, pop_weight]
233
 
234
  items.loc[:, 'weighted_score_sum'] = round(items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
235
  'norm_pop'] * pop_weight, 4)
236
 
237
  continue_idx = 3
238
 
239
+ # select threshold
240
  with sub_selecters[continue_idx]:
241
+ dist_threshold = st.number_input('Distinctiveness Threshold', min_value=0.0, max_value=1.0, value=0.84, step=0.01, help='Only show models with distinctiveness score lower than this threshold, set 1.0 to show all images')
242
+ items = items[items['avg_rank'] < dist_threshold].reset_index(drop=True)
243
+
244
+ # filter = st.selectbox('Filter', ['Safe', 'All', 'Unsafe'])
245
+ # print('filter', filter)
246
+ # # initialize unsafe_modelVersion_ids
247
+ # if filter == 'Safe':
248
+ # # return unchecked items
249
+ # items = items[items['checked'] == False].reset_index(drop=True)
250
+ #
251
+ # elif filter == 'Unsafe':
252
+ # # return checked items
253
+ # items = items[items['checked'] == True].reset_index(drop=True)
 
 
 
 
 
 
 
254
 
255
  # draw a distribution histogram
256
  if sort_type == 'Scores':
257
+ try:
258
+ with st.expander('Show score distribution histogram and select score range'):
259
+ st.write('**Score distribution histogram**')
260
+ chart_space = st.container()
261
+ # st.write('Select the range of scores to show')
262
+ hist_data = pd.DataFrame(items[sort_by])
263
+ mini = hist_data[sort_by].min().item()
264
+ mini = mini//0.1 * 0.1
265
+ maxi = hist_data[sort_by].max().item()
266
+ maxi = maxi//0.1 * 0.1 + 0.1
267
+ st.write('**Select the range of scores to show**')
268
+ r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
269
+ with chart_space:
270
+ st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
271
+ # event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
272
+ # r = event_dict.get(sort_by)
273
+ if r:
274
+ items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
275
+ # st.write(r)
276
+ except:
277
+ pass
278
+
279
+ display_options = st.columns([1, 4])
280
+
281
+ with display_options[0]:
282
+ # select order
283
+ order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
284
+ if order == 'Ascending':
285
+ order = True
286
+ else:
287
+ order = False
288
 
289
+ with display_options[1]:
290
 
291
+ # select info to show
292
+ info = st.multiselect('Show Info',
293
+ ['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
294
+ 'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
295
+ 'clip+rank+pop', 'weighted_score_sum'],
296
+ default=sort_by)
297
 
298
+ # apply sorting to dataframe
299
+ items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
 
 
 
 
 
 
 
 
 
300
 
301
+ # select number of columns
302
  col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
303
 
304
  return items, info, col_num
 
365
  unsafe_prompts['people'] = [53]
366
  unsafe_prompts['art'] = [23]
367
  unsafe_prompts['abstract'] = [10, 12]
368
+ unsafe_prompts['food'] = [34]
369
 
370
  if int(prompt_id.item()) in unsafe_prompts[tag]:
371
  st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
 
373
 
374
  if safety_check:
375
  items, info, col_num = self.selection_panel_2(items)
376
+ self.gallery_standard(items, col_num, info)
377
+
378
+ # with st.form(key=f'{prompt_id}', clear_on_submit=True):
379
+ # buttons = st.columns([1, 1, 1])
380
+ # with buttons[0]:
381
+ # submit = st.form_submit_button('Save selections', on_click=self.save_checked, use_container_width=True, type='primary')
382
+ # with buttons[1]:
383
+ # submit = st.form_submit_button('Reset current prompt', on_click=self.reset_current_prompt, kwargs={'prompt_id': prompt_id} , use_container_width=True)
384
+ # with buttons[2]:
385
+ # submit = st.form_submit_button('Reset all selections', on_click=self.reset_all, use_container_width=True)
386
+ #
387
+ # self.gallery_standard(items, col_num, info)
 
388
 
389
  def reset_current_prompt(self, prompt_id):
390
  # reset current prompt
 
407
  dataset = dataset.add_column('checked', checked_info)
408
 
409
  # print('metadata dataset: ', dataset)
410
+ st.cache_data.clear()
411
  dataset.push_to_hub('NYUSHPRP/ModelCofferMetadata', split='train')
412
 
413
 
414
  @st.cache_data
415
  def load_hf_dataset():
416
+ # login to huggingface
417
+ login(token=os.environ.get("HF_TOKEN"))
418
+
419
  # load from huggingface
420
  roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
421
  promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
 
444
 
445
 
446
  if __name__ == '__main__':
 
447
  st.set_page_config(layout="wide")
448
 
449
  roster, promptBook, images_ds = load_hf_dataset()
data/download_script.py CHANGED
@@ -5,9 +5,9 @@ def main():
5
  promptbook = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train')
6
  print(promptbook)
7
  promptbook.save_to_disk('./promptbook')
8
-
9
- roster = load_dataset('NYUSHPRP/ModelCofferRoster', split='train')
10
- roster.save_to_disk('./roster')
11
 
12
 
13
  def load():
 
5
  promptbook = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train')
6
  print(promptbook)
7
  promptbook.save_to_disk('./promptbook')
8
+ #
9
+ # roster = load_dataset('NYUSHPRP/ModelCofferRoster', split='train')
10
+ # roster.save_to_disk('./roster')
11
 
12
 
13
  def load():
data/roster/data-00000-of-00001.arrow DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0d92d1f86f02823ca64b7d88ffbb4c03a1ca8fe9990a54e37b9a1d9171782fca
3
- size 147952
 
 
 
 
data/roster/dataset_info.json DELETED
@@ -1,57 +0,0 @@
1
- {
2
- "citation": "",
3
- "dataset_size": 145934,
4
- "description": "",
5
- "download_checksums": {
6
- "https://huggingface.co/datasets/NYUSHPRP/ModelCofferRoster/resolve/ca9efb0b73c3383dfb5bc9fff380b068d468bfde/data/train-00000-of-00001-0fd3ef44b360ac99.parquet": {
7
- "num_bytes": 27979,
8
- "checksum": null
9
- }
10
- },
11
- "download_size": 27979,
12
- "features": {
13
- "tag": {
14
- "dtype": "string",
15
- "_type": "Value"
16
- },
17
- "model_name": {
18
- "dtype": "string",
19
- "_type": "Value"
20
- },
21
- "model_id": {
22
- "dtype": "int64",
23
- "_type": "Value"
24
- },
25
- "modelVersion_name": {
26
- "dtype": "string",
27
- "_type": "Value"
28
- },
29
- "modelVersion_id": {
30
- "dtype": "int64",
31
- "_type": "Value"
32
- },
33
- "modelVersion_url": {
34
- "dtype": "string",
35
- "_type": "Value"
36
- },
37
- "modelVersion_trainedWords": {
38
- "dtype": "string",
39
- "_type": "Value"
40
- },
41
- "model_download_count": {
42
- "dtype": "int64",
43
- "_type": "Value"
44
- }
45
- },
46
- "homepage": "",
47
- "license": "",
48
- "size_in_bytes": 173913,
49
- "splits": {
50
- "train": {
51
- "name": "train",
52
- "num_bytes": 145934,
53
- "num_examples": 1059,
54
- "dataset_name": "parquet"
55
- }
56
- }
57
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/roster/state.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "_data_files": [
3
- {
4
- "filename": "data-00000-of-00001.arrow"
5
- }
6
- ],
7
- "_fingerprint": "9508df8b007debc4",
8
- "_format_columns": null,
9
- "_format_kwargs": {},
10
- "_format_type": null,
11
- "_output_all_columns": false,
12
- "_split": "train"
13
- }