File size: 11,923 Bytes
c1b01fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e25210
 
 
 
 
 
c1b01fa
 
 
 
 
 
 
 
 
 
 
 
2e25210
 
 
 
 
 
c1b01fa
 
 
2e25210
 
 
 
 
 
c1b01fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import torch
import numpy as np
import gradio as gr
import helpers.models as h_models
import lucent.optvis.param as param
import lucent.optvis.transform as tr
import helpers.manipulation as h_manip
import lucent.optvis.objectives as objs

from torch import nn
from time import sleep
from lucent.optvis import render
from lucent.modelzoo.util import get_model_layers


# Event listener functions
def on_model(model, model_layers, ft_map_sizes, evt: gr.SelectData, progress=gr.Progress()):
    """
    Logic flow when model is selected. Updates model, the model layers, and the
    feature map sizes.
    :param model: Current model (object) selected. Updated by this method
    :param model_layers: List of model layers. Updated by this method
    :param ft_map_sizes: List of Feature map sizes. Updated by this method
    :param evt: Event data from Dropdown selection
    :return: [Layer Dropdown Component, Model state, Model Layers state, 
              Feature Map Sizes State]
    """
    progress(0, desc="Setting up model...")
    model = h_models.setup_model(h_models.ModelTypes[evt.value])
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device).eval()

    progress(0.25, desc="Getting layers names and details...")
    model_layers = list(get_model_layers(model, 
                                            getLayerRepr=True).items())
    choices = [f"({k}): {v.split('(')[0]}" for k, v in model_layers]


    progress(0.5, desc="Getting layer objects...")
    for i in range(len(model_layers)):
        try:
            layer = h_models.get_layer_by_name(model, model_layers[i][0])
        except ValueError as e:
            gr.Error(e)

        model_layers[i] = (model_layers[i][0], layer)

    progress(0.75, desc="Getting feature maps sizes...")
    ft_map_sizes = h_models.get_feature_map_sizes(model, [v for _, v in model_layers])
    progress(1, desc="Done")
    sleep(0.25)  # To allow for progress animation, not good practice
    return [gr.update(choices=choices, value=''),
            model, model_layers, ft_map_sizes]


def on_layer(selected_layer, model_layers, ft_map_sizes, evt: gr.SelectData):
    """
    Logic flow when a layer is selected. Updates max values of layer
    specific input fields.
    :param selected_layer: Current selected layer, updated by this method.
    :param model_layers: All model layers
    :param ft_map_sizes: Feature maps sizes for all conv layers
    :param evt: Event data from Dropdown selection
    :return [Layer Text Component,
             Channel Number Component,
             Node X Number Component,
             Node Y Number Component,
             Selected layer state/variable,
             Channel max state/variable,
             NodeX max state/variable,
             NodeY max state/variable,
             Node max state/variable]
    """
    channel_max, nodeX_max, nodeY_max, node_max = -1, -1, -1, -1
    selected_layer = model_layers[evt.index]
    match type(selected_layer[1]):
        case nn.Conv2d:
            # Calculate maxes for conv specific
            channel_max = selected_layer[1].out_channels-1
            nodeX_max = ft_map_sizes[evt.index][1]-1
            nodeY_max = ft_map_sizes[evt.index][2]-1
            
            return [gr.update(visible=True),
                    gr.Number.update(info=f"""Values between 0-{channel_max}""", 
                              visible=True, value=None),
                    gr.Number.update(info=f"""Values between 0-{nodeX_max}""", 
                              visible=True, value=None),
                    gr.Number.update(info=f"""Values between 0-{nodeY_max}""", 
                              visible=True, value=None),
                    gr.update(visible=False, value=None),
                    selected_layer,
                    channel_max,
                    nodeX_max,
                    nodeY_max,
                    node_max]
        case nn.Linear:
            # Calculate maxes for linear specific
            node_max = selected_layer[1].out_features-1
            return [gr.update(visible=True),
                    gr.Number.update(visible=False, value=None),
                    gr.Number.update(visible=False, value=None),
                    gr.Number.update(visible=False, value=None),
                    gr.update(info=f"""Values between 0-{node_max}""", 
                              maximum=node_max, 
                              visible=True, value=None),
                    selected_layer,
                    channel_max,
                    nodeX_max,
                    nodeY_max,
                    node_max]
        case _:
            gr.Warning("Unknown layer type")
            return [gr.update(visible=False),
                    gr.update(visible=False, value=None),
                    gr.update(visible=False, value=None),
                    gr.update(visible=False, value=None),
                    gr.update(visible=False, value=None),
                    selected_layer,
                    channel_max,
                    nodeX_max,
                    nodeY_max,
                    node_max]

