brad commited on
Commit
4c157d1
·
unverified ·
1 Parent(s): d665b68

inital transform interface

Browse files
Files changed (2) hide show
  1. helpers/models.py +9 -0
  2. main.py +173 -61
helpers/models.py CHANGED
@@ -20,6 +20,14 @@ class LayerTypes(Enum):
20
  CONVOLUTIONAL = nn.Conv2d
21
  LINEAR = nn.Linear
22
 
 
 
 
 
 
 
 
 
23
 
24
  _hook_activations = None
25
 
@@ -152,6 +160,7 @@ def get_feature_map_sizes(model, layers, img=None):
152
  """
153
  feature_map_sizes = [None] * len(layers)
154
  if img is None:
 
155
  img = h_manipulation.create_random_image((227, 227),
156
  h_manipulation.DatasetNormalizations.CIFAR10_MEAN.value,
157
  h_manipulation.DatasetNormalizations.CIFAR10_STD.value).clone().unsqueeze(0)
 
20
  CONVOLUTIONAL = nn.Conv2d
21
  LINEAR = nn.Linear
22
 
23
+ class TransformTypes(Enum):
24
+ PAD = "Pad"
25
+ JITTER = "Jitter"
26
+ RANDOM_SCALE = "Random Scale"
27
+ RANDOM_ROTATE = "Random Rotate"
28
+ AD_JITTER = "Additional Jitter"
29
+
30
+
31
 
32
  _hook_activations = None
33
 
 
160
  """
161
  feature_map_sizes = [None] * len(layers)
162
  if img is None:
163
+ # TODO Remove this and just generates a blank image of 227 by 227
164
  img = h_manipulation.create_random_image((227, 227),
165
  h_manipulation.DatasetNormalizations.CIFAR10_MEAN.value,
166
  h_manipulation.DatasetNormalizations.CIFAR10_STD.value).clone().unsqueeze(0)
main.py CHANGED
@@ -10,41 +10,32 @@ from time import sleep
10
  from lucent.optvis import render
11
  from lucent.modelzoo.util import get_model_layers
12
 
13
-
14
- # deep_orange = gr.themes.Color(c50="#FFEDE5",
15
- # c100="#FFDACC",
16
- # c200="#FFB699",
17
- # c300="#FF9166",
18
- # c400="#FF6D33",
19
- # c500="#FF4700",
20
- # c600="#CC3A00",
21
- # c700="#992B00",
22
- # c800="#661D00",
23
- # c900="#330E00",
24
- # c950="#190700")
25
- css = """
26
- div[data-testid="block-label"] {z-index: var(--layer-3)}
27
- """
28
 
29
  def main():
30
- # with gr.Blocks(theme=gr.themes.Soft(primary_hue=deep_orange,
31
- # secondary_hue=deep_orange,
32
- # neutral_hue=gr.themes.colors.zinc)) as demo:
33
- with gr.Blocks(title="Feature Visualization Generator", css=css, theme=gr.themes.Soft()) as demo:
34
- # Session states
35
- selected_layer = gr.State(None)
36
- model, model_layers = gr.State(None), gr.State(None)
37
- ft_map_sizes = gr.State(None)
38
- thresholds = gr.State(None)
39
- channel_max = gr.State(None)
40
- nodeX_max = gr.State(None)
41
- nodeY_max = gr.State(None)
42
- node_max = gr.State(None)
43
 
44
  # GUI Elements
45
  with gr.Row(): # Upper banner
46
  gr.Markdown("""# Feature Visualization Generator\n
47
- Start by selecting a model from the drop down.""")
 
 
 
 
 
 
 
48
  with gr.Row(): # Lower inputs and outputs
49
  with gr.Column(): # Inputs
50
  gr.Markdown("""## Model Settings""")
@@ -108,29 +99,102 @@ def main():
108
  precision=0,
109
  minimum=1,
110
  value=200)
111
-
112
- img_num = gr.Number(label="Image Size",
113
- info="Image is square (<value> by <value>)",
114
- precision=0,
115
- minimum=1,
116
- value=227)
117
 
118
  with gr.Accordion("Advanced Settings", open=False):
