Spaces:
Runtime error
Runtime error
File size: 7,063 Bytes
fcc16aa 7a3450e fcc16aa 7a3450e fcc16aa 7a3450e fcc16aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import streamlit as st
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import graphviz
#from backend.maximally_activating_patches import load_layer_infos, load_activation, get_receptive_field_coordinates
from frontend import on_click_graph
#from backend.utils import load_dataset_dict
HIGHTLIGHT_COLOR = '#e7bcc5'
st.set_page_config(layout='wide')
st.title('Comparison among concept vectors')
st.write('> **How do the concept vectors relate to each other?**')
st.write('> **What is their join impact on the image?**')
st.write("""Description to write""")
# -------------------------- LOAD DATASET ---------------------------------
dataset_dict = load_dataset_dict()
# -------------------------- LOAD GRAPH -----------------------------------
def load_dot_to_graph(filename):
dot = graphviz.Source.from_file(filename)
source_lines = str(dot).splitlines()
source_lines.pop(0)
source_lines.pop(-1)
graph = graphviz.Digraph()
graph.body += source_lines
return graph, dot
# st.header('ConvNeXt')
convnext_dot_file = './data/dot_architectures/convnext_architecture.dot'
convnext_graph = load_dot_to_graph(convnext_dot_file)[0]
convnext_graph.graph_attr['size'] = '4,40'
# -------------------------- DISPLAY GRAPH -----------------------------------
def chosen_node_text(clicked_node_title):
clicked_node_title = clicked_node_title.replace('stage ', 'stage_').replace('block ', 'block_')
stage_id = clicked_node_title.split()[0].split('_')[1] if 'stage' in clicked_node_title else None
block_id = clicked_node_title.split()[1].split('_')[1] if 'block' in clicked_node_title else None
layer_id = clicked_node_title.split()[-1]
if 'embeddings' in layer_id:
display_text = 'Patchify layer'
activation_key = 'embeddings.patch_embeddings'
elif 'downsampling' in layer_id:
display_text = f'Stage {stage_id} > Downsampling layer'
activation_key = f'encoder.stages[{stage_id}].downsampling_layer[1]'
else:
display_text = f'Stage {stage_id} > Block {block_id} > {layer_id} layer'
activation_key = f'encoder.stages[{int(stage_id)-1}].layers[{int(block_id)-1}].{layer_id}'
return display_text, activation_key
props = {
'hightlight_color': HIGHTLIGHT_COLOR,
'initial_state': {
'group_1_header': 'Choose an option from group 1',
'group_2_header': 'Choose an option from group 2'
}
}
col1, col2 = st.columns((2,5))
col1.markdown("#### Architecture")
col1.write('')
col1.write('Click on a layer below to generate top-k maximally activating image patches')
col1.graphviz_chart(convnext_graph)
with col2:
st.markdown("#### Output")
nodes = on_click_graph(key='toggle_buttons', **props)
# -------------------------- DISPLAY OUTPUT -----------------------------------
if nodes != None:
clicked_node_title = nodes["choice"]["node_title"]
clicked_node_id = nodes["choice"]["node_id"]
display_text, activation_key = chosen_node_text(clicked_node_title)
col2.write(f'**Chosen layer:** {display_text}')
# col2.write(f'**Activation key:** {activation_key}')
hightlight_syle = f'''
<style>
div[data-stale]:has(iframe) {{
height: 0;
}}
#{clicked_node_id}>polygon {{
fill: {HIGHTLIGHT_COLOR};
stroke: {HIGHTLIGHT_COLOR};
}}
</style>
'''
col2.markdown(hightlight_syle, unsafe_allow_html=True)
with col2:
layer_infos = None
with st.form('top_k_form'):
activation_path = './data/activation/convnext_activation.json'
activation = load_activation(activation_path)
num_channels = activation[activation_key].shape[1]
top_k = st.slider('Choose K for top-K maximally activating patches', 1,20, value=10)
channel_start, channel_end = st.slider(
'Choose channel range of this layer (recommend to choose small range less than 30)',
1, num_channels, value=(1, 30))
summit_button = st.form_submit_button('Generate image patches')
if summit_button:
activation = activation[activation_key][:top_k,:,:]
layer_infos = load_layer_infos('./data/layer_infos/convnext_layer_infos.json')
# st.write(channel_start, channel_end)
# st.write(activation.shape, activation.shape[1])
if layer_infos != None:
num_cols, num_rows = top_k, channel_end - channel_start + 1
# num_rows = activation.shape[1]
top_k_coor_max_ = activation
st.markdown(f"#### Top-{top_k} maximally activating image patches of {num_rows} channels ({channel_start}-{channel_end})")
for row in range(channel_start, channel_end+1):
if row == channel_start:
top_margin = 50
fig = make_subplots(
rows=1, cols=num_cols,
subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
else:
top_margin = 0
fig = make_subplots(rows=1, cols=num_cols, shared_yaxes=True)
for col in range(1, num_cols+1):
k, c = col-1, row-1
img_index = int(top_k_coor_max_[k, c, 3])
activation_value = top_k_coor_max_[k, c, 0]
img = dataset_dict[img_index//10_000][img_index%10_000]['image']
class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
idx_x, idx_y = top_k_coor_max_[k, c, 1], top_k_coor_max_[k, c, 2]
x1, x2, y1, y2 = get_receptive_field_coordinates(layer_infos, activation_key, idx_x, idx_y)
img = np.array(img)[y1:y2, x1:x2, :]
hovertemplate = f"""Top-{col}<br>Activation value: {activation_value:.5f}<br>Class Label: {class_label}<br>Class id: {class_id}<br>Image id: {img_index}"""
fig.add_trace(go.Image(z=img, hovertemplate=hovertemplate), row=1, col=col)
fig.update_xaxes(showticklabels=False, showgrid=False)
fig.update_yaxes(showticklabels=False, showgrid=False)
fig.update_layout(margin={'b':0, 't':top_margin, 'r':0, 'l':0})
fig.update_layout(showlegend=False, yaxis_title=row)
fig.update_layout(height=100, plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
fig.update_layout(hoverlabel=dict(bgcolor="#e9f2f7"))
st.plotly_chart(fig, use_container_width=True)
else:
col2.markdown(f'Chosen layer: <code>None</code>', unsafe_allow_html=True)
col2.markdown("""<style>div[data-stale]:has(iframe) {height: 0};""", unsafe_allow_html=True)
|