# Having this many inputs is typically not good practice
def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, layer_sel, 
             model, thresholds, chan_decor, spacial_decor, sd_num, 
             transforms_selected, pad_num, pad_mode, constant_num, jitter_num, 
             scale_num, rotate_num, ad_jitter_num, 
             progress=gr.Progress(track_tqdm=True)):
    """
    Generates the feature visualizaiton with given parameters and tuning. 
    Utilizes the Lucent (Pytorch Lucid library).

    Inputs are different gradio components. Outputs an image component, and 
    their respective epoch numbers. Method tracks its own tqdm progress.
    """
    
    # Image setup
    def param_f(): return param.image(img_size, 
                                      fft=spacial_decor, 
                                      decorrelate=chan_decor,
                                      sd=sd_num)
    
    def optimizer(params): return torch.optim.Adam(params, lr=lr)
    
    # Tranforms setup
    tr_states = {
        h_models.TransformTypes.PAD.value: None,
        h_models.TransformTypes.JITTER.value: None,
        h_models.TransformTypes.RANDOM_SCALE.value: None,
        h_models.TransformTypes.RANDOM_ROTATE.value: None,
        h_models.TransformTypes.AD_JITTER.value: None
    }

    for tr_sel in transforms_selected:
        match tr_sel:
            case h_models.TransformTypes.PAD.value:
                tr_states[tr_sel] = tr.pad(pad_num, 
                                           mode = "constant" if pad_mode == "Constant" else "reflect",
                                           constant_value=constant_num)
            case h_models.TransformTypes.JITTER.value:
                tr_states[tr_sel] = tr.jitter(jitter_num)
            case h_models.TransformTypes.RANDOM_SCALE.value:
                tr_states[tr_sel] = tr.random_scale([1.0 - scale_num + i * (scale_num*2/(51-1)) for i in range(51)])
            case h_models.TransformTypes.RANDOM_ROTATE.value:
                tr_states[tr_sel] = tr.random_rotate([0 - rotate_num + i for i in range(rotate_num*2+1)])
            case h_models.TransformTypes.AD_JITTER.value:
                tr_states[tr_sel] = tr.jitter(ad_jitter_num)
    
    transforms = [t for t in tr_states.values() if t is not None]

    # Specific layer type handling
    match type(layer_sel[1]):
        case nn.Conv2d:
            if (channel is not None and nodeX is not None and nodeY is not None):
                gr.Info("Convolutional Node Specific")
                obj = objs.neuron(layer_sel[0], channel, x=nodeX, y=nodeY)

            elif (channel is not None):
                gr.Info("Convolutional Channel Specific ")
                obj = objs.channel(layer_sel[0], channel)

            elif (channel is None and nodeX is None and nodeY is None):
                gr.Info("Convolutional Layer Specific")
                if torch.cuda.is_available():
                    obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(), 
                                                         torch.tensor(2).cuda())).cuda()
                else:
                    obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]), 
                                                         torch.tensor(2)))
            
            # Unknown
            else:
                gr.Error("Invalid layer settings")
                return None

        case nn.Linear:
            if (node is not None):
                gr.Info("Linear Node Specific")
                obj = objs.channel(layer_sel[0], node)
            else:
                gr.Info("Linear Layer Specific")
                if torch.cuda.is_available():
                    obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(), 
                                                         torch.tensor(2).cuda())).cuda()
                else:
                    obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]), 
                                                         torch.tensor(2)))
        case _:
            gr.Info("Attempting unknown Layer Specific")
            transforms = [] # Just in case
            if torch.cuda.is_available():
                obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(), 
                                                        torch.tensor(2).cuda())).cuda()
            else:
                obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]), 
                                                        torch.tensor(2)))

    thresholds = h_manip.expo_tuple(epochs, 6)

    img = np.array(render.render_vis(model,
                                     obj,
                                     thresholds=thresholds,
                                     show_image=False,
                                     optimizer=optimizer,
                                     param_f=param_f,
                                     transforms=transforms,
                                     verbose=True)).squeeze(1)
    
    return gr.Gallery.update(img), thresholds


def update_img_label(epoch_nums, evt: gr.SelectData):
    """
    Updates the image label with its respective epoch number.
    :param epoch_nums: The epoch numbers
    :param evt: Event data from Gallery selection
    :return: Image Gallery Component
    """
    return gr.Gallery.update(label='Epoch ' + str(epoch_nums[evt.index]), 
                             show_label=True)


def check_input(curr, maxx):
    """
    Checks if the current input is higher then the max. Will raise if an error
    if so.
    :param curr: Current value
    :param maxx: Max value to check against
    """
    if curr > maxx:
        raise gr.Error(f"""Value {curr} is higher then maximum of {maxx}""")


def on_transform(transforms):
    """
    Logic for when a transform is selected. Controls the visbility of the
    transform specific inputs/settings.
    :param transforms: The transforms currently selected
    :return: Column Components with modified visibility
    """
    transform_states = {
        h_models.TransformTypes.PAD.value: False,
        h_models.TransformTypes.JITTER.value: False,
        h_models.TransformTypes.RANDOM_SCALE.value: False,
        h_models.TransformTypes.RANDOM_ROTATE.value: False,
        h_models.TransformTypes.AD_JITTER.value: False
    }
    for transform in transforms:
        transform_states[transform] = True

    return [gr.update(visible=state) for state in transform_states.values()]


def on_pad_mode (evt: gr.SelectData):
    """
    Hides the constant value input if the constant pad mode is not selected
    :param evt: Event data from Radio selection
    """
    if (evt.value == "Constant"):
        return gr.update(visible=True)
    return gr.update(visible=False)