119
- gr.Markdown("""## Image Settings (WIP)""")
120
- chan_decor_ck = gr.Checkbox(label="Channel Decorrelation",
121
- info="Only works if 3 channels",
122
- value=True)
123
- spacial_decor_ck = gr.Checkbox(label="Spacial Decorrelation (FFT)",
124
- value=True)
125
- batch_num = gr.Number(label="Batch",
126
- value=1,
127
- precision=0)
128
- sd_num = gr.Number(label="Standard Deviation",
129
- value=0.01)
130
-
131
- gr.Markdown("""## Transform Settings (WIP)""")
132
- gr.Checkbox(label="Preprocess", info="Enable or disable preprocessing via transformations")
133
- gr.Dropdown(label="Applied Transforms", info="Transforms to apply", multiselect=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  confirm_btn = gr.Button("Generate", visible=False)
136
 
@@ -169,6 +233,35 @@ def main():
169
  nodeY_num.blur(check_input, inputs=[nodeY_num, nodeY_max])
170
  node_num.blur(check_input, inputs=[node_num, node_max])
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  confirm_btn.click(generate,
173
  inputs=[lr_sl,
174
  epoch_num,
@@ -182,12 +275,9 @@ def main():
182
  thresholds,
183
  chan_decor_ck,
184
  spacial_decor_ck,
185
- batch_num,
186
  sd_num],
187
  outputs=[images_gal, thresholds])
188
- images_gal.select(update_img_label,
189
- inputs=thresholds,
190
- outputs=images_gal)
191
  demo.queue().launch()
192
 
193
 
@@ -295,7 +385,7 @@ def on_layer(selected_layer, model_layers, ft_map_sizes, evt: gr.SelectData):
295
 
296
 
297
  def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, selected_layer,
298
- model, thresholds, chan_decor, spacial_decor, batch_num,
299
  sd_num, progress=gr.Progress(track_tqdm=True)):
300
  """
301
  Generates the feature visualizaiton with given parameters and tuning.
@@ -308,7 +398,6 @@ def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, selected_layer,
308
  def param_f(): return param.image(img_size,
309
  fft=spacial_decor,
310
  decorrelate=chan_decor,
311
- batch=batch_num,
312
  sd=sd_num) # Image setup
313
  def optimizer(params): return torch.optim.Adam(params, lr=lr)
314
 
@@ -344,12 +433,16 @@ def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, selected_layer,
344
  else: # Layer Specific
345
  obj = lambda m: torch.mean(torch.pow(-m(selected_layer[0]).cuda(), torch.tensor(2).cuda())).cuda()
346
  thresholds = h_manip.expo_tuple(epochs, 6)
 
 
 
347
  img = np.array(render.render_vis(model,
348
- obj,
349
- thresholds=thresholds,
350
- show_image=False,
351
- optimizer=optimizer,
352
- param_f=param_f)).squeeze(1)
 
353
 
354
  return gr.Gallery.update(img), thresholds
355
 
@@ -357,9 +450,28 @@ def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, selected_layer,
357
  def update_img_label(thresholds, evt: gr.SelectData):
358
  return gr.Gallery.update(label='Epoch ' + str(thresholds[evt.index]), show_label=True)
359
 
 
360
  def check_input(curr, maxx):
361
  if curr > maxx:
362
  raise gr.Error(f"""Value {curr} is higher then maximum of {maxx}""")
363
 
364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  main()
 
10
  from lucent.optvis import render
11
  from lucent.modelzoo.util import get_model_layers
12
 
13
+ # Custom css
14
+ css = """div[data-testid="block-label"] {z-index: var(--layer-3)}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def main():
17
+ with gr.Blocks(title="Feature Visualization Generator",
18
+ css=css,
19
+ theme=gr.themes.Soft(primary_hue="blue",
20
+ secondary_hue="blue",
21
+ )) as demo:
22
+
23
+ # Session state init
24
+ model, model_layers, selected_layer, ft_map_sizes, \
25
+ thresholds, channel_max, nodeX_max, nodeY_max, \
26
+ node_max = (gr.State(None) for _ in range(9))
 
 
 
27
 
28
  # GUI Elements
29
  with gr.Row(): # Upper banner
