import streamlit as st import numpy as np from plotly.subplots import make_subplots import plotly.graph_objects as go import graphviz from backend.maximally_activating_patches import load_layer_infos, load_activation, get_receptive_field_coordinates from frontend import on_click_graph from backend.utils import load_dataset_dict HIGHTLIGHT_COLOR = '#e7bcc5' st.set_page_config(layout='wide') st.title('Maximally activating image patches') st.write('> **What patterns maximally activate this channel in ConvNeXt model?**') st.write("""The maximally activating image patches method is a technique used in visualizing the interpretation of convolutional neural networks. It works by identifying the regions of the input image that activate a particular neuron in the convolutional layer, thus revealing the features that the neuron is detecting. To achieve this, the method generates image patches then feeds into the model while monitoring the neuron's activation. The algorithm then selects the patch that produces the highest activation and overlays it on the original image to visualize the features that the neuron is responding to. """) # -------------------------- LOAD DATASET --------------------------------- dataset_dict = load_dataset_dict() # -------------------------- LOAD GRAPH ----------------------------------- def load_dot_to_graph(filename): dot = graphviz.Source.from_file(filename) source_lines = str(dot).splitlines() source_lines.pop(0) source_lines.pop(-1) graph = graphviz.Digraph() graph.body += source_lines return graph, dot # st.header('ConvNeXt') convnext_dot_file = './data/dot_architectures/convnext_architecture.dot' convnext_graph = load_dot_to_graph(convnext_dot_file)[0] convnext_graph.graph_attr['size'] = '4,40' # -------------------------- DISPLAY GRAPH ----------------------------------- def chosen_node_text(clicked_node_title): clicked_node_title = clicked_node_title.replace('stage ', 'stage_').replace('block ', 'block_') stage_id = clicked_node_title.split()[0].split('_')[1] if 'stage' in clicked_node_title else None block_id = clicked_node_title.split()[1].split('_')[1] if 'block' in clicked_node_title else None layer_id = clicked_node_title.split()[-1] if 'embeddings' in layer_id: display_text = 'Patchify layer' activation_key = 'embeddings.patch_embeddings' elif 'downsampling' in layer_id: display_text = f'Stage {stage_id} > Downsampling layer' activation_key = f'encoder.stages[{stage_id}].downsampling_layer[1]' else: display_text = f'Stage {stage_id} > Block {block_id} > {layer_id} layer' activation_key = f'encoder.stages[{int(stage_id)-1}].layers[{int(block_id)-1}].{layer_id}' return display_text, activation_key props = { 'hightlight_color': HIGHTLIGHT_COLOR, 'initial_state': { 'group_1_header': 'Choose an option from group 1', 'group_2_header': 'Choose an option from group 2' } } col1, col2 = st.columns((2,5)) col1.markdown("#### Architecture") col1.write('') col1.write('Click on a layer below to generate top-k maximally activating image patches') col1.graphviz_chart(convnext_graph) with col2: st.markdown("#### Output") nodes = on_click_graph(key='toggle_buttons', **props) # -------------------------- DISPLAY OUTPUT ----------------------------------- if nodes != None: clicked_node_title = nodes["choice"]["node_title"] clicked_node_id = nodes["choice"]["node_id"] display_text, activation_key = chosen_node_text(clicked_node_title) col2.write(f'**Chosen layer:** {display_text}') # col2.write(f'**Activation key:** {activation_key}') hightlight_syle = f''' ''' col2.markdown(hightlight_syle, unsafe_allow_html=True) with col2: layer_infos = None with st.form('top_k_form'): activation_path = './data/activation/convnext_activation.json' activation = load_activation(activation_path) num_channels = activation[activation_key].shape[1] top_k = st.slider('Choose K for top-K maximally activating patches', 1,20, value=10) channel_start, channel_end = st.slider( 'Choose channel range of this layer (recommend to choose small range less than 30)', 1, num_channels, value=(1, 30)) summit_button = st.form_submit_button('Generate image patches') if summit_button: activation = activation[activation_key][:top_k,:,:] layer_infos = load_layer_infos('./data/layer_infos/convnext_layer_infos.json') # st.write(channel_start, channel_end) # st.write(activation.shape, activation.shape[1]) if layer_infos != None: num_cols, num_rows = top_k, channel_end - channel_start + 1 # num_rows = activation.shape[1] top_k_coor_max_ = activation st.markdown(f"#### Top-{top_k} maximally activating image patches of {num_rows} channels ({channel_start}-{channel_end})") for row in range(channel_start, channel_end+1): if row == channel_start: top_margin = 50 fig = make_subplots( rows=1, cols=num_cols, subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True) else: top_margin = 0 fig = make_subplots(rows=1, cols=num_cols, shared_yaxes=True) for col in range(1, num_cols+1): k, c = col-1, row-1 img_index = int(top_k_coor_max_[k, c, 3]) activation_value = top_k_coor_max_[k, c, 0] img = dataset_dict[img_index//10_000][img_index%10_000]['image'] class_label = dataset_dict[img_index//10_000][img_index%10_000]['label'] class_id = dataset_dict[img_index//10_000][img_index%10_000]['id'] idx_x, idx_y = top_k_coor_max_[k, c, 1], top_k_coor_max_[k, c, 2] x1, x2, y1, y2 = get_receptive_field_coordinates(layer_infos, activation_key, idx_x, idx_y) img = np.array(img)[y1:y2, x1:x2, :] hovertemplate = f"""Top-{col}
Activation value: {activation_value:.5f}
Class Label: {class_label}
Class id: {class_id}
Image id: {img_index}""" fig.add_trace(go.Image(z=img, hovertemplate=hovertemplate), row=1, col=col) fig.update_xaxes(showticklabels=False, showgrid=False) fig.update_yaxes(showticklabels=False, showgrid=False) fig.update_layout(margin={'b':0, 't':top_margin, 'r':0, 'l':0}) fig.update_layout(showlegend=False, yaxis_title=row) fig.update_layout(height=100, plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)') fig.update_layout(hoverlabel=dict(bgcolor="#e9f2f7")) st.plotly_chart(fig, use_container_width=True) else: col2.markdown(f'Chosen layer: None', unsafe_allow_html=True) col2.markdown("""