Spaces:
Runtime error
Runtime error
File size: 11,923 Bytes
c1b01fa 2e25210 c1b01fa 2e25210 c1b01fa 2e25210 c1b01fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
import torch
import numpy as np
import gradio as gr
import helpers.models as h_models
import lucent.optvis.param as param
import lucent.optvis.transform as tr
import helpers.manipulation as h_manip
import lucent.optvis.objectives as objs
from torch import nn
from time import sleep
from lucent.optvis import render
from lucent.modelzoo.util import get_model_layers
# Event listener functions
def on_model(model, model_layers, ft_map_sizes, evt: gr.SelectData, progress=gr.Progress()):
"""
Logic flow when model is selected. Updates model, the model layers, and the
feature map sizes.
:param model: Current model (object) selected. Updated by this method
:param model_layers: List of model layers. Updated by this method
:param ft_map_sizes: List of Feature map sizes. Updated by this method
:param evt: Event data from Dropdown selection
:return: [Layer Dropdown Component, Model state, Model Layers state,
Feature Map Sizes State]
"""
progress(0, desc="Setting up model...")
model = h_models.setup_model(h_models.ModelTypes[evt.value])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
progress(0.25, desc="Getting layers names and details...")
model_layers = list(get_model_layers(model,
getLayerRepr=True).items())
choices = [f"({k}): {v.split('(')[0]}" for k, v in model_layers]
progress(0.5, desc="Getting layer objects...")
for i in range(len(model_layers)):
try:
layer = h_models.get_layer_by_name(model, model_layers[i][0])
except ValueError as e:
gr.Error(e)
model_layers[i] = (model_layers[i][0], layer)
progress(0.75, desc="Getting feature maps sizes...")
ft_map_sizes = h_models.get_feature_map_sizes(model, [v for _, v in model_layers])
progress(1, desc="Done")
sleep(0.25) # To allow for progress animation, not good practice
return [gr.update(choices=choices, value=''),
model, model_layers, ft_map_sizes]
def on_layer(selected_layer, model_layers, ft_map_sizes, evt: gr.SelectData):
"""
Logic flow when a layer is selected. Updates max values of layer
specific input fields.
:param selected_layer: Current selected layer, updated by this method.
:param model_layers: All model layers
:param ft_map_sizes: Feature maps sizes for all conv layers
:param evt: Event data from Dropdown selection
:return [Layer Text Component,
Channel Number Component,
Node X Number Component,
Node Y Number Component,
Selected layer state/variable,
Channel max state/variable,
NodeX max state/variable,
NodeY max state/variable,
Node max state/variable]
"""
channel_max, nodeX_max, nodeY_max, node_max = -1, -1, -1, -1
selected_layer = model_layers[evt.index]
match type(selected_layer[1]):
case nn.Conv2d:
# Calculate maxes for conv specific
channel_max = selected_layer[1].out_channels-1
nodeX_max = ft_map_sizes[evt.index][1]-1
nodeY_max = ft_map_sizes[evt.index][2]-1
return [gr.update(visible=True),
gr.Number.update(info=f"""Values between 0-{channel_max}""",
visible=True, value=None),
gr.Number.update(info=f"""Values between 0-{nodeX_max}""",
visible=True, value=None),
gr.Number.update(info=f"""Values between 0-{nodeY_max}""",
visible=True, value=None),
gr.update(visible=False, value=None),
selected_layer,
channel_max,
nodeX_max,
nodeY_max,
node_max]
case nn.Linear:
# Calculate maxes for linear specific
node_max = selected_layer[1].out_features-1
return [gr.update(visible=True),
gr.Number.update(visible=False, value=None),
gr.Number.update(visible=False, value=None),
gr.Number.update(visible=False, value=None),
gr.update(info=f"""Values between 0-{node_max}""",
maximum=node_max,
visible=True, value=None),
selected_layer,
channel_max,
nodeX_max,
nodeY_max,
node_max]
case _:
gr.Warning("Unknown layer type")
return [gr.update(visible=False),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
selected_layer,
channel_max,
nodeX_max,
nodeY_max,
node_max]
# Having this many inputs is typically not good practice
def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, layer_sel,
model, thresholds, chan_decor, spacial_decor, sd_num,
transforms_selected, pad_num, pad_mode, constant_num, jitter_num,
scale_num, rotate_num, ad_jitter_num,
progress=gr.Progress(track_tqdm=True)):
"""
Generates the feature visualizaiton with given parameters and tuning.
Utilizes the Lucent (Pytorch Lucid library).
Inputs are different gradio components. Outputs an image component, and
their respective epoch numbers. Method tracks its own tqdm progress.
"""
# Image setup
def param_f(): return param.image(img_size,
fft=spacial_decor,
decorrelate=chan_decor,
sd=sd_num)
def optimizer(params): return torch.optim.Adam(params, lr=lr)
# Tranforms setup
tr_states = {
h_models.TransformTypes.PAD.value: None,
h_models.TransformTypes.JITTER.value: None,
h_models.TransformTypes.RANDOM_SCALE.value: None,
h_models.TransformTypes.RANDOM_ROTATE.value: None,
h_models.TransformTypes.AD_JITTER.value: None
}
for tr_sel in transforms_selected:
match tr_sel:
case h_models.TransformTypes.PAD.value:
tr_states[tr_sel] = tr.pad(pad_num,
mode = "constant" if pad_mode == "Constant" else "reflect",
constant_value=constant_num)
case h_models.TransformTypes.JITTER.value:
tr_states[tr_sel] = tr.jitter(jitter_num)
case h_models.TransformTypes.RANDOM_SCALE.value:
tr_states[tr_sel] = tr.random_scale([1.0 - scale_num + i * (scale_num*2/(51-1)) for i in range(51)])
case h_models.TransformTypes.RANDOM_ROTATE.value:
tr_states[tr_sel] = tr.random_rotate([0 - rotate_num + i for i in range(rotate_num*2+1)])
case h_models.TransformTypes.AD_JITTER.value:
tr_states[tr_sel] = tr.jitter(ad_jitter_num)
transforms = [t for t in tr_states.values() if t is not None]
# Specific layer type handling
match type(layer_sel[1]):
case nn.Conv2d:
if (channel is not None and nodeX is not None and nodeY is not None):
gr.Info("Convolutional Node Specific")
obj = objs.neuron(layer_sel[0], channel, x=nodeX, y=nodeY)
elif (channel is not None):
gr.Info("Convolutional Channel Specific ")
obj = objs.channel(layer_sel[0], channel)
elif (channel is None and nodeX is None and nodeY is None):
gr.Info("Convolutional Layer Specific")
if torch.cuda.is_available():
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(),
torch.tensor(2).cuda())).cuda()
else:
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]),
torch.tensor(2)))
# Unknown
else:
gr.Error("Invalid layer settings")
return None
case nn.Linear:
if (node is not None):
gr.Info("Linear Node Specific")
obj = objs.channel(layer_sel[0], node)
else:
gr.Info("Linear Layer Specific")
if torch.cuda.is_available():
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(),
torch.tensor(2).cuda())).cuda()
else:
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]),
torch.tensor(2)))
case _:
gr.Info("Attempting unknown Layer Specific")
transforms = [] # Just in case
if torch.cuda.is_available():
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(),
torch.tensor(2).cuda())).cuda()
else:
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]),
torch.tensor(2)))
thresholds = h_manip.expo_tuple(epochs, 6)
img = np.array(render.render_vis(model,
obj,
thresholds=thresholds,
show_image=False,
optimizer=optimizer,
param_f=param_f,
transforms=transforms,
verbose=True)).squeeze(1)
return gr.Gallery.update(img), thresholds
def update_img_label(epoch_nums, evt: gr.SelectData):
"""
Updates the image label with its respective epoch number.
:param epoch_nums: The epoch numbers
:param evt: Event data from Gallery selection
:return: Image Gallery Component
"""
return gr.Gallery.update(label='Epoch ' + str(epoch_nums[evt.index]),
show_label=True)
def check_input(curr, maxx):
"""
Checks if the current input is higher then the max. Will raise if an error
if so.
:param curr: Current value
:param maxx: Max value to check against
"""
if curr > maxx:
raise gr.Error(f"""Value {curr} is higher then maximum of {maxx}""")
def on_transform(transforms):
"""
Logic for when a transform is selected. Controls the visbility of the
transform specific inputs/settings.
:param transforms: The transforms currently selected
:return: Column Components with modified visibility
"""
transform_states = {
h_models.TransformTypes.PAD.value: False,
h_models.TransformTypes.JITTER.value: False,
h_models.TransformTypes.RANDOM_SCALE.value: False,
h_models.TransformTypes.RANDOM_ROTATE.value: False,
h_models.TransformTypes.AD_JITTER.value: False
}
for transform in transforms:
transform_states[transform] = True
return [gr.update(visible=state) for state in transform_states.values()]
def on_pad_mode (evt: gr.SelectData):
"""
Hides the constant value input if the constant pad mode is not selected
:param evt: Event data from Radio selection
"""
if (evt.value == "Constant"):
return gr.update(visible=True)
return gr.update(visible=False) |