30
  gr.Markdown("""# Feature Visualization Generator\n
31
+ Feature Visualizations (FV's) answer questions
32
+ about what a network—or parts of a network—are
33
+ looking for by generating examples.
34
+ ([Read more about it here](https://distill.pub/2017/feature-visualization/))
35
+ This generator aims to make it easier to explore
36
+ different concepts used in FV generation and allow
37
+ for experimentation.\n\n
38
+ **Start by selecting a model from the drop down.**""")
39
  with gr.Row(): # Lower inputs and outputs
40
  with gr.Column(): # Inputs
41
  gr.Markdown("""## Model Settings""")
 
99
  precision=0,
100
  minimum=1,
101
  value=200)
 
 
 
 
 
 
102
 
103
  with gr.Accordion("Advanced Settings", open=False):
104
+ with gr.Column(variant="panel"):
105
+ gr.Markdown("""## Image Settings""")
106
+ img_num = gr.Number(label="Image Size",
107
+ info="Image is square (<value> by <value>)",
108
+ precision=0,
109
+ minimum=1,
110
+ value=227)
111
+ chan_decor_ck = gr.Checkbox(label="Channel Decorrelation",
112
+ info="Reduces channel-to-channel correlations",
113
+ value=True)
114
+ spacial_decor_ck = gr.Checkbox(label="Spacial Decorrelation (FFT)",
115
+ info="Reduces pixel-to-pixel correlations",
116
+ value=True)
117
+ sd_num = gr.Number(label="Standard Deviation",
118
+ info="The STD of the randomly generated starter image",
119
+ value=0.01)
120
+
121
+ with gr.Column(variant="panel"):
122
+ gr.Markdown("""## Transform Settings (WIP)""")
123
+ preprocess_ck = gr.Checkbox(label="Preprocess",
124
+ info="Enable or disable preprocessing via transformations",
125
+ value=True,
126
+ interactive=True)
127
+ transform_choices = [t.value for t in h_models.TransformTypes]
128
+ transforms_dd = gr.Dropdown(label="Applied Transforms",
129
+ info="Transforms to apply",
130
+ choices=transform_choices,
131
+ multiselect=True,
132
+ value=transform_choices,
133
+ interactive=True)
134
+
135
+ # Transform specific settings
136
+ pad_col = gr.Column()
137
+ with pad_col:
138
+ gr.Markdown("""### Pad Settings""")
139
+ with gr.Row():
140
+ pad_num = gr.Number(label="Padding",
141
+ info="How many pixels of padding",
142
+ minimum=0,
143
+ value=12,
144
+ precision=0,
145
+ interactive=True)
146
+ mode_rad = gr.Radio(label="Mode",
147
+ info="Constant fills padded pixels with a value. Reflect fills with edge pixels",
148
+ choices=["Constant", "Reflect"],
149
+ value="Constant",
150
+ interactive=True)
151
+ constant_num = gr.Number(label="Constant Fill Value",
152
+ info="Value to fill padded pixels",
153
+ value=0.5,
154
+ interactive=True)
155
+
156
+ jitter_col = gr.Column()
157
+ with jitter_col:
158
+ gr.Markdown("""### Jitter Settings""")
159
+ with gr.Row():
160
+ jitter_num = gr.Number(label="Jitter",
161
+ info="How much to jitter image by",
162
+ minimum=1,
163
+ value=8,
164
+ interactive=True)
165
+
166
+ rand_scale_col = gr.Column()
167
+ with rand_scale_col:
168
+ gr.Markdown("""### Random Scale Settings""")
169
+ with gr.Row():
170
+ scale_num = gr.Number(label="Max scale",
171
+ info="How much to scale in both directions (+ and -)",
172
+ minimum=0,
173
+ value=10,
174
+ interactive=True)
175
+
176
+ rand_rotate_col = gr.Column()
177
+ with rand_rotate_col:
178
+ gr.Markdown("""### Random Rotate Settings""")
179
+ with gr.Row():
180
+ rotate_num = gr.Number(label="Max angle",
181
+ info="How much to rotate in both directions (+ and -)",
182
+ minimum=0,
183
+ value=10,
184
+ interactive=True)
185
+
186
+ ad_jitter_col = gr.Column()
187
+ with ad_jitter_col:
188
+ gr.Markdown("""### Additional Jitter Settings""")
189
+ with gr.Row():
190
+ ad_jitter_num = gr.Number(label="Jitter",
191
+ info="How much to jitter image by",
192
+ minimum=1,
193
+ value=4,
194
+ interactive=True)
195
+
196
+
197
+
198
 
