File size: 7,063 Bytes
fcc16aa
 
 
 
 
 
 
 
7a3450e
fcc16aa
7a3450e
fcc16aa
 
 
 
 
7a3450e
 
 
 
fcc16aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import streamlit as st
import numpy as np

from plotly.subplots import make_subplots
import plotly.graph_objects as go

import graphviz

#from backend.maximally_activating_patches import load_layer_infos, load_activation, get_receptive_field_coordinates
from frontend import on_click_graph
#from backend.utils import load_dataset_dict

HIGHTLIGHT_COLOR = '#e7bcc5'
st.set_page_config(layout='wide')


st.title('Comparison among concept vectors')
st.write('> **How do the concept vectors relate to each other?**')
st.write('> **What is their join impact on the image?**')
st.write("""Description to write""")

# -------------------------- LOAD DATASET ---------------------------------
dataset_dict = load_dataset_dict()

# -------------------------- LOAD GRAPH -----------------------------------

def load_dot_to_graph(filename):
    dot = graphviz.Source.from_file(filename)
    source_lines = str(dot).splitlines()
    source_lines.pop(0)
    source_lines.pop(-1)
    graph = graphviz.Digraph()
    graph.body += source_lines
    return graph, dot
    

# st.header('ConvNeXt')
convnext_dot_file = './data/dot_architectures/convnext_architecture.dot'
convnext_graph = load_dot_to_graph(convnext_dot_file)[0]

convnext_graph.graph_attr['size'] = '4,40'

# -------------------------- DISPLAY GRAPH -----------------------------------

def chosen_node_text(clicked_node_title):
    clicked_node_title = clicked_node_title.replace('stage ', 'stage_').replace('block ', 'block_')
    stage_id = clicked_node_title.split()[0].split('_')[1] if 'stage' in clicked_node_title else None
    block_id = clicked_node_title.split()[1].split('_')[1] if 'block' in clicked_node_title else None
    layer_id = clicked_node_title.split()[-1]
    
    if 'embeddings' in layer_id:
        display_text = 'Patchify layer'
        activation_key = 'embeddings.patch_embeddings'
    elif 'downsampling' in layer_id:
        display_text = f'Stage {stage_id} > Downsampling layer'
        activation_key = f'encoder.stages[{stage_id}].downsampling_layer[1]'
    else:
        display_text = f'Stage {stage_id} > Block {block_id} > {layer_id} layer'
        activation_key = f'encoder.stages[{int(stage_id)-1}].layers[{int(block_id)-1}].{layer_id}'
    return display_text, activation_key


props = {
    'hightlight_color': HIGHTLIGHT_COLOR,
    'initial_state': {
        'group_1_header': 'Choose an option from group 1',
        'group_2_header': 'Choose an option from group 2'
    }
}


col1, col2 = st.columns((2,5))
col1.markdown("#### Architecture")
col1.write('')
col1.write('Click on a layer below to generate top-k maximally activating image patches')
col1.graphviz_chart(convnext_graph)

with col2:
    st.markdown("#### Output")
    nodes = on_click_graph(key='toggle_buttons', **props)

# -------------------------- DISPLAY OUTPUT -----------------------------------

if nodes != None:
    clicked_node_title = nodes["choice"]["node_title"]
    clicked_node_id = nodes["choice"]["node_id"]
    display_text, activation_key = chosen_node_text(clicked_node_title)
    col2.write(f'**Chosen layer:** {display_text}')
    # col2.write(f'**Activation key:** {activation_key}')

    hightlight_syle = f'''
        <style>
            div[data-stale]:has(iframe) {{
                height: 0;
            }}
            #{clicked_node_id}>polygon {{
                fill: {HIGHTLIGHT_COLOR};
                stroke: {HIGHTLIGHT_COLOR};
            }}
        </style>
    '''
    col2.markdown(hightlight_syle, unsafe_allow_html=True)

    with col2:
        layer_infos = None
        with st.form('top_k_form'):
            activation_path = './data/activation/convnext_activation.json'
            activation = load_activation(activation_path)
            num_channels = activation[activation_key].shape[1]

            top_k = st.slider('Choose K for top-K maximally activating patches', 1,20, value=10)
            channel_start, channel_end = st.slider(
                'Choose channel range of this layer (recommend to choose small range less than 30)',
                1, num_channels, value=(1, 30))
            summit_button = st.form_submit_button('Generate image patches')
            if summit_button:
                
                activation = activation[activation_key][:top_k,:,:]
                layer_infos = load_layer_infos('./data/layer_infos/convnext_layer_infos.json')
                # st.write(channel_start, channel_end)
                # st.write(activation.shape, activation.shape[1])

        if layer_infos != None:
            num_cols, num_rows = top_k, channel_end - channel_start + 1
            # num_rows = activation.shape[1]
            top_k_coor_max_ = activation
            st.markdown(f"#### Top-{top_k} maximally activating image patches of {num_rows} channels ({channel_start}-{channel_end})")

            for row in range(channel_start, channel_end+1):
                if row == channel_start:
                    top_margin = 50
                    fig = make_subplots(
                        rows=1, cols=num_cols, 
                        subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
                else:
                    top_margin = 0
                    fig = make_subplots(rows=1, cols=num_cols, shared_yaxes=True)
                for col in range(1, num_cols+1):
                    k, c = col-1, row-1
                    img_index = int(top_k_coor_max_[k, c, 3])
                    activation_value = top_k_coor_max_[k, c, 0]
                    img = dataset_dict[img_index//10_000][img_index%10_000]['image']
                    class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
                    class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']

                    idx_x, idx_y = top_k_coor_max_[k, c, 1], top_k_coor_max_[k, c, 2]
                    x1, x2, y1, y2 = get_receptive_field_coordinates(layer_infos, activation_key, idx_x, idx_y)
                    img = np.array(img)[y1:y2, x1:x2, :]
                    
                    hovertemplate = f"""Top-{col}<br>Activation value: {activation_value:.5f}<br>Class Label: {class_label}<br>Class id: {class_id}<br>Image id: {img_index}"""
                    fig.add_trace(go.Image(z=img, hovertemplate=hovertemplate), row=1, col=col)
                    fig.update_xaxes(showticklabels=False, showgrid=False)
                    fig.update_yaxes(showticklabels=False, showgrid=False)
                    fig.update_layout(margin={'b':0, 't':top_margin, 'r':0, 'l':0})
                    fig.update_layout(showlegend=False, yaxis_title=row)
                    fig.update_layout(height=100, plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
                    fig.update_layout(hoverlabel=dict(bgcolor="#e9f2f7"))
                st.plotly_chart(fig, use_container_width=True)


else:
    col2.markdown(f'Chosen layer: <code>None</code>', unsafe_allow_html=True)
    col2.markdown("""<style>div[data-stale]:has(iframe) {height: 0};""", unsafe_allow_html=True)