brad commited on
Commit
c1b01fa
1 Parent(s): 4c157d1

working transform interface

Browse files
Files changed (2) hide show
  1. helpers/listeners.py +268 -0
  2. main.py +40 -224
helpers/listeners.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ import helpers.models as h_models
5
+ import lucent.optvis.param as param
6
+ import lucent.optvis.transform as tr
7
+ import helpers.manipulation as h_manip
8
+ import lucent.optvis.objectives as objs
9
+
10
+ from torch import nn
11
+ from time import sleep
12
+ from lucent.optvis import render
13
+ from lucent.modelzoo.util import get_model_layers
14
+
15
+
16
+ # Event listener functions
17
+ def on_model(model, model_layers, ft_map_sizes, evt: gr.SelectData, progress=gr.Progress()):
18
+ """
19
+ Logic flow when model is selected. Updates model, the model layers, and the
20
+ feature map sizes.
21
+ :param model: Current model (object) selected. Updated by this method
22
+ :param model_layers: List of model layers. Updated by this method
23
+ :param ft_map_sizes: List of Feature map sizes. Updated by this method
24
+ :param evt: Event data from Dropdown selection
25
+ :return: [Layer Dropdown Component, Model state, Model Layers state,
26
+ Feature Map Sizes State]
27
+ """
28
+ progress(0, desc="Setting up model...")
29
+ model = h_models.setup_model(h_models.ModelTypes[evt.value])
30
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
+ model.to(device).eval()
32
+
33
+ progress(0.25, desc="Getting layers names and details...")
34
+ model_layers = list(get_model_layers(model,
35
+ getLayerRepr=True).items())
36
+ choices = [f"({k}): {v.split('(')[0]}" for k, v in model_layers]
37
+
38
+
39
+ progress(0.5, desc="Getting layer objects...")
40
+ for i in range(len(model_layers)):
41
+ try:
42
+ layer = h_models.get_layer_by_name(model, model_layers[i][0])
43
+ except ValueError as e:
44
+ gr.Error(e)
45
+
46
+ model_layers[i] = (model_layers[i][0], layer)
47
+
48
+ progress(0.75, desc="Getting feature maps sizes...")
49
+ ft_map_sizes = h_models.get_feature_map_sizes(model, [v for _, v in model_layers])
50
+ progress(1, desc="Done")
51
+ sleep(0.25) # To allow for progress animation, not good practice
52
+ return [gr.update(choices=choices, value=''),
53
+ model, model_layers, ft_map_sizes]
54
+
55
+
56
+ def on_layer(selected_layer, model_layers, ft_map_sizes, evt: gr.SelectData):
57
+ """
58
+ Logic flow when a layer is selected. Updates max values of layer
59
+ specific input fields.
60
+ :param selected_layer: Current selected layer, updated by this method.
61
+ :param model_layers: All model layers
62
+ :param ft_map_sizes: Feature maps sizes for all conv layers
63
+ :param evt: Event data from Dropdown selection
64
+ :return [Layer Text Component,
65
+ Channel Number Component,
66
+ Node X Number Component,
67
+ Node Y Number Component,
68
+ Selected layer state/variable,
69
+ Channel max state/variable,
70
+ NodeX max state/variable,
71
+ NodeY max state/variable,
72
+ Node max state/variable]
73
+ """
74
+ channel_max, nodeX_max, nodeY_max, node_max = -1, -1, -1, -1
75
+ selected_layer = model_layers[evt.index]
76
+ match type(selected_layer[1]):
77
+ case nn.Conv2d:
78
+ # Calculate maxes for conv specific
79
+ channel_max = selected_layer[1].out_channels-1
80
+ nodeX_max = ft_map_sizes[evt.index][1]-1
81
+ nodeY_max = ft_map_sizes[evt.index][2]-1
82
+
83
+ return [gr.update(visible=True),
84
+ gr.Number.update(info=f"""Values between 0-{channel_max}""",
85
+ visible=True, value=None),
86
+ gr.Number.update(info=f"""Values between 0-{nodeX_max}""",
87
+ visible=True, value=None),
88
+ gr.Number.update(info=f"""Values between 0-{nodeY_max}""",
89
+ visible=True, value=None),
90
+ gr.update(visible=False, value=None),
91
+ selected_layer,
92
+ channel_max,
93
+ nodeX_max,
94
+ nodeY_max,
95
+ node_max]
96
+ case nn.Linear:
97
+ # Calculate maxes for linear specific
98
+ node_max = selected_layer[1].out_features-1
99
+ return [gr.update(visible=True),
100
+ gr.Number.update(visible=False, value=None),
101
+ gr.Number.update(visible=False, value=None),
102
+ gr.Number.update(visible=False, value=None),
103
+ gr.update(info=f"""Values between 0-{node_max}""",
104
+ maximum=node_max,
105
+ visible=True, value=None),
106
+ selected_layer,
107
+ channel_max,
108
+ nodeX_max,
109
+ nodeY_max,
110
+ node_max]
111
+ case _:
112
+ gr.Warning("Unknown layer type")
113
+ return [gr.update(visible=False),
114
+ gr.update(visible=False, value=None),
115
+ gr.update(visible=False, value=None),
116
+ gr.update(visible=False, value=None),
117
+ gr.update(visible=False, value=None),
118
+ selected_layer,
119
+ channel_max,
120
+ nodeX_max,
121
+ nodeY_max,
122
+ node_max]
123
+
124
+ # Having this many inputs is typically not good practice
125
+ def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, layer_sel,
126
+ model, thresholds, chan_decor, spacial_decor, sd_num,
127
+ transforms_selected, pad_num, pad_mode, constant_num, jitter_num,
128
+ scale_num, rotate_num, ad_jitter_num,
129
+ progress=gr.Progress(track_tqdm=True)):
130
+ """
131
+ Generates the feature visualizaiton with given parameters and tuning.
132
+ Utilizes the Lucent (Pytorch Lucid library).
133
+
134
+ Inputs are different gradio components. Outputs an image component, and
135
+ their respective epoch numbers. Method tracks its own tqdm progress.
136
+ """
137
+
138
+ # Image setup
139
+ def param_f(): return param.image(img_size,
140
+ fft=spacial_decor,
141
+ decorrelate=chan_decor,
142
+ sd=sd_num)
143
+
144
+ def optimizer(params): return torch.optim.Adam(params, lr=lr)
145
+
146
+ # Tranforms setup
147
+ tr_states = {
148
+ h_models.TransformTypes.PAD.value: None,
149
+ h_models.TransformTypes.JITTER.value: None,
150
+ h_models.TransformTypes.RANDOM_SCALE.value: None,
151
+ h_models.TransformTypes.RANDOM_ROTATE.value: None,
152
+ h_models.TransformTypes.AD_JITTER.value: None
153
+ }
154
+
155
+ for tr_sel in transforms_selected:
156
+ match tr_sel:
157
+ case h_models.TransformTypes.PAD.value:
158
+ tr_states[tr_sel] = tr.pad(pad_num,
159
+ mode = "constant" if pad_mode == "Constant" else "reflect",
160
+ constant_value=constant_num)
161
+ case h_models.TransformTypes.JITTER.value:
162
+ tr_states[tr_sel] = tr.jitter(jitter_num)
163
+ case h_models.TransformTypes.RANDOM_SCALE.value:
164
+ tr_states[tr_sel] = tr.random_scale([1.0 - scale_num + i * (scale_num*2/(51-1)) for i in range(51)])
165
+ case h_models.TransformTypes.RANDOM_ROTATE.value:
166
+ tr_states[tr_sel] = tr.random_rotate([0 - rotate_num + i for i in range(rotate_num*2+1)])
167
+ case h_models.TransformTypes.AD_JITTER.value:
168
+ tr_states[tr_sel] = tr.jitter(ad_jitter_num)
169
+
170
+ transforms = [t for t in tr_states.values() if t is not None]
171
+
172
+ # Specific layer type handling
173
+ match type(layer_sel[1]):
174
+ case nn.Conv2d:
175
+ if (channel is not None and nodeX is not None and nodeY is not None):
176
+ gr.Info("Convolutional Node Specific")
177
+ obj = objs.neuron(layer_sel[0], channel, x=nodeX, y=nodeY)
178
+
179
+ elif (channel is not None):
180
+ gr.Info("Convolutional Channel Specific ")
181
+ obj = objs.channel(layer_sel[0], channel)
182
+
183
+ elif (channel is None and nodeX is None and nodeY is None):
184
+ gr.Info("Convolutional Layer Specific")
185
+ obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(),
186
+ torch.tensor(2).cuda())).cuda()
187
+
188
+ # Unknown
189
+ else:
190
+ gr.Error("Invalid layer settings")
191
+ return None
192
+
193
+ case nn.Linear:
194
+ if (node is not None):
195
+ gr.Info("Linear Node Specific")
196
+ obj = objs.channel(layer_sel[0], node)
197
+ else:
198
+ gr.Info("Linear Layer Specific")
199
+ obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(), torch.tensor(2).cuda())).cuda()
200
+ case _:
201
+ gr.Info("Attempting unknown Layer Specific")
202
+ transforms = [] # Just in case
203
+ obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(), torch.tensor(2).cuda())).cuda()
204
+
205
+ thresholds = h_manip.expo_tuple(epochs, 6)
206
+
207
+ img = np.array(render.render_vis(model,
208
+ obj,
209
+ thresholds=thresholds,
210
+ show_image=False,
211
+ optimizer=optimizer,
212
+ param_f=param_f,
213
+ transforms=transforms,
214
+ verbose=True)).squeeze(1)
215
+
216
+ return gr.Gallery.update(img), thresholds
217
+
218
+
219
+ def update_img_label(epoch_nums, evt: gr.SelectData):
220
+ """
221
+ Updates the image label with its respective epoch number.
222
+ :param epoch_nums: The epoch numbers
223
+ :param evt: Event data from Gallery selection
224
+ :return: Image Gallery Component
225
+ """
226
+ return gr.Gallery.update(label='Epoch ' + str(epoch_nums[evt.index]),
227
+ show_label=True)
228
+
229
+
230
+ def check_input(curr, maxx):
231
+ """
232
+ Checks if the current input is higher then the max. Will raise if an error
233
+ if so.
234
+ :param curr: Current value
235
+ :param maxx: Max value to check against
236
+ """
237
+ if curr > maxx:
238
+ raise gr.Error(f"""Value {curr} is higher then maximum of {maxx}""")
239
+
240
+
241
+ def on_transform(transforms):
242
+ """
243
+ Logic for when a transform is selected. Controls the visbility of the
244
+ transform specific inputs/settings.
245
+ :param transforms: The transforms currently selected
246
+ :return: Column Components with modified visibility
247
+ """
248
+ transform_states = {
249
+ h_models.TransformTypes.PAD.value: False,
250
+ h_models.TransformTypes.JITTER.value: False,
251
+ h_models.TransformTypes.RANDOM_SCALE.value: False,
252
+ h_models.TransformTypes.RANDOM_ROTATE.value: False,
253
+ h_models.TransformTypes.AD_JITTER.value: False
254
+ }
255
+ for transform in transforms:
256
+ transform_states[transform] = True
257
+
258
+ return [gr.update(visible=state) for state in transform_states.values()]
259
+
260
+
261
+ def on_pad_mode (evt: gr.SelectData):
262
+ """
263
+ Hides the constant value input if the constant pad mode is not selected
264
+ :param evt: Event data from Radio selection
265
+ """
266
+ if (evt.value == "Constant"):
267
+ return gr.update(visible=True)
268
+ return gr.update(visible=False)
main.py CHANGED
@@ -1,14 +1,6 @@
1
- import torch
2
- import numpy as np
3
  import gradio as gr
