import streamlit as st import numpy as np from plotly.subplots import make_subplots import plotly.graph_objects as go import graphviz from backend.maximally_activating_patches import load_layer_infos, load_activation, get_receptive_field_coordinates from frontend import on_click_graph from backend.utils import load_dataset_dict HIGHTLIGHT_COLOR = '#e7bcc5' st.set_page_config(layout='wide') # -------------------------- LOAD DATASET --------------------------------- dataset_dict = load_dataset_dict() # -------------------------- LOAD GRAPH ----------------------------------- def load_dot_to_graph(filename): dot = graphviz.Source.from_file(filename) source_lines = str(dot).splitlines() source_lines.pop(0) source_lines.pop(-1) graph = graphviz.Digraph() graph.body += source_lines return graph, dot st.title('Maximally activating image patches') st.write('Visualize image patches that maximize the activation of layers in ConvNeXt model') # st.header('ConvNeXt') convnext_dot_file = './data/dot_architectures/convnext_architecture.dot' convnext_graph = load_dot_to_graph(convnext_dot_file)[0] convnext_graph.graph_attr['size'] = '4,40' # -------------------------- DISPLAY GRAPH ----------------------------------- def chosen_node_text(clicked_node_title): clicked_node_title = clicked_node_title.replace('stage ', 'stage_').replace('block ', 'block_') stage_id = clicked_node_title.split()[0].split('_')[1] if 'stage' in clicked_node_title else None block_id = clicked_node_title.split()[1].split('_')[1] if 'block' in clicked_node_title else None layer_id = clicked_node_title.split()[-1] if 'embeddings' in layer_id: display_text = 'Patchify layer' activation_key = 'embeddings.patch_embeddings' elif 'downsampling' in layer_id: display_text = f'Stage {stage_id} > Downsampling layer' activation_key = f'encoder.stages[{stage_id}].downsampling_layer[1]' else: display_text = f'Stage {stage_id} > Block {block_id} > {layer_id} layer' activation_key = f'encoder.stages[{int(stage_id)-1}].layers[{int(block_id)-1}].{layer_id}' return display_text, activation_key props = { 'hightlight_color': HIGHTLIGHT_COLOR, 'initial_state': { 'group_1_header': 'Choose an option from group 1', 'group_2_header': 'Choose an option from group 2' } } col1, col2 = st.columns((2,5)) col1.markdown("#### Architecture") col1.write('') col1.write('Click on a layer below to generate top-k maximally activating image patches') col1.graphviz_chart(convnext_graph) with col2: st.markdown("#### Output") nodes = on_click_graph(key='toggle_buttons', **props) # -------------------------- DISPLAY OUTPUT ----------------------------------- if nodes != None: clicked_node_title = nodes["choice"]["node_title"] clicked_node_id = nodes["choice"]["node_id"] display_text, activation_key = chosen_node_text(clicked_node_title) col2.write(f'**Chosen layer:** {display_text}') # col2.write(f'**Activation key:** {activation_key}') hightlight_syle = f''' ''' col2.markdown(hightlight_syle, unsafe_allow_html=True) with col2: layer_infos = None with st.form('top_k_form'): activation_path = './data/activation/convnext_activation.json' activation = load_activation(activation_path) num_channels = activation[activation_key].shape[1] top_k = st.slider('Choose K for top-K maximally activating patches', 1,20, value=10) channel_start, channel_end = st.slider( 'Choose channel range of this layer (recommend to choose small range less than 30)', 1, num_channels, value=(1, 30)) summit_button = st.form_submit_button('Generate image patches') if summit_button: activation = activation[activation_key][:top_k,:,:] layer_infos = load_layer_infos('./data/layer_infos/convnext_layer_infos.json') # st.write(channel_start, channel_end) # st.write(activation.shape, activation.shape[1]) if layer_infos != None: num_cols, num_rows = top_k, channel_end - channel_start + 1 # num_rows = activation.shape[1] top_k_coor_max_ = activation st.markdown(f"#### Top-{top_k} maximally activating image patches of {num_rows} channels ({channel_start}-{channel_end})") for row in range(channel_start, channel_end+1): if row == channel_start: top_margin = 50 fig = make_subplots( rows=1, cols=num_cols, subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True) else: top_margin = 0 fig = make_subplots(rows=1, cols=num_cols, shared_yaxes=True) for col in range(1, num_cols+1): k, c = col-1, row-1 img_index = int(top_k_coor_max_[k, c, 3]) activation_value = top_k_coor_max_[k, c, 0] img = dataset_dict[img_index//10_000][img_index%10_000]['image'] class_label = dataset_dict[img_index//10_000][img_index%10_000]['label'] class_id = dataset_dict[img_index//10_000][img_index%10_000]['id'] idx_x, idx_y = top_k_coor_max_[k, c, 1], top_k_coor_max_[k, c, 2] x1, x2, y1, y2 = get_receptive_field_coordinates(layer_infos, activation_key, idx_x, idx_y) img = np.array(img)[y1:y2, x1:x2, :] hovertemplate = f"""Top-{col}
Activation value: {activation_value:.5f}
Class Label: {class_label}
Class id: {class_id}
Image id: {img_index}""" fig.add_trace(go.Image(z=img, hovertemplate=hovertemplate), row=1, col=col) fig.update_xaxes(showticklabels=False, showgrid=False) fig.update_yaxes(showticklabels=False, showgrid=False) fig.update_layout(margin={'b':0, 't':top_margin, 'r':0, 'l':0}) fig.update_layout(showlegend=False, yaxis_title=row) fig.update_layout(height=100, plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)') fig.update_layout(hoverlabel=dict(bgcolor="#e9f2f7")) st.plotly_chart(fig, use_container_width=True) else: col2.markdown(f'Chosen layer: None', unsafe_allow_html=True) col2.markdown("""