199
  confirm_btn = gr.Button("Generate", visible=False)
200
 
 
233
  nodeY_num.blur(check_input, inputs=[nodeY_num, nodeY_max])
234
  node_num.blur(check_input, inputs=[node_num, node_max])
235
 
236
+ images_gal.select(update_img_label,
237
+ inputs=thresholds,
238
+ outputs=images_gal)
239
+
240
+ preprocess_ck.select(lambda status: (gr.update(visible=status),
241
+ gr.update(visible=status),
242
+ gr.update(visible=status),
243
+ gr.update(visible=status),
244
+ gr.update(visible=status),
245
+ gr.update(visible=status)),
246
+ inputs=preprocess_ck,
247
+ outputs=[transforms_dd,
248
+ pad_col,
249
+ jitter_col,
250
+ rand_scale_col,
251
+ rand_rotate_col,
252
+ ad_jitter_col])
253
+
254
+ transforms_dd.change(on_transform,
255
+ inputs=transforms_dd,
256
+ outputs=[pad_col,
257
+ jitter_col,
258
+ rand_scale_col,
259
+ rand_rotate_col,
260
+ ad_jitter_col])
261
+
262
+ mode_rad.select(on_pad_mode,
263
+ outputs=constant_num)
264
+
265
  confirm_btn.click(generate,
266
  inputs=[lr_sl,
267
  epoch_num,
 
275
  thresholds,
276
  chan_decor_ck,
277
  spacial_decor_ck,
 
278
  sd_num],
279
  outputs=[images_gal, thresholds])
280
+
 
 
281
  demo.queue().launch()
282
 
283
 
 
385
 
386
 
387
  def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, selected_layer,
388
+ model, thresholds, chan_decor, spacial_decor,
389
  sd_num, progress=gr.Progress(track_tqdm=True)):
390
  """
391
  Generates the feature visualizaiton with given parameters and tuning.
 
398
  def param_f(): return param.image(img_size,
399
  fft=spacial_decor,
400
  decorrelate=chan_decor,
 
401
  sd=sd_num) # Image setup
402
  def optimizer(params): return torch.optim.Adam(params, lr=lr)
403
 
 
433
  else: # Layer Specific
434
  obj = lambda m: torch.mean(torch.pow(-m(selected_layer[0]).cuda(), torch.tensor(2).cuda())).cuda()
435
  thresholds = h_manip.expo_tuple(epochs, 6)
436
+ print(thresholds)
437
+
438
+
439
  img = np.array(render.render_vis(model,
440
+ obj,
441
+ thresholds=thresholds,
442
+ show_image=False,
443
+ optimizer=optimizer,
444
+ param_f=param_f,
445
+ verbose=True)).squeeze(1)
446
 
447
  return gr.Gallery.update(img), thresholds
448
 
 
450
  def update_img_label(thresholds, evt: gr.SelectData):
451
  return gr.Gallery.update(label='Epoch ' + str(thresholds[evt.index]), show_label=True)
452
 
453
+
454
  def check_input(curr, maxx):
455
  if curr > maxx:
456
  raise gr.Error(f"""Value {curr} is higher then maximum of {maxx}""")
457
 
458
 
459
+ def on_transform(transforms):
460
+ transform_states = {
461
+ h_models.TransformTypes.PAD.value: False,
462
+ h_models.TransformTypes.JITTER.value: False,
463
+ h_models.TransformTypes.RANDOM_SCALE.value: False,
464
+ h_models.TransformTypes.RANDOM_ROTATE.value: False,
465
+ h_models.TransformTypes.AD_JITTER.value: False
466
+ }
467
+ for transform in transforms:
468
+ transform_states[transform] = True
469
+
470
+ return [gr.update(visible=state) for state in transform_states.values()]
471
+
472
+
473
+ def on_pad_mode (evt: gr.SelectData):
474
+ if (evt.value == "Constant"):
475
+ return gr.update(visible=True)
476
+ return gr.update(visible=False)
477
  main()