4
  import helpers.models as h_models
5
- import helpers.manipulation as h_manip
6
- import lucent.optvis.param as param
7
- import lucent.optvis.objectives as objectives
8
- from torch import nn
9
- from time import sleep
10
- from lucent.optvis import render
11
- from lucent.modelzoo.util import get_model_layers
12
 
13
  # Custom css
14
  css = """div[data-testid="block-label"] {z-index: var(--layer-3)}"""
@@ -31,10 +23,12 @@ def main():
31
  Feature Visualizations (FV's) answer questions
32
  about what a network—or parts of a network—are
33
  looking for by generating examples.
34
- ([Read more about it here](https://distill.pub/2017/feature-visualization/))
35
- This generator aims to make it easier to explore
36
- different concepts used in FV generation and allow
37
- for experimentation.\n\n
 
 
38
  **Start by selecting a model from the drop down.**""")
39
  with gr.Row(): # Lower inputs and outputs
40
  with gr.Column(): # Inputs
@@ -103,27 +97,33 @@ def main():
103
  with gr.Accordion("Advanced Settings", open=False):
104
  with gr.Column(variant="panel"):
105
  gr.Markdown("""## Image Settings""")
 
106
  img_num = gr.Number(label="Image Size",
107
  info="Image is square (<value> by <value>)",
108
  precision=0,
109
  minimum=1,
110
  value=227)
 
111
  chan_decor_ck = gr.Checkbox(label="Channel Decorrelation",
112
  info="Reduces channel-to-channel correlations",
113
  value=True)
 
114
  spacial_decor_ck = gr.Checkbox(label="Spacial Decorrelation (FFT)",
115
  info="Reduces pixel-to-pixel correlations",
116
  value=True)
 
117
  sd_num = gr.Number(label="Standard Deviation",
118
  info="The STD of the randomly generated starter image",
119
  value=0.01)
120
 
121
  with gr.Column(variant="panel"):
122
  gr.Markdown("""## Transform Settings (WIP)""")
 
123
  preprocess_ck = gr.Checkbox(label="Preprocess",
124
  info="Enable or disable preprocessing via transformations",
125
  value=True,
126
  interactive=True)
 
127
  transform_choices = [t.value for t in h_models.TransformTypes]
128
  transforms_dd = gr.Dropdown(label="Applied Transforms",
129
  info="Transforms to apply",
@@ -131,7 +131,7 @@ def main():
131
  multiselect=True,
132
  value=transform_choices,
133
  interactive=True)
134
-
135
  # Transform specific settings
136
  pad_col = gr.Column()
137
  with pad_col:
@@ -143,11 +143,13 @@ def main():
143
  value=12,
144
  precision=0,
145
  interactive=True)
 
146
  mode_rad = gr.Radio(label="Mode",
147
  info="Constant fills padded pixels with a value. Reflect fills with edge pixels",
148
  choices=["Constant", "Reflect"],
149
  value="Constant",
150
  interactive=True)
 
151
  constant_num = gr.Number(label="Constant Fill Value",
152
  info="Value to fill padded pixels",
153
  value=0.5,
@@ -161,6 +163,7 @@ def main():
161
  info="How much to jitter image by",
162
  minimum=1,
163
  value=8,
 
164
  interactive=True)
165
 
166
  rand_scale_col = gr.Column()
@@ -168,9 +171,9 @@ def main():
168
  gr.Markdown("""### Random Scale Settings""")
169
  with gr.Row():
170
  scale_num = gr.Number(label="Max scale",
171
- info="How much to scale in both directions (+ and -)",
172
  minimum=0,
173
- value=10,
174
  interactive=True)
175
 
176
  rand_rotate_col = gr.Column()
@@ -181,6 +184,7 @@ def main():
181
  info="How much to rotate in both directions (+ and -)",
182
  minimum=0,
183
  value=10,
 
184
  interactive=True)
185
 
186
  ad_jitter_col = gr.Column()
@@ -191,10 +195,8 @@ def main():
191
  info="How much to jitter image by",
192
  minimum=1,
193
  value=4,
 
194
  interactive=True)
