Spaces:
Build error
Build error
File size: 6,992 Bytes
18f2f54 0c1e42b 18f2f54 8a287fa 18f2f54 8a287fa 0c1e42b 8a287fa 18f2f54 8a287fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import streamlit as st
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import graphviz
from backend.maximally_activating_patches import load_layer_infos, load_activation, get_receptive_field_coordinates
from frontend import on_click_graph
from backend.utils import load_dataset_dict
HIGHTLIGHT_COLOR = '#e7bcc5'
st.set_page_config(layout='wide')
# -------------------------- LOAD DATASET ---------------------------------
dataset_dict = load_dataset_dict()
# -------------------------- LOAD GRAPH -----------------------------------
def load_dot_to_graph(filename):
dot = graphviz.Source.from_file(filename)
source_lines = str(dot).splitlines()
source_lines.pop(0)
source_lines.pop(-1)
graph = graphviz.Digraph()
graph.body += source_lines
return graph, dot
st.title('Maximally activating image patches')
st.write('Visualize image patches that maximize the activation of layers in ConvNeXt model')
# st.header('ConvNeXt')
convnext_dot_file = './data/dot_architectures/convnext_architecture.dot'
convnext_graph = load_dot_to_graph(convnext_dot_file)[0]
convnext_graph.graph_attr['size'] = '4,40'
# -------------------------- DISPLAY GRAPH -----------------------------------
def chosen_node_text(clicked_node_title):
clicked_node_title = clicked_node_title.replace('stage ', 'stage_').replace('block ', 'block_')
stage_id = clicked_node_title.split()[0].split('_')[1] if 'stage' in clicked_node_title else None
block_id = clicked_node_title.split()[1].split('_')[1] if 'block' in clicked_node_title else None
layer_id = clicked_node_title.split()[-1]
if 'embeddings' in layer_id:
display_text = 'Patchify layer'
activation_key = 'embeddings.patch_embeddings'
elif 'downsampling' in layer_id:
display_text = f'Stage {stage_id} > Downsampling layer'
activation_key = f'encoder.stages[{stage_id}].downsampling_layer[1]'
else:
display_text = f'Stage {stage_id} > Block {block_id} > {layer_id} layer'
activation_key = f'encoder.stages[{int(stage_id)-1}].layers[{int(block_id)-1}].{layer_id}'
return display_text, activation_key
props = {
'hightlight_color': HIGHTLIGHT_COLOR,
'initial_state': {
'group_1_header': 'Choose an option from group 1',
'group_2_header': 'Choose an option from group 2'
}
}
col1, col2 = st.columns((2,5))
col1.markdown("#### Architecture")
col1.write('')
col1.write('Click on a layer below to generate top-k maximally activating image patches')
col1.graphviz_chart(convnext_graph)
with col2:
st.markdown("#### Output")
nodes = on_click_graph(key='toggle_buttons', **props)
# -------------------------- DISPLAY OUTPUT -----------------------------------
if nodes != None:
clicked_node_title = nodes["choice"]["node_title"]
clicked_node_id = nodes["choice"]["node_id"]
display_text, activation_key = chosen_node_text(clicked_node_title)
col2.write(f'**Chosen layer:** {display_text}')
# col2.write(f'**Activation key:** {activation_key}')
hightlight_syle = f'''
<style>
div[data-stale]:has(iframe) {{
height: 0;
}}
#{clicked_node_id}>polygon {{
fill: {HIGHTLIGHT_COLOR};
stroke: {HIGHTLIGHT_COLOR};
}}
</style>
'''
col2.markdown(hightlight_syle, unsafe_allow_html=True)
with col2:
layer_infos = None
with st.form('top_k_form'):
activation_path = './data/activation/convnext_activation.json'
activation = load_activation(activation_path)
num_channels = activation[activation_key].shape[1]
top_k = st.slider('Choose K for top-K maximally activating patches', 1,20, value=10)
channel_start, channel_end = st.slider(
'Choose channel range of this layer (recommend to choose small range less than 30)',
1, num_channels, value=(1, 30))
summit_button = st.form_submit_button('Generate image patches')
if summit_button:
activation = activation[activation_key][:top_k,:,:]
layer_infos = load_layer_infos('./data/layer_infos/convnext_layer_infos.json')
# st.write(channel_start, channel_end)
# st.write(activation.shape, activation.shape[1])
if layer_infos != None:
num_cols, num_rows = top_k, channel_end - channel_start + 1
# num_rows = activation.shape[1]
top_k_coor_max_ = activation
st.markdown(f"#### Top-{top_k} maximally activating image patches of {num_rows} channels ({channel_start}-{channel_end})")
for row in range(channel_start, channel_end+1):
if row == channel_start:
top_margin = 50
fig = make_subplots(
rows=1, cols=num_cols,
subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
else:
top_margin = 0
fig = make_subplots(rows=1, cols=num_cols, shared_yaxes=True)
for col in range(1, num_cols+1):
k, c = col-1, row-1
img_index = int(top_k_coor_max_[k, c, 3])
activation_value = top_k_coor_max_[k, c, 0]
img = dataset_dict[img_index//10_000][img_index%10_000]['image']
class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
idx_x, idx_y = top_k_coor_max_[k, c, 1], top_k_coor_max_[k, c, 2]
x1, x2, y1, y2 = get_receptive_field_coordinates(layer_infos, activation_key, idx_x, idx_y)
img = np.array(img)[y1:y2, x1:x2, :]
hovertemplate = f"""Top-{col}<br>Activation value: {activation_value:.5f}<br>Class Label: {class_label}<br>Class id: {class_id}<br>Image id: {img_index}"""
fig.add_trace(go.Image(z=img, hovertemplate=hovertemplate), row=1, col=col)
fig.update_xaxes(showticklabels=False, showgrid=False)
fig.update_yaxes(showticklabels=False, showgrid=False)
fig.update_layout(margin={'b':0, 't':top_margin, 'r':0, 'l':0})
fig.update_layout(showlegend=False, yaxis_title=row)
fig.update_layout(height=100, plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
fig.update_layout(hoverlabel=dict(bgcolor="#e9f2f7"))
st.plotly_chart(fig, use_container_width=True)
else:
col2.markdown(f'Chosen layer: <code>None</code>', unsafe_allow_html=True)
col2.markdown("""<style>div[data-stale]:has(iframe) {height: 0};""", unsafe_allow_html=True)
|