import os import datasets import numpy as np import pandas as pd import pymysql.cursors import streamlit as st from streamlit_elements import elements, mui, html, dashboard, nivo from streamlit_extras.switch_page_button import switch_page from pages.Gallery import load_hf_dataset class RankingApp: def __init__(self, promptBook, images_endpoint, batch_size=4): self.promptBook = promptBook self.images_endpoint = images_endpoint self.batch_size = batch_size # self.batch_num = len(self.promptBook) // self.batch_size # self.batch_num += 1 if len(self.promptBook) % self.batch_size != 0 else 0 if 'counter' not in st.session_state: st.session_state.counter = {} def sidebar(self): with st.sidebar: prompt_tags = self.promptBook['tag'].unique() prompt_tags = np.sort(prompt_tags) tag = st.selectbox('Select a prompt tag', prompt_tags) items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True) prompts = np.sort(items['prompt'].unique())[::-1] selected_prompt = st.selectbox('Select a prompt', prompts) items = items[items['prompt'] == selected_prompt].reset_index(drop=True) prompt_id = items['prompt_id'].unique()[0] with st.form(key='prompt_form'): # input image metadata prompt = st.text_area('Prompt', selected_prompt, height=150, key='prompt', disabled=True) negative_prompt = st.text_area('Negative Prompt', items['negativePrompt'].unique()[0], height=150, key='negative_prompt', disabled=True) st.form_submit_button('Generate Images [Coming Soon]', type='primary', use_container_width=True, disabled=True) return prompt_tags, tag, prompt_id, items def draggable_images(self, items, prompt_id, layout='portrait'): # init ranking by the order of items if 'ranking' not in st.session_state: st.session_state.ranking = {} if prompt_id not in st.session_state.ranking: st.session_state.ranking[prompt_id] = {} if st.session_state.counter[prompt_id] not in st.session_state.ranking[prompt_id]: st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]] = {} for i in range(len(items)): st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][str(items['image_id'][i])] = i else: # set the index of items to the corresponding ranking value of the image_id items.index = items['image_id'].apply(lambda x: st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][str(x)]) with elements('dashboard'): if layout == 'portrait': col_num = 4 layout = [dashboard.Item(str(items['image_id'][i]), i % col_num, i//col_num, 1, 2, isResizable=False) for i in range(len(items))] elif layout == 'landscape': col_num = 2 layout = [ dashboard.Item(str(items['image_id'][i]), i % col_num * 2, i // col_num, 2, 1.6, isResizable=False) for i in range(len(items)) ] with dashboard.Grid(layout, cols={'lg': 4, 'md': 4, 'sm': 4, 'xs': 4, 'xxs': 2}, onLayoutChange=self.handle_layout_change, margin=[18, 18], containerPadding=[0, 0]): for i in range(len(layout)): with mui.Card(key=str(items['image_id'][i]), variant="outlined"): prompt_id = st.session_state.prompt_id_tmp batch_idx = st.session_state.counter[prompt_id] rank = st.session_state.ranking[prompt_id][batch_idx][str(items['image_id'][i])] + 1 mui.Chip(label=rank, # variant="outlined" if rank!=1 else "default", color="primary" if rank == 1 else "warning" if rank == 2 else "info", size="small", sx={"position": "absolute", "left": "-0.3rem", "top": "-0.3rem"}) img_url = self.images_endpoint + str(items['image_id'][i]) + '.png' mui.CardMedia( component="img", # image={"data:image/png;base64", img_str}, image=img_url, alt="There should be an image", sx={"height": "100%", "object-fit": "contain", 'bgcolor': 'black'}, ) def handle_layout_change(self, updated_layout): # print(updated_layout) sorted_list = sorted(updated_layout, key=lambda x: (x['y'], x['x'])) sorted_list = [str(item['i']) for item in sorted_list] prompt_id = st.session_state.prompt_id_tmp batch_idx = st.session_state.counter[prompt_id] for k in st.session_state.ranking[prompt_id][batch_idx].keys(): st.session_state.ranking[prompt_id][batch_idx][k] = sorted_list.index(k) def app(self): st.title('Personal Image Ranking') st.write('Here you can test out your selected images with any prompt you like.') # st.write(self.promptBook) # save the current progress to session state if 'progress' not in st.session_state: st.session_state.progress = {} print('current progress: ', st.session_state.progress) prompt_tags, tag, prompt_id, items = self.sidebar() batch_num = len(items) // self.batch_size batch_num += 1 if len(items) % self.batch_size != 0 else 0 st.session_state.counter[prompt_id] = 0 if prompt_id not in st.session_state.counter else st.session_state.counter[prompt_id] # save prompt_id in session state st.session_state.prompt_id_tmp = prompt_id if prompt_id not in st.session_state.progress: st.session_state.progress[prompt_id] = 'ranking' if st.session_state.progress[prompt_id] == 'ranking': sorting, control = st.columns((11, 1), gap='large') with sorting: # st.write('## Sorting') # st.write('Please drag the images to sort them.') st.progress((st.session_state.counter[prompt_id] + 1) / batch_num, text=f"Batch {st.session_state.counter[prompt_id] + 1} / {batch_num}") # st.write(items.iloc[self.batch_size*st.session_state.counter[prompt_id]: self.batch_size*(st.session_state.counter[prompt_id]+1)]) width, height = items.loc[0, 'size'].split('x') if int(height) >= int(width): self.draggable_images(items.iloc[self.batch_size*st.session_state.counter[prompt_id]: self.batch_size*(st.session_state.counter[prompt_id]+1)].reset_index(drop=True), prompt_id=prompt_id, layout='portrait') else: self.draggable_images(items.iloc[self.batch_size*st.session_state.counter[prompt_id]: self.batch_size*(st.session_state.counter[prompt_id]+1)].reset_index(drop=True), prompt_id=prompt_id, layout='landscape') # st.write(str(st.session_state.ranking)) with control: if st.session_state.counter[prompt_id] < batch_num - 1: st.button(":arrow_right:", key='next', on_click=self.next_batch, help='Next Batch', kwargs={'prompt_id': prompt_id}, use_container_width=True) else: st.button(":heavy_check_mark:", key='finished', on_click=self.next_batch, help='Finished', kwargs={'prompt_id': prompt_id, 'progress': 'finished'}, use_container_width=True) if st.session_state.counter[prompt_id] > 0: st.button(":arrow_left:", key='prev', on_click=self.prev_batch, help='Previous Batch', kwargs={'prompt_id': prompt_id}, use_container_width=True) elif st.session_state.progress[prompt_id] == 'finished': st.write('## You have ranked all models for this tag!') st.write('Thank you for your participation! Feel free to do the following things:') st.write('* Rank for other tags and prompts.') st.write('* Back to the gallery page to see more images.') st.write('* Rank again for this tag and prompt.') st.write('*More functions are coming soon... Please stay tuned*') gallery_btn = st.button('🖼️ Back to Gallery') if gallery_btn: switch_page('gallery') restart_btn = st.button('🎖️ Rank Again') if restart_btn: st.session_state.progress[prompt_id] = 'ranking' st.session_state.counter[prompt_id] = 0 st.experimental_rerun() def next_batch(self, prompt_id, progress=None): # save ranking to dataset # print(st.session_state.ranking) # ranking_dataset = datasets.load_dataset('MAPS-research/GEMRec-Ranking', split='train') curser = RANKING_CONN.cursor() for image_id in st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]].keys(): modelVersion_id = self.promptBook[self.promptBook['image_id'] == image_id]['modelVersion_id'].values[0] ranking = st.session_state.ranking[prompt_id][st.session_state.counter[prompt_id]][image_id] # print({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]}) # ranking_dataset = ranking_dataset.add_item({'image_id': image_id, 'modelVersion_id': modelVersion_id, 'ranking': ranking, "user_name": st.session_state.user_id[0], "timestamp": st.session_state.user_id[1]}) query = "INSERT INTO rankings (image_id, modelVersion_id, ranking, user_name, timestamp) VALUES (%s, %s, %s, %s, %s)" curser.execute(query, (image_id, modelVersion_id, ranking, st.session_state.user_id[0], st.session_state.user_id[1])) curser.close() RANKING_CONN.commit() # ranking_dataset.push_to_hub('MAPS-research/GEMRec-Ranking', split='train') if progress == 'finished': st.session_state.progress[prompt_id] = 'finished' else: st.session_state.counter[prompt_id] += 1 def prev_batch(self, prompt_id): st.session_state.counter[prompt_id] -= 1 def connect_to_db(): conn = pymysql.connect( host=os.environ.get('RANKING_DB_HOST'), port=3306, database='myRanking', user=os.environ.get('RANKING_DB_USER'), password=os.environ.get('RANKING_DB_PASSWORD'), charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor ) return conn if __name__ == "__main__": st.set_page_config(page_title="Personal Image Ranking", page_icon="🎖️️", layout="wide") if 'user_id' not in st.session_state: st.warning('Please log in first.') home_btn = st.button('Go to Home Page') if home_btn: switch_page("home") else: has_selection = False for key, value in st.session_state.selected_dict.items(): for v in value: if v: has_selection = True break if not has_selection: st.info('You have not checked any image yet. Please go back to the gallery page and check some images.') gallery_btn = st.button('Go to Gallery') if gallery_btn: switch_page('gallery') else: # st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.') roster, promptBook, images_ds = load_hf_dataset() print(st.session_state.selected_dict) # st.write("# Full function is coming soon.") RANKING_CONN = connect_to_db() # only select the part of the promptbook where tag is the same as st.session_state.selected_dict.keys(), while model version ids are the same as corresponding values to each key promptBook_selected = pd.DataFrame() for key, value in st.session_state.selected_dict.items(): # promptBook_selected = promptBook_selected.append(promptBook[(promptBook['prompt_id'] == key) & (promptBook['modelVersion_id'].isin(value))]) # replace append with pd.concat promptBook_selected = pd.concat([promptBook_selected, promptBook[(promptBook['prompt_id'] == key) & (promptBook['modelVersion_id'].isin(value))]]) promptBook_selected = promptBook_selected.reset_index(drop=True) # st.write(promptBook_selected) images_endpoint = "https://modelcofferbucket.s3-accelerate.amazonaws.com/" app = RankingApp(promptBook_selected, images_endpoint, batch_size=4) app.app()