195
-
196
-
197
-
198
 
199
  confirm_btn = gr.Button("Generate", visible=False)
200
 
@@ -208,14 +210,14 @@ def main():
208
  # Event listener binding
209
  model_dd.select(lambda: gr.Dropdown.update(visible=True),
210
  outputs=layer_dd)
211
- model_dd.select(on_model,
212
  inputs=[model, model_layers, ft_map_sizes],
213
  outputs=[layer_dd, model, model_layers, ft_map_sizes])
214
 
215
  # TODO: Make button invisible always until layer selection
216
  layer_dd.select(lambda: gr.Button.update(visible=True),
217
  outputs=confirm_btn)
218
- layer_dd.select(on_layer,
219
  inputs=[selected_layer, model_layers, ft_map_sizes],
220
  outputs=[layer_text,
221
  channel_num,
@@ -228,12 +230,12 @@ def main():
228
  nodeY_max,
229
  node_max])
230
 
231
- channel_num.blur(check_input, inputs=[channel_num, channel_max])
232
- nodeX_num.blur(check_input, inputs=[nodeX_num, nodeX_max])
233
- nodeY_num.blur(check_input, inputs=[nodeY_num, nodeY_max])
234
- node_num.blur(check_input, inputs=[node_num, node_max])
235
 
236
- images_gal.select(update_img_label,
237
  inputs=thresholds,
238
  outputs=images_gal)
