Spaces:
Runtime error
Runtime error
brad
commited on
Commit
•
c1b01fa
1
Parent(s):
4c157d1
working transform interface
Browse files- helpers/listeners.py +268 -0
- main.py +40 -224
helpers/listeners.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
+
import helpers.models as h_models
|
5 |
+
import lucent.optvis.param as param
|
6 |
+
import lucent.optvis.transform as tr
|
7 |
+
import helpers.manipulation as h_manip
|
8 |
+
import lucent.optvis.objectives as objs
|
9 |
+
|
10 |
+
from torch import nn
|
11 |
+
from time import sleep
|
12 |
+
from lucent.optvis import render
|
13 |
+
from lucent.modelzoo.util import get_model_layers
|
14 |
+
|
15 |
+
|
16 |
+
# Event listener functions
|
17 |
+
def on_model(model, model_layers, ft_map_sizes, evt: gr.SelectData, progress=gr.Progress()):
|
18 |
+
"""
|
19 |
+
Logic flow when model is selected. Updates model, the model layers, and the
|
20 |
+
feature map sizes.
|
21 |
+
:param model: Current model (object) selected. Updated by this method
|
22 |
+
:param model_layers: List of model layers. Updated by this method
|
23 |
+
:param ft_map_sizes: List of Feature map sizes. Updated by this method
|
24 |
+
:param evt: Event data from Dropdown selection
|
25 |
+
:return: [Layer Dropdown Component, Model state, Model Layers state,
|
26 |
+
Feature Map Sizes State]
|
27 |
+
"""
|
28 |
+
progress(0, desc="Setting up model...")
|
29 |
+
model = h_models.setup_model(h_models.ModelTypes[evt.value])
|
30 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
31 |
+
model.to(device).eval()
|
32 |
+
|
33 |
+
progress(0.25, desc="Getting layers names and details...")
|
34 |
+
model_layers = list(get_model_layers(model,
|
35 |
+
getLayerRepr=True).items())
|
36 |
+
choices = [f"({k}): {v.split('(')[0]}" for k, v in model_layers]
|
37 |
+
|
38 |
+
|
39 |
+
progress(0.5, desc="Getting layer objects...")
|
40 |
+
for i in range(len(model_layers)):
|
41 |
+
try:
|
42 |
+
layer = h_models.get_layer_by_name(model, model_layers[i][0])
|
43 |
+
except ValueError as e:
|
44 |
+
gr.Error(e)
|
45 |
+
|
46 |
+
model_layers[i] = (model_layers[i][0], layer)
|
47 |
+
|
48 |
+
progress(0.75, desc="Getting feature maps sizes...")
|
49 |
+
ft_map_sizes = h_models.get_feature_map_sizes(model, [v for _, v in model_layers])
|
50 |
+
progress(1, desc="Done")
|
51 |
+
sleep(0.25) # To allow for progress animation, not good practice
|
52 |
+
return [gr.update(choices=choices, value=''),
|
53 |
+
model, model_layers, ft_map_sizes]
|
54 |
+
|
55 |
+
|
56 |
+
def on_layer(selected_layer, model_layers, ft_map_sizes, evt: gr.SelectData):
|
57 |
+
"""
|
58 |
+
Logic flow when a layer is selected. Updates max values of layer
|
59 |
+
specific input fields.
|
60 |
+
:param selected_layer: Current selected layer, updated by this method.
|
61 |
+
:param model_layers: All model layers
|
62 |
+
:param ft_map_sizes: Feature maps sizes for all conv layers
|
63 |
+
:param evt: Event data from Dropdown selection
|
64 |
+
:return [Layer Text Component,
|
65 |
+
Channel Number Component,
|
66 |
+
Node X Number Component,
|
67 |
+
Node Y Number Component,
|
68 |
+
Selected layer state/variable,
|
69 |
+
Channel max state/variable,
|
70 |
+
NodeX max state/variable,
|
71 |
+
NodeY max state/variable,
|
72 |
+
Node max state/variable]
|
73 |
+
"""
|
74 |
+
channel_max, nodeX_max, nodeY_max, node_max = -1, -1, -1, -1
|
75 |
+
selected_layer = model_layers[evt.index]
|
76 |
+
match type(selected_layer[1]):
|
77 |
+
case nn.Conv2d:
|
78 |
+
# Calculate maxes for conv specific
|
79 |
+
channel_max = selected_layer[1].out_channels-1
|
80 |
+
nodeX_max = ft_map_sizes[evt.index][1]-1
|
81 |
+
nodeY_max = ft_map_sizes[evt.index][2]-1
|
82 |
+
|
83 |
+
return [gr.update(visible=True),
|
84 |
+
gr.Number.update(info=f"""Values between 0-{channel_max}""",
|
85 |
+
visible=True, value=None),
|
86 |
+
gr.Number.update(info=f"""Values between 0-{nodeX_max}""",
|
87 |
+
visible=True, value=None),
|
88 |
+
gr.Number.update(info=f"""Values between 0-{nodeY_max}""",
|
89 |
+
visible=True, value=None),
|
90 |
+
gr.update(visible=False, value=None),
|
91 |
+
selected_layer,
|
92 |
+
channel_max,
|
93 |
+
nodeX_max,
|
94 |
+
nodeY_max,
|
95 |
+
node_max]
|
96 |
+
case nn.Linear:
|
97 |
+
# Calculate maxes for linear specific
|
98 |
+
node_max = selected_layer[1].out_features-1
|
99 |
+
return [gr.update(visible=True),
|
100 |
+
gr.Number.update(visible=False, value=None),
|
101 |
+
gr.Number.update(visible=False, value=None),
|
102 |
+
gr.Number.update(visible=False, value=None),
|
103 |
+
gr.update(info=f"""Values between 0-{node_max}""",
|
104 |
+
maximum=node_max,
|
105 |
+
visible=True, value=None),
|
106 |
+
selected_layer,
|
107 |
+
channel_max,
|
108 |
+
nodeX_max,
|
109 |
+
nodeY_max,
|
110 |
+
node_max]
|
111 |
+
case _:
|
112 |
+
gr.Warning("Unknown layer type")
|
113 |
+
return [gr.update(visible=False),
|
114 |
+
gr.update(visible=False, value=None),
|
115 |
+
gr.update(visible=False, value=None),
|
116 |
+
gr.update(visible=False, value=None),
|
117 |
+
gr.update(visible=False, value=None),
|
118 |
+
selected_layer,
|
119 |
+
channel_max,
|
120 |
+
nodeX_max,
|
121 |
+
nodeY_max,
|
122 |
+
node_max]
|
123 |
+
|
124 |
+
# Having this many inputs is typically not good practice
|
125 |
+
def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, layer_sel,
|
126 |
+
model, thresholds, chan_decor, spacial_decor, sd_num,
|
127 |
+
transforms_selected, pad_num, pad_mode, constant_num, jitter_num,
|
128 |
+
scale_num, rotate_num, ad_jitter_num,
|
129 |
+
progress=gr.Progress(track_tqdm=True)):
|
130 |
+
"""
|
131 |
+
Generates the feature visualizaiton with given parameters and tuning.
|
132 |
+
Utilizes the Lucent (Pytorch Lucid library).
|
133 |
+
|
134 |
+
Inputs are different gradio components. Outputs an image component, and
|
135 |
+
their respective epoch numbers. Method tracks its own tqdm progress.
|
136 |
+
"""
|
137 |
+
|
138 |
+
# Image setup
|
139 |
+
def param_f(): return param.image(img_size,
|
140 |
+
fft=spacial_decor,
|
141 |
+
decorrelate=chan_decor,
|
142 |
+
sd=sd_num)
|
143 |
+
|
144 |
+
def optimizer(params): return torch.optim.Adam(params, lr=lr)
|
145 |
+
|
146 |
+
# Tranforms setup
|
147 |
+
tr_states = {
|
148 |
+
h_models.TransformTypes.PAD.value: None,
|
149 |
+
h_models.TransformTypes.JITTER.value: None,
|
150 |
+
h_models.TransformTypes.RANDOM_SCALE.value: None,
|
151 |
+
h_models.TransformTypes.RANDOM_ROTATE.value: None,
|
152 |
+
h_models.TransformTypes.AD_JITTER.value: None
|
153 |
+
}
|
154 |
+
|
155 |
+
for tr_sel in transforms_selected:
|
156 |
+
match tr_sel:
|
157 |
+
case h_models.TransformTypes.PAD.value:
|
158 |
+
tr_states[tr_sel] = tr.pad(pad_num,
|
159 |
+
mode = "constant" if pad_mode == "Constant" else "reflect",
|
160 |
+
constant_value=constant_num)
|
161 |
+
case h_models.TransformTypes.JITTER.value:
|
162 |
+
tr_states[tr_sel] = tr.jitter(jitter_num)
|
163 |
+
case h_models.TransformTypes.RANDOM_SCALE.value:
|
164 |
+
tr_states[tr_sel] = tr.random_scale([1.0 - scale_num + i * (scale_num*2/(51-1)) for i in range(51)])
|
165 |
+
case h_models.TransformTypes.RANDOM_ROTATE.value:
|
166 |
+
tr_states[tr_sel] = tr.random_rotate([0 - rotate_num + i for i in range(rotate_num*2+1)])
|
167 |
+
case h_models.TransformTypes.AD_JITTER.value:
|
168 |
+
tr_states[tr_sel] = tr.jitter(ad_jitter_num)
|
169 |
+
|
170 |
+
transforms = [t for t in tr_states.values() if t is not None]
|
171 |
+
|
172 |
+
# Specific layer type handling
|
173 |
+
match type(layer_sel[1]):
|
174 |
+
case nn.Conv2d:
|
175 |
+
if (channel is not None and nodeX is not None and nodeY is not None):
|
176 |
+
gr.Info("Convolutional Node Specific")
|
177 |
+
obj = objs.neuron(layer_sel[0], channel, x=nodeX, y=nodeY)
|
178 |
+
|
179 |
+
elif (channel is not None):
|
180 |
+
gr.Info("Convolutional Channel Specific ")
|
181 |
+
obj = objs.channel(layer_sel[0], channel)
|
182 |
+
|
183 |
+
elif (channel is None and nodeX is None and nodeY is None):
|
184 |
+
gr.Info("Convolutional Layer Specific")
|
185 |
+
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(),
|
186 |
+
torch.tensor(2).cuda())).cuda()
|
187 |
+
|
188 |
+
# Unknown
|
189 |
+
else:
|
190 |
+
gr.Error("Invalid layer settings")
|
191 |
+
return None
|
192 |
+
|
193 |
+
case nn.Linear:
|
194 |
+
if (node is not None):
|
195 |
+
gr.Info("Linear Node Specific")
|
196 |
+
obj = objs.channel(layer_sel[0], node)
|
197 |
+
else:
|
198 |
+
gr.Info("Linear Layer Specific")
|
199 |
+
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(), torch.tensor(2).cuda())).cuda()
|
200 |
+
case _:
|
201 |
+
gr.Info("Attempting unknown Layer Specific")
|
202 |
+
transforms = [] # Just in case
|
203 |
+
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(), torch.tensor(2).cuda())).cuda()
|
204 |
+
|
205 |
+
thresholds = h_manip.expo_tuple(epochs, 6)
|
206 |
+
|
207 |
+
img = np.array(render.render_vis(model,
|
208 |
+
obj,
|
209 |
+
thresholds=thresholds,
|
210 |
+
show_image=False,
|
211 |
+
optimizer=optimizer,
|
212 |
+
param_f=param_f,
|
213 |
+
transforms=transforms,
|
214 |
+
verbose=True)).squeeze(1)
|
215 |
+
|
216 |
+
return gr.Gallery.update(img), thresholds
|
217 |
+
|
218 |
+
|
219 |
+
def update_img_label(epoch_nums, evt: gr.SelectData):
|
220 |
+
"""
|
221 |
+
Updates the image label with its respective epoch number.
|
222 |
+
:param epoch_nums: The epoch numbers
|
223 |
+
:param evt: Event data from Gallery selection
|
224 |
+
:return: Image Gallery Component
|
225 |
+
"""
|
226 |
+
return gr.Gallery.update(label='Epoch ' + str(epoch_nums[evt.index]),
|
227 |
+
show_label=True)
|
228 |
+
|
229 |
+
|
230 |
+
def check_input(curr, maxx):
|
231 |
+
"""
|
232 |
+
Checks if the current input is higher then the max. Will raise if an error
|
233 |
+
if so.
|
234 |
+
:param curr: Current value
|
235 |
+
:param maxx: Max value to check against
|
236 |
+
"""
|
237 |
+
if curr > maxx:
|
238 |
+
raise gr.Error(f"""Value {curr} is higher then maximum of {maxx}""")
|
239 |
+
|
240 |
+
|
241 |
+
def on_transform(transforms):
|
242 |
+
"""
|
243 |
+
Logic for when a transform is selected. Controls the visbility of the
|
244 |
+
transform specific inputs/settings.
|
245 |
+
:param transforms: The transforms currently selected
|
246 |
+
:return: Column Components with modified visibility
|
247 |
+
"""
|
248 |
+
transform_states = {
|
249 |
+
h_models.TransformTypes.PAD.value: False,
|
250 |
+
h_models.TransformTypes.JITTER.value: False,
|
251 |
+
h_models.TransformTypes.RANDOM_SCALE.value: False,
|
252 |
+
h_models.TransformTypes.RANDOM_ROTATE.value: False,
|
253 |
+
h_models.TransformTypes.AD_JITTER.value: False
|
254 |
+
}
|
255 |
+
for transform in transforms:
|
256 |
+
transform_states[transform] = True
|
257 |
+
|
258 |
+
return [gr.update(visible=state) for state in transform_states.values()]
|
259 |
+
|
260 |
+
|
261 |
+
def on_pad_mode (evt: gr.SelectData):
|
262 |
+
"""
|
263 |
+
Hides the constant value input if the constant pad mode is not selected
|
264 |
+
:param evt: Event data from Radio selection
|
265 |
+
"""
|
266 |
+
if (evt.value == "Constant"):
|
267 |
+
return gr.update(visible=True)
|
268 |
+
return gr.update(visible=False)
|
main.py
CHANGED
@@ -1,14 +1,6 @@
|
|
1 |
-
import torch
|
2 |
-
import numpy as np
|
3 |
import gradio as gr
|
4 |
import helpers.models as h_models
|
5 |
-
import helpers.
|
6 |
-
import lucent.optvis.param as param
|
7 |
-
import lucent.optvis.objectives as objectives
|
8 |
-
from torch import nn
|
9 |
-
from time import sleep
|
10 |
-
from lucent.optvis import render
|
11 |
-
from lucent.modelzoo.util import get_model_layers
|
12 |
|
13 |
# Custom css
|
14 |
css = """div[data-testid="block-label"] {z-index: var(--layer-3)}"""
|
@@ -31,10 +23,12 @@ def main():
|
|
31 |
Feature Visualizations (FV's) answer questions
|
32 |
about what a network—or parts of a network—are
|
33 |
looking for by generating examples.
|
34 |
-
([Read more about it here](https://distill.pub/2017/feature-visualization/)
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
38 |
**Start by selecting a model from the drop down.**""")
|
39 |
with gr.Row(): # Lower inputs and outputs
|
40 |
with gr.Column(): # Inputs
|
@@ -103,27 +97,33 @@ def main():
|
|
103 |
with gr.Accordion("Advanced Settings", open=False):
|
104 |
with gr.Column(variant="panel"):
|
105 |
gr.Markdown("""## Image Settings""")
|
|
|
106 |
img_num = gr.Number(label="Image Size",
|
107 |
info="Image is square (<value> by <value>)",
|
108 |
precision=0,
|
109 |
minimum=1,
|
110 |
value=227)
|
|
|
111 |
chan_decor_ck = gr.Checkbox(label="Channel Decorrelation",
|
112 |
info="Reduces channel-to-channel correlations",
|
113 |
value=True)
|
|
|
114 |
spacial_decor_ck = gr.Checkbox(label="Spacial Decorrelation (FFT)",
|
115 |
info="Reduces pixel-to-pixel correlations",
|
116 |
value=True)
|
|
|
117 |
sd_num = gr.Number(label="Standard Deviation",
|
118 |
info="The STD of the randomly generated starter image",
|
119 |
value=0.01)
|
120 |
|
121 |
with gr.Column(variant="panel"):
|
122 |
gr.Markdown("""## Transform Settings (WIP)""")
|
|
|
123 |
preprocess_ck = gr.Checkbox(label="Preprocess",
|
124 |
info="Enable or disable preprocessing via transformations",
|
125 |
value=True,
|
126 |
interactive=True)
|
|
|
127 |
transform_choices = [t.value for t in h_models.TransformTypes]
|
128 |
transforms_dd = gr.Dropdown(label="Applied Transforms",
|
129 |
info="Transforms to apply",
|
@@ -131,7 +131,7 @@ def main():
|
|
131 |
multiselect=True,
|
132 |
value=transform_choices,
|
133 |
interactive=True)
|
134 |
-
|
135 |
# Transform specific settings
|
136 |
pad_col = gr.Column()
|
137 |
with pad_col:
|
@@ -143,11 +143,13 @@ def main():
|
|
143 |
value=12,
|
144 |
precision=0,
|
145 |
interactive=True)
|
|
|
146 |
mode_rad = gr.Radio(label="Mode",
|
147 |
info="Constant fills padded pixels with a value. Reflect fills with edge pixels",
|
148 |
choices=["Constant", "Reflect"],
|
149 |
value="Constant",
|
150 |
interactive=True)
|
|
|
151 |
constant_num = gr.Number(label="Constant Fill Value",
|
152 |
info="Value to fill padded pixels",
|
153 |
value=0.5,
|
@@ -161,6 +163,7 @@ def main():
|
|
161 |
info="How much to jitter image by",
|
162 |
minimum=1,
|
163 |
value=8,
|
|
|
164 |
interactive=True)
|
165 |
|
166 |
rand_scale_col = gr.Column()
|
@@ -168,9 +171,9 @@ def main():
|
|
168 |
gr.Markdown("""### Random Scale Settings""")
|
169 |
with gr.Row():
|
170 |
scale_num = gr.Number(label="Max scale",
|
171 |
-
info="How much to scale in both directions (+ and -)",
|
172 |
minimum=0,
|
173 |
-
value=
|
174 |
interactive=True)
|
175 |
|
176 |
rand_rotate_col = gr.Column()
|
@@ -181,6 +184,7 @@ def main():
|
|
181 |
info="How much to rotate in both directions (+ and -)",
|
182 |
minimum=0,
|
183 |
value=10,
|
|
|
184 |
interactive=True)
|
185 |
|
186 |
ad_jitter_col = gr.Column()
|
@@ -191,10 +195,8 @@ def main():
|
|
191 |
info="How much to jitter image by",
|
192 |
minimum=1,
|
193 |
value=4,
|
|
|
194 |
interactive=True)
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
|
199 |
confirm_btn = gr.Button("Generate", visible=False)
|
200 |
|
@@ -208,14 +210,14 @@ def main():
|
|
208 |
# Event listener binding
|
209 |
model_dd.select(lambda: gr.Dropdown.update(visible=True),
|
210 |
outputs=layer_dd)
|
211 |
-
model_dd.select(on_model,
|
212 |
inputs=[model, model_layers, ft_map_sizes],
|
213 |
outputs=[layer_dd, model, model_layers, ft_map_sizes])
|
214 |
|
215 |
# TODO: Make button invisible always until layer selection
|
216 |
layer_dd.select(lambda: gr.Button.update(visible=True),
|
217 |
outputs=confirm_btn)
|
218 |
-
layer_dd.select(on_layer,
|
219 |
inputs=[selected_layer, model_layers, ft_map_sizes],
|
220 |
outputs=[layer_text,
|
221 |
channel_num,
|
@@ -228,12 +230,12 @@ def main():
|
|
228 |
nodeY_max,
|
229 |
node_max])
|
230 |
|
231 |
-
channel_num.blur(check_input, inputs=[channel_num, channel_max])
|
232 |
-
nodeX_num.blur(check_input, inputs=[nodeX_num, nodeX_max])
|
233 |
-
nodeY_num.blur(check_input, inputs=[nodeY_num, nodeY_max])
|
234 |
-
node_num.blur(check_input, inputs=[node_num, node_max])
|
235 |
|
236 |
-
images_gal.select(update_img_label,
|
237 |
inputs=thresholds,
|
238 |
outputs=images_gal)
|
239 |
|
@@ -251,7 +253,7 @@ def main():
|
|
251 |
rand_rotate_col,
|
252 |
ad_jitter_col])
|
253 |
|
254 |
-
transforms_dd.change(on_transform,
|
255 |
inputs=transforms_dd,
|
256 |
outputs=[pad_col,
|
257 |
jitter_col,
|
@@ -259,10 +261,10 @@ def main():
|
|
259 |
rand_rotate_col,
|
260 |
ad_jitter_col])
|
261 |
|
262 |
-
mode_rad.select(on_pad_mode,
|
263 |
outputs=constant_num)
|
264 |
|
265 |
-
confirm_btn.click(generate,
|
266 |
inputs=[lr_sl,
|
267 |
epoch_num,
|
268 |
img_num,
|
@@ -275,203 +277,17 @@ def main():
|
|
275 |
thresholds,
|
276 |
chan_decor_ck,
|
277 |
spacial_decor_ck,
|
278 |
-
sd_num
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
outputs=[images_gal, thresholds])
|
280 |
|
281 |
demo.queue().launch()
|
282 |
|
283 |
-
|
284 |
-
# Event listener functions
|
285 |
-
def on_model(model, model_layers, ft_map_sizes, evt: gr.SelectData, progress=gr.Progress()):
|
286 |
-
"""
|
287 |
-
Logic flow when model is selected. Updates model, the model layers, and the
|
288 |
-
feature map sizes.
|
289 |
-
:param model: Current model (object) selected. Updated by this method
|
290 |
-
:param model_layers: List of model layers. Updated by this method
|
291 |
-
:param ft_map_sizes: List of Feature map sizes. Updated by this method
|
292 |
-
:param evt: Event data from Dropdown selection
|
293 |
-
:return: [Layer Dropdown Component, Model state, Model Layers state,
|
294 |
-
Feature Map Sizes State]
|
295 |
-
"""
|
296 |
-
progress(0, desc="Setting up model...")
|
297 |
-
model = h_models.setup_model(h_models.ModelTypes[evt.value])
|
298 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
299 |
-
model.to(device).eval()
|
300 |
-
|
301 |
-
progress(0.25, desc="Getting layers names and details...")
|
302 |
-
model_layers = list(get_model_layers(model,
|
303 |
-
getLayerRepr=True).items())
|
304 |
-
choices = [f"({k}): {v.split('(')[0]}" for k, v in model_layers]
|
305 |
-
|
306 |
-
|
307 |
-
progress(0.5, desc="Getting layer objects...")
|
308 |
-
for i in range(len(model_layers)):
|
309 |
-
try:
|
310 |
-
layer = h_models.get_layer_by_name(model, model_layers[i][0])
|
311 |
-
except ValueError as e:
|
312 |
-
gr.Error(e)
|
313 |
-
|
314 |
-
model_layers[i] = (model_layers[i][0], layer)
|
315 |
-
|
316 |
-
progress(0.75, desc="Getting feature maps sizes...")
|
317 |
-
ft_map_sizes = h_models.get_feature_map_sizes(model, [v for _, v in model_layers])
|
318 |
-
progress(1, desc="Done")
|
319 |
-
sleep(0.25) # To allow for progress animation, not good practice
|
320 |
-
return [gr.update(choices=choices, value=''),
|
321 |
-
model, model_layers, ft_map_sizes]
|
322 |
-
|
323 |
-
|
324 |
-
def on_layer(selected_layer, model_layers, ft_map_sizes, evt: gr.SelectData):
|
325 |
-
"""
|
326 |
-
Logic flow when a layer is selected. Updates max values of layer
|
327 |
-
specific input fields.
|
328 |
-
:param selected_layer: Current selected layer, updated by this method.
|
329 |
-
:param model_layers: All model layers
|
330 |
-
:param ft_map_sizes: Feature maps sizes for all conv layers
|
331 |
-
:param evt: Event data from Dropdown selection
|
332 |
-
:return [Layer Text Component,
|
333 |
-
Channel Number Component,
|
334 |
-
Node X Number Component,
|
335 |
-
Node Y Number Component,
|
336 |
-
Selected layer state/variable]
|
337 |
-
"""
|
338 |
-
channel_max, nodeX_max, nodeY_max, node_max = -1, -1, -1, -1
|
339 |
-
selected_layer = model_layers[evt.index]
|
340 |
-
match type(selected_layer[1]):
|
341 |
-
case nn.Conv2d:
|
342 |
-
channel_max = selected_layer[1].out_channels-1
|
343 |
-
nodeX_max = ft_map_sizes[evt.index][1]-1
|
344 |
-
nodeY_max = ft_map_sizes[evt.index][2]-1
|
345 |
-
|
346 |
-
return [gr.update(visible=True),
|
347 |
-
gr.Number.update(info=f"""Values between 0-{channel_max}""",
|
348 |
-
visible=True, value=None),
|
349 |
-
gr.Number.update(info=f"""Values between 0-{nodeX_max}""",
|
350 |
-
visible=True, value=None),
|
351 |
-
gr.Number.update(info=f"""Values between 0-{nodeY_max}""",
|
352 |
-
visible=True, value=None),
|
353 |
-
gr.update(visible=False, value=None),
|
354 |
-
selected_layer,
|
355 |
-
channel_max,
|
356 |
-
nodeX_max,
|
357 |
-
nodeY_max,
|
358 |
-
node_max]
|
359 |
-
case nn.Linear:
|
360 |
-
node_max = selected_layer[1].out_features-1
|
361 |
-
return [gr.update(visible=True),
|
362 |
-
gr.Number.update(visible=False, value=None),
|
363 |
-
gr.Number.update(visible=False, value=None),
|
364 |
-
gr.Number.update(visible=False, value=None),
|
365 |
-
gr.update(info=f"""Values between 0-{node_max}""",
|
366 |
-
maximum=node_max,
|
367 |
-
visible=True, value=None),
|
368 |
-
selected_layer,
|
369 |
-
channel_max,
|
370 |
-
nodeX_max,
|
371 |
-
nodeY_max,
|
372 |
-
node_max]
|
373 |
-
case _:
|
374 |
-
gr.Warning("Unknown layer type")
|
375 |
-
return [gr.update(visible=False),
|
376 |
-
gr.update(visible=False, value=None),
|
377 |
-
gr.update(visible=False, value=None),
|
378 |
-
gr.update(visible=False, value=None),
|
379 |
-
gr.update(visible=False, value=None),
|
380 |
-
selected_layer,
|
381 |
-
channel_max,
|
382 |
-
nodeX_max,
|
383 |
-
nodeY_max,
|
384 |
-
node_max]
|
385 |
-
|
386 |
-
|
387 |
-
def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, selected_layer,
|
388 |
-
model, thresholds, chan_decor, spacial_decor,
|
389 |
-
sd_num, progress=gr.Progress(track_tqdm=True)):
|
390 |
-
"""
|
391 |
-
Generates the feature visualizaiton with given parameters and tuning.
|
392 |
-
Utilizes the Lucent (Pytorch Lucid library).
|
393 |
-
|
394 |
-
Inputs are different gradio components. Outputs an image component. Method
|
395 |
-
tracks its own tqdm progress.
|
396 |
-
"""
|
397 |
-
|
398 |
-
def param_f(): return param.image(img_size,
|
399 |
-
fft=spacial_decor,
|
400 |
-
decorrelate=chan_decor,
|
401 |
-
sd=sd_num) # Image setup
|
402 |
-
def optimizer(params): return torch.optim.Adam(params, lr=lr)
|
403 |
-
|
404 |
-
# Specific layer type handling
|
405 |
-
match type(selected_layer[1]):
|
406 |
-
case nn.Conv2d:
|
407 |
-
# Node specific
|
408 |
-
if (channel is not None and nodeX is not None and nodeY is not None):
|
409 |
-
gr.Info("Node Specific Convolution")
|
410 |
-
obj = objectives.neuron(selected_layer[0],
|
411 |
-
channel,
|
412 |
-
x=nodeX,
|
413 |
-
y=nodeY)
|
414 |
-
|
415 |
-
# Channel specific
|
416 |
-
elif (channel is not None):
|
417 |
-
gr.Info("Channel Specific Convolution")
|
418 |
-
obj = objectives.channel(selected_layer[0], channel)
|
419 |
-
|
420 |
-
# Layer specific
|
421 |
-
elif (channel is None and nodeX is None and nodeY is None):
|
422 |
-
gr.Info("Layer Specific Convolution")
|
423 |
-
obj = lambda m: torch.mean(torch.pow(-m(selected_layer[0]).cuda(), torch.tensor(2).cuda())).cuda()
|
424 |
-
|
425 |
-
# Unknown
|
426 |
-
else:
|
427 |
-
gr.Error("Invalid layer settings")
|
428 |
-
return None
|
429 |
-
|
430 |
-
case nn.Linear:
|
431 |
-
if (node is not None): # Node Specific
|
432 |
-
obj = objectives.channel(selected_layer[0], node)
|
433 |
-
else: # Layer Specific
|
434 |
-
obj = lambda m: torch.mean(torch.pow(-m(selected_layer[0]).cuda(), torch.tensor(2).cuda())).cuda()
|
435 |
-
thresholds = h_manip.expo_tuple(epochs, 6)
|
436 |
-
print(thresholds)
|
437 |
-
|
438 |
-
|
439 |
-
img = np.array(render.render_vis(model,
|
440 |
-
obj,
|
441 |
-
thresholds=thresholds,
|
442 |
-
show_image=False,
|
443 |
-
optimizer=optimizer,
|
444 |
-
param_f=param_f,
|
445 |
-
verbose=True)).squeeze(1)
|
446 |
-
|
447 |
-
return gr.Gallery.update(img), thresholds
|
448 |
-
|
449 |
-
|
450 |
-
def update_img_label(thresholds, evt: gr.SelectData):
|
451 |
-
return gr.Gallery.update(label='Epoch ' + str(thresholds[evt.index]), show_label=True)
|
452 |
-
|
453 |
-
|
454 |
-
def check_input(curr, maxx):
|
455 |
-
if curr > maxx:
|
456 |
-
raise gr.Error(f"""Value {curr} is higher then maximum of {maxx}""")
|
457 |
-
|
458 |
-
|
459 |
-
def on_transform(transforms):
|
460 |
-
transform_states = {
|
461 |
-
h_models.TransformTypes.PAD.value: False,
|
462 |
-
h_models.TransformTypes.JITTER.value: False,
|
463 |
-
h_models.TransformTypes.RANDOM_SCALE.value: False,
|
464 |
-
h_models.TransformTypes.RANDOM_ROTATE.value: False,
|
465 |
-
h_models.TransformTypes.AD_JITTER.value: False
|
466 |
-
}
|
467 |
-
for transform in transforms:
|
468 |
-
transform_states[transform] = True
|
469 |
-
|
470 |
-
return [gr.update(visible=state) for state in transform_states.values()]
|
471 |
-
|
472 |
-
|
473 |
-
def on_pad_mode (evt: gr.SelectData):
|
474 |
-
if (evt.value == "Constant"):
|
475 |
-
return gr.update(visible=True)
|
476 |
-
return gr.update(visible=False)
|
477 |
main()
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import helpers.models as h_models
|
3 |
+
import helpers.listeners as listeners
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
# Custom css
|
6 |
css = """div[data-testid="block-label"] {z-index: var(--layer-3)}"""
|
|
|
23 |
Feature Visualizations (FV's) answer questions
|
24 |
about what a network—or parts of a network—are
|
25 |
looking for by generating examples.
|
26 |
+
([Read more about it here](https://distill.pub/2017/feature-visualization/)
|
27 |
+
FVs are a part of a wider field called Explainable
|
28 |
+
Artificial Intelligence (XAI) This generator aims
|
29 |
+
to make it easier to explore different concepts
|
30 |
+
used in FV generation and allow for experimentation.
|
31 |
+
Currently Convolutional and Linear layers were tested.\n\n
|
32 |
**Start by selecting a model from the drop down.**""")
|
33 |
with gr.Row(): # Lower inputs and outputs
|
34 |
with gr.Column(): # Inputs
|
|
|
97 |
with gr.Accordion("Advanced Settings", open=False):
|
98 |
with gr.Column(variant="panel"):
|
99 |
gr.Markdown("""## Image Settings""")
|
100 |
+
|
101 |
img_num = gr.Number(label="Image Size",
|
102 |
info="Image is square (<value> by <value>)",
|
103 |
precision=0,
|
104 |
minimum=1,
|
105 |
value=227)
|
106 |
+
|
107 |
chan_decor_ck = gr.Checkbox(label="Channel Decorrelation",
|
108 |
info="Reduces channel-to-channel correlations",
|
109 |
value=True)
|
110 |
+
|
111 |
spacial_decor_ck = gr.Checkbox(label="Spacial Decorrelation (FFT)",
|
112 |
info="Reduces pixel-to-pixel correlations",
|
113 |
value=True)
|
114 |
+
|
115 |
sd_num = gr.Number(label="Standard Deviation",
|
116 |
info="The STD of the randomly generated starter image",
|
117 |
value=0.01)
|
118 |
|
119 |
with gr.Column(variant="panel"):
|
120 |
gr.Markdown("""## Transform Settings (WIP)""")
|
121 |
+
|
122 |
preprocess_ck = gr.Checkbox(label="Preprocess",
|
123 |
info="Enable or disable preprocessing via transformations",
|
124 |
value=True,
|
125 |
interactive=True)
|
126 |
+
|
127 |
transform_choices = [t.value for t in h_models.TransformTypes]
|
128 |
transforms_dd = gr.Dropdown(label="Applied Transforms",
|
129 |
info="Transforms to apply",
|
|
|
131 |
multiselect=True,
|
132 |
value=transform_choices,
|
133 |
interactive=True)
|
134 |
+
|
135 |
# Transform specific settings
|
136 |
pad_col = gr.Column()
|
137 |
with pad_col:
|
|
|
143 |
value=12,
|
144 |
precision=0,
|
145 |
interactive=True)
|
146 |
+
|
147 |
mode_rad = gr.Radio(label="Mode",
|
148 |
info="Constant fills padded pixels with a value. Reflect fills with edge pixels",
|
149 |
choices=["Constant", "Reflect"],
|
150 |
value="Constant",
|
151 |
interactive=True)
|
152 |
+
|
153 |
constant_num = gr.Number(label="Constant Fill Value",
|
154 |
info="Value to fill padded pixels",
|
155 |
value=0.5,
|
|
|
163 |
info="How much to jitter image by",
|
164 |
minimum=1,
|
165 |
value=8,
|
166 |
+
precision=0,
|
167 |
interactive=True)
|
168 |
|
169 |
rand_scale_col = gr.Column()
|
|
|
171 |
gr.Markdown("""### Random Scale Settings""")
|
172 |
with gr.Row():
|
173 |
scale_num = gr.Number(label="Max scale",
|
174 |
+
info="How much to scale (from 1.0) in both directions (+ and -)",
|
175 |
minimum=0,
|
176 |
+
value=0.1,
|
177 |
interactive=True)
|
178 |
|
179 |
rand_rotate_col = gr.Column()
|
|
|
184 |
info="How much to rotate in both directions (+ and -)",
|
185 |
minimum=0,
|
186 |
value=10,
|
187 |
+
precision=0,
|
188 |
interactive=True)
|
189 |
|
190 |
ad_jitter_col = gr.Column()
|
|
|
195 |
info="How much to jitter image by",
|
196 |
minimum=1,
|
197 |
value=4,
|
198 |
+
precision=0,
|
199 |
interactive=True)
|
|
|
|
|
|
|
200 |
|
201 |
confirm_btn = gr.Button("Generate", visible=False)
|
202 |
|
|
|
210 |
# Event listener binding
|
211 |
model_dd.select(lambda: gr.Dropdown.update(visible=True),
|
212 |
outputs=layer_dd)
|
213 |
+
model_dd.select(listeners.on_model,
|
214 |
inputs=[model, model_layers, ft_map_sizes],
|
215 |
outputs=[layer_dd, model, model_layers, ft_map_sizes])
|
216 |
|
217 |
# TODO: Make button invisible always until layer selection
|
218 |
layer_dd.select(lambda: gr.Button.update(visible=True),
|
219 |
outputs=confirm_btn)
|
220 |
+
layer_dd.select(listeners.on_layer,
|
221 |
inputs=[selected_layer, model_layers, ft_map_sizes],
|
222 |
outputs=[layer_text,
|
223 |
channel_num,
|
|
|
230 |
nodeY_max,
|
231 |
node_max])
|
232 |
|
233 |
+
channel_num.blur(listeners.check_input, inputs=[channel_num, channel_max])
|
234 |
+
nodeX_num.blur(listeners.check_input, inputs=[nodeX_num, nodeX_max])
|
235 |
+
nodeY_num.blur(listeners.check_input, inputs=[nodeY_num, nodeY_max])
|
236 |
+
node_num.blur(listeners.check_input, inputs=[node_num, node_max])
|
237 |
|
238 |
+
images_gal.select(listeners.update_img_label,
|
239 |
inputs=thresholds,
|
240 |
outputs=images_gal)
|
241 |
|
|
|
253 |
rand_rotate_col,
|
254 |
ad_jitter_col])
|
255 |
|
256 |
+
transforms_dd.change(listeners.on_transform,
|
257 |
inputs=transforms_dd,
|
258 |
outputs=[pad_col,
|
259 |
jitter_col,
|
|
|
261 |
rand_rotate_col,
|
262 |
ad_jitter_col])
|
263 |
|
264 |
+
mode_rad.select(listeners.on_pad_mode,
|
265 |
outputs=constant_num)
|
266 |
|
267 |
+
confirm_btn.click(listeners.generate,
|
268 |
inputs=[lr_sl,
|
269 |
epoch_num,
|
270 |
img_num,
|
|
|
277 |
thresholds,
|
278 |
chan_decor_ck,
|
279 |
spacial_decor_ck,
|
280 |
+
sd_num,
|
281 |
+
transforms_dd,
|
282 |
+
pad_num,
|
283 |
+
mode_rad,
|
284 |
+
constant_num,
|
285 |
+
jitter_num,
|
286 |
+
scale_num,
|
287 |
+
rotate_num,
|
288 |
+
ad_jitter_num],
|
289 |
outputs=[images_gal, thresholds])
|
290 |
|
291 |
demo.queue().launch()
|
292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
main()
|