baxtrax's picture
Update helpers/listeners.py
2e25210
import torch
import numpy as np
import gradio as gr
import helpers.models as h_models
import lucent.optvis.param as param
import lucent.optvis.transform as tr
import helpers.manipulation as h_manip
import lucent.optvis.objectives as objs
from torch import nn
from time import sleep
from lucent.optvis import render
from lucent.modelzoo.util import get_model_layers
# Event listener functions
def on_model(model, model_layers, ft_map_sizes, evt: gr.SelectData, progress=gr.Progress()):
"""
Logic flow when model is selected. Updates model, the model layers, and the
feature map sizes.
:param model: Current model (object) selected. Updated by this method
:param model_layers: List of model layers. Updated by this method
:param ft_map_sizes: List of Feature map sizes. Updated by this method
:param evt: Event data from Dropdown selection
:return: [Layer Dropdown Component, Model state, Model Layers state,
Feature Map Sizes State]
"""
progress(0, desc="Setting up model...")
model = h_models.setup_model(h_models.ModelTypes[evt.value])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
progress(0.25, desc="Getting layers names and details...")
model_layers = list(get_model_layers(model,
getLayerRepr=True).items())
choices = [f"({k}): {v.split('(')[0]}" for k, v in model_layers]
progress(0.5, desc="Getting layer objects...")
for i in range(len(model_layers)):
try:
layer = h_models.get_layer_by_name(model, model_layers[i][0])
except ValueError as e:
gr.Error(e)
model_layers[i] = (model_layers[i][0], layer)
progress(0.75, desc="Getting feature maps sizes...")
ft_map_sizes = h_models.get_feature_map_sizes(model, [v for _, v in model_layers])
progress(1, desc="Done")
sleep(0.25) # To allow for progress animation, not good practice
return [gr.update(choices=choices, value=''),
model, model_layers, ft_map_sizes]
def on_layer(selected_layer, model_layers, ft_map_sizes, evt: gr.SelectData):
"""
Logic flow when a layer is selected. Updates max values of layer
specific input fields.
:param selected_layer: Current selected layer, updated by this method.
:param model_layers: All model layers
:param ft_map_sizes: Feature maps sizes for all conv layers
:param evt: Event data from Dropdown selection
:return [Layer Text Component,
Channel Number Component,
Node X Number Component,
Node Y Number Component,
Selected layer state/variable,
Channel max state/variable,
NodeX max state/variable,
NodeY max state/variable,
Node max state/variable]
"""
channel_max, nodeX_max, nodeY_max, node_max = -1, -1, -1, -1
selected_layer = model_layers[evt.index]
match type(selected_layer[1]):
case nn.Conv2d:
# Calculate maxes for conv specific
channel_max = selected_layer[1].out_channels-1
nodeX_max = ft_map_sizes[evt.index][1]-1
nodeY_max = ft_map_sizes[evt.index][2]-1
return [gr.update(visible=True),
gr.Number.update(info=f"""Values between 0-{channel_max}""",
visible=True, value=None),
gr.Number.update(info=f"""Values between 0-{nodeX_max}""",
visible=True, value=None),
gr.Number.update(info=f"""Values between 0-{nodeY_max}""",
visible=True, value=None),
gr.update(visible=False, value=None),
selected_layer,
channel_max,
nodeX_max,
nodeY_max,
node_max]
case nn.Linear:
# Calculate maxes for linear specific
node_max = selected_layer[1].out_features-1
return [gr.update(visible=True),
gr.Number.update(visible=False, value=None),
gr.Number.update(visible=False, value=None),
gr.Number.update(visible=False, value=None),
gr.update(info=f"""Values between 0-{node_max}""",
maximum=node_max,
visible=True, value=None),
selected_layer,
channel_max,
nodeX_max,
nodeY_max,
node_max]
case _:
gr.Warning("Unknown layer type")
return [gr.update(visible=False),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
selected_layer,
channel_max,
nodeX_max,
nodeY_max,
node_max]
# Having this many inputs is typically not good practice
def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, layer_sel,
model, thresholds, chan_decor, spacial_decor, sd_num,
transforms_selected, pad_num, pad_mode, constant_num, jitter_num,
scale_num, rotate_num, ad_jitter_num,
progress=gr.Progress(track_tqdm=True)):
"""
Generates the feature visualizaiton with given parameters and tuning.
Utilizes the Lucent (Pytorch Lucid library).
Inputs are different gradio components. Outputs an image component, and
their respective epoch numbers. Method tracks its own tqdm progress.
"""
# Image setup
def param_f(): return param.image(img_size,
fft=spacial_decor,
decorrelate=chan_decor,
sd=sd_num)
def optimizer(params): return torch.optim.Adam(params, lr=lr)
# Tranforms setup
tr_states = {
h_models.TransformTypes.PAD.value: None,
h_models.TransformTypes.JITTER.value: None,
h_models.TransformTypes.RANDOM_SCALE.value: None,
h_models.TransformTypes.RANDOM_ROTATE.value: None,
h_models.TransformTypes.AD_JITTER.value: None
}
for tr_sel in transforms_selected:
match tr_sel:
case h_models.TransformTypes.PAD.value:
tr_states[tr_sel] = tr.pad(pad_num,
mode = "constant" if pad_mode == "Constant" else "reflect",
constant_value=constant_num)
case h_models.TransformTypes.JITTER.value:
tr_states[tr_sel] = tr.jitter(jitter_num)
case h_models.TransformTypes.RANDOM_SCALE.value:
tr_states[tr_sel] = tr.random_scale([1.0 - scale_num + i * (scale_num*2/(51-1)) for i in range(51)])
case h_models.TransformTypes.RANDOM_ROTATE.value:
tr_states[tr_sel] = tr.random_rotate([0 - rotate_num + i for i in range(rotate_num*2+1)])
case h_models.TransformTypes.AD_JITTER.value:
tr_states[tr_sel] = tr.jitter(ad_jitter_num)
transforms = [t for t in tr_states.values() if t is not None]
# Specific layer type handling
match type(layer_sel[1]):
case nn.Conv2d:
if (channel is not None and nodeX is not None and nodeY is not None):
gr.Info("Convolutional Node Specific")
obj = objs.neuron(layer_sel[0], channel, x=nodeX, y=nodeY)
elif (channel is not None):
gr.Info("Convolutional Channel Specific ")
obj = objs.channel(layer_sel[0], channel)
elif (channel is None and nodeX is None and nodeY is None):
gr.Info("Convolutional Layer Specific")
if torch.cuda.is_available():
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(),
torch.tensor(2).cuda())).cuda()
else:
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]),
torch.tensor(2)))
# Unknown
else:
gr.Error("Invalid layer settings")
return None
case nn.Linear:
if (node is not None):
gr.Info("Linear Node Specific")
obj = objs.channel(layer_sel[0], node)
else:
gr.Info("Linear Layer Specific")
if torch.cuda.is_available():
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(),
torch.tensor(2).cuda())).cuda()
else:
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]),
torch.tensor(2)))
case _:
gr.Info("Attempting unknown Layer Specific")
transforms = [] # Just in case
if torch.cuda.is_available():
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(),
torch.tensor(2).cuda())).cuda()
else:
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]),
torch.tensor(2)))
thresholds = h_manip.expo_tuple(epochs, 6)
img = np.array(render.render_vis(model,
obj,
thresholds=thresholds,
show_image=False,
optimizer=optimizer,
param_f=param_f,
transforms=transforms,
verbose=True)).squeeze(1)
return gr.Gallery.update(img), thresholds
def update_img_label(epoch_nums, evt: gr.SelectData):
"""
Updates the image label with its respective epoch number.
:param epoch_nums: The epoch numbers
:param evt: Event data from Gallery selection
:return: Image Gallery Component
"""
return gr.Gallery.update(label='Epoch ' + str(epoch_nums[evt.index]),
show_label=True)
def check_input(curr, maxx):
"""
Checks if the current input is higher then the max. Will raise if an error
if so.
:param curr: Current value
:param maxx: Max value to check against
"""
if curr > maxx:
raise gr.Error(f"""Value {curr} is higher then maximum of {maxx}""")
def on_transform(transforms):
"""
Logic for when a transform is selected. Controls the visbility of the
transform specific inputs/settings.
:param transforms: The transforms currently selected
:return: Column Components with modified visibility
"""
transform_states = {
h_models.TransformTypes.PAD.value: False,
h_models.TransformTypes.JITTER.value: False,
h_models.TransformTypes.RANDOM_SCALE.value: False,
h_models.TransformTypes.RANDOM_ROTATE.value: False,
h_models.TransformTypes.AD_JITTER.value: False
}
for transform in transforms:
transform_states[transform] = True
return [gr.update(visible=state) for state in transform_states.values()]
def on_pad_mode (evt: gr.SelectData):
"""
Hides the constant value input if the constant pad mode is not selected
:param evt: Event data from Radio selection
"""
if (evt.value == "Constant"):
return gr.update(visible=True)
return gr.update(visible=False)