239
 
@@ -251,7 +253,7 @@ def main():
251
  rand_rotate_col,
252
  ad_jitter_col])
253
 
254
- transforms_dd.change(on_transform,
255
  inputs=transforms_dd,
256
  outputs=[pad_col,
257
  jitter_col,
@@ -259,10 +261,10 @@ def main():
259
  rand_rotate_col,
260
  ad_jitter_col])
261
 
262
- mode_rad.select(on_pad_mode,
263
  outputs=constant_num)
264
 
265
- confirm_btn.click(generate,
266
  inputs=[lr_sl,
267
  epoch_num,
268
  img_num,
@@ -275,203 +277,17 @@ def main():
275
  thresholds,
276
  chan_decor_ck,
277
  spacial_decor_ck,
278
- sd_num],
 
 
 
 
 
 
 
 
279
  outputs=[images_gal, thresholds])
280
 
281
  demo.queue().launch()
282
 
283
-
284
- # Event listener functions
285
- def on_model(model, model_layers, ft_map_sizes, evt: gr.SelectData, progress=gr.Progress()):
286
- """
287
- Logic flow when model is selected. Updates model, the model layers, and the
288
- feature map sizes.
289
- :param model: Current model (object) selected. Updated by this method
290
- :param model_layers: List of model layers. Updated by this method
291
- :param ft_map_sizes: List of Feature map sizes. Updated by this method
292
- :param evt: Event data from Dropdown selection
293
- :return: [Layer Dropdown Component, Model state, Model Layers state,
294
- Feature Map Sizes State]
295
- """
296
- progress(0, desc="Setting up model...")
297
- model = h_models.setup_model(h_models.ModelTypes[evt.value])
298
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
299
- model.to(device).eval()
300
-
301
- progress(0.25, desc="Getting layers names and details...")
302
- model_layers = list(get_model_layers(model,
303
- getLayerRepr=True).items())
304
- choices = [f"({k}): {v.split('(')[0]}" for k, v in model_layers]
305
-
306
-
307
- progress(0.5, desc="Getting layer objects...")
308
- for i in range(len(model_layers)):
309
- try:
310
- layer = h_models.get_layer_by_name(model, model_layers[i][0])
311
- except ValueError as e:
312
- gr.Error(e)
313
-
314
- model_layers[i] = (model_layers[i][0], layer)
315
-
316
- progress(0.75, desc="Getting feature maps sizes...")
317
- ft_map_sizes = h_models.get_feature_map_sizes(model, [v for _, v in model_layers])
318
- progress(1, desc="Done")
319
- sleep(0.25) # To allow for progress animation, not good practice
320
- return [gr.update(choices=choices, value=''),
321
- model, model_layers, ft_map_sizes]
322
-
323
-
324
- def on_layer(selected_layer, model_layers, ft_map_sizes, evt: gr.SelectData):
325
- """
326
- Logic flow when a layer is selected. Updates max values of layer
327
- specific input fields.
328
- :param selected_layer: Current selected layer, updated by this method.
329
- :param model_layers: All model layers
330
- :param ft_map_sizes: Feature maps sizes for all conv layers
331
- :param evt: Event data from Dropdown selection
332
- :return [Layer Text Component,
333
- Channel Number Component,
334
- Node X Number Component,
335
- Node Y Number Component,
336
- Selected layer state/variable]
337
- """
338
- channel_max, nodeX_max, nodeY_max, node_max = -1, -1, -1, -1
339
- selected_layer = model_layers[evt.index]
340
- match type(selected_layer[1]):
341
- case nn.Conv2d:
342
- channel_max = selected_layer[1].out_channels-1
343
- nodeX_max = ft_map_sizes[evt.index][1]-1
344
- nodeY_max = ft_map_sizes[evt.index][2]-1
345
-
346
- return [gr.update(visible=True),
347
- gr.Number.update(info=f"""Values between 0-{channel_max}""",
348
- visible=True, value=None),
349
- gr.Number.update(info=f"""Values between 0-{nodeX_max}""",
350
- visible=True, value=None),
351
- gr.Number.update(info=f"""Values between 0-{nodeY_max}""",
352
- visible=True, value=None),
353
- gr.update(visible=False, value=None),
354
- selected_layer,
355
- channel_max,
356
- nodeX_max,
357
- nodeY_max,
358
- node_max]
359
- case nn.Linear:
360
- node_max = selected_layer[1].out_features-1
361
- return [gr.update(visible=True),
362
- gr.Number.update(visible=False, value=None),
363
- gr.Number.update(visible=False, value=None),
364
- gr.Number.update(visible=False, value=None),
365
- gr.update(info=f"""Values between 0-{node_max}""",
366
- maximum=node_max,
367
- visible=True, value=None),
368
- selected_layer,
369
- channel_max,
370
- nodeX_max,
371
- nodeY_max,
372
- node_max]
373
- case _:
374
- gr.Warning("Unknown layer type")
375
- return [gr.update(visible=False),
376
- gr.update(visible=False, value=None),
377
- gr.update(visible=False, value=None),
378
- gr.update(visible=False, value=None),
379
- gr.update(visible=False, value=None),
380
- selected_layer,
381
- channel_max,
382
- nodeX_max,
383
- nodeY_max,
384
- node_max]
385
-
386
-
387
- def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, selected_layer,
388
- model, thresholds, chan_decor, spacial_decor,
389
- sd_num, progress=gr.Progress(track_tqdm=True)):
390
- """
391
- Generates the feature visualizaiton with given parameters and tuning.
392
- Utilizes the Lucent (Pytorch Lucid library).
393
-
394
- Inputs are different gradio components. Outputs an image component. Method
395
- tracks its own tqdm progress.
396
- """
397
-
398
- def param_f(): return param.image(img_size,
399
- fft=spacial_decor,
400
- decorrelate=chan_decor,
401
- sd=sd_num) # Image setup
402
- def optimizer(params): return torch.optim.Adam(params, lr=lr)
403
-
404
- # Specific layer type handling
405
- match type(selected_layer[1]):
406
- case nn.Conv2d:
407
- # Node specific
408
- if (channel is not None and nodeX is not None and nodeY is not None):
409
- gr.Info("Node Specific Convolution")
410
- obj = objectives.neuron(selected_layer[0],
411
- channel,
412
- x=nodeX,
413
- y=nodeY)
414
-
415
- # Channel specific
416
- elif (channel is not None):
417
- gr.Info("Channel Specific Convolution")
418
- obj = objectives.channel(selected_layer[0], channel)
419
-
420
- # Layer specific
421
- elif (channel is None and nodeX is None and nodeY is None):
422
- gr.Info("Layer Specific Convolution")
423
- obj = lambda m: torch.mean(torch.pow(-m(selected_layer[0]).cuda(), torch.tensor(2).cuda())).cuda()
424
-
425
- # Unknown
426
- else:
427
- gr.Error("Invalid layer settings")
428
- return None
429
-
430
- case nn.Linear:
431
- if (node is not None): # Node Specific
432
- obj = objectives.channel(selected_layer[0], node)
433
- else: # Layer Specific
434
- obj = lambda m: torch.mean(torch.pow(-m(selected_layer[0]).cuda(), torch.tensor(2).cuda())).cuda()
435
- thresholds = h_manip.expo_tuple(epochs, 6)
436
- print(thresholds)
437
-
438
-
439
- img = np.array(render.render_vis(model,
440
- obj,
441
- thresholds=thresholds,
442
- show_image=False,
443
- optimizer=optimizer,
444
- param_f=param_f,
445
- verbose=True)).squeeze(1)
446
-
447
- return gr.Gallery.update(img), thresholds
448
-
449
-
450
- def update_img_label(thresholds, evt: gr.SelectData):
451
- return gr.Gallery.update(label='Epoch ' + str(thresholds[evt.index]), show_label=True)
452
-
453
-
454
- def check_input(curr, maxx):
455
- if curr > maxx:
456
- raise gr.Error(f"""Value {curr} is higher then maximum of {maxx}""")
457
-
458
-
459
- def on_transform(transforms):
460
- transform_states = {
461
- h_models.TransformTypes.PAD.value: False,
462
- h_models.TransformTypes.JITTER.value: False,
463
- h_models.TransformTypes.RANDOM_SCALE.value: False,
464
- h_models.TransformTypes.RANDOM_ROTATE.value: False,
465
- h_models.TransformTypes.AD_JITTER.value: False
466
- }
467
- for transform in transforms:
468
- transform_states[transform] = True
469
-
470
- return [gr.update(visible=state) for state in transform_states.values()]
471
-
472
-
473
- def on_pad_mode (evt: gr.SelectData):
474
- if (evt.value == "Constant"):
475
- return gr.update(visible=True)
476
- return gr.update(visible=False)
477
  main()
 
 
 
