from datasets import load_dataset import streamlit as st from data_utils import get_embedding from bokeh.plotting import figure,show from bokeh.io import push_notebook, output_notebook # output_notebook() from bokeh.palettes import d3 from bokeh.models import ColumnDataSource, Grid, LinearAxis, Plot, Scatter from bokeh.transform import factor_cmap, factor_mark import base64 from io import BytesIO label_columns=["gender","subCategory","masterCategory"] model_interest=['facebook/deit-tiny-patch16-224', # very small model 5M param model 'microsoft/beit-base-patch16-224', # big model "facebook/dino-vits8", "facebook/levit-128S"] def convert_base64(img): buffered = BytesIO() img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return "data:image/jpeg;base64,"+img_str @st.experimental_singleton def cache_embedding(model_name): dataset=load_dataset("ceyda/fashion-products-small", split="train") dataset=dataset.shuffle(seed=100) #pick a random seed viz_dat=dataset.train_test_split(0.1,shuffle=False) #일부를 visualization위해서 뽑시단 viz_dat=viz_dat["test"] embedding = get_embedding(model_name,viz_dat) embedding["image"]=embedding["image"].apply(convert_base64) labels = {label:viz_dat.unique(label) for label in label_columns} return embedding,labels @st.experimental_singleton def cache_graph(model_name,color_column): embedding,labels=cache_embedding(model_name) color_palette = (d3['Category20'][20]+d3['Category20b'][20]+d3['Category20c'][20])[:len(labels[color_column])] source = ColumnDataSource(data=embedding) # colors = factor_cmap('gender', palette=["purple","navy","green","blue","pink"], factors=embedding["gender"].unique()) TOOLS="hover,crosshair,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,reset,tap,save,box_select,lasso_select," TOOLTIPS = """
@image
""" p = figure(tools=TOOLS,tooltips=TOOLTIPS) p.scatter(x="x", y="y", source=source, # marker=factor_mark('gender', ['circle', 'circle_cross', 'circle_dot','circle_x','circle_y'], labels["gender"]), color=factor_cmap(color_column, color_palette, labels[color_column]) ) return p st.write("It takes some time for the graph to load...wait please") model_name=st.sidebar.selectbox("Model",model_interest) color_column=st.selectbox("Color by",label_columns) p=cache_graph(model_name,color_column) st.bokeh_chart(p, use_container_width=False)