1
  import gradio as gr
2
  import helpers.models as h_models
3
+ import helpers.listeners as listeners
 
 
 
 
 
 
4
 
5
  # Custom css
6
  css = """div[data-testid="block-label"] {z-index: var(--layer-3)}"""
 
23
  Feature Visualizations (FV's) answer questions
24
  about what a network—or parts of a network—are
25
  looking for by generating examples.
26
+ ([Read more about it here](https://distill.pub/2017/feature-visualization/)
27
+ FVs are a part of a wider field called Explainable
28
+ Artificial Intelligence (XAI) This generator aims
29
+ to make it easier to explore different concepts
30
+ used in FV generation and allow for experimentation.
31
+ Currently Convolutional and Linear layers were tested.\n\n
32
  **Start by selecting a model from the drop down.**""")
33
  with gr.Row(): # Lower inputs and outputs
34
  with gr.Column(): # Inputs
 
97
  with gr.Accordion("Advanced Settings", open=False):
98
  with gr.Column(variant="panel"):
99
  gr.Markdown("""## Image Settings""")
100
+
101
  img_num = gr.Number(label="Image Size",
102
  info="Image is square (<value> by <value>)",
103
  precision=0,
104
  minimum=1,
105
  value=227)
106
+
107
  chan_decor_ck = gr.Checkbox(label="Channel Decorrelation",
108
  info="Reduces channel-to-channel correlations",
109
  value=True)
110
+
111
  spacial_decor_ck = gr.Checkbox(label="Spacial Decorrelation (FFT)",
112
  info="Reduces pixel-to-pixel correlations",
113
  value=True)
114
+
115
  sd_num = gr.Number(label="Standard Deviation",
116
  info="The STD of the randomly generated starter image",
117
  value=0.01)
118
 
119
  with gr.Column(variant="panel"):
120
  gr.Markdown("""## Transform Settings (WIP)""")
121
+
122
  preprocess_ck = gr.Checkbox(label="Preprocess",
123
  info="Enable or disable preprocessing via transformations",
124
  value=True,
125
  interactive=True)
126
+
127
  transform_choices = [t.value for t in h_models.TransformTypes]
128
  transforms_dd = gr.Dropdown(label="Applied Transforms",
129
  info="Transforms to apply",
 
131
  multiselect=True,
132
  value=transform_choices,
133
  interactive=True)
134
+
135
  # Transform specific settings
136
  pad_col = gr.Column()
137
  with pad_col:
 
143
  value=12,
144
  precision=0,
145
  interactive=True)
146
+
147
  mode_rad = gr.Radio(label="Mode",
148
  info="Constant fills padded pixels with a value. Reflect fills with edge pixels",
149
  choices=["Constant", "Reflect"],
150
  value="Constant",
151
  interactive=True)
152
+
153
  constant_num = gr.Number(label="Constant Fill Value",
154
  info="Value to fill padded pixels",
155
  value=0.5,
 
163
  info="How much to jitter image by",
164
  minimum=1,
165
  value=8,
166
+ precision=0,
167
  interactive=True)
168
 
169
  rand_scale_col = gr.Column()
 
171
  gr.Markdown("""### Random Scale Settings""")
172
  with gr.Row():
173
  scale_num = gr.Number(label="Max scale",
174
+ info="How much to scale (from 1.0) in both directions (+ and -)",
175
  minimum=0,
176
+ value=0.1,
177
  interactive=True)
178
 
179
  rand_rotate_col = gr.Column()
 
184
  info="How much to rotate in both directions (+ and -)",
185
  minimum=0,
186
  value=10,
187
+ precision=0,
188
  interactive=True)
189
 
190
  ad_jitter_col = gr.Column()
 
195
  info="How much to jitter image by",
196
  minimum=1,
197
  value=4,
198
+ precision=0,
199
  interactive=True)
 
 
 
200
 
201
  confirm_btn = gr.Button("Generate", visible=False)
202
 
 
210
  # Event listener binding
211
  model_dd.select(lambda: gr.Dropdown.update(visible=True),
212
  outputs=layer_dd)
213
+ model_dd.select(listeners.on_model,
214
  inputs=[model, model_layers, ft_map_sizes],
215
  outputs=[layer_dd, model, model_layers, ft_map_sizes])
216
 
217
  # TODO: Make button invisible always until layer selection
218
  layer_dd.select(lambda: gr.Button.update(visible=True),
219
  outputs=confirm_btn)
220
+ layer_dd.select(listeners.on_layer,
221
  inputs=[selected_layer, model_layers, ft_map_sizes],
222
  outputs=[layer_text,
223
  channel_num,
 
230
  nodeY_max,
231
  node_max])
232
 
233
+ channel_num.blur(listeners.check_input, inputs=[channel_num, channel_max])
234
+ nodeX_num.blur(listeners.check_input, inputs=[nodeX_num, nodeX_max])
235
+ nodeY_num.blur(listeners.check_input, inputs=[nodeY_num, nodeY_max])
236
+ node_num.blur(listeners.check_input, inputs=[node_num, node_max])
237
 
238
+ images_gal.select(listeners.update_img_label,
239
  inputs=thresholds,
240
  outputs=images_gal)
241
 
 
253
  rand_rotate_col,
254
  ad_jitter_col])
255
 
256
+ transforms_dd.change(listeners.on_transform,
257
  inputs=transforms_dd,
258
  outputs=[pad_col,
259
  jitter_col,
 
261
  rand_rotate_col,
262
  ad_jitter_col])
263
 
264
+ mode_rad.select(listeners.on_pad_mode,
265
  outputs=constant_num)
266
 
267
+ confirm_btn.click(listeners.generate,
268
  inputs=[lr_sl,
269
  epoch_num,
270
  img_num,
 
277
  thresholds,
278
  chan_decor_ck,
279
  spacial_decor_ck,
280
+ sd_num,
281
+ transforms_dd,
282
+ pad_num,
283
+ mode_rad,
284
+ constant_num,
285
+ jitter_num,
286
+ scale_num,
287
+ rotate_num,
288
+ ad_jitter_num],
289
  outputs=[images_gal, thresholds])
290
 
291
  demo.queue().